<a href="https://colab.research.google.com/github/SaiRajesh228/DA6401_Assignment3/blob/main/DA6401A3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Single-cell version of Sequence-to-Sequence Transliteration Project
import os
import subprocess
import sys

# Create project directory
!mkdir -p /content/transliteration_project
%cd /content/transliteration_project

# Install dependencies
!pip install -q pandas numpy torch matplotlib wandb tqdm

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Initialize wandb (optional)
import wandb
# wandb.login()  # Uncomment if you want to use wandb

# Write Core_Utils.py
with open('Core_Utils.py', 'w') as f:
    f.write('''
import numpy as np
import random
import pandas as pd
import gc
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

seed = 23
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

class LanguageProcessor:
    def __init__(self,language_directory,target_lang_name,mode="train",meta_tokens=True):
        """
        Default Constructor for this class.
        Params:
            language_directory : ex : "aksharantar_sampled/tel/"
            mode : "train" or "test" or "valid", accordingly the appropriate dataset is read.
            meta_tokens : If true creates the first three tokens of the dictionary as <start>,<end>,<pad>.
        """
        self.meta_tokens = meta_tokens
        self.language_directory = language_directory
        self.target_lang_name = target_lang_name
        self.mode = mode
        self.source_lang = 0
        self.target_lang = 1
        self.source_char2id,self.source_id2char = self.build_char_vocab(self.source_lang)
        self.target_char2id,self.target_id2char = self.build_char_vocab(self.target_lang)

    def build_char_vocab(self,lang_id,max_len=None):
        """
        Method to create a vocabulary of characters in language corresponding to lang_id.
        """
        # Modified to ensure all data is read as strings and handle tab separators
        df = pd.read_csv(self.language_directory+self.mode+".txt",
                         header=None,
                         sep='\\t',
                         dtype=str,  # Enforce string type for all columns
                         na_filter=False)  # Don't interpret empty fields as NaN

        # Only keep the first two columns for source and target languages
        if df.shape[1] > 2:
            df = df.iloc[:, 0:2]

        self.data = df.to_numpy()

        lang_chars = []
        lang_words = df[lang_id].to_numpy()

        for word in lang_words:
            # Make sure word is treated as a string
            if word is not None and word != '':
                lang_chars += list(str(word))

        unique_lang_chars = sorted(list(set(lang_chars)))

        if self.meta_tokens:
            char2id_dict = {'<start>':0,'<end>':1,'<pad>': 2}
            id2char_dict = {0:'<start>',1:'<end>',2:'<pad>'}
            self.start_token_id = char2id_dict['<start>']
            self.end_token_id = char2id_dict['<end>']
            self.pad_token_id = char2id_dict['<pad>']
        else:
            char2id_dict = {}
            id2char_dict = {}

        start = len(char2id_dict)

        for i in range(len(unique_lang_chars)):
            char2id_dict[unique_lang_chars[i]] = i+start
            id2char_dict[i+start] = unique_lang_chars[i]

        del df
        del lang_chars
        del unique_lang_chars
        gc.collect()

        return char2id_dict,id2char_dict

    def encode_word(self,word,lang_id,padding=False,append_eos = False):
        """
        Method to encode characters of a given word.
        """
        if lang_id == self.source_lang:
            char2id_dict = self.source_char2id
        else:
            char2id_dict = self.target_char2id

        word_encoding = []
        # Ensure we're working with a string
        word = str(word).lower()

        for i in word:
            word_encoding.append(char2id_dict[i])

        if append_eos:
            word_encoding.append(char2id_dict['<end>'])

        return np.array(word_encoding)

    def decode_word(self,code_word,lang_id):
        """
        Method to decode an encoded word.
        """
        word = []

        if lang_id == self.source_lang:
            id2char_dict = self.source_id2char
            char2id_dict = self.source_char2id
        else:
            id2char_dict = self.target_id2char
            char2id_dict = self.target_char2id

        start_idx = 0

        for i in code_word[start_idx:]:
            ## if we reached <end>, then stop decoding
            if self.meta_tokens and i == char2id_dict['<end>'] or i == char2id_dict['<pad>']:
                break
            word.append(id2char_dict[i])

        return np.array(word)

class WordDataset(Dataset):
    """
    Class that inherits and overrides the methods of Dataset class.
    """
    def __init__(self, language_processor,append_eos=True,device='cpu'):
        self.lp = language_processor
        self.data = self.lp.data
        self.device = device
        self.append_eos = append_eos
        self.start_token_id = self.lp.start_token_id
        self.end_token_id = self.lp.end_token_id
        self.pad_token_id = self.lp.pad_token_id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_word, output_word = self.data[idx]
        input_sequence = self.lp.encode_word(input_word,self.lp.source_lang,padding=False,append_eos=self.append_eos)
        output_sequence = self.lp.encode_word(output_word,self.lp.target_lang,padding=False,append_eos=self.append_eos)
        return torch.tensor(input_sequence).to(self.device), torch.tensor(output_sequence).to(self.device)

def collate_fn(batch,pad_token_id,device):
    """
    The method to collate on a batch of data, by adding padding based on the longest string in the batch.
    """
    input_words, target_words = zip(*batch)

    padded_inputs = pad_sequence(input_words, batch_first=True, padding_value=pad_token_id)
    padded_targets = pad_sequence(target_words, batch_first=True, padding_value=pad_token_id)

    input_lengths = torch.LongTensor([len(seq) for seq in input_words]).to(device)
    target_lengths = torch.LongTensor([len(seq) for seq in target_words]).to(device)

    return padded_inputs, padded_targets, input_lengths, target_lengths
''')

# Write Encoder_Decoder_Architecture.py
with open('Encoder_Decoder_Architecture.py', 'w') as f:
    f.write('''
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random

seed = 23
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

class BahdanauAttention(nn.Module):
    """
    The class to implement Additive attention aka Bhadanau Attention.
    """
    def __init__(self, hidden_size,D,expected_dim,batch_size):
        super(BahdanauAttention, self).__init__()
        self.U_att = nn.Linear(hidden_size*expected_dim, hidden_size)
        self.W_att = nn.Linear(hidden_size*D, hidden_size)
        self.V_att = nn.Linear(hidden_size, 1)
        self.batch_size = batch_size

    def forward(self, decoder_prev_hidden, encoder_contexts):
        decoder_prev_hidden = decoder_prev_hidden.reshape(self.batch_size,1,-1)
        scores = self.V_att(torch.tanh(self.U_att(decoder_prev_hidden.reshape(self.batch_size,1,-1)) + self.W_att(encoder_contexts))).squeeze(2).unsqueeze(1)
        weights = F.softmax(scores, dim=-1)
        context = torch.bmm(weights, encoder_contexts)
        return context, weights

class Encoder(nn.Module):
    """
    The class that implements the encoder using Recurrent Units RNN/LSTM/GRU.
    """
    def __init__(self, source_vocab_size,hidden_size,embedding_size,rnn_type = "GRU",padding_idx = None ,dropout=0.1,num_layers = 1,bidirectional = False):
        super(Encoder, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.embedding_size = self.hidden_size
        self.rnn_type = rnn_type
        self.D = 1 ##the number of directions in which the input is viewed.
        if bidirectional:
            self.D = 2
        self.rnn_dropout = 0
        if self.num_layers>1:
            self.rnn_dropout = dropout
        self.embedding = nn.Embedding(source_vocab_size, self.embedding_size,padding_idx = padding_idx)

        if self.rnn_type == "GRU":
            self.rnn = nn.GRU(self.embedding_size, self.hidden_size, batch_first=True,num_layers = num_layers,bidirectional = bidirectional,dropout=self.rnn_dropout)
        elif self.rnn_type == "RNN":
            self.rnn = nn.RNN(self.embedding_size, self.hidden_size, batch_first=True,num_layers = num_layers,bidirectional = bidirectional,dropout=self.rnn_dropout)
        elif self.rnn_type == "LSTM":
            self.rnn = nn.LSTM(self.embedding_size, self.hidden_size, batch_first=True,num_layers = num_layers,bidirectional = bidirectional,dropout=self.rnn_dropout)

        self.dropout = nn.Dropout(dropout)

    def forward(self, input,hidden=None,cell=None):
        input_embedding = self.dropout(self.embedding(input))
        if self.rnn_type == "LSTM":
            output, (hidden, cell) = self.rnn(input_embedding, (hidden, cell))
        else:
            output, hidden = self.rnn(input_embedding)
            cell = None
        return output, hidden, cell

class Decoder(nn.Module):
    """
    The class to implement Decoder in the encoder-decoder architecture.
    """
    def __init__(self, hidden_size,embedding_size,target_vocab_size,rnn_type,batch_size,use_attention = True,padding_idx = None,num_layers = 1,bidirectional = False,dropout=0,device = "cpu"):
        super(Decoder, self).__init__()
        self.num_layers = num_layers
        self.rnn_type = rnn_type
        self.device = device
        self.D = 1
        self.hidden_size = hidden_size
        self.embedding_size = self.hidden_size
        self.use_attention = use_attention
        if bidirectional:
            self.D = 2
        ## In h0 (the input to the decoder) first dimension expected is number of directions X number of layers
        self.expected_h0_dim1 = self.D*self.num_layers
        ##create an embedding layer, and ignore padding index
        if self.use_attention:
            factor = self.D
        else:
            factor = 1
        self.embedding = nn.Embedding(target_vocab_size, self.embedding_size*factor,padding_idx = padding_idx)
        if self.use_attention:
            self.attention = BahdanauAttention(self.hidden_size,self.D,self.expected_h0_dim1,batch_size)
            recurrent_unit_input_dim = 2*self.D*self.hidden_size
        else:
            recurrent_unit_input_dim = self.embedding_size
        self.rnn_dropout = 0
        if self.num_layers>1:
            self.rnn_dropout = dropout
        if self.rnn_type == "GRU":
            self.rnn = nn.GRU(recurrent_unit_input_dim, self.hidden_size, batch_first=True,num_layers = num_layers,bidirectional = bidirectional,dropout=self.rnn_dropout)
        elif self.rnn_type == "RNN":
            self.rnn = nn.RNN(recurrent_unit_input_dim, self.hidden_size, batch_first=True,num_layers = num_layers,bidirectional = bidirectional,dropout=self.rnn_dropout)
        elif self.rnn_type == "LSTM":
            self.rnn = nn.LSTM(recurrent_unit_input_dim, self.hidden_size, batch_first=True,num_layers = num_layers,bidirectional = bidirectional,dropout=self.rnn_dropout)
        ## Passing the hidden state through a fully conencted layer and then applying softmax
        self.output_layer = nn.Linear(self.hidden_size*self.D, target_vocab_size)

    def forward(self, encoder_hidden_contexts, encoder_last_hidden,encoder_cell,target_tensor,eval_mode = False,teacher_forcing_ratio=0):
        batch_size = encoder_hidden_contexts.size(0)
        if not eval_mode:
            max_word_len = target_tensor.size(1)
        ## eval mode is for looking at a specific word that is predicted to compare with the correct word.
        if eval_mode:
            batch_size = 1
            max_word_len = 30 ## an arbitrary number, larger in expecected sense.

        decoder_outputs = []
        if self.use_attention:
            attentions = []
        else:
            attentions = None

        ## At the first time step < SOS > token (which has an id 0, is fed as an input to the decoder).
        decoder_input = torch.zeros((batch_size, 1), dtype=torch.long, device=self.device)
        decoder_hidden = encoder_last_hidden ## in the first time step of the decoder, the output of the encoder is the input.
        decoder_cell = encoder_cell ## the cell state, which is initially same as that of encoder, (applies to LSTM unit only)

        for step in range(max_word_len):
            ## eval mode is for looking at a specific word that is predicted to compare with the correct word.
            if eval_mode:
                decoder_input = decoder_input.view(1,-1)

            embedding = self.embedding(decoder_input)

            if decoder_hidden.shape[0] != self.expected_h0_dim1:
                reshaped_hidden = decoder_hidden.repeat(self.expected_h0_dim1,1,1)
            else:
                reshaped_hidden = decoder_hidden

            if self.use_attention:
                ## the attention part.
                decoder_prev_hidden = reshaped_hidden.permute(1, 0, 2)
                context_vector, attention_weights = self.attention(decoder_prev_hidden, encoder_hidden_contexts)
                tmp_input = torch.cat((embedding, context_vector), dim=2)
            else:
                ## introducing non-lineartiy through ReLU activation
                activated_embedding = F.relu(embedding)
                tmp_input = activated_embedding

            if self.rnn_type == "LSTM":
                tmp_output, (decoder_hidden, decoder_cell) = self.rnn(tmp_input, (reshaped_hidden, decoder_cell))
            else:
                tmp_output, decoder_hidden = self.rnn(tmp_input, reshaped_hidden)
                cell = None

            decoder_output = self.output_layer(tmp_output.squeeze(0))

            ## randomly sample a number in (0,1) and if the number is less than the teacher forcing ratio
            ## apply teacher forcing at the current step
            apply_teacher_forcing = random.random() < teacher_forcing_ratio

            if (target_tensor is not None) and (apply_teacher_forcing):
                ## Teacher forcing: Feed the target as the next input
                ## extract the 't'th token from th target string to feed as input at "t"th time step.
                decoder_input = target_tensor[:, step].unsqueeze(1) # Teacher forcing
            else:
                ##greedily pick predictions, i.e pick the character corresponding to the hightest probability
                _,preds = torch.max(decoder_output,dim=2)
                decoder_input = preds.detach()

            decoder_outputs.append(decoder_output)
            if self.use_attention:
                attentions.append(attention_weights)

        ## concatenate the predictions across all the timesteps into a singel tensor
        ## found in literature that log_softmax does better than softmax, hence going with that.
        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)

        ## the idea is to have a common API for both attention and normal decoder, achiveing ease of use.
        return decoder_outputs, decoder_hidden,attentions
''')

# Write Machine_Translator.py
with open('Machine_Translator.py', 'w') as f:
    f.write('''
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import random
from torch import optim
import wandb

from Encoder_Decoder_Architecture import *
from tqdm import tqdm

seed = 23
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

class MachineTranslator:
    """
    The class that instantiates the encoder-decoder architecture.
    """
    def __init__(self,source_vocab_size,target_vocab_size,hidden_size,embedding_size,rnn_type,batch_size,pad_token_id,dropout,num_layers,bidirectional,use_attention,device):
        self.device = device
        self.encoder = Encoder(source_vocab_size = source_vocab_size, hidden_size = hidden_size,embedding_size=embedding_size,rnn_type = rnn_type,padding_idx=pad_token_id,num_layers=num_layers,bidirectional=bidirectional,dropout=dropout).to(self.device)
        self.decoder = Decoder(hidden_size = hidden_size,embedding_size=embedding_size, target_vocab_size = target_vocab_size,batch_size = batch_size,rnn_type = rnn_type,use_attention = use_attention, padding_idx = pad_token_id,num_layers = num_layers,bidirectional = bidirectional,dropout=dropout,device=self.device).to(self.device)

    def train_epoch(self,train_loader, encoder, decoder, encoder_optim,decoder_optim, loss_criterion,teacher_forcing_ratio,ignore_padding=True,device='cpu'):
        """
        Method to train the encoder-decoder model for 1 epoch.
        """
        tot_correct_word_preds = 0
        tot_words = 0
        epoch_loss = 0

        for data in tqdm(train_loader):
            input_tensor, target_tensor,_,_ = data
            encoder_optim.zero_grad()
            decoder_optim.zero_grad()
            batch_size = data[0].shape[0]

            if encoder.rnn_type == "LSTM":
                encoder_hidden = torch.zeros(encoder.num_layers*encoder.D, batch_size, encoder.hidden_size, device=device)
                encoder_cell = torch.zeros(encoder.num_layers*encoder.D, batch_size, encoder.hidden_size, device=device)
            else:
                encoder_hidden = None
                encoder_cell = None

            encoder_hidden_contexts, encoder_last_hidden, encoder_cell = encoder(input_tensor,encoder_hidden,encoder_cell)
            decoder_outputs, _, _ = decoder(encoder_hidden_contexts, encoder_last_hidden,encoder_cell, target_tensor=target_tensor,teacher_forcing_ratio = teacher_forcing_ratio)

            multi_step_preds = torch.argmax(decoder_outputs,dim=2)
            multi_step_pred_correctness = (multi_step_preds ==  target_tensor)
            num_words = multi_step_preds.shape[0]

            if ignore_padding: ## if padding has to be ignored.
                ## for each word, based on the padding token ID, find the first occurance of the padding token, marking the begining of padding.
                ## argmax is not supported for bool on cuda, hence casting it to long.
                padding_start = torch.argmax((target_tensor == train_loader.dataset.pad_token_id).to(torch.long),dim=1).to(device)
                ## Creating a mask with 1's in each position of a padding token
                mask = (torch.arange(target_tensor.size(1)).unsqueeze(0).to(device) >= padding_start.unsqueeze(1))

                ##doing a logical OR with the mask makes sure that the padding tokens do not affect the correctness of the word
                tot_correct_word_preds += (torch.all(torch.logical_or(multi_step_pred_correctness,mask),dim=1).int().sum()).item()
                tot_words += num_words

            loss = loss_criterion(
                decoder_outputs.view(-1, decoder_outputs.size(-1)),
                target_tensor.view(-1)
            )
            loss.backward()
            encoder_optim.step()
            decoder_optim.step()
            epoch_loss += loss.item()

        epoch_loss = round(epoch_loss / len(train_loader),4)
        epoch_accuracy = round(tot_correct_word_preds*100/tot_words,2)
        return epoch_loss,epoch_accuracy

    def train(self,train_loader,valid_loader, encoder, decoder, epochs,padding_idx,optimiser = "adam",loss="crossentropy",weight_decay=0, lr=0.001,teacher_forcing_ratio = 0,device='cpu',wandb_logging = False):
        """
        The method to train the encoder-decoder model.
        """
        ## specify the optimiser
        if optimiser.lower() == "adam":
            encoder_optimizer = optim.Adam(encoder.parameters(), lr=lr,weight_decay=weight_decay)
            decoder_optimizer = optim.Adam(decoder.parameters(), lr=lr,weight_decay=weight_decay)
        elif optimiser.lower() == "nadam":
            encoder_optimizer = optim.NAdam(encoder.parameters(), lr=lr,weight_decay=weight_decay)
            decoder_optimizer = optim.NAdam(decoder.parameters(), lr=lr,weight_decay=weight_decay)
        elif optimiser.lower() == "rmsprop":
            encoder_optimizer = optim.RMSprop(encoder.parameters(), lr=lr,weight_decay=weight_decay)
            decoder_optimizer = optim.RMSprop(decoder.parameters(), lr=lr,weight_decay=weight_decay)

        ## Specify the loss criteria
        if loss.lower() == "crossentropy":
            loss_criterion = nn.CrossEntropyLoss(ignore_index = padding_idx).to(device)

        lp = train_loader.dataset.lp

        for epoch in tqdm(range(epochs)):
            ## Train for 1 epoch.
            train_loss,train_accuracy = self.train_epoch(train_loader, encoder, decoder, encoder_optimizer, decoder_optimizer, loss_criterion,teacher_forcing_ratio,device=device)
            ## compute validation accuracy.
            val_loss,_,val_accuracy = self.compute_accuracy(valid_loader,encoder,decoder,loss_criterion,ignore_padding=True,device=device)
            print(f"Epoch {epoch+1}\\t Train Loss : {train_loss}\\t Train Acc : {train_accuracy}% \\t Val Loss : {val_loss}\\t Val Acc : {val_accuracy}%")
            if wandb_logging:
                wandb.log({'epoch': epoch+1,'train loss': train_loss, 'train accuracy': train_accuracy, 'Validation loss': val_loss, 'Validation accuracy': val_accuracy})

    def compute_accuracy(self,dataloader,encoder,decoder,criterion,ignore_padding = True,device='cpu',save_results=False,filename=""):
        """
        Method to compute the accuracy using the model (encoder-decoder) using dataloader.
        """
        char_lvl_accuracy = 0
        word_level_accuracy = 0
        tot_chars = 0
        tot_words = 0
        tot_correct_char_preds = 0
        tot_correct_word_preds = 0
        loss = 0

        if save_results:
            rows = []

        with torch.no_grad():
            train = 0
            if encoder.training and decoder.training: ## reset the the model back to train mode
                train = 1
            encoder.eval()
            decoder.eval()

            for data in dataloader:
                input_tensor, target_tensor,_,target_max_len = data
                batch_size = data[0].shape[0]

                if encoder.rnn_type == "LSTM":
                    encoder_hidden = torch.zeros(encoder.num_layers*encoder.D, batch_size, encoder.hidden_size, device=device)
                    encoder_cell = torch.zeros(encoder.num_layers*encoder.D, batch_size, encoder.hidden_size, device=device)
                else:
                    encoder_hidden = None
                    encoder_cell = None

                encoder_hidden_contexts, encoder_last_hidden, encoder_cell = encoder(input_tensor,encoder_hidden,encoder_cell)
                ## even though we are passing target tensor, the teacher forcing ratio is 0, so no teacher forcing
                decoder_outputs, _, attentions = decoder(encoder_hidden_contexts, encoder_last_hidden,encoder_cell, target_tensor = target_tensor,teacher_forcing_ratio = 0)
                loss += criterion(decoder_outputs.view(-1, decoder_outputs.size(-1)), target_tensor.view(-1)).item()

                ## For a batch, for each character find the most probable output word.
                multi_step_preds = torch.argmax(decoder_outputs,dim=2)
                multi_step_pred_correctness = (multi_step_preds ==  target_tensor)
                num_chars = multi_step_preds.numel() ##find the total number of characters in the current batch
                num_words = multi_step_preds.shape[0] ##find the total number of words in the current batch.

                if ignore_padding: ## if padding has to be ignored.
                    ## for each word, based on the padding token ID, find the first occurance of the padding token, marking the begining of padding.
                    ## argmax is not supported for bool on cuda, hence casting it to long.
                    padding_start = torch.argmax((target_tensor == dataloader.dataset.pad_token_id).to(torch.long),dim=1).to(device)
                    ## Creating a mask with 1's in each position of a padding token
                    mask = (torch.arange(target_tensor.size(1)).unsqueeze(0).to(device) >= padding_start.unsqueeze(1))

                    ##doing a logical OR with the mask makes sure that the padding tokens do not affect the correctness of the word
                    tot_correct_word_preds += (torch.all(torch.logical_or(multi_step_pred_correctness,mask),dim=1).int().sum()).item()
                    tot_words += num_words

                    ##creating a complement of the mask so to mark padding tokens as irrelevant
                    complement_mask = (1-mask.int()).bool()
                    num_pad_chars = mask.int().sum().item()
                    ##counting number of non_pad_chars to compute accuracy.
                    num_non_pad_chars = num_chars - num_pad_chars

                    tot_correct_char_preds += (torch.logical_and(multi_step_pred_correctness,complement_mask).int().sum()).item()
                    tot_chars += num_non_pad_chars

                    if save_results:
                        word_preds_correctness = torch.all(torch.logical_or(multi_step_pred_correctness,mask),dim=1).int()
                        for i in range(multi_step_preds.shape[0]):
                            rows.append([dataloader.dataset.lp.decode_word(input_tensor[i].cpu().numpy(),lang_id=0),dataloader.dataset.lp.decode_word(target_tensor[i].cpu().numpy(),lang_id=1),dataloader.dataset.lp.decode_word(multi_step_preds[i].cpu().numpy(),lang_id=1),word_preds_correctness[i].cpu().item()])
                else: ##otherwise.
                    tot_correct_word_preds += (torch.all(multi_step_pred_correctness,dim=1).int().sum()).item()
                    tot_words += num_words
                    tot_correct_char_preds += (multi_step_pred_correctness.int().sum()).item()
                    tot_chars += num_chars

            char_lvl_accuracy = round(tot_correct_char_preds*100/tot_chars,2)
            word_lvl_accuracy = round(tot_correct_word_preds*100/tot_words,2)
            loss /= dataloader.dataset.data.shape[0]

            if save_results:
                df = pd.DataFrame(data=rows, columns=["Source Word","Target Word","Predicted Word","Is Prediction Correct"])
                df.to_csv(filename+".csv",index=False)

            if train:
                encoder.train()
                decoder.train()

            return round(loss,4),char_lvl_accuracy,word_lvl_accuracy
''')

# Write visualize_attention.py
with open('visualize_attention.py', 'w') as f:
    f.write('''
import torch
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from Core_Utils import *
from Encoder_Decoder_Architecture import *
from Machine_Translator import *
from functools import partial
from torch.utils.data import DataLoader

def visualize_attention(model, input_word, target_word, input_lang, target_lang, pad_token_id, device, ax=None):
    """
    Visualizes attention weights for a given input and target word pair.
    """
    # Create figure if not provided
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 8))

    # Encode input word
    input_seq = input_lang.encode_word(input_word, input_lang.source_lang, padding=False, append_eos=True)
    input_tensor = torch.tensor(input_seq).unsqueeze(0).to(device)

    # Set model to evaluation mode
    model.encoder.eval()
    model.decoder.eval()

    with torch.no_grad():
        batch_size = 1

        if model.encoder.rnn_type == "LSTM":
            encoder_hidden = torch.zeros(model.encoder.num_layers*model.encoder.D, batch_size,
                                         model.encoder.hidden_size, device=device)
            encoder_cell = torch.zeros(model.encoder.num_layers*model.encoder.D, batch_size,
                                        model.encoder.hidden_size, device=device)
        else:
            encoder_hidden = None
            encoder_cell = None

        # Run encoder
        encoder_outputs, encoder_hidden, encoder_cell = model.encoder(input_tensor, encoder_hidden, encoder_cell)

        # Run decoder with teacher forcing disabled
        decoder_outputs, _, attention_weights = model.decoder(
            encoder_outputs, encoder_hidden, encoder_cell, target_tensor=None,
            eval_mode=True, teacher_forcing_ratio=0
        )

        # Get predicted output
        _, predicted_indices = torch.max(decoder_outputs, dim=2)
        predicted_word = target_lang.decode_word(predicted_indices[0].cpu().numpy(), target_lang.target_lang)
        predicted_word = ''.join(predicted_word)

        # Concatenate attention weights
        attention = torch.cat([a for a in attention_weights], 0).squeeze(1).cpu().numpy()

    # Get input characters (remove <end> token)
    input_chars = [input_lang.source_id2char[i] for i in input_seq if i != pad_token_id and i != input_lang.end_token_id]

    # Get output characters from prediction
    output_chars = list(predicted_word)

    # Trim attention matrix to actual sequence length
    attention = attention[:len(output_chars), :len(input_chars)]

    # Create heatmap
    cax = ax.matshow(attention, cmap='viridis')

    # Set labels
    ax.set_xticklabels([''] + input_chars, rotation=90)
    ax.set_yticklabels([''] + output_chars)

    # Major ticks
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    # Add title with input, prediction and target
    ax.set_title(f'Input: {input_word}\\nPrediction: {predicted_word}\\nTarget: {target_word}')

    return ax

def visualize_examples(model, test_loader, num_examples=9):
    """
    Visualize attention weights for multiple examples
    """
    # Get a batch of examples
    batch = next(iter(test_loader))
    input_tensors, target_tensors = batch[0], batch[1]

    # Create a grid of figures
    rows = int(np.ceil(num_examples / 3))
    fig, axes = plt.subplots(rows, 3, figsize=(18, 6*rows))
    axes = axes.flatten()

    lp = test_loader.dataset.lp
    pad_token_id = lp.pad_token_id

    for i in range(min(num_examples, len(input_tensors))):
        # Get input and target words
        input_seq = input_tensors[i].cpu().numpy()
        target_seq = target_tensors[i].cpu().numpy()

        input_word = ''.join(lp.decode_word(input_seq, lp.source_lang))
        target_word = ''.join(lp.decode_word(target_seq, lp.target_lang))

        # Visualize attention for this example
        visualize_attention(model, input_word, target_word, lp, lp, pad_token_id, model.device, ax=axes[i])

    # Hide any unused subplots
    for i in range(num_examples, len(axes)):
        axes[i].axis('off')

    plt.tight_layout()
    plt.savefig('attention_heatmaps.png')
    plt.show()
    return fig
''')

# Write trainer.py
with open('trainer.py', 'w') as f:
    f.write('''
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import random
from functools import partial
import argparse
import wandb
from tqdm import tqdm

from Core_Utils import *
from Encoder_Decoder_Architecture import *
from Machine_Translator import *

seed = 23
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

def train_model(config, model_type="vanilla", wandb_log=False):
    """
    Main function to train a seq2seq model.

    Args:
        config: Dictionary with model hyperparameters
        model_type: "vanilla" or "attention"
        wandb_log: Whether to log metrics to wandb

    Returns:
        Trained model
    """
    batch_size = config['batch_size']
    target_lang = "tel"
    base_dir = "data_folder/"

    # Setup device
    if config['device'] == None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(config['device'])

    print(f"Using device: {device}")

    # Setup data
    use_meta_tokens = True
    lang_dir = base_dir + target_lang + "/"

    # Create train loader
    train_lp = LanguageProcessor(language_directory=lang_dir, target_lang_name=target_lang, mode="train", meta_tokens=use_meta_tokens)
    pad_token_id = train_lp.source_char2id['<pad>']

    collate_fn_ptr = partial(collate_fn, pad_token_id=pad_token_id, device=device)

    train_dataset = WordDataset(train_lp, device=device)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn_ptr, shuffle=True)

    # Create test loader
    test_lp = LanguageProcessor(language_directory=lang_dir, target_lang_name=target_lang, mode="test", meta_tokens=use_meta_tokens)
    test_lp.source_char2id = train_lp.source_char2id
    test_lp.source_id2char = train_lp.source_id2char
    test_lp.target_char2id = train_lp.target_char2id
    test_lp.target_id2char = train_lp.target_id2char

    test_dataset = WordDataset(test_lp, device=device)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn_ptr, shuffle=True)

    # Create validation loader
    valid_lp = LanguageProcessor(language_directory=lang_dir, target_lang_name=target_lang, mode="valid", meta_tokens=use_meta_tokens)
    valid_lp.source_char2id = train_lp.source_char2id
    valid_lp.source_id2char = train_lp.source_id2char
    valid_lp.target_char2id = train_lp.target_char2id
    valid_lp.target_id2char = train_lp.target_id2char

    valid_dataset = WordDataset(valid_lp, device=device)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate_fn_ptr, shuffle=True)

    # Get vocabulary sizes
    source_vocab_size = len(list(train_lp.source_char2id.keys()))
    target_vocab_size = len(list(train_lp.target_char2id.keys()))

    # Extract hyperparameters
    hidden_size = config['hidden_size']
    embedding_size = hidden_size
    epochs = config['epochs']
    optimiser = config['optimiser']
    weight_decay = config['weight_decay']
    lr = config['lr']
    num_layers = config['num_layers']
    rnn_type = config['rnn_type'].upper()
    bidirectional = config['bidirectional']
    teacher_forcing_ratio = config['teacher_forcing_ratio']
    dropout = config['dropout']

    # Set use_attention based on model type
    use_attention = True if model_type == "attention" else False

    # Create model
    model = MachineTranslator(
        source_vocab_size=source_vocab_size,
        target_vocab_size=target_vocab_size,
        hidden_size=hidden_size,
        embedding_size=embedding_size,
        rnn_type=rnn_type,
        batch_size=batch_size,
        pad_token_id=pad_token_id,
        dropout=dropout,
        num_layers=num_layers,
        bidirectional=bidirectional,
        use_attention=use_attention,
        device=device
    )

    # Train model
    model.train(
        train_loader=train_loader,
        valid_loader=valid_loader,
        encoder=model.encoder,
        decoder=model.decoder,
        epochs=epochs,
        padding_idx=pad_token_id,
        optimiser=optimiser,
        weight_decay=weight_decay,
        lr=lr,
        teacher_forcing_ratio=teacher_forcing_ratio,
        device=device,
        wandb_logging=wandb_log
    )

    # Evaluate on test set
    loss_criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id).to(device)
    test_loss, char_acc, test_accuracy = model.compute_accuracy(
        test_loader,
        model.encoder,
        model.decoder,
        loss_criterion,
        ignore_padding=True,
        device=device,
        save_results=True,
        filename=f"predictions_{model_type}"
    )
    print(f"Testing Loss: {test_loss}\\tChar Accuracy: {char_acc}%\\tWord Accuracy: {test_accuracy}%")

    return model

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # Model type
    parser.add_argument("--model", type=str, default="vanilla", choices=["vanilla", "attention"],
                       help="Model type: vanilla seq2seq or seq2seq with attention")

    # Hyperparameters
    parser.add_argument("-b", "--batch_size", type=int, default=64,
                       help="Batch size used to train neural network.")
    parser.add_argument("-bid", "--bidirectional", type=bool, default=True, choices=[True, False],
                       help="If True, input would be seen in both directions.")
    parser.add_argument("-dpt", "--dropout", type=float, default=0.2,
                       help="The dropout probability.")
    parser.add_argument("-es", "--embedding_size", type=int, default=256,
                       help="The input embedding dimension.")
    parser.add_argument("-e", "--epochs", type=int, default=15,
                       help="Number of epochs to train.")
    parser.add_argument("-hs", "--hidden_size", type=int, default=256,
                       help="The dimension of the hidden state.")
    parser.add_argument("-lr", "--learning_rate", type=float, default=3e-4,
                       help="Learning rate used to optimize model parameters.")
    parser.add_argument("-nl", "--num_layers", type=int, default=2,
                       help="Number of Recurrence Layers.")
    parser.add_argument("-o", "--optimizer", type=str, default="nadam", choices=["rmsprop", "adam", "nadam"],
                       help="Optimizer used to minimize the loss.")
    parser.add_argument("-rt", "--rnn_type", type=str, default="lstm", choices=["lstm", "gru", "rnn"],
                       help="The type of recurrent cell to be used.")
    parser.add_argument("-tf", "--teacher_forcing", type=float, default=0.4,
                       help="The Teacher Forcing Ratio.")
    parser.add_argument("-w_d", "--weight_decay", type=float, default=1e-5,
                       help="Weight decay used by optimizers.")
    parser.add_argument("-d", "--device", type=str, default=None,
                       help="The device on which the training happens.")
    parser.add_argument("--wandb", action="store_true",
                       help="Log metrics to Weights & Biases")

    args = parser.parse_args()

    config = {
        'batch_size': args.batch_size,
        'bidirectional': args.bidirectional,
        'dropout': args.dropout,
        'embedding_size': args.embedding_size,
        'epochs': args.epochs,
        'hidden_size': args.hidden_size,
        'lr': args.learning_rate,
        'num_layers': args.num_layers,
        'optimiser': args.optimizer,
        'rnn_type': args.rnn_type,
        'teacher_forcing_ratio': args.teacher_forcing,
        'weight_decay': args.weight_decay,
        'device': args.device
    }

    if args.wandb:
        # Initialize wandb
        wandb.init(
            project=f"cs6910_assignment3_{args.model}",
            config=config
        )

    model = train_model(config, model_type=args.model, wandb_log=args.wandb)
''')

# Create data directory
!mkdir -p data_folder/tel

# Copy data files (adjust paths to match your Google Drive data location)
data_base_path = "/content/drive/MyDrive/dakshina_dataset_v1.0/te/lexicons/"
!cp {data_base_path}te.translit.sampled.train.tsv data_folder/tel/train.txt
!cp {data_base_path}te.translit.sampled.dev.tsv data_folder/tel/valid.txt
!cp {data_base_path}te.translit.sampled.test.tsv data_folder/tel/test.txt

print("Setup complete! Files created:")
!ls -la

# To train a model, uncomment one of these lines:
!python trainer.py --model vanilla -b 64 -bid True -dpt 0.2 -es 128 -e 15 -hs 256 -nl 2 -o nadam -rt lstm -tf 0.4 -w_d 1e-5
# !python trainer.py --model attention -b 64 -bid True -dpt 0.0 -es 256 -e 15 -hs 512 -nl 2 -o rmsprop -rt gru -tf 0.4 -w_d 1e-5

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 44% 405/915 [00:28<00:32, 15.81it/s][A
 44% 407/915 [00:28<00:32, 15.69it/s][A
 45% 409/915 [00:28<00:31, 15.89it/s][A
 45% 411/915 [00:29<00:32, 15.45it/s][A
 45% 413/915 [00:29<00:32, 15.61it/s][A
 45% 415/915 [00:29<00:31, 15.69it/s][A
 46% 417/915 [00:29<00:32, 15.47it/s][A
 46% 419/915 [00:29<00:32, 15.08it/s][A
 46% 421/915 [00:29<00:32, 15.24it/s][A
 46% 423/915 [00:29<00:32, 15.37it/s][A
 46% 425/915 [00:29<00:30, 15.90it/s][A
 47% 427/915 [00:30<00:33, 14.49it/s][A
 47% 429/915 [00:30<00:32, 14.86it/s][A
 47% 431/915 [00:30<00:31, 15.52it/s][A
 47% 433/915 [00:30<00:31, 15.27it/s][A
 48% 435/915 [00:30<00:30, 15.81it/s][A
 48% 437/915 [00:30<00:30, 15.76it/s][A
 48% 439/915 [00:30<00:29, 16.15it/s][A
 48% 441/915 [00:31<00:29, 15.86it/s][A
 48% 443/915 [00:31<00:30, 15.24it/s][A
 49% 445/915 [00:31<00:30, 15.19it/s][A
 49% 447/915 [00:31<00:30, 15.42it/s][A
 49% 449/915 [00:31<00:30, 15.52i