In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import pandas as pd
import time
import wandb

wandb.login(key='208eb9fbdf5d2187fde3a83cdf51d2c458066577')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SOS_token = 0
EOS_token = 1

class Language:
    def __init__(self, name):
        self.name = name
        self.char2index = {}
        self.char2count = {}
        self.index2char = {SOS_token: "<", EOS_token: ">"}
        self.n_chars = 2

    def add_char(self, char):
        if char not in self.char2index:
            self.char2index[char] = self.n_chars
            self.char2count[char] = 1
            self.index2char[self.n_chars] = char
            self.n_chars += 1
        else:
            self.char2count[char] += 1

def load_data(language, data_type):
    path = f"/kaggle/input/akshantar-data/aksharantar_sampled/{language}/{language}_{data_type}.csv"
    df = pd.read_csv(path, header=None)
    pairs = df.values.tolist()
    return pairs

def get_languages(lang: str):
    input_lang = Language('eng')
    output_lang = Language(lang)
    pairs = load_data(lang, "train")
    for pair in pairs:
        for char in pair[0]:
            input_lang.add_char(char)
        for char in pair[1]:
            output_lang.add_char(char)
    return input_lang, output_lang, pairs

def get_cell(cell_type: str):
    cells = {'LSTM': nn.LSTM, 'GRU': nn.GRU, 'RNN': nn.RNN}
    return cells[cell_type]

def get_optimizer(optimizer: str):
    optimizers = {'SGD': optim.SGD, 'ADAM': optim.Adam}
    return optimizers[optimizer]

class Encoder(nn.Module):
    def __init__(self, input_size: int, embed_size: int, hidden_size: int, cell_type: str, num_layers: int, dropout: float):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.cell_type = cell_type
        self.embedding = nn.Embedding(input_size, embed_size)
        self.rnn = get_cell(cell_type)(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout)

    def forward(self, input, hidden, cell):
        embedded = self.embedding(input).view(1, 1, -1)
        if self.cell_type == "LSTM":
            output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        else:
            output, hidden = self.rnn(embedded, hidden)
        return output, hidden, cell

    def init_hidden(self):
        return torch.zeros(self.num_layers, 1, self.hidden_size, device=device)

class AttentionDecoder(nn.Module):
    def __init__(self, output_size: int, embed_size: int, hidden_size: int, cell_type: str, num_layers: int, dropout: float):
        super(AttentionDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.cell_type = cell_type
        self.embedding = nn.Embedding(output_size, embed_size)
        self.attn = nn.Linear(hidden_size + embed_size, 50)
        self.attn_combine = nn.Linear(hidden_size + embed_size, hidden_size)
        self.rnn = get_cell(cell_type)(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden, cell, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        attn_weights = torch.softmax(self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))
        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)
        if self.cell_type == "LSTM":
            output, (hidden, cell) = self.rnn(output, (hidden, cell))
        else:
            output, hidden = self.rnn(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden, cell, attn_weights

    def init_hidden(self):
        return torch.zeros(self.num_layers, 1, self.hidden_size, device=device)

def indexes_from_word(lang: Language, word: str):
    return [lang.char2index[char] for char in word]

def tensor_from_word(lang: Language, word: str):
    indexes = indexes_from_word(lang, word)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)

def tensors_from_pair(input_lang: Language, output_lang: Language, pair: list[str]):
    input_tensor = tensor_from_word(input_lang, pair[0])
    target_tensor = tensor_from_word(output_lang, pair[1])
    return input_tensor, target_tensor

def train_single(model, input_tensor, target_tensor):
    encoder_hidden = model.encoder.init_hidden()
    encoder_cell = model.encoder.init_hidden()
    model.encoder_optimizer.zero_grad()
    model.decoder_optimizer.zero_grad()
    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)
    encoder_outputs = torch.zeros(model.max_length, model.encoder.hidden_size, device=device)
    loss = 0
    for ei in range(input_length):
        encoder_output, encoder_hidden, encoder_cell = model.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
        encoder_outputs[ei] = encoder_output[0, 0]
    decoder_input = torch.tensor([[SOS_token]], device=device)
    decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
    use_teacher_forcing = True if random.random() < model.teacher_forcing_ratio else False
    if use_teacher_forcing:
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_cell, _ = model.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
            loss += model.criterion(decoder_output, target_tensor[di])
            decoder_input = target_tensor[di]
    else:
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_cell, _ = model.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
            loss += model.criterion(decoder_output, target_tensor[di])
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()
            if decoder_input.item() == EOS_token:
                break
    loss.backward()
    model.encoder_optimizer.step()
    model.decoder_optimizer.step()
    return loss.item() / target_length

def train(model, iters=-1):
    start_time = time.time()
    plot_losses = []
    print_loss_total = 0
    plot_loss_total = 0
    random.shuffle(model.training_pairs)
    iters = len(model.training_pairs) if iters == -1 else iters
    for iter in range(1, iters):
        training_pair = model.training_pairs[iter - 1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]
        loss = train_single(model, input_tensor, target_tensor)
        print_loss_total += loss
        plot_loss_total += loss
        if iter % model.PRINT_EVERY == 0:
            print_loss_avg = print_loss_total / model.PRINT_EVERY
            print_loss_total = 0
            current_time = time.time()
            print("Loss: {:.4f} | Iterations: {} | Time: {:.3f}".format(print_loss_avg, iter, current_time - start_time))
        if iter % model.PLOT_EVERY == 0:
            plot_loss_avg = plot_loss_total / model.PLOT_EVERY
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0
    return plot_losses

def evaluate(model, word):
    with torch.no_grad():
        input_tensor = tensor_from_word(model.input_lang, word)
        input_length = input_tensor.size()[0]
        encoder_hidden = model.encoder.init_hidden()
        encoder_cell = model.encoder.init_hidden()
        encoder_outputs = torch.zeros(model.max_length, model.encoder.hidden_size, device=device)
        for ei in range(input_length):
            encoder_output, encoder_hidden, encoder_cell = model.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
            encoder_outputs[ei] += encoder_output[0, 0]
        decoder_input = torch.tensor([[SOS_token]], device=device)
        decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
        decoded_chars = ""
        for di in range(model.max_length):
            decoder_output, decoder_hidden, decoder_cell, _ = model.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
            topv, topi = decoder_output.topk(1)
            if topi.item() == EOS_token:
                break
            else:
                decoded_chars += model.output_lang.index2char[topi.item()]
            decoder_input = topi.squeeze().detach()
        return decoded_chars

def test_validate(model, type:str):
    pairs = load_data(model.lang, type)
    accuracy = 0
    for pair in pairs:
        output = evaluate(model, pair[0])
        if output == pair[1]:
            accuracy += 1
    return accuracy / len(pairs)

class Translator:
    def __init__(self, lang: str, params: dict):
        self.lang = lang
        self.input_lang, self.output_lang, self.pairs = get_languages(self.lang)
        self.input_size = self.input_lang.n_chars
        self.output_size = self.output_lang.n_chars
        self.training_pairs = [tensors_from_pair(self.input_lang, self.output_lang, pair) for pair in self.pairs]
        self.encoder = Encoder(input_size=self.input_size, embed_size=params["embed_size"], hidden_size=params["hidden_size"], cell_type=params["cell_type"], num_layers=params["num_layers"], dropout=params["dropout"]).to(device)
        self.decoder = AttentionDecoder(output_size=self.output_size, embed_size=params["embed_size"], hidden_size=params["hidden_size"], cell_type=params["cell_type"], num_layers=params["num_layers"], dropout=params["dropout"]).to(device)
        self.encoder_optimizer = get_optimizer(params["optimizer"])(self.encoder.parameters(), lr=params["learning_rate"])
        self.decoder_optimizer = get_optimizer(params["optimizer"])(self.decoder.parameters(), lr=params["learning_rate"])
        self.criterion = nn.NLLLoss()
        self.teacher_forcing_ratio = params["teacher_forcing_ratio"]
        self.max_length = params["max_length"]
        self.PRINT_EVERY = 40000
        self.PLOT_EVERY = 40000

    def train(self, iters=-1):
        start_time = time.time()
        plot_losses = []
        print_loss_total = 0
        plot_loss_total = 0
        random.shuffle(self.training_pairs)
        iters = len(self.training_pairs) if iters == -1 else iters
        for iter in range(1, iters):
            training_pair = self.training_pairs[iter - 1]
            input_tensor = training_pair[0]
            target_tensor = training_pair[1]
            loss = train_single(self, input_tensor, target_tensor)
            print_loss_total += loss
            plot_loss_total += loss
            if iter % self.PRINT_EVERY == 0:
                print_loss_avg = print_loss_total / self.PRINT_EVERY
                print_loss_total = 0
                current_time = time.time()
                print("Loss: {:.4f} | Iterations: {} | Time: {:.3f}".format(print_loss_avg, iter, current_time - start_time))
            if iter % self.PLOT_EVERY == 0:
                plot_loss_avg = plot_loss_total / self.PLOT_EVERY
                plot_losses.append(plot_loss_avg)
                plot_loss_total = 0
        return plot_losses

    def evaluate(self, word):
        return evaluate(self, word)

    def test_validate(self, type:str):
        return test_validate(self, type)

def train_sweep():
    run = wandb.init()
    config = wandb.config
    run.name = f"embed_size: {config.embed_size} | hidden_size: {config.hidden_size} | cell_type: {config.cell_type} | num_layers: {config.num_layers} | dropout: {config.dropout} | learning_rate: {config.learning_rate} | optimizer: {config.optimizer} | teacher_forcing_ratio: {config.teacher_forcing_ratio} | max_length: {config.max_length}"
    model = Translator('hin', config)
    epochs = 5
    old_validation_accuracy = 0
    for epoch in range(epochs):
        print("Epoch: {}".format(epoch + 1))
        plot_losses = model.train()
        training_loss = sum(plot_losses) / len(plot_losses)
        validation_accuracy = model.test_validate('valid')
        print("Validation Accuracy: {:.4f}".format(validation_accuracy))
        if epoch % 3 == 0:
            wandb.log({
                "epoch": epoch + 1,
                "training_loss": training_loss,
                "validation_accuracy": validation_accuracy
            })
        if epoch == epochs:
            wandb.log({
                "epoch": epoch + 1,
                "training_loss": training_loss,
                "validation_accuracy": validation_accuracy
            })
        if epoch > 0:
            if validation_accuracy < 0.0001:
                break
            if validation_accuracy < 0.9 * old_validation_accuracy:
                break
        old_validation_accuracy = validation_accuracy
    test_accuracy = model.test_validate('test')
    print("Test Accuracy: {:.4f}".format(test_accuracy))
    wandb.log({
        "test_accuracy": test_accuracy
    })
    run.finish()

sweep_configuration = {
    "method": "bayes",
    'name' : 'with attention sweep 1',
    "metric": {
        "name": "validation_accuracy",
        "goal": "maximize"
    },
    "parameters": {
        "embed_size": {
            "values": [16, 32, 64]
        },
        "hidden_size": {
            "values": [128, 256, 512]
        },
        "cell_type": {
            "values": ["RNN", "LSTM", "GRU"]
        },
        "num_layers": {
            "values": [1, 2, 3]
        },
        "dropout": {
            "values": [0, 0.1, 0.2]
        },
        "learning_rate": {
            "value": 0.001
        },
        "optimizer": {
            "value": "ADAM"
        },
        "teacher_forcing_ratio": {
            'value': 0.5
        },
        "max_length": {
            'value': 50
        }
    }
}

wandb_id = wandb.sweep(sweep_configuration, project="CS6910_Assignment_3")
wandb.agent(wandb_id, train_sweep)
