In [16]:
import os
import wandb
import torch
import random
import warnings
import numpy as np
import pandas as pd
import torch.nn as nn
import seaborn as sns
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from prettytable import PrettyTable
import matplotlib.font_manager as fm
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

# This used to print malayalam words when heatmap is ploted
malayalam_font = fm.FontProperties(fname='/kaggle/input/malayalam-font/Noto_Serif_Malayalam/NotoSerifMalayalam-VariableFont_wght.ttf')

warnings.filterwarnings("ignore", category=UserWarning)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
os.environ['WANDB_API_KEY'] = '1ffc33d77af0fd022201ec32b81cd0e92cd75821'
wandb.login()

#1ffc33d77af0fd022201ec32b81cd0e92cd75821

True

In [18]:
SOW_token = 0
EOW_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.n_letters = 2 # Count SOW and EOW
        self.letter2index = {}
        self.letter2count = {}
        self.index2letter = {0: "0", 1: "1"}

    def addWord(self, word):
        for ch in word:
            self.addLetter(ch)

    def addLetter(self, ch):
        if ch not in self.letter2index:
            self.letter2index[ch] = self.n_letters
            self.letter2count[ch] = 1
            self.index2letter[self.n_letters] = ch
            self.n_letters += 1
        else:
            self.letter2count[ch] += 1

input_lang = Lang('eng')
output_lang = Lang('mal')
x_train = pd.read_csv('/kaggle/input/malayalam/ml/lexicons/ml.translit.sampled.train.tsv', sep='\t', header=None) #, nrows=1000)
x_val = pd.read_csv('/kaggle/input/malayalam/ml/lexicons/ml.translit.sampled.dev.tsv', sep='\t', header=None)
x_test = pd.read_csv('/kaggle/input/malayalam/ml/lexicons/ml.translit.sampled.test.tsv', sep='\t', header=None)


In [19]:
MAX_LENGTH = 50

def indexesFromWord(lang, word):
    return [lang.letter2index[ch] for ch in word]

def tensorFromWord(lang, word):
    indexes = indexesFromWord(lang, word)
    indexes.append(EOW_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)

def wordFromTensor(lang, tensor):
    s = ""
    for i in tensor:
        if(i.item()==1):
            break
        s += lang.index2letter[i.item()] 
    return s

def get_dataloader(x, input_lang, output_lang, batch_size):
    pairs = list(zip(x[1].values, x[0].values))  # Get list of (input, target) tuples
    n = len(pairs)
    input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
    target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)

    for i, (inp, tgt) in enumerate(pairs):
        if not isinstance(inp, str) or not isinstance(tgt, str):
            continue  # skip malformed entries

        input_lang.addWord(inp)
        output_lang.addWord(tgt)
        inp_ids = indexesFromWord(input_lang, inp)
        tgt_ids = indexesFromWord(output_lang, tgt)
        inp_ids.append(EOW_token)
        tgt_ids.append(EOW_token)
        input_ids[i, :len(inp_ids)] = inp_ids
        target_ids[i, :len(tgt_ids)] = tgt_ids

    input_lengths = (input_ids != 0).astype(np.uint8)
    input_lengths = torch.BoolTensor(input_lengths).to(device)
    
    data = TensorDataset(torch.LongTensor(input_ids).to(device),torch.LongTensor(target_ids).to(device),input_lengths)

    sampler = RandomSampler(data)
    dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size)
    return dataloader

In [20]:
class EncoderRNN(nn.Module):
    def __init__(self, config, input_size):
        super(EncoderRNN, self).__init__()
        
        self.embedding = nn.Embedding(input_size, config.inp_embed_size)
        self.algo = algorithms[config.cell_type](config.inp_embed_size, config.hidden_size, batch_first=True)
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, input):
        output, hidden = self.algo(self.dropout(self.embedding(input)))
        return output, hidden

In [21]:
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, keys, mask=None):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))  # (B, T, 1)
        scores = scores.squeeze(2).unsqueeze(1)  # (B, 1, T)
    
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(1) == 0, -1e9)
    
        weights = F.softmax(scores, dim=-1)  # (B, 1, T)
        context = torch.bmm(weights, keys)   # (B, 1, H)
        return context, weights


class AttnDecoderRNN(nn.Module):
    def __init__(self, config, output_size):
        super(AttnDecoderRNN, self).__init__()
        
        self.dropout_p = config.dropout
        hidden_size = config.hidden_size
        
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.attention = Attention(hidden_size)
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.algo = algorithms[config.cell_type](hidden_size + hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(self.dropout_p)

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None, src_mask=None):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOW_token)
        decoder_hidden = encoder_hidden
        decoder_outputs = []
        attentions = []

        for i in range(MAX_LENGTH):
            decoder_output, decoder_hidden, attn_weights = self.forward_step(decoder_input, decoder_hidden, encoder_outputs) # , src_mask
            decoder_outputs.append(decoder_output)
            attentions.append(attn_weights)

            if target_tensor is not None:
                # Teacher forcing: Feed the target as the next input
                decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
            else:
                # Without teacher forcing: use its own predictions as the next input
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()  # detach from history as input

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        attentions = torch.cat(attentions, dim=1)

        return decoder_outputs, decoder_hidden, attentions


    def forward_step(self, input, hidden, encoder_outputs):
        embedded =  self.dropout(self.embedding(input))

        # Ensure `embedded` has shape [B, 1, H]
        if embedded.dim() == 2:
            embedded = embedded.unsqueeze(1)  # [B, 1, H]
        elif embedded.dim() == 4:
            embedded = embedded.squeeze(1)  # Remove accidental extra dim


         # Unpack hidden state if using LSTM
        if isinstance(hidden, tuple):
            h = hidden[0]  # (num_layers, batch, hidden_size)
        else:
            h = hidden
    
        query = h.permute(1, 0, 2)  # (batch, num_layers, hidden_size)

        context, attn_weights = self.attention(query, encoder_outputs)
        # Ensure `context` has shape [B, 1, H]
        if context.dim() == 2:
            context = context.unsqueeze(1)
        elif context.dim() == 4:
            context = context.squeeze(1)
    
        input_gru = torch.cat((embedded, context), dim=2)

        output, hidden = self.algo(input_gru, hidden)
        output = self.out(output)

        return output, hidden, attn_weights

In [22]:
def train_epoch(dataloader, encoder, decoder, encoder_optimizer,
          decoder_optimizer, criterion, batch_size, teacher_forcing = True):

    total_loss = 0
    correct = 0
    total = 0

    for data in dataloader:
        input_tensor, target_tensor, src_mask = data 
        # input_tensor, target_tensor = data

        target_tensor2 = None
        if (teacher_forcing):
            target_tensor2 = target_tensor

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        
        decoder_outputs, _, attentions = decoder(encoder_outputs, encoder_hidden, target_tensor, src_mask)
        # decoder_outputs, _, attentions = decoder(encoder_outputs, encoder_hidden, target_tensor2)

        outputs = decoder_outputs.view(-1, decoder_outputs.size(-1))
        labels = target_tensor.view(-1)

        loss = criterion(outputs, labels)
        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)

        seq_len = target_tensor.size(1)  # Typically MAX_LENGTH
        actual_batch_size = target_tensor.size(0)
        flat_len = predicted.size(0)
        
        for i in range(actual_batch_size):
            start = i * seq_len
            end = start + seq_len
            if end > flat_len:
                break  # Avoid out-of-bounds
            if torch.equal(predicted[start:end], labels[start:end]):
                correct += 1
        total += actual_batch_size

    return total_loss / len(dataloader), (correct*100) / total

In [23]:
def train(train_dataloader, val_dataloader, test_dataloader, encoder, decoder, n_epochs, config):
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=config.lr)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=config.lr)
    criterion = nn.NLLLoss()

    for epoch in range(1, n_epochs + 1):
        print("Epoch:",epoch)
        loss, acc = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, config.batch_size)
        print("Train: accuracy:", acc, "loss:", loss)
        if(acc<0.01 and epoch>=15):
            break
        val_loss, val_acc = train_epoch(val_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, config.batch_size, teacher_forcing=False)
        print("Validation: accuracy:", val_acc, "Loss:", val_loss)
        wandb.log({'train_accuracy': acc,'train_loss': loss,'val_accuracy': val_acc,'val_loss': val_loss})
        
    test_loss, test_acc = train_epoch(test_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, config.batch_size, teacher_forcing=False)
    print("Test: accuracy:", test_acc, "Loss:", test_loss, "\n")

In [24]:
num_epochs = 20

algorithms = {'rnn': nn.RNN,'gru': nn.GRU,'lstm': nn.LSTM}

sweep_config = {
    'method': 'bayes', 
    'metric': {
      'name': 'val_accuracy',
      'goal': 'maximize'   
    },
    'parameters': {
        'inp_embed_size':{
            'values': [64, 128, 256]
        },
        'dropout': {
            'values': [0.2, 0.3]
        },
        'lr': {
            'values': [0.001, 0.0001]
        },
        'hidden_size': {
            'values': [128, 256]
        },
        'batch_size': {
            'values': [64, 128, 256]
        },
        'cell_type':{
            'values': ['rnn', 'gru', 'lstm']
        }
    }
}

sweep_id = wandb.sweep(sweep=sweep_config, project='DL_A3')

In [25]:
def test():
    with wandb.init() as run:
        config = wandb.config

        # Run name formatting
        wandb.run.name = (
            f"{config.cell_type}-do_{config.dropout}-bs_{config.batch_size}-lr_{config.lr}-"
            f"hs_{config.hidden_size}-emb_{config.inp_embed_size}-")
        
        train_dataloader = get_dataloader(x_train, input_lang, output_lang, wandb.config.batch_size)
        val_dataloader = get_dataloader(x_val, input_lang, output_lang, wandb.config.batch_size)
        test_dataloader = get_dataloader(x_test, input_lang, output_lang, wandb.config.batch_size)
        encoder = EncoderRNN(wandb.config, input_lang.n_letters).to(device)
        decoder = AttnDecoderRNN(wandb.config, output_lang.n_letters).to(device)
        print(input_lang.n_letters, output_lang.n_letters)
        train(train_dataloader, val_dataloader, test_dataloader, encoder, decoder, num_epochs, wandb.config)
        
wandb.agent(sweep_id, function=test, count=1) # calls main function for count number of times. , count=1
wandb.finish() #cojvqj9b sweep_id

In [28]:
def evaluate_and_save_predictions(encoder, decoder):
    results = []
    output_file = 'test_predictions.tsv'

    with torch.no_grad():
        for i in range(len(x_test[0])):
            input_seq = x_test[1][i]
            true_output = x_test[0][i]

            input_tensor = tensorFromWord(input_lang, input_seq)
            encoder_outputs, encoder_hidden = encoder(input_tensor)

            decoder_input = torch.tensor([[SOW_token]], device=device)
            decoder_hidden = encoder_hidden
            decoded_ids = []
            attentions = []

            for _ in range(MAX_LENGTH):
                decoder_output, decoder_hidden, attn_weights = decoder.forward_step(decoder_input, decoder_hidden, encoder_outputs)
                _, topi = decoder_output.topk(1)
                next_token = topi.item()
                if next_token == EOW_token:
                    break
                decoded_ids.append(next_token)
                attentions.append(attn_weights.squeeze(1).cpu().numpy())
                decoder_input = topi.detach()

            predicted_output = ''.join(output_lang.index2letter.get(idx, '?') for idx in decoded_ids)

            results.append((input_seq, true_output, predicted_output, attentions))

    # Save predictions to TSV
    df = pd.DataFrame([(x, y, z) for x, y, z, _ in results],
                      columns=["Input", "Ground Truth", "Prediction"])
    df.to_csv(output_file, index=False, sep='\t', encoding='utf-8-sig')
    print(f"\n📁 All test predictions saved to: {output_file}")
    
    # Plot grid of attention heatmaps for 9 random samples
    print(f"\n🧠 Showing attention heatmaps for 9 random predictions:")
    show_attention_grid(random.sample(results, min(9, len(results))))



def show_attention_grid(samples):
    import matplotlib.pyplot as plt
    import seaborn as sns
    import numpy as np

    fig, axes = plt.subplots(3, 3, figsize=(15, 15))
    axes = axes.flatten()

    for idx, (input_seq, true_output, predicted_output, attentions) in enumerate(samples):
        ax = axes[idx]
        input_chars = list(input_seq)
        output_chars = list(predicted_output)
        attn_matrix = np.array(attentions).squeeze(1)

        sns.heatmap(
            attn_matrix[:len(output_chars), :len(input_chars)],
            xticklabels=input_chars,
            yticklabels=output_chars,
            cmap='viridis',
            ax=ax,
            cbar=False
        )
        ax.set_yticklabels(ax.get_yticklabels(), fontproperties=malayalam_font, rotation=0)
        ax.set_title(f"In: {input_seq}\nGT: {true_output}\nPred: {predicted_output}", 
             fontsize=9, fontproperties=malayalam_font)
        # ax.set_title(f"In: {input_seq}\nGT: {true_output}\nPred: {predicted_output}", fontsize=9)
        ax.set_xlabel('')
        ax.set_ylabel('')

    # Hide any unused subplots
    for j in range(len(samples), 9):
        fig.delaxes(axes[j])

    plt.tight_layout()
    plt.show()


In [29]:
# Best Model
num_epochs = 1

algorithms = {'rnn': nn.RNN,'gru': nn.GRU,'lstm': nn.LSTM}

best_config = {
    'method': 'bayes', 
    'metric': {
      'name': 'val_accuracy',
      'goal': 'maximize'   
    },
    'parameters': {
        'inp_embed_size':{
            'values': [256]
        },
        'dropout': {
            'values': [0.3]
        },
        'lr': {
            'values': [0.001]
        },
        'hidden_size': {
            'values': [256]
        },
        'batch_size': {
            'values': [64]
        },
        'cell_type':{
            'values': ['lstm']
        }
    }
}

sweep_id = wandb.sweep(sweep=best_config, project='DL_A3')


Create sweep with ID: lv2nh1c7
Sweep URL: https://wandb.ai/alandandoor-iit-madras/DL_A3/sweeps/lv2nh1c7


In [None]:
def test():
    with wandb.init() as run:
        config = wandb.config

        # Run name formatting
        wandb.run.name = (
            f"{config.cell_type}-do_{config.dropout}-bs_{config.batch_size}-lr_{config.lr}-"
            f"hs_{config.hidden_size}-emb_{config.inp_embed_size}-")
        
        train_dataloader = get_dataloader(x_train, input_lang, output_lang, wandb.config.batch_size)
        val_dataloader = get_dataloader(x_val, input_lang, output_lang, wandb.config.batch_size)
        test_dataloader = get_dataloader(x_test, input_lang, output_lang, wandb.config.batch_size)
        encoder = EncoderRNN(wandb.config, input_lang.n_letters).to(device)
        decoder = AttnDecoderRNN(wandb.config, output_lang.n_letters).to(device)
        print(input_lang.n_letters, output_lang.n_letters)
        train(train_dataloader, val_dataloader, test_dataloader, encoder, decoder, num_epochs, wandb.config)
        encoder.eval()
        decoder.eval()
        evaluate_and_save_predictions(encoder, decoder)
        
wandb.agent(sweep_id, function=test, count=1) # calls main function for count number of times. , count=1
wandb.finish() #cojvqj9b sweep_id