In [25]:
# Importing the necessary libraries needed
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms.functional as Fn
import torch.nn.functional as F
import lightning as L
from lightning.pytorch import Trainer
from torch.utils.data import DataLoader, Subset, Dataset
from lightning.pytorch.loggers import WandbLogger
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split
import os
from tqdm import tqdm

torch.set_printoptions(linewidth=50)
np.set_printoptions(linewidth=50)

In [26]:
# Data preparation
# Loading the dataset
df_train = pd.read_csv('/home/joel/DA6401_DL/DA6401_A03/ta_lexicons/ta.translit.sampled.train.tsv', sep='\t',  header=None, names=["native","latin","count"])
df_test = pd.read_csv('/home/joel/DA6401_DL/DA6401_A03/ta_lexicons/ta.translit.sampled.test.tsv', sep='\t',  header=None, names=["native","latin","count"])
df_val = pd.read_csv('/home/joel/DA6401_DL/DA6401_A03/ta_lexicons/ta.translit.sampled.dev.tsv', sep='\t',  header=None, names=["native","latin","count"])


# Show first few rows
print(df_train.head())

      native    latin  count
0     ஃபியட்     fiat      2
1     ஃபியட்   phiyat      1
2     ஃபியட்    piyat      1
3  ஃபிரான்ஸ்  firaans      1
4  ஃபிரான்ஸ்   france      2


In [27]:
# Building the dataset for the Seq2Seq model
class Dataset_Tamil(Dataset):
    def __init__(self, dataframe, build_vocab=True, input_token_index=None, output_token_index=None,
                 max_enc_seq_len=0, max_dec_seq_len=0):
        
        # Input variables
        self.input_df = dataframe
        self.input_words = []
        self.output_words = []
        # Characters of the language
        self.input_characters = set()
        self.output_characters = set()

        # Iterating thorough the rows
        for _, row in self.input_df.iterrows():
            input_word = str(row["latin"])
            output_word = "\t" + str(row["native"]) + "\n"
            self.input_words.append(input_word)
            self.output_words.append(output_word)
        
        if build_vocab:
            self.build_vocab()
        else:
            # Token index for sequence building
            self.input_token_index = input_token_index
            self.output_token_index = output_token_index
            # Heuristics lengths for the encoder decoder
            self.max_enc_seq_len = max_enc_seq_len
            self.max_dec_seq_len = max_dec_seq_len

        # Finding the encoder/decoder tokens 
        self.total_encoder_tokens = len(self.input_token_index)
        self.total_decoder_tokens = len(self.output_token_index)

    def build_vocab(self):
        # Building the vocabulary
        self.input_characters = sorted(set(" ".join(self.input_words)))
        self.output_characters = sorted(set(" ".join(self.output_words)))
        # Adding the padding character if not present
        if " " not in self.input_characters:
            self.input_characters.append(" ")
        if " " not in self.output_characters:
            self.output_characters.append(" ")

        # Fitting/Finding the necessary values from training data
        self.input_token_index = {char: i for i, char in enumerate(self.input_characters)}
        self.output_token_index = {char: i for i, char in enumerate(self.output_characters)}

        self.max_enc_seq_len = max(len(txt) for txt in self.input_words)
        self.max_dec_seq_len = max(len(txt) for txt in self.output_words)

    def __len__(self):
        return len(self.input_words)
    
    def __getitem__(self, index):
        input_word = self.input_words[index]
        output_word = self.output_words[index]

        # Finding the input for each stages of the network
        encoder_input = np.zeros((self.max_enc_seq_len, self.total_encoder_tokens), dtype=np.float32)
        decoder_input = np.zeros((self.max_dec_seq_len, self.total_decoder_tokens), dtype=np.float32)
        decoder_output = np.zeros((self.max_dec_seq_len, self.total_decoder_tokens), dtype=np.float32)

        for t, char in enumerate(input_word):
            if char in self.input_token_index:
                encoder_input[t, self.input_token_index[char]] = 1.0
        for t in range(len(input_word), self.max_enc_seq_len):
            encoder_input[t, self.input_token_index[" "]] = 1.0

        for t, char in enumerate(output_word):
            if char in self.output_token_index:
                decoder_input[t, self.output_token_index[char]] = 1.0
                if t > 0:
                    decoder_output[t - 1, self.output_token_index[char]] = 1.0
        # Fill remaining positions with space character
        for t in range(len(output_word), self.max_dec_seq_len):
            decoder_input[t, self.output_token_index[" "]] = 1.0

        # Ensure decoder_output is padded *after* last real target (t - 1 from above loop)
        for t in range(len(output_word) - 1, self.max_dec_seq_len):
            decoder_output[t, self.output_token_index[" "]] = 1.0

        return (
            torch.from_numpy(encoder_input),
            torch.from_numpy(decoder_input),
            torch.from_numpy(decoder_output)
        )

In [28]:
# Loading the datasets and dataloaders
train_dataset = Dataset_Tamil(df_train)
val_dataset = Dataset_Tamil(df_val, build_vocab=False, input_token_index=train_dataset.input_token_index, 
                            output_token_index=train_dataset.output_token_index, max_enc_seq_len=train_dataset.max_enc_seq_len,
                            max_dec_seq_len=train_dataset.max_dec_seq_len)
test_dataset = Dataset_Tamil(df_test, build_vocab=False, input_token_index=train_dataset.input_token_index, 
                            output_token_index=train_dataset.output_token_index, max_enc_seq_len=train_dataset.max_enc_seq_len,
                            max_dec_seq_len=train_dataset.max_dec_seq_len)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)


In [None]:
# Model classes definitions #
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, dropout=0.3, cell_type="RNN", num_layers=1, bi_directional=False):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.cell_type = cell_type.upper()
        self.dropout = dropout
        self.num_layers = num_layers

        if self.cell_type == 'LSTM':
            self.enc = nn.LSTM(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers, bidirectional=bi_directional)
        elif self.cell_type == 'GRU':
            self.enc = nn.GRU(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers, bidirectional=bi_directional)
        else:
            self.enc = nn.RNN(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers, bidirectional=bi_directional)

    def forward(self, x):
        if self.cell_type == "LSTM":
            hidden, (hn, cn) = self.enc(x)
            return hidden, (hn, cn)
        else:
            hidden, out = self.enc(x)
            return hidden, out
        
class Attention_Mechanism(nn.Module):
    def __init__(self, hidden_dim, device="cpu"):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.softmax = nn.Softmax(dim=1)
        self.tanh = nn.Tanh()
        # Creating the matrices for attention calculation
        self.V_att = nn.Parameter(torch.randn(size=(self.hidden_enc_dim, 1), device=device)*0.1)
        self.U_att = nn.Parameter(torch.randn(size=(self.hidden_dim, self.hidden_dim), device=device)*0.1)
        self.W_att = nn.Parameter(torch.randn(size=(self.hidden_dim, self.hidden_dim), device=device)*0.1)

    def forward(self, st_1, c_j, mask):
        # Compute the attention scores and softmax
        """
        st_1 : input of size (bx1xd)
        c_j : input of size (bxLxd)
        """

        inside = self.tanh(torch.matmul(c_j, self.W_att) + torch.matmul(st_1, self.U_att))
        scores = torch.matmul(inside, self.V_att).unsqueeze(2)
        scores_mask = scores[mask] = -torch.inf

        attention = self.softmax(scores_mask)
        return attention
    
class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size, dropout=0.3, cell_type='RNN', num_layers=1):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.cell_type = cell_type.upper()
        self.dropout = dropout
        self.num_layers = num_layers

        if self.cell_type == 'LSTM':
            self.dec = nn.LSTM(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers)
        elif self.cell_type == 'GRU':
            self.dec = nn.GRU(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers)
        else:
            self.dec = nn.RNN(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers)

    def forward(self, x, states):
        if type(states) == tuple:
            hidden, (hn, cn) = self.dec(x, states)
            return hidden, (hn, cn)
        else:
            hidden, out = self.dec(x, states)
            return hidden, out
        
class Attention_Seq2Seq(nn.Module):
    def __init__(self, input_token_index, output_token_index, max_dec_seq_len, embedding_dim,hidden_size_enc, bi_directional=False,
            nature="train", enc_cell="LSTM", dec_cell="LSTM", num_layers=1,dropout=0.2, device="cpu"):
        super().__init__()

        self.input_index_token = input_token_index
        self.output_index_token = output_token_index
        self.max_dec_seq_len = max_dec_seq_len
        self.nature = nature
        self.enc_cell_type = enc_cell.upper()
        self.dec_cell_type = dec_cell.upper()
        self.num_layers= num_layers
        self.bi_directional = bi_directional
        self.hidden_size_enc = hidden_size_enc
        self.hidden_size_dec = (1 + int(self.bi_directional == True))*hidden_size_enc
        self.embedding = nn.Linear(in_features=len(self.input_index_token), out_features=embedding_dim)
        self.embedding_act = nn.Tanh()
        self.encoder = Encoder(input_size=embedding_dim, hidden_size=hidden_size_enc, dropout=dropout, cell_type=enc_cell, num_layers=num_layers, bi_directional=self.bi_directional).to(device)
        self.attention = Attention_Mechanism(hidden_dim=self.hidden_size_dec)
        self.decoder = Decoder(input_size=len(self.output_index_token)+self.hidden_size_dec, hidden_size=self.hidden_size_dec, dropout=dropout, cell_type=dec_cell, num_layers=num_layers).to(device)
        self.device = device
        self.loss_fn = nn.CrossEntropyLoss()
        self.fc = nn.Linear(in_features=self.hidden_size_dec, out_features=len(output_token_index))

    def forward(self, batch):
        ENC_IN, DEC_IN, DEC_OUT = batch
        ENC_IN = ENC_IN.to(self.device)
        DEC_IN = DEC_IN.to(self.device)

        batch_size = ENC_IN.size(0)
        input_embedding = self.embedding_act(self.embedding(ENC_IN))
        hidden_enc, states_enc = self.encoder(input_embedding)

        # Final matrix
        final_out = torch.zeros(batch_size, self.max_dec_seq_len, len(self.output_index_token), device=self.device)

        # Initial decoder input (with start token)
        in_ = DEC_IN[:, 0, :]

        for t in range(self.max_dec_seq_len):
            if t==0:
                out_step, states_dec = self.decoder(torch.cat(in_, hidden_enc[0][:,-1,:]), dim=2)  # (B, 1, H)
            else:
                out_step, states_dec = self.decoder(in_, states_dec)  # (B, 1, H)

            logits_step = self.fc(out_step.squeeze(1))            # (B, V)
            final_out[:, t, :] = logits_step

            # input for next input
            in_ = DEC_IN[:, t+1, :]
            mask_ = torch.argmax(in_, 2) == 2
            att_scores = self.attention((out_step.squeeze(1)).unsqueeze(2), hidden_enc, mask_)

            in_ = torch.cat(in_, torch.bmm(att_scores.unsqueeze(1), hidden_enc), dim=2)s
        return final_out
    
    def predict_greedy(self, batch):
        ENC_IN, DEC_IN, DEC_OUT = batch
        ENC_IN = ENC_IN.to(self.device)
        DEC_IN = DEC_IN.to(self.device)

        batch_size = ENC_IN.size(0)
        input_embedding = self.embedding_act(self.embedding(ENC_IN))
        mask_ = torch.argmax(ENC_IN, 2) == 2
        hidden_enc, states_enc = self.encoder(input_embedding)

        # Final matrix
        final_out = torch.zeros(batch_size, self.max_dec_seq_len, len(self.output_index_token), device=self.device)

        # Initial decoder input (with start token)
        in_ = torch.zeros(batch_size, 1, len(self.output_index_token), device=self.device)
        in_[:, 0, 0] = 1.0

        for t in range(self.max_dec_seq_len):
            if t==0:
                out_step, states_dec = self.decoder(torch.cat(in_, hidden_enc[0][:,-1,:]), dim=2)  # (B, 1, H)
            else:
                out_step, states_dec = self.decoder(in_, states_dec)  # (B, 1, H)

            logits_step = self.fc(out_step.squeeze(1))            # (B, V)
            final_out[:, t, :] = logits_step

            # Greedy argmax for next input
            top1 = torch.argmax(logits_step, dim=1)               # (B,)
            in_ = torch.zeros(batch_size, 1, len(self.output_index_token), device=self.device)
            in_[torch.arange(batch_size), 0, top1] = 1.0
            att_scores = self.attention((out_step.squeeze(1)).unsqueeze(2), hidden_enc, mask_)

            in_ = torch.cat(in_, torch.bmm(att_scores.unsqueeze(1), hidden_enc), dim=2)
        return final_out
    

