# Attention Seq2Seq Report Tester Notebook
This notebook is used to take the best model from the sweep retrain the model using appropriate callbacks and then predict on the test set and save it and also create some visualizations if required. Without much details lets get into the assignment.

In [1]:
# Importing the necessary libraries #
# Importing the necessary libraries needed
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms.functional as Fn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split
import os
from tqdm import tqdm
import wandb
wandb.login(key = "5ef7c4bbfa350a2ffd3c198cb9289f544e3a0910")

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/joel/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mae21b105[0m ([33mRough[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
# Data preparation
# Loading the dataset
df_train = pd.read_csv('/home/joel/DA6401_DL/DA6401_A03/ta_lexicons/ta.translit.sampled.train.tsv', sep='\t',  header=None, names=["native","latin","count"])
df_test = pd.read_csv('/home/joel/DA6401_DL/DA6401_A03/ta_lexicons/ta.translit.sampled.test.tsv', sep='\t',  header=None, names=["native","latin","count"])
df_val = pd.read_csv('/home/joel/DA6401_DL/DA6401_A03/ta_lexicons/ta.translit.sampled.dev.tsv', sep='\t',  header=None, names=["native","latin","count"])

In [None]:
# Preparing the dataset for the model to fit #
class Dataset_Tamil(Dataset):
    def __init__(self, dataframe, build_vocab=True, input_token_index=None, output_token_index=None,
                 max_enc_seq_len=0, max_dec_seq_len=0):
        
        # Input variables
        self.input_df = dataframe
        self.input_words = []
        self.output_words = []
        # Characters of the language
        self.input_characters = set()
        self.output_characters = set()

        # Iterating thorough the rows
        for _, row in self.input_df.iterrows():
            input_word = str(row["latin"])
            output_word = "\t" + str(row["native"]) + "\n"
            self.input_words.append(input_word)
            self.output_words.append(output_word)
        
        if build_vocab:
            self.build_vocab()
        else:
            # Token index for sequence building
            self.input_token_index = input_token_index
            self.output_token_index = output_token_index
            # Heuristics lengths for the encoder decoder
            self.max_enc_seq_len = max_enc_seq_len
            self.max_dec_seq_len = max_dec_seq_len

        # Finding the encoder/decoder tokens 
        self.total_encoder_tokens = len(self.input_token_index)
        self.total_decoder_tokens = len(self.output_token_index)

    def build_vocab(self):
        # Building the vocabulary
        self.input_characters = sorted(set(" ".join(self.input_words)))
        self.output_characters = sorted(set(" ".join(self.output_words)))
        # Adding the padding character if not present
        if " " not in self.input_characters:
            self.input_characters.append(" ")
        if " " not in self.output_characters:
            self.output_characters.append(" ")

        # Fitting/Finding the necessary values from training data
        self.input_token_index = {char: i for i, char in enumerate(self.input_characters)}
        self.output_token_index = {char: i for i, char in enumerate(self.output_characters)}

        self.input_token_index_reversed = {i: char for i, char in enumerate(self.input_characters)}
        self.output_token_index_reversed = {i: char for i, char in enumerate(self.output_characters)}

        self.max_enc_seq_len = max(len(txt) for txt in self.input_words)
        self.max_dec_seq_len = max(len(txt) for txt in self.output_words)

    def __len__(self):
        return len(self.input_words)
    
    def __getitem__(self, index):
        input_word = self.input_words[index]
        output_word = self.output_words[index]

        # Finding the input for each stages of the network
        encoder_input = np.zeros((self.max_enc_seq_len, self.total_encoder_tokens), dtype=np.float32)
        decoder_input = np.zeros((self.max_dec_seq_len, self.total_decoder_tokens), dtype=np.float32)
        decoder_output = np.zeros((self.max_dec_seq_len, self.total_decoder_tokens), dtype=np.float32)

        for t, char in enumerate(input_word):
            if char in self.input_token_index:
                encoder_input[t, self.input_token_index[char]] = 1.0
        for t in range(len(input_word), self.max_enc_seq_len):
            encoder_input[t, self.input_token_index[" "]] = 1.0

        for t, char in enumerate(output_word):
            if char in self.output_token_index:
                decoder_input[t, self.output_token_index[char]] = 1.0
                if t > 0:
                    decoder_output[t - 1, self.output_token_index[char]] = 1.0
        # Fill remaining positions with space character
        for t in range(len(output_word), self.max_dec_seq_len):
            decoder_input[t, self.output_token_index[" "]] = 1.0

        for t in range(len(output_word) - 1, self.max_dec_seq_len):
            decoder_output[t, self.output_token_index[" "]] = 1.0

        return (
            torch.from_numpy(encoder_input),
            torch.from_numpy(decoder_input),
            torch.from_numpy(decoder_output)
        )

In [None]:
# Model classes definitions #
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, dropout=0.3, cell_type="RNN", num_layers=1, bi_directional=False):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.cell_type = cell_type.upper()
        self.dropout = dropout
        self.num_layers = num_layers

        if self.cell_type == 'LSTM':
            self.enc = nn.LSTM(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers, bidirectional=bi_directional)
        elif self.cell_type == 'GRU':
            self.enc = nn.GRU(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers, bidirectional=bi_directional)
        else:
            self.enc = nn.RNN(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers, bidirectional=bi_directional)

    def forward(self, x):
        if self.cell_type == "LSTM":
            hidden, (hn, cn) = self.enc(x)
            return hidden, (hn, cn)
        else:
            hidden, out = self.enc(x)
            return hidden, out
        
class Attention_Mechanism(nn.Module):
    def __init__(self, hidden_dim, device="cpu"):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.softmax = nn.Softmax(dim=1)
        self.tanh = nn.Tanh()
        # Creating the matrices for attention calculation
        self.V_att = nn.Parameter(torch.randn(size=(self.hidden_dim, 1), device=device)*0.1)
        self.U_att = nn.Parameter(torch.randn(size=(self.hidden_dim, self.hidden_dim), device=device)*0.1)
        self.W_att = nn.Parameter(torch.randn(size=(self.hidden_dim, self.hidden_dim), device=device)*0.1)

    def forward(self, st_1, c_j, mask):
        # Compute the attention scores and softmax
        """
        st_1 : input of size (bx1xd)
        c_j : input of size (bxLxd)
        """
        #print(st_1.shape, c_j.shape)
        inside = self.tanh(torch.matmul(c_j, self.W_att) + torch.matmul(st_1, self.U_att))
        #print(inside.shape)
        scores = torch.matmul(inside, self.V_att).squeeze(2)
        #print(scores.shape)
        scores[mask] = -torch.inf

        attention = self.softmax(scores)
        return attention
    
class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size, dropout=0.3, cell_type='RNN', num_layers=1):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.cell_type = cell_type.upper()
        self.dropout = dropout
        self.num_layers = num_layers

        if self.cell_type == 'LSTM':
            self.dec = nn.LSTM(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers)
        elif self.cell_type == 'GRU':
            self.dec = nn.GRU(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers)
        else:
            self.dec = nn.RNN(input_size, hidden_size, batch_first=True, dropout=self.dropout, num_layers=self.num_layers)

    def forward(self, x, states):
        if states == None:
            hidden, out = self.dec(x)
            return hidden, out
        elif type(states) == tuple:
            hidden, (hn, cn) = self.dec(x, states)
            return hidden, (hn, cn)
        else:
            hidden, out = self.dec(x, states)
            return hidden, out
        
class Attention_Seq2Seq(nn.Module):
    def __init__(self, input_token_index, output_token_index, max_dec_seq_len, embedding_dim,hidden_size_enc, bi_directional=False,
            nature="train", enc_cell="LSTM", dec_cell="LSTM", num_layers=1,dropout=0.2, device="cpu"):
        super().__init__()

        self.input_index_token = input_token_index
        self.output_index_token = output_token_index
        self.max_dec_seq_len = max_dec_seq_len
        self.nature = nature
        self.enc_cell_type = enc_cell.upper()
        self.dec_cell_type = dec_cell.upper()
        self.num_layers= num_layers
        self.bi_directional = bi_directional
        self.hidden_size_enc = hidden_size_enc
        self.hidden_size_dec = (1 + int(self.bi_directional == True))*hidden_size_enc
        self.embedding = nn.Linear(in_features=len(self.input_index_token), out_features=embedding_dim)
        self.embedding_act = nn.Tanh()
        self.encoder = Encoder(input_size=embedding_dim, hidden_size=hidden_size_enc, dropout=dropout, cell_type=enc_cell, num_layers=num_layers, bi_directional=self.bi_directional).to(device)
        self.attention = Attention_Mechanism(hidden_dim=self.hidden_size_dec)
        self.decoder = Decoder(input_size=len(self.output_index_token)+self.hidden_size_dec, hidden_size=self.hidden_size_dec, dropout=dropout, cell_type=dec_cell, num_layers=num_layers).to(device)
        self.device = device
        self.loss_fn = nn.CrossEntropyLoss()
        self.fc = nn.Linear(in_features=self.hidden_size_dec, out_features=len(output_token_index))

    def forward(self, batch):
        ENC_IN, DEC_IN, DEC_OUT = batch
        ENC_IN = ENC_IN.to(self.device)
        DEC_IN = DEC_IN.to(self.device)

        batch_size = ENC_IN.size(0)
        input_embedding = self.embedding_act(self.embedding(ENC_IN))
        mask_ = torch.argmax(ENC_IN, 2) == 2
        hidden_enc, states_enc = self.encoder(input_embedding)

        # Final matrix
        final_out = torch.zeros(batch_size, self.max_dec_seq_len, len(self.output_index_token), device=self.device)

        # Initial decoder input (with start token)
        in_ = DEC_IN[:, 0:1, :].clone()
        for t in range(self.max_dec_seq_len):
            if t==0:
                out_step, states_dec = self.decoder(torch.cat((in_, hidden_enc[:,-1,:].unsqueeze(1)), dim=2), None)  
            else:
                # input for next input
                in_ = DEC_IN[:, t, :].unsqueeze(1).clone()
                att_scores = self.attention(out_step, hidden_enc, mask_)

                in_ = torch.cat((in_, torch.bmm(att_scores.unsqueeze(1), hidden_enc)), dim=2)
                # Output
                out_step, states_dec = self.decoder(in_, states_dec) 

            logits_step = self.fc(out_step.squeeze(1))         
            final_out[:, t, :] = logits_step
   
        return final_out
    
    def predict_greedy(self, batch):
        ENC_IN, DEC_IN, DEC_OUT = batch
        ENC_IN = ENC_IN.to(self.device)
        DEC_IN = DEC_IN.to(self.device)

        batch_size = ENC_IN.size(0)
        input_embedding = self.embedding_act(self.embedding(ENC_IN))
        mask_ = torch.argmax(ENC_IN, 2) == 2
        hidden_enc, states_enc = self.encoder(input_embedding)

        # Final matrix
        final_out = torch.zeros(batch_size, self.max_dec_seq_len, len(self.output_index_token), device=self.device)

        # Initial decoder input (with start token)
        in_ = torch.zeros(batch_size, 1, len(self.output_index_token), device=self.device)
        in_[:, 0, 0] = 1.0

        for t in range(self.max_dec_seq_len):
            if t==0:
                out_step, states_dec = self.decoder(torch.cat((in_, hidden_enc[:,-1,:].unsqueeze(1)), dim=2), None)  
            else:
                out_step, states_dec = self.decoder(in_, states_dec)  

            logits_step = self.fc(out_step.squeeze(1))            
            final_out[:, t, :] = logits_step

            # Greedy argmax for next input
            top1 = torch.argmax(logits_step, dim=1)              
            in_ = torch.zeros(batch_size, 1, len(self.output_index_token), device=self.device)
            in_[torch.arange(batch_size), 0, top1] = 1.0
            att_scores = self.attention(out_step, hidden_enc, mask_)

            in_ = torch.cat((in_, torch.bmm(att_scores.unsqueeze(1), hidden_enc)), dim=2)
        return final_out

In [None]:
# Fucntion for validation of the model # 
def validate_seq2seq(model, val_loader, device, val_type = "greedy", beam_width=None):
    model.eval()
    total_loss = 0.0
    correct_chars = 0
    total_chars = 0
    correct_words = 0
    total_words = 0
    loss_fn = nn.CrossEntropyLoss(ignore_index=2)

    with torch.no_grad():
        tqdm_progress = tqdm(val_loader, desc="Predicting...")
        for batch in tqdm_progress:
            ENC_IN, DEC_IN, DEC_OUT = batch
            ENC_IN = ENC_IN.to(device)
            DEC_IN = DEC_IN.to(device)
            DEC_OUT = DEC_OUT.to(device)

            # Forward pass
            decoder_output = model(batch)

            # Compute loss
            vocab_size = decoder_output.size(-1)
            decoder_output = decoder_output.view(-1, vocab_size)
            decoder_target_indices = DEC_OUT.argmax(dim=-1).view(-1)

            loss = loss_fn(decoder_output, decoder_target_indices)
            total_loss += loss.item()

            # Character-wise accuracy
            if val_type == "greedy":
                decoder_output = model.predict_greedy(batch)
            else:
                decoder_output = model.predict_beam_search(batch, beam_width=beam_width)

            #print(decoder_output.shape)
            pred_tokens = decoder_output.argmax(dim=2)
            true_tokens = DEC_OUT.argmax(dim=2)
            #print(pred_tokens.shape)
            #print(true_tokens.shape)
            
            mask = true_tokens != 2  # Ignore PAD tokens
            correct_chars += (pred_tokens[mask] == true_tokens[mask]).sum().item()
            total_chars += mask.sum().item()

            mask = true_tokens != 2  # Ignore PAD tokens
            #print(mask.shape)
            total_words += decoder_output.shape[0]
            #print(pred_tokens[mask].shape)
            chk_words = (mask.int() - (pred_tokens == true_tokens).int())
            chk_words[mask == False] = 0
            correct_words += (chk_words.sum(dim = 1) == 0).sum().item()

    avg_loss = total_loss / len(val_loader)
    accuracy = correct_chars / total_chars if total_chars > 0 else 0.0
    word_acc = correct_words / total_words if total_words > 0 else 0.0
    return avg_loss, accuracy, word_acc

In [None]:
# Trainloop
def train_seq2seq(model, train_loader, val_loader, optimizer, num_epochs, device, beam_sizes = [3,5], run=None):
    loss_fn = nn.CrossEntropyLoss(ignore_index=2)  # 2 is the padding index
    max_val_char_acc = 0
    max_val_word_acc = 0
    print("Training of the model has started...")
    counter = 0
    patience = 7
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        tqdm_loader = tqdm(train_loader, desc=f"Epoch : {epoch + 1} ", ncols=100)

        for batch in tqdm_loader:
            ENC_IN, DEC_IN, DEC_OUT = batch
            ENC_IN = ENC_IN.to(device)
            DEC_IN = DEC_IN.to(device)
            DEC_OUT = DEC_OUT.to(device)
            # Move to device
            decoder_output = model(batch)

            # Reshape for loss
            decoder_output = decoder_output.view(-1, decoder_output.size(-1))
            decoder_target_indices = DEC_OUT.argmax(dim=-1).view(-1)

            loss = loss_fn(decoder_output, decoder_target_indices)
            
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            tqdm_loader.set_postfix({"Train Loss": loss.item()})

        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}] | Train Loss: {avg_loss:.4f}")

        val_loss, val_acc, val_word_acc = validate_seq2seq(model, val_loader, device)
        print(f"Epoch [{epoch+1}/{num_epochs}] | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val Word Acc: {val_word_acc:.4f}")

        if run is not None:
            run.log({"train_loss_epoch" : avg_loss, "val_loss_epoch" : val_loss, "val_char_acc" : val_acc, "val_word_acc" : val_word_acc})

        if val_word_acc > max_val_word_acc:
            max_val_char_acc = val_acc
            max_val_word_acc = val_word_acc
            torch.save(model.state_dict(),"Attention_weights.pth")
            counter = 0
        else:
            counter += 1

        if counter > patience:
            break

    if run is not None:
        run.summary["max_val_char_acc"] = max_val_char_acc
        run.summary["max_val_word_acc"] = max_val_word_acc


In [7]:
torch.cuda.empty_cache()
config = {
        "learning_rate" : 0.001,
        "dropout_rnn" : 0.2, 
        "batch_size" :  256,
        "epochs" : 30,
        "embedding_dim" : 256,
        "num_layers" : 1,
        "hidden_size_enc" : 128,
        "enc_cell_type" : "GRU",
        "dec_cell_type" : "RNN",
        "bi_directional" : True,
    }

# Loading the datasets and dataloaders
train_dataset = Dataset_Tamil(df_train)
val_dataset = Dataset_Tamil(df_val, build_vocab=False, input_token_index=train_dataset.input_token_index, 
                            output_token_index=train_dataset.output_token_index, max_enc_seq_len=train_dataset.max_enc_seq_len,
                            max_dec_seq_len=train_dataset.max_dec_seq_len)
test_dataset = Dataset_Tamil(df_test, build_vocab=False, input_token_index=train_dataset.input_token_index, 
                            output_token_index=train_dataset.output_token_index, max_enc_seq_len=train_dataset.max_enc_seq_len,
                            max_dec_seq_len=train_dataset.max_dec_seq_len)

train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Attention_Seq2Seq(input_token_index=train_dataset.input_token_index, output_token_index=train_dataset.output_token_index, max_dec_seq_len=train_dataset.max_dec_seq_len,
                embedding_dim=config["embedding_dim"], hidden_size_enc=config["hidden_size_enc"], bi_directional=config["bi_directional"], enc_cell=config["enc_cell_type"], dec_cell=config["dec_cell_type"], 
                num_layers=config["num_layers"], dropout=config["dropout_rnn"], device=device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])

#train_seq2seq(model, train_loader, val_loader, optimizer, num_epochs=config["epochs"], device=device)



In [8]:
model_best = Attention_Seq2Seq(input_token_index=train_dataset.input_token_index, output_token_index=train_dataset.output_token_index, max_dec_seq_len=train_dataset.max_dec_seq_len,
                embedding_dim=config["embedding_dim"], hidden_size_enc=config["hidden_size_enc"], bi_directional=config["bi_directional"], enc_cell=config["enc_cell_type"], dec_cell=config["dec_cell_type"], 
                num_layers=config["num_layers"], dropout=config["dropout_rnn"], device=device).to(device)

model_best.load_state_dict(torch.load("/home/joel/DA6401_DL/DA6401_A03/Attention_weights.pth", weights_only=True))

<All keys matched successfully>

In [None]:
run = wandb.init(entity="A3_DA6401_DL", project="Attention_RNN", name="Attention weights final chk", config=config)

In [None]:
# Attention wwights
import plotly.graph_objects as go
def log_attention_heatmap(attention, input_tokens, output_tokens, step, name="Attention Heatmap", i=0):
    fig = go.Figure(data=go.Heatmap(
        z=attention,
        colorscale='Blues',
        zmin=0,
        zmax=1
    ))

    fig.update_xaxes(tickmode='array', tickvals=list(range(len(input_tokens))), ticktext=input_tokens, constrain='domain')
    fig.update_yaxes(tickmode='array', tickvals=list(range(len(output_tokens))), ticktext=output_tokens, scaleanchor="x", scaleratio=1, constrain='domain')

    fig.update_layout(xaxis_title="Input Tokens", yaxis_title="Output Tokens", autosize=True, margin=dict(l=50, r=50, t=50, b=50), width=500, height=500, 
        title=dict(text=f'{name} [in : {"".join(input_tokens)}] [out : {"".join(output_tokens)}]', x=0.5, font=dict(size=12)))
    fig.show()

    return fig

In [None]:
TOTAL = [20, 30, 430, 2040, 2555, 3495, 4295, 5875, 6275, 4141, 3254] # Handpicked random images

In [None]:
# Test data attention weights finder
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
i = 0
X = 0
FIGS = []
for batch in test_loader:
    INPUT_TOKEN = []
    OUTPUT_TOKEN = []
    ATT_WEIGHTS = np.zeros((train_dataset.max_dec_seq_len, train_dataset.max_enc_seq_len))
    if i == TOTAL[X]:
        i += 1
        ENC_IN, DEC_IN, DEC_OUT = batch
        ENC_IN = ENC_IN.to(model_best.device)
        DEC_IN = DEC_IN.to(model_best.device)

        batch_size = ENC_IN.size(0)
        input_embedding = model_best.embedding_act(model_best.embedding(ENC_IN))
        mask_ = torch.argmax(ENC_IN, 2) == 2
        hidden_enc, states_enc = model_best.encoder(input_embedding)

        # Final matrix
        final_out = torch.zeros(batch_size, model_best.max_dec_seq_len, len(model_best.output_index_token), device=model_best.device)

        # Initial decoder input (with start token)
        in_ = torch.zeros(batch_size, 1, len(model_best.output_index_token), device=model_best.device)
        in_[:, 0, 0] = 1.0

        for t in range(model_best.max_dec_seq_len):
            if t==0:
                out_step, states_dec = model_best.decoder(torch.cat((in_, hidden_enc[:,-1,:].unsqueeze(1)), dim=2), None) 
            else:
                out_step, states_dec = model_best.decoder(in_, states_dec)

            logits_step = model_best.fc(out_step.squeeze(1))   
            final_out[:, t, :] = logits_step

            # Greedy argmax for next input
            top1 = torch.argmax(logits_step, dim=1)          
            in_ = torch.zeros(batch_size, 1, len(model_best.output_index_token), device=model_best.device)
            in_[torch.arange(batch_size), 0, top1] = 1.0
            att_scores = model_best.attention(out_step, hidden_enc, mask_)

            in_ = torch.cat((in_, torch.bmm(att_scores.unsqueeze(1), hidden_enc)), dim=2)

            ATT_WEIGHTS[t, :] = att_scores.detach().cpu().numpy()
    else:
        i += 1
        continue

    input_word_vec = ENC_IN[0].argmax(1)
    output_pred_vec = final_out[0].argmax(1)
    
    for jx in range(train_dataset.max_dec_seq_len):
        char = train_dataset.output_token_index_reversed[output_pred_vec[jx].item()]
        if char == "\n":
            break
        OUTPUT_TOKEN.append(char)      
    
    for jx in range(train_dataset.max_enc_seq_len):
        char = train_dataset.input_token_index_reversed[input_word_vec[jx].item()]
        if char == " ":
            break
        INPUT_TOKEN.append(char) 
    
    FINAL_ATT = ATT_WEIGHTS[:len(OUTPUT_TOKEN), :len(INPUT_TOKEN)]
    TOTAL_ATT = ATT_WEIGHTS.copy()
    
    FINAL_ATT.shape
    FIG = log_attention_heatmap(FINAL_ATT, INPUT_TOKEN, OUTPUT_TOKEN, step=1, name="Attention Weights Heatmap", i=i)
    FIGS.append(FIG)

    break

print(INPUT_TOKEN)
print(OUTPUT_TOKEN)

In [None]:
run.finish()

In [None]:
# Test data for connectivity
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
i = 0
X = 1
FIGS = []
for batch in test_loader:
    INPUT_TOKEN = []
    OUTPUT_TOKEN = []
    GRADS_final = np.zeros((train_dataset.max_dec_seq_len, train_dataset.max_enc_seq_len))
    if i == TOTAL[X]:
        i += 1
        ENC_IN, DEC_IN, DEC_OUT = batch
        ENC_IN = ENC_IN.to(model_best.device)
        DEC_IN = DEC_IN.to(model_best.device)

        batch_size = ENC_IN.size(0)
        input_embedding = model_best.embedding_act(model_best.embedding(ENC_IN))
        mask_ = torch.argmax(ENC_IN, 2) == 2
        hidden_enc, states_enc = model_best.encoder(input_embedding)

        # Final matrix
        final_out = torch.zeros(batch_size, model_best.max_dec_seq_len, len(model_best.output_index_token), device=model_best.device)

        # Initial decoder input (with start token)
        in_ = torch.zeros(batch_size, 1, len(model_best.output_index_token), device=model_best.device)
        in_[:, 0, 0] = 1.0

        for t in range(model_best.max_dec_seq_len):
            if t==0:
                out_step, states_dec = model_best.decoder(torch.cat((in_, hidden_enc[:,-1,:].unsqueeze(1)), dim=2), None) 
            else:
                out_step, states_dec = model_best.decoder(in_, states_dec)  

            logits_step = model_best.fc(out_step.squeeze(1))          
            final_out[:, t, :] = logits_step

            grads_ = torch.autograd.grad(outputs=logits_step.sum(), inputs=input_embedding, retain_graph=True)[0]
            squeezed_grads = grads_.squeeze(0).norm(dim=1).unsqueeze(0).detach().cpu().numpy()
            GRADS_final[t, :] = squeezed_grads

            # Greedy argmax for next input
            top1 = torch.argmax(logits_step, dim=1)              
            in_ = torch.zeros(batch_size, 1, len(model_best.output_index_token), device=model_best.device)
            in_[torch.arange(batch_size), 0, top1] = 1.0
            att_scores = model_best.attention(out_step, hidden_enc, mask_)

            in_ = torch.cat((in_, torch.bmm(att_scores.unsqueeze(1), hidden_enc)), dim=2)

    else:
        i += 1
        continue

    input_word_vec = ENC_IN[0].argmax(1)
    output_pred_vec = final_out[0].argmax(1)
    
    for jx in range(train_dataset.max_dec_seq_len):
        char = train_dataset.output_token_index_reversed[output_pred_vec[jx].item()]
        if char == "\n":
            break
        OUTPUT_TOKEN.append(char)      
    
    for jx in range(train_dataset.max_enc_seq_len):
        char = train_dataset.input_token_index_reversed[input_word_vec[jx].item()]
        if char == " ":
            break
        INPUT_TOKEN.append(char) 

    break

print(INPUT_TOKEN)
print(OUTPUT_TOKEN)
print(GRADS_final.shape)

['a', 'k', 'a', 'l', 'v', 'a', 'a', 'i', 'v', 'i', 'l']
['அ', 'க', 'ழ', '்', 'வ', 'ா', 'ய', '்', 'வ', 'ி', 'ல', '்']
(28, 30)


In [None]:
# Normalization 
from sklearn.preprocessing import MinMaxScaler
def get_gradient_norms(grad_list, word_in, word_out):
    grad_list = grad_list[:len(word_in), :len(word_out)].copy()
    norms_per_step = np.zeros_like(grad_list)
    for i in range(grad_list.shape[0]):
        scaled = MinMaxScaler().fit_transform(grad_list[i].reshape(-1, 1)).flatten()
        scaled_norm = scaled / scaled.sum()
        norms_per_step[i] = scaled_norm

    return norms_per_step

In [89]:
get_gradient_norms(GRADS_final, INPUT_TOKEN, OUTPUT_TOKEN)

array([[0.7531229 , 0.16131324, 0.02733328, 0.00931772, 0.00210579,
        0.        , 0.00227567, 0.03703563, 0.00749577],
       [0.06023386, 0.04957183, 0.29597427, 0.13271008, 0.15445384,
        0.15883172, 0.10249771, 0.04572667, 0.        ],
       [0.06979467, 0.03745478, 0.15508828, 0.21106213, 0.23021964,
        0.20188429, 0.05587937, 0.03861683, 0.        ],
       [0.06388663, 0.0694043 , 0.10765719, 0.22707103, 0.20969904,
        0.20485047, 0.06111545, 0.0563159 , 0.        ],
       [0.09472936, 0.06320268, 0.12440794, 0.14267857, 0.18454592,
        0.20027521, 0.10644718, 0.08371314, 0.        ],
       [0.0311924 , 0.02207705, 0.03607198, 0.09640151, 0.14684419,
        0.1460138 , 0.22866576, 0.29273331, 0.        ],
       [0.05429193, 0.        , 0.03109674, 0.10606389, 0.17904135,
        0.13401896, 0.19229774, 0.28943444, 0.01375496]])

In [None]:
import plotly.graph_objects as go
import numpy as np

def plot_connectivity_stacked(GRADS_final, INPUT_TOKEN, OUTPUT_TOKEN):
    max_dec_len, max_enc_len = GRADS_final.shape

    # Ensure input and output tokens are padded
    input_chars = INPUT_TOKEN + [''] * (max_enc_len - len(INPUT_TOKEN))
    output_labels = OUTPUT_TOKEN + [''] * (max_dec_len - len(OUTPUT_TOKEN))

    # Create full matrix for plotting (each row = output token step)
    full_matrix = GRADS_final[:len(output_labels), :len(input_chars)]

    # Repeat input_chars for each row to place inside boxes
    repeated_text = np.tile(input_chars, (len(output_labels), 1))

    # Create y-axis labels using output characters
    y_labels = [f'{i}: {char}' if char else f'{i}' for i, char in enumerate(output_labels)]

    fig = go.Figure(
        data=go.Heatmap(z=full_matrix, x=list(range(len(input_chars))), y=y_labels, text=repeated_text, texttemplate="%{text}", textfont={"size": 14, "color": "black"},
            colorscale='Blues', zmid=0, zmin=GRADS_final.min(), zmax=GRADS_final.max(), showscale=True)
    )

    fig.update_layout(
        title=dict(text=f'Gradient Connectivity Matrix [in : {"".join(INPUT_TOKEN)}] [out : {"".join(OUTPUT_TOKEN)}]', x=0.5, font=dict(size=16)),
        xaxis=dict( title="Input Characters", tickmode='array', tickvals=list(range(len(input_chars))), ticktext=input_chars, showgrid=False, scaleanchor='y'),
        yaxis=dict(title="Output Characters", tickmode='array', tickvals=list(range(len(output_labels))), ticktext=y_labels, autorange='reversed',showgrid=False,),
        width=50 * len(input_chars) + 200,
        height=40 * len(output_labels) + 200,
        margin=dict(t=40, l=60, r=40, b=40)
    )

    fig.show()

    return fig


In [91]:
fig = plot_connectivity_stacked(get_gradient_norms(GRADS_final, INPUT_TOKEN, OUTPUT_TOKEN), INPUT_TOKEN, OUTPUT_TOKEN)