# Assignment 03 - AE21B105
This notebook will be used as the base version of testing the code written and for the sweeps later on. Then once finalized and all good this will be transfered to a script with command line arguments. Lets begin !!!

In [1]:
# Importing the necessary libraries needed
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
import lightning as L
from lightning.pytorch import Trainer
from torch.utils.data import DataLoader, Subset, Dataset
from lightning.pytorch.loggers import WandbLogger
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split
import os
from tqdm import tqdm

torch.set_printoptions(linewidth=50)
np.set_printoptions(linewidth=50)

In [2]:
# Data preparation
# Loading the dataset
df_train = pd.read_csv('/home/joel/DA6401_DL/DA6401_A03/ta_lexicons/ta.translit.sampled.train.tsv', sep='\t',  header=None, names=["native","latin","count"])
df_test = pd.read_csv('/home/joel/DA6401_DL/DA6401_A03/ta_lexicons/ta.translit.sampled.test.tsv', sep='\t',  header=None, names=["native","latin","count"])
df_val = pd.read_csv('/home/joel/DA6401_DL/DA6401_A03/ta_lexicons/ta.translit.sampled.dev.tsv', sep='\t',  header=None, names=["native","latin","count"])


# Show first few rows
print(df_train.head())

      native    latin  count
0     ஃபியட்     fiat      2
1     ஃபியட்   phiyat      1
2     ஃபியட்    piyat      1
3  ஃபிரான்ஸ்  firaans      1
4  ஃபிரான்ஸ்   france      2


In [3]:
# Building the dataset for the Seq2Seq model
class Dataset_Tamil(Dataset):
    def __init__(self, dataframe, build_vocab=True, input_token_index=None, output_token_index=None,
                 max_enc_seq_len=0, max_dec_seq_len=0):
        
        # Input variables
        self.input_df = dataframe
        self.input_words = []
        self.output_words = []
        # Characters of the language
        self.input_characters = set()
        self.output_characters = set()

        # Iterating thorough the rows
        for _, row in self.input_df.iterrows():
            input_word = str(row["latin"])
            output_word = "\t" + str(row["native"]) + "\n"
            self.input_words.append(input_word)
            self.output_words.append(output_word)
        
        if build_vocab:
            self.build_vocab()
        else:
            # Token index for sequence building
            self.input_token_index = input_token_index
            self.output_token_index = output_token_index
            # Heuristics lengths for the encoder decoder
            self.max_enc_seq_len = max_enc_seq_len
            self.max_dec_seq_len = max_dec_seq_len

        # Finding the encoder/decoder tokens 
        self.total_encoder_tokens = len(self.input_token_index)
        self.total_decoder_tokens = len(self.output_token_index)

    def build_vocab(self):
        # Building the vocabulary
        self.input_characters = sorted(set(" ".join(self.input_words)))
        self.output_characters = sorted(set(" ".join(self.output_words)))
        # Adding the padding character if not present
        if " " not in self.input_characters:
            self.input_characters.append(" ")
        if " " not in self.output_characters:
            self.output_characters.append(" ")

        # Fitting/Finding the necessary values from training data
        self.input_token_index = {char: i for i, char in enumerate(self.input_characters)}
        self.output_token_index = {char: i for i, char in enumerate(self.output_characters)}

        self.max_enc_seq_len = max(len(txt) for txt in self.input_words)
        self.max_dec_seq_len = max(len(txt) for txt in self.output_words)

    def __len__(self):
        return len(self.input_words)
    
    def __getitem__(self, index):
        input_word = self.input_words[index]
        output_word = self.output_words[index]

        # Finding the input for each stages of the network
        encoder_input = np.zeros((self.max_enc_seq_len, self.total_encoder_tokens), dtype=np.float32)
        decoder_input = np.zeros((self.max_dec_seq_len, self.total_decoder_tokens), dtype=np.float32)
        decoder_output = np.zeros((self.max_dec_seq_len, self.total_decoder_tokens), dtype=np.float32)

        for t, char in enumerate(input_word):
            if char in self.input_token_index:
                encoder_input[t, self.input_token_index[char]] = 1.0
        for t in range(len(input_word), self.max_enc_seq_len):
            encoder_input[t, self.input_token_index[" "]] = 1.0

        for t, char in enumerate(output_word):
            if char in self.output_token_index:
                decoder_input[t, self.output_token_index[char]] = 1.0
                if t > 0:
                    decoder_output[t - 1, self.output_token_index[char]] = 1.0
        # Fill remaining positions with space character
        for t in range(len(output_word), self.max_dec_seq_len):
            decoder_input[t, self.output_token_index[" "]] = 1.0

        # Ensure decoder_output is padded *after* last real target (t - 1 from above loop)
        for t in range(len(output_word) - 1, self.max_dec_seq_len):
            decoder_output[t, self.output_token_index[" "]] = 1.0

        return (
            torch.from_numpy(encoder_input),
            torch.from_numpy(decoder_input),
            torch.from_numpy(decoder_output)
        )

In [4]:
# Loading the datasets and dataloaders
train_dataset = Dataset_Tamil(df_train)
val_dataset = Dataset_Tamil(df_val, build_vocab=False, input_token_index=train_dataset.input_token_index, 
                            output_token_index=train_dataset.output_token_index, max_enc_seq_len=train_dataset.max_enc_seq_len,
                            max_dec_seq_len=train_dataset.max_dec_seq_len)
test_dataset = Dataset_Tamil(df_test, build_vocab=False, input_token_index=train_dataset.input_token_index, 
                            output_token_index=train_dataset.output_token_index, max_enc_seq_len=train_dataset.max_enc_seq_len,
                            max_dec_seq_len=train_dataset.max_dec_seq_len)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)


In [None]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, dropout=0.3, cell_type="RNN", num_layers=1):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.cell_type = cell_type.upper()
        self.dropout = dropout
        self.num_layers = num_layers

        if self.cell_type == 'LSTM':
            self.enc = nn.LSTM(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers)
        elif self.cell_type == 'GRU':
            self.enc = nn.GRU(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers)
        else:
            self.enc = nn.RNN(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers)

    def forward(self, x):
        if self.cell_type == "LSTM":
            hidden, (hn, cn) = self.enc(x)
            return hidden, (hn, cn)
        else:
            hidden, out = self.enc(x)
            return hidden, out
        

class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size, dropout=0.3, cell_type='RNN', num_layers=1):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.cell_type = cell_type.upper()
        self.dropout = dropout
        self.num_layers = num_layers

        if self.cell_type == 'LSTM':
            self.dec = nn.LSTM(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers)
        elif self.cell_type == 'GRU':
            self.dec = nn.GRU(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers)
        else:
            self.dec = nn.RNN(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers)

    def forward(self, x, states):
        if type(states) == tuple:
            hidden, (hn, cn) = self.dec(x, states)
            return hidden, (hn, cn)
        else:
            hidden, out = self.dec(x, states)
            return hidden, out
        

class Seq2Seq(nn.Module):
    def __init__(self, input_token_index, output_token_index, max_dec_seq_len, embedding_dim,hidden_size_enc, hidden_size_dec, nature="train", enc_cell="LSTM", dec_cell="LSTM", num_layers=1, dropout=0.2, device="cpu"):
        super(Seq2Seq, self).__init__()
        self.input_index_token = input_token_index
        self.output_index_token = output_token_index
        self.max_dec_seq_len = max_dec_seq_len
        self.nature = nature
        self.enc_cell_type = enc_cell.upper()
        self.dec_cell_type = dec_cell.upper()
        self.num_layers = num_layers
        self.embedding = nn.Linear(in_features=len(self.input_index_token), out_features=embedding_dim)
        self.embedding_act = nn.Tanh()
        self.encoder = Encoder(input_size=embedding_dim, hidden_size=hidden_size_enc, dropout=dropout, cell_type=enc_cell, num_layers=num_layers).to(device)
        self.decoder = Decoder(input_size=len(self.output_index_token), hidden_size=hidden_size_dec, dropout=dropout, cell_type=dec_cell, num_layers=num_layers).to(device)
        self.device = device
        self.loss_fn = nn.CrossEntropyLoss()
        self.fc = nn.Linear(in_features=hidden_size_dec, out_features=len(output_token_index))

    def forward(self, batch):
        ENC_IN, DEC_IN, DEC_OUT = batch
        ENC_IN = ENC_IN.to(self.device)
        DEC_IN = DEC_IN.to(self.device)

        batch_size = ENC_IN.size(0)
        input_embedding = self.embedding_act(self.embedding(ENC_IN))
        hidden_enc, states_enc = self.encoder(input_embedding)

        # Teacher forcing mode #    
        # Making the states correctly formatted
        if self.dec_cell_type == "LSTM": 
            if isinstance(states_enc, tuple):
                states_dec = states_enc
            else:
                h = torch.zeros(self.num_layers, batch_size, self.decoder.hidden_size, device=self.device)
                c = states_enc
                states_dec = (h, c)
        else:
            if isinstance(states_enc, tuple):
                states_dec = states_enc[1]

        # Decoder gives the outputs batchwise
        decoder_outputs, _ = self.decoder(DEC_IN, states_dec)  # (B, T, H)
        logits = self.fc(decoder_outputs)                      # (B, T, Vocab)
        return logits

    def predict_greedy(self, batch):
        # Greedy force outputs #
        ENC_IN, DEC_IN, DEC_OUT = batch
        ENC_IN = ENC_IN.to(self.device)
        DEC_IN = DEC_IN.to(self.device)

        batch_size = ENC_IN.size(0)
        input_embedding = self.embedding_act(self.embedding(ENC_IN))
        hidden_enc, states_enc = self.encoder(input_embedding)
            
        # Final matrix
        final_out = torch.zeros(batch_size, self.max_dec_seq_len, len(self.output_index_token), device=self.device)

        # Initial decoder input (with start token)
        in_ = torch.zeros(batch_size, 1, len(self.output_index_token), device=self.device)
        in_[:, 0, 0] = 1.0
        # Making the states correctly formatted
        if self.dec_cell_type == "LSTM":
            if isinstance(states_enc, tuple):
                states_dec = states_enc
            else:
                h = torch.zeros(self.num_layers, batch_size, self.decoder.hidden_size, device=self.device)
                c = states_enc
                states_dec = (h, c)
        else:
            if isinstance(states_enc, tuple):
                states_dec = states_enc[1]

        # Output to input
        for t in range(self.max_dec_seq_len):
            out_step, states_dec = self.decoder(in_, states_dec)  # (B, 1, H)
            logits_step = self.fc(out_step.squeeze(1))            # (B, V)
            final_out[:, t, :] = logits_step

            # Greedy argmax for next input
            top1 = torch.argmax(logits_step, dim=1)               # (B,)
            in_ = torch.zeros(batch_size, 1, len(self.output_index_token), device=self.device)
            in_[torch.arange(batch_size), 0, top1] = 1.0

        return final_out
    
    def predict_beam_search(self, batch, beam_width = 3):
        ENC_IN, _, _ = batch
        ENC_IN = ENC_IN.to(self.device)

        # Encoder inputs #
        batch_size = ENC_IN.size(0)
        input_embedding = self.embedding_act(self.embedding(ENC_IN))
        hidden_enc, states_enc = self.encoder(input_embedding)




 

In [70]:
def train_seq2seq(model, train_loader, val_loader, optimizer, num_epochs, device):
    loss_fn = nn.CrossEntropyLoss(ignore_index=2)  # 2 is the padding index

    print("Training of the model has started...")

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        tqdm_loader = tqdm(train_loader, desc=f"Epoch : {epoch + 1} ", ncols=100)

        for batch in tqdm_loader:
            ENC_IN, DEC_IN, DEC_OUT = batch
            ENC_IN = ENC_IN.to(device)
            DEC_IN = DEC_IN.to(device)
            DEC_OUT = DEC_OUT.to(device)
            # Move to device
            decoder_output = model(batch)

            # Reshape for loss
            decoder_output = decoder_output.view(-1, decoder_output.size(-1))
            decoder_target_indices = DEC_OUT.argmax(dim=-1).view(-1)

            loss = loss_fn(decoder_output, decoder_target_indices)
            
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            tqdm_loader.set_postfix({"Train Loss": loss.item()})

        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}] | Train Loss: {avg_loss:.4f}")

        val_loss, val_acc, val_word_acc = validate_seq2seq(model, val_loader, device)
        print(f"Epoch [{epoch+1}/{num_epochs}] | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val Word Acc: {val_word_acc:.4f}")

def validate_seq2seq(model, val_loader, device):
    model.eval()
    total_loss = 0.0
    correct_chars = 0
    total_chars = 0
    correct_words = 0
    total_words = 0
    loss_fn = nn.CrossEntropyLoss(ignore_index=2)

    with torch.no_grad():
        for batch in val_loader:
            ENC_IN, DEC_IN, DEC_OUT = batch
            ENC_IN = ENC_IN.to(device)
            DEC_IN = DEC_IN.to(device)
            DEC_OUT = DEC_OUT.to(device)

            # Forward pass
            decoder_output = model(batch)

            # Compute loss
            vocab_size = decoder_output.size(-1)
            decoder_output = decoder_output.view(-1, vocab_size)
            decoder_target_indices = DEC_OUT.argmax(dim=-1).view(-1)

            loss = loss_fn(decoder_output, decoder_target_indices)
            total_loss += loss.item()

            # Character-wise accuracy
            decoder_output = model.predict_greedy(batch)
            #decoder_output = model.predict_beam_search(batch)

            #print(decoder_output.shape)
            pred_tokens = decoder_output.argmax(dim=2)#.view(DEC_OUT.size(0), DEC_OUT.size(1))
            true_tokens = DEC_OUT.argmax(dim=2)
            #print(pred_tokens.shape)
            #print(true_tokens.shape)
            
            mask = true_tokens != 2  # Ignore PAD tokens
            correct_chars += (pred_tokens[mask] == true_tokens[mask]).sum().item()
            total_chars += mask.sum().item()

            mask = true_tokens != 2  # Ignore PAD tokens
            #print(mask.shape)
            total_words += decoder_output.shape[0]
            #print(pred_tokens[mask].shape)
            chk_words = (mask.int() - (pred_tokens == true_tokens).int())
            chk_words[mask == False] = 0
            correct_words += (chk_words.sum(dim = 1) == 0).sum().item()

            """
            ind = torch.arange(0,64)[chk_words.sum(dim = 1) == 0][0]      
            print(pred_tokens[ind])
            print(true_tokens[ind])
            """

    avg_loss = total_loss / len(val_loader)
    accuracy = correct_chars / total_chars if total_chars > 0 else 0.0
    word_acc = correct_words / total_words if total_words > 0 else 0.0
    return avg_loss, accuracy, word_acc


In [67]:
validate_seq2seq(model=model,val_loader=val_loader, device=device)

tensor([22, 16, 48, 15, 38,  1,  1, 48,  1,  1,
         1, 48,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1])
tensor([22, 16, 48, 15, 38,  1,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2])
tensor([15, 30, 44, 15, 48,  1,  1,  1,  1, 48,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        40,  1, 16, 48,  1, 15, 31, 38])
tensor([15, 30, 44, 15, 48,  1,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2])
tensor([30, 37, 25, 26, 48,  1,  1,  1,  1,  1,
         1,  1, 48,  1, 15, 31, 38, 24, 48,  1,
         1,  1,  1,  1,  1,  1,  1,  1])
tensor([30, 37, 25, 26, 48,  1,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2])
tensor([30, 38, 16, 48, 15, 26, 48,  1,  1,  1,
        48,  1,  1,  1, 48,  1,  1,  1,  1,  1,
        48,  1,  1,  1,  1,  1,  1,  1])
tensor([30, 38, 16, 48, 15, 26, 48,  1,  

KeyboardInterrupt: 

In [71]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Seq2Seq(train_dataset.input_token_index,train_dataset.output_token_index, train_dataset.max_dec_seq_len,embedding_dim=50,hidden_size_enc=64, hidden_size_dec=64, num_layers=2, device=device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_seq2seq(model, train_loader, val_loader, optimizer, num_epochs=20, device=device)

Training of the model has started...


Epoch : 1 : 100%|██████████████████████████████| 1066/1066 [00:32<00:00, 33.19it/s, Train Loss=2.07]


Epoch [1/20] | Train Loss: 2.4201
Epoch [1/20] | Val Loss: 2.1087 | Val Acc: 0.1522 | Val Word Acc: 0.0000


Epoch : 2 : 100%|██████████████████████████████| 1066/1066 [00:32<00:00, 32.84it/s, Train Loss=1.87]


Epoch [2/20] | Train Loss: 1.9526
Epoch [2/20] | Val Loss: 1.8085 | Val Acc: 0.1832 | Val Word Acc: 0.0003


Epoch : 3 : 100%|███████████████████████████████| 1066/1066 [00:31<00:00, 34.09it/s, Train Loss=1.4]


Epoch [3/20] | Train Loss: 1.6340
Epoch [3/20] | Val Loss: 1.4321 | Val Acc: 0.2827 | Val Word Acc: 0.0044


Epoch : 4 : 100%|██████████████████████████████| 1066/1066 [00:29<00:00, 36.23it/s, Train Loss=1.03]


Epoch [4/20] | Train Loss: 1.2928
Epoch [4/20] | Val Loss: 1.1689 | Val Acc: 0.3533 | Val Word Acc: 0.0199


Epoch : 5 : 100%|█████████████████████████████| 1066/1066 [00:30<00:00, 34.73it/s, Train Loss=0.984]


Epoch [5/20] | Train Loss: 1.0744
Epoch [5/20] | Val Loss: 0.9880 | Val Acc: 0.4191 | Val Word Acc: 0.0523


Epoch : 6 : 100%|█████████████████████████████| 1066/1066 [00:29<00:00, 35.62it/s, Train Loss=0.911]


Epoch [6/20] | Train Loss: 0.9132
Epoch [6/20] | Val Loss: 0.8397 | Val Acc: 0.4840 | Val Word Acc: 0.0961


Epoch : 7 : 100%|█████████████████████████████| 1066/1066 [00:27<00:00, 38.67it/s, Train Loss=0.723]


Epoch [7/20] | Train Loss: 0.7896
Epoch [7/20] | Val Loss: 0.7300 | Val Acc: 0.5509 | Val Word Acc: 0.1704


Epoch : 8 : 100%|█████████████████████████████| 1066/1066 [00:27<00:00, 38.73it/s, Train Loss=0.565]


Epoch [8/20] | Train Loss: 0.6964
Epoch [8/20] | Val Loss: 0.6606 | Val Acc: 0.5894 | Val Word Acc: 0.2096


Epoch : 9 : 100%|█████████████████████████████| 1066/1066 [00:27<00:00, 38.47it/s, Train Loss=0.435]


Epoch [9/20] | Train Loss: 0.6254
Epoch [9/20] | Val Loss: 0.6088 | Val Acc: 0.6267 | Val Word Acc: 0.2533


Epoch : 10 : 100%|████████████████████████████| 1066/1066 [00:29<00:00, 35.64it/s, Train Loss=0.705]


Epoch [10/20] | Train Loss: 0.5712
Epoch [10/20] | Val Loss: 0.5695 | Val Acc: 0.6517 | Val Word Acc: 0.2802


Epoch : 11 : 100%|████████████████████████████| 1066/1066 [00:28<00:00, 37.95it/s, Train Loss=0.564]


Epoch [11/20] | Train Loss: 0.5253
Epoch [11/20] | Val Loss: 0.5353 | Val Acc: 0.6770 | Val Word Acc: 0.3117


Epoch : 12 : 100%|████████████████████████████| 1066/1066 [00:25<00:00, 41.30it/s, Train Loss=0.564]


Epoch [12/20] | Train Loss: 0.4879
Epoch [12/20] | Val Loss: 0.5253 | Val Acc: 0.6852 | Val Word Acc: 0.3246


Epoch : 13 : 100%|████████████████████████████| 1066/1066 [00:25<00:00, 41.72it/s, Train Loss=0.421]


Epoch [13/20] | Train Loss: 0.4565
Epoch [13/20] | Val Loss: 0.4828 | Val Acc: 0.7157 | Val Word Acc: 0.3766


Epoch : 14 : 100%|████████████████████████████| 1066/1066 [00:25<00:00, 41.11it/s, Train Loss=0.445]


Epoch [14/20] | Train Loss: 0.4289
Epoch [14/20] | Val Loss: 0.4657 | Val Acc: 0.7255 | Val Word Acc: 0.3825


Epoch : 15 : 100%|████████████████████████████| 1066/1066 [00:25<00:00, 41.48it/s, Train Loss=0.362]


Epoch [15/20] | Train Loss: 0.4081
Epoch [15/20] | Val Loss: 0.4570 | Val Acc: 0.7270 | Val Word Acc: 0.3930


Epoch : 16 : 100%|████████████████████████████| 1066/1066 [00:25<00:00, 41.59it/s, Train Loss=0.345]


Epoch [16/20] | Train Loss: 0.3888
Epoch [16/20] | Val Loss: 0.4553 | Val Acc: 0.7285 | Val Word Acc: 0.3931


Epoch : 17 : 100%|████████████████████████████| 1066/1066 [00:25<00:00, 41.49it/s, Train Loss=0.252]


Epoch [17/20] | Train Loss: 0.3721
Epoch [17/20] | Val Loss: 0.4401 | Val Acc: 0.7385 | Val Word Acc: 0.4179


Epoch : 18 : 100%|████████████████████████████| 1066/1066 [00:25<00:00, 41.48it/s, Train Loss=0.332]


Epoch [18/20] | Train Loss: 0.3584
Epoch [18/20] | Val Loss: 0.4259 | Val Acc: 0.7504 | Val Word Acc: 0.4318


Epoch : 19 : 100%|████████████████████████████| 1066/1066 [00:25<00:00, 41.37it/s, Train Loss=0.273]


Epoch [19/20] | Train Loss: 0.3444
Epoch [19/20] | Val Loss: 0.4168 | Val Acc: 0.7545 | Val Word Acc: 0.4355


Epoch : 20 : 100%|█████████████████████████████| 1066/1066 [00:25<00:00, 41.32it/s, Train Loss=0.35]


Epoch [20/20] | Train Loss: 0.3327
Epoch [20/20] | Val Loss: 0.4215 | Val Acc: 0.7562 | Val Word Acc: 0.4456


In [None]:
for batch in val_loader:
    ENC_IN, DEC_IN, DEC_OUT = batch
    break

torch.set_printoptions(threshold=10000, linewidth=1000)
DEC_OUT[2].argmax(1)

In [None]:
DEC_CHK = model.predict_greedy(batch)
DEC_CHK[2].argmax(1)

In [None]:
def train_seq2seq(model, train_loader, val_loader, optimizer, num_epochs, device):
    model.train()
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        tqdm_loader = tqdm(train_loader, desc=f"Epoch : {epoch} ")
        for (encoder_input, decoder_input, decoder_target) in tqdm_loader:
            # Move data to the appropriate device
            encoder_input = encoder_input.to(device)
            decoder_input = decoder_input.to(device)
            decoder_target = decoder_target.to(device)

            # Forward pass
            hidden = model.encoder(encoder_input)
            decoder_output, _ = model.decoder(decoder_input, hidden)

            # Reshape output and target for loss calculation
            batch_size, dec_seq_len, vocab_size = decoder_output.shape
            decoder_output = decoder_output.view(-1, vocab_size)  # (batch_size * seq_len, vocab)
            decoder_target = decoder_target.view(-1, vocab_size).argmax(dim=1)  # class indices

            loss = loss_fn(decoder_output, decoder_target)

            # Backpropagation and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            tqdm_loader.set_postfix({
            "Train loss (batch)" : loss.item(),
            })

        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch [{epoch + 1}/{num_epochs}] | Loss: {avg_loss:.4f}")

        val_loss, val_acc = validate_seq2seq(model, val_loader, device)
        print(f"Epoch [{epoch + 1}/{num_epochs}] | Val Loss: {val_loss:.4f} | Val Acc : {val_acc:.4f}")


def validate_seq2seq(model, val_loader, device):
    model.eval()
    total_loss = 0.0
    total_chars = 0
    correct_chars = 0
    loss_fn = nn.CrossEntropyLoss()

    with torch.no_grad():
        for encoder_input, _, decoder_target in val_loader:
            encoder_input = encoder_input.to(device)
            decoder_target = decoder_target.to(device)

            # ---- 1. Loss calculation (using teacher forcing only for loss) ----
            hidden = model.encoder(encoder_input)
            batch_size, dec_seq_len, vocab_size = decoder_target.shape

            # Prepare decoder input using start tokens
            decoder_input = torch.zeros(batch_size, dec_seq_len, vocab_size).to(device)
            start_token_idx = model.output_index_token['\t']
            decoder_input[:, 0, start_token_idx] = 1.0

            # Fill decoder input with shifted decoder_target (teacher forcing)
            decoder_input[:, 1:] = decoder_target[:, :-1]

            outputs, _ = model.decoder(decoder_input, hidden)
            outputs_flat = outputs.view(-1, vocab_size)
            targets_flat = decoder_target.view(-1, vocab_size).argmax(dim=1)
            loss = loss_fn(outputs_flat, targets_flat)
            total_loss += loss.item()

            """
            # ---- 2. Accuracy calculation using model.predict() ----
            for i in range(encoder_input.size(0)):
                # Fix: Unsqueeze to make sure the input has batch size dimension
                pred_seq = model.predict(encoder_input[i])  # Make it batch_size=1
                true_seq = ''.join([
                    model.output_index_token[idx.item()]
                    for idx in decoder_target[i].argmax(dim=1)
                    if model.output_index_token[idx.item()] not in ['\t', '\n']
                ])
                min_len = min(len(pred_seq), len(true_seq))
                correct_chars += sum(pred_seq[j] == true_seq[j] for j in range(min_len))
                total_chars += len(true_seq)
            """

            # ---- 2. Accuracy calculation using model.predict() ----
            # Fix: Unsqueeze to make sure the input has batch size dimension
            pred_seq = model.predict_greedy(encoder_input)  # Make it batch_size=1
            pred_max = torch.argmax(pred_seq, dim=2)
            val_max = torch.argmax(decoder_target, dim=2)

            total_correct = torch.sum(pred_max[val_max!=0] == val_max[val_max!=1])
            total_avail = torch.sum(val_max!=1)

            correct_chars += total_correct
            total_chars += total_avail

    avg_loss = total_loss / len(val_loader)
    accuracy = correct_chars / total_chars if total_chars > 0 else 0.0
    print(f"Validation Loss: {avg_loss:.4f}, Character Accuracy: {accuracy:.4f}")
    return avg_loss, accuracy


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Seq2Seq(train_dataset.input_token_index,train_dataset.output_token_index, train_dataset.max_dec_seq_len,64, device).to(device)


validate_seq2seq(model, val_loader, device)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Seq2Seq(train_dataset.input_token_index,train_dataset.output_token_index, 512, device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_seq2seq(model, train_loader, val_loader, optimizer, num_epochs=10, device=device)


In [None]:
model.output_index_token

In [None]:
train_dataset[120][1].shape

In [None]:
str1 = "123abc"
str2 = "cde456"
list1 = []
list1.append(str1)
list1.append(str2)

sorted(set(" ".join(list1)))

In [None]:
df = pd.DataFrame([["hello",5],["hel",3]], columns=["A","B"])
df

In [None]:
def f(str):
    return "\t" + str + "\n"

df["A"] = df["A"].apply(f)
df