In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random
import os
import wandb
from tqdm import tqdm
import re
import matplotlib.pyplot as plt
import pandas as pd
import json

In [3]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_key")

os.environ['WANDB_API_KEY'] = secret_value_0

In [4]:
DEVICE   = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BASE_DIR = '/kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons' # Change 'te' → desired language

In [5]:
class CharacterEmbedding(nn.Module):
    # Creating an embedding layer that maps input character indices to embedding vectors.
    # input_size: number of unique characters (vocabulary size)
    # embedding_dim: size of each embedding vector
    def __init__(self, input_size, embedding_dim):
        super(CharacterEmbedding, self).__init__()
        self.embedding = nn.Embedding(input_size, embedding_dim)

    # Returns corresponding embedding vectors of shape (batch_size, seq_length, embedding_dim)
    def forward(self, input_seq):
        # input_seq: a tensor of character indices, typically of shape (batch_size, seq_length)
        return self.embedding(input_seq)

In [6]:
# EncoderRNN transforms sequences of token IDs into contextual hidden states
# Supports GRU, LSTM, or vanilla RNN cells
# input_size: number of unique tokens
# hidden_size: size of the RNN hidden state
# embedding_dim: size of token embedding vectors
# num_layers: number of stacked recurrent layers
# cell_type: 'GRU', 'LSTM', or 'RNN'
# dropout_p: dropout probability between RNN layers (only if num_layers > 1)
# bidirectional: whether to run the RNN in both forward and backward directions
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, embedding_dim, num_layers=1,cell_type='GRU', dropout_p=0.1, bidirectional=False):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.cell_type = cell_type
        self.bidirectional = bidirectional
        self.directions = 2 if bidirectional else 1
        
        # Embedding layer
        self.embedding = nn.Embedding(input_size, embedding_dim)
        
        # Dropout before the RNN (applied to embeddings)
        self.dropout = nn.Dropout(dropout_p)
        dropout_p = dropout_p if num_layers > 1 else 0
        
        # RNN layer
        if cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_size, num_layers,dropout=dropout_p,bidirectional=bidirectional, batch_first=True)
        elif cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_size, num_layers,dropout=dropout_p,bidirectional=bidirectional, batch_first=True)
        else:  # Default to RNN
            self.rnn = nn.RNN(embedding_dim, hidden_size, num_layers,dropout=dropout_p,bidirectional=bidirectional, nonlinearity='tanh', batch_first=True)

    # Forward pass through the encoder
    def forward(self, input_seq):
        # Input shape: [batch_size, seq_len]
        batch_size = input_seq.size(0)
        
        # Convert indices to embeddings and apply dropout to embeddings
        embedded = self.embedding(input_seq)  # [batch_size, seq_len, embedding_dim]
        embedded = self.dropout(embedded)
        
        # Pass through RNN
        outputs, hidden = self.rnn(embedded)
        
        return outputs, hidden

In [7]:
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.hidden_size = hidden_size
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.rand(hidden_size))
        nn.init.uniform_(self.v, -0.1, 0.1)

    def forward(self, hidden, encoder_outputs):
        # hidden: [batch_size, hidden_size]
        # encoder_outputs: [batch_size, seq_len, hidden_size]

        batch_size = encoder_outputs.size(0)
        seq_len = encoder_outputs.size(1)

        # Repeat hidden state seq_len times to concat with encoder outputs
        hidden = hidden.unsqueeze(1).repeat(1, seq_len, 1)  # [batch_size, seq_len, hidden_size]

        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))  # [batch_size, seq_len, hidden_size]
        energy = energy.transpose(1, 2)  # [batch_size, hidden_size, seq_len]

        v = self.v.repeat(batch_size, 1).unsqueeze(1)  # [batch_size, 1, hidden_size]

        energy = torch.bmm(v, energy).squeeze(1)  # [batch_size, seq_len]

        attn_weights = F.softmax(energy, dim=1)  # [batch_size, seq_len]

        return attn_weights

In [8]:
class DecoderRNNWithAttention(nn.Module):
    def __init__(self, output_size, hidden_size, embedding_dim, num_layers=1, 
                 cell_type='GRU', dropout_p=0.1):
        
        super(DecoderRNNWithAttention, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.cell_type = cell_type
        
        # Embedding layer for target characters
        self.embedding = nn.Embedding(output_size, embedding_dim)
        
        # Dropout applied before RNN
        self.dropout = nn.Dropout(dropout_p)

        dropout_p = dropout_p if num_layers > 1 else 0
        
        # RNN input size is embedding_dim + hidden_size (due to attention context)
        rnn_input_size = embedding_dim + hidden_size
        
        # RNN layer
        if cell_type == 'GRU':
            self.rnn = nn.GRU(rnn_input_size, hidden_size, num_layers, dropout=dropout_p, batch_first=True)
        elif cell_type == 'LSTM':
            self.rnn = nn.LSTM(rnn_input_size, hidden_size, num_layers, dropout=dropout_p, batch_first=True)
        else:
            self.rnn = nn.RNN(rnn_input_size, hidden_size, num_layers, dropout=dropout_p, nonlinearity='tanh', batch_first=True)
            
        # Attention mechanism
        self.attention = Attention(hidden_size)
        
        # Output layer with dropout
        self.out_dropout = nn.Dropout(dropout_p)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, input_char, hidden, encoder_outputs):
        # input_char: [batch_size] (current input token indices)
        # hidden: (h_n, c_n) for LSTM or h_n for GRU/RNN
        # encoder_outputs: [batch_size, seq_len, hidden_size]
        
        batch_size = input_char.size(0)
        
        # Embed input character and apply dropout
        embedded = self.embedding(input_char.squeeze(1)).unsqueeze(1)  # [batch_size, 1, embedding_dim]
        embedded = self.dropout(embedded)
        
        # Get the last hidden state from hidden (handle LSTM tuple)
        if self.cell_type == 'LSTM':
            last_hidden = hidden[0][-1]  # [batch_size, hidden_size]
        else:
            last_hidden = hidden[-1]     # [batch_size, hidden_size]
        
        # Calculate attention weights and context vector
        attn_weights = self.attention(last_hidden, encoder_outputs)  # [batch_size, seq_len]
        
        attn_weights = attn_weights.unsqueeze(1)  # [batch_size, 1, seq_len]
        
        # Compute context vector as weighted sum of encoder outputs
        context = torch.bmm(attn_weights, encoder_outputs)  # [batch_size, 1, hidden_size]
        
        # Concatenate embedded input and context vector
        
        rnn_input = torch.cat((embedded, context), dim=2)  # [batch_size, 1, embedding_dim + hidden_size]
        
        # Pass through RNN
        output, hidden = self.rnn(rnn_input, hidden)
        
        # Output shape: [batch_size, 1, hidden_size]
        output = self.out_dropout(output)
        output = self.out(output.squeeze(1))  # [batch_size, output_size]
        
        return F.log_softmax(output, dim=1), hidden, attn_weights.squeeze(1)

In [9]:
def beam_search_decode(model, src, sos_idx, eos_idx, max_len=30, beam_width=3, device='cuda'):
    model.eval()
    with torch.no_grad():
        # Encode input
        encoder_outputs, encoder_hidden = model.encoder(src)

        # Prepare initial decoder hidden state
        if model.bidirectional:
            if model.cell_type == 'LSTM':
                h_n, c_n = encoder_hidden
                h_dec = torch.zeros(model.decoder.num_layers, 1, model.decoder.hidden_size, device=device)
                c_dec = torch.zeros(model.decoder.num_layers, 1, model.decoder.hidden_size, device=device)
                for layer in range(model.encoder.num_layers):
                    h_combined = torch.cat((h_n[2*layer], h_n[2*layer+1]), dim=1)
                    c_combined = torch.cat((c_n[2*layer], c_n[2*layer+1]), dim=1)
                    h_dec[layer] = model.hidden_transform(h_combined)
                    c_dec[layer] = model.hidden_transform(c_combined)
                decoder_hidden = (h_dec, c_dec)
            else:
                decoder_hidden = torch.zeros(model.decoder.num_layers, 1, model.decoder.hidden_size, device=device)
                for layer in range(model.encoder.num_layers):
                    h_combined = torch.cat((encoder_hidden[2*layer], encoder_hidden[2*layer+1]), dim=1)
                    decoder_hidden[layer] = model.hidden_transform(h_combined)
        else:
            decoder_hidden = encoder_hidden

        # Beam search initialization
        beams = [([sos_idx], 0.0, decoder_hidden)]  # (sequence, cumulative log-prob, hidden)
        completed = []

        for _ in range(max_len):
            new_beams = []
            for seq, score, hidden in beams:
                if seq[-1] == eos_idx:
                    completed.append((seq, score))
                    continue
                input_char = torch.tensor([seq[-1]], device=device)

                # Decoder forward with attention requires encoder_outputs
                output, hidden_new, _attn = model.decoder(input_char, hidden, encoder_outputs)

                log_probs = output  # Already log_softmax, shape [batch_size=1, output_size]
                topk_log_probs, topk_indices = torch.topk(log_probs.squeeze(0), beam_width)
                for k in range(beam_width):
                    next_seq = seq + [topk_indices[k].item()]
                    next_score = score + topk_log_probs[k].item()

                    # Make sure to detach hidden state to avoid graph buildup
                    if model.cell_type == 'LSTM':
                        h_new = tuple(h.detach() for h in hidden_new)
                        new_beams.append((next_seq, next_score, h_new))
                    else:
                        new_beams.append((next_seq, next_score, hidden_new.detach()))

            # Keep top beam_width beams
            beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
            if not beams:
                break

        # Add any remaining beams ending with <eos>
        completed += [(seq, score) for seq, score, _ in beams if seq[-1] == eos_idx]

        # If none ended with <eos>, just take the best
        if not completed:
            completed = beams

        # Sort by score
        completed = sorted(completed, key=lambda x: x[1], reverse=True)
        return completed


In [10]:
# Seq2Seq implements an Encoder and Decoder for end-to-end sequence-to-sequence modeling
# input_size: size of source vocabulary
# output_size: size of target vocabulary
# embedding_dim: dimension of embeddings in both encoder and decoder
# hidden_size: size of hidden states in encoder and decoder (must match for vanilla seq2seq)
# encoder_layers / decoder_layers: number of stacked RNN layers
# cell_type: 'GRU', 'LSTM', or 'RNN'
# dropout_p: dropout probability for embeddings and RNN layers
# bidirectional_encoder: if True, runs encoder bidirectionally and transforms hidden state

class Seq2Seq(nn.Module):
    def __init__(self, input_size, output_size, embedding_dim=256, hidden_size=256,
                 encoder_layers=1, decoder_layers=1, cell_type='GRU', dropout_p=0.2,
                 bidirectional_encoder=False):
        super(Seq2Seq, self).__init__()
        
        self.encoder = EncoderRNN(input_size, hidden_size, embedding_dim,
                                  num_layers=encoder_layers, cell_type=cell_type,
                                  dropout_p=dropout_p, bidirectional=bidirectional_encoder)
        
        self.bidirectional = bidirectional_encoder
        directions = 2 if bidirectional_encoder else 1
        
        if bidirectional_encoder:
            self.hidden_transform = nn.Linear(hidden_size * directions, hidden_size)
        
        self.decoder = DecoderRNNWithAttention(output_size, hidden_size, embedding_dim,
                                               num_layers=decoder_layers, cell_type=cell_type,
                                               dropout_p=dropout_p)
        
        self.cell_type = cell_type

    def _match_decoder_layers(self, hidden, batch_size):
        if hidden.size(0) > self.decoder.num_layers:
            return hidden[:self.decoder.num_layers]
        elif hidden.size(0) < self.decoder.num_layers:
            pad = torch.zeros(self.decoder.num_layers - hidden.size(0),
                              batch_size, self.decoder.hidden_size,
                              device=hidden.device)
            return torch.cat([hidden, pad], dim=0)
        else:
            return hidden

    def forward(self, src, trg, teacher_forcing_ratio=0.5, return_attention=False):
        batch_size = src.size(0)
        trg_len = trg.size(1)
        output_size = self.decoder.output_size
        
        outputs = torch.zeros(batch_size, trg_len, output_size).to(src.device)
        all_attentions = [] if return_attention else None

        encoder_outputs, encoder_hidden = self.encoder(src)
        decoder_hidden = None 
        
        if self.bidirectional:
            if self.cell_type == 'LSTM':
                h_n, c_n = encoder_hidden
                h_dec = torch.zeros(self.decoder.num_layers, batch_size, self.decoder.hidden_size).to(src.device)
                c_dec = torch.zeros(self.decoder.num_layers, batch_size, self.decoder.hidden_size).to(src.device)
                
                for layer in range(self.decoder.num_layers):
                    enc_layer = min(layer, self.encoder.num_layers - 1)
                    h_combined = torch.cat((h_n[2 * enc_layer], h_n[2 * enc_layer + 1]), dim=1)
                    c_combined = torch.cat((c_n[2 * enc_layer], c_n[2 * enc_layer + 1]), dim=1)
                    h_dec[layer] = self.hidden_transform(h_combined)
                    c_dec[layer] = self.hidden_transform(c_combined)
                
                decoder_hidden = (h_dec, c_dec)
            
            else:
                h_n = encoder_hidden
                h_dec = torch.zeros(self.decoder.num_layers, batch_size, self.decoder.hidden_size).to(src.device)
                for layer in range(self.decoder.num_layers):
                    enc_layer = min(layer, self.encoder.num_layers - 1)
                    h_combined = torch.cat((h_n[2 * enc_layer], h_n[2 * enc_layer + 1]), dim=1)
                    h_dec[layer] = self.hidden_transform(h_combined)
                decoder_hidden = h_dec
        
        else:
            if self.cell_type == "LSTM":
                h, c = encoder_hidden
                decoder_hidden = (
                    self._match_decoder_layers(h, batch_size),
                    self._match_decoder_layers(c, batch_size)
                )
            else:
                decoder_hidden = self._match_decoder_layers(encoder_hidden, batch_size)
        
        input_char = trg[:, 0].unsqueeze(1)

        for t in range(1, trg_len):
            output, decoder_hidden, attn_weights = self.decoder(input_char, decoder_hidden, encoder_outputs)
            outputs[:, t, :] = output

            if return_attention:
                all_attentions.append(attn_weights.unsqueeze(1))  # [batch, 1, src_len]

            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1).unsqueeze(1)
            input_char = trg[:, t].unsqueeze(1) if teacher_force else top1

        if return_attention:
            # Shape: [batch_size, trg_len-1, src_len]
            all_attentions = torch.cat(all_attentions, dim=1)
            return outputs, all_attentions

        return outputs

In [11]:
class LexiconDataset(Dataset):
    def __init__(self, path, src_vocab=None, tgt_vocab=None, build_vocab=False):
        self.pairs = []
        with open(path, encoding='utf-8') as f:
            for line in f:
                cols = line.strip().split('\t')
                if len(cols) < 2:
                    continue
                tgt, src = cols[0], cols[1]  # Telugu is first column, romanized second
                self.pairs.append((src, tgt)) # src = romanized, tgt = telugu

        if build_vocab:
            self.src_vocab = {'<pad>':0, '<sos>':1, '<eos>':2, '<unk>':3}
            self.tgt_vocab = {'<pad>':0, '<sos>':1, '<eos>':2, '<unk>':3}
            for rom, dev in self.pairs:
                for c in rom:
                    self.src_vocab.setdefault(c, len(self.src_vocab))
                for c in dev:
                    self.tgt_vocab.setdefault(c, len(self.tgt_vocab))
        else:
            assert src_vocab and tgt_vocab, "Must provide vocabs if not building."
            self.src_vocab, self.tgt_vocab = src_vocab, tgt_vocab

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        rom, dev = self.pairs[idx]
        src_idxs = [self.src_vocab.get(c, self.src_vocab['<unk>']) for c in rom]
        tgt_idxs = [self.tgt_vocab['<sos>']] + [self.tgt_vocab.get(c, self.tgt_vocab['<unk>']) for c in dev] + [self.tgt_vocab['<eos>']]
        return torch.tensor(src_idxs, dtype=torch.long), torch.tensor(tgt_idxs, dtype=torch.long)

def collate_fn(batch):
    """
    Pads all src/tgt sequences in the batch to the max length.
    Returns:
      padded_src: (batch_size, max_src_len)
      padded_tgt: (batch_size, max_tgt_len)
    """
    srcs, tgts = zip(*batch)
    max_src = max(len(s) for s in srcs)
    max_tgt = max(len(t) for t in tgts)

    padded_src = torch.full((len(batch), max_src), 0, dtype=torch.long)
    padded_tgt = torch.full((len(batch), max_tgt), 0, dtype=torch.long)
    for i, (s, t) in enumerate(zip(srcs, tgts)):
        padded_src[i, :len(s)] = s
        padded_tgt[i, :len(t)] = t

    return padded_src, padded_tgt

def get_dataloaders(base_dir, batch_size, build_vocab=False):
    """
    Returns:
      train_loader, val_loader, test_loader,
      src_vocab_size, tgt_vocab_size, pad_index, src_vocab, tgt_vocab
    """
    train_p = os.path.join(base_dir, 'te.translit.sampled.train.tsv')
    dev_p   = os.path.join(base_dir, 'te.translit.sampled.dev.tsv')
    test_p  = os.path.join(base_dir, 'te.translit.sampled.test.tsv')

    train_ds = LexiconDataset(train_p, build_vocab=build_vocab)
    src_vocab, tgt_vocab = train_ds.src_vocab, train_ds.tgt_vocab
    val_ds   = LexiconDataset(dev_p,  src_vocab, tgt_vocab)
    test_ds  = LexiconDataset(test_p, src_vocab, tgt_vocab)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  collate_fn=collate_fn)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    test_loader  = DataLoader(test_ds,  batch_size=1,            shuffle=False, collate_fn=collate_fn)

    return (train_loader, val_loader, test_loader,
            len(src_vocab), len(tgt_vocab), src_vocab['<pad>'],
            src_vocab, tgt_vocab)

In [12]:
class EarlyStopper:
    """Stops a run if the monitored metric doesn’t improve for `patience` steps."""
    def __init__(self, patience=5, min_delta=1e-4):
        self.patience, self.min_delta = patience, min_delta
        self.counter, self.best = 0, None

    def should_stop(self, current):
        if self.best is None or current > self.best + self.min_delta:
            self.best, self.counter = current, 0
        else:
            self.counter += 1
        return self.counter >= self.patience

In [13]:
CHAR2IDX_SRC = {
    "<pad>": 0,
    "<sos>": 1,
    "<eos>": 2,
    "<unk>": 3,
    **{c: i + 4 for i, c in enumerate("abcdefghijklmnopqrstuvwxyz")}
}
IDX2CHAR_SRC = {i: c for c, i in CHAR2IDX_SRC.items()}

# Load data, build vocabs, and create reverse target-char map
train_loader, val_loader, test_loader, src_size, tgt_size, pad_idx, src_vocab, tgt_vocab = get_dataloaders(
    BASE_DIR, batch_size=64, build_vocab=True
)

IDX2CHAR_TGT = {idx: ch for ch, idx in tgt_vocab.items()}  # Map decoder indices back to Telugu chars

# Model, optimizer, loss, early stopping

model = Seq2Seq(
    input_size=src_size,
    output_size=tgt_size,
    embedding_dim=64,
    hidden_size=128,
    encoder_layers=3,
    decoder_layers=3,
    cell_type='GRU',  # or 'GRU' or 'RNN'
    dropout_p=0.3,
    bidirectional_encoder=False
).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.NLLLoss(ignore_index=pad_idx)
stopper = EarlyStopper(patience=5)
best_val_acc = 0.0

# Training Loop

for epoch in range(1, 11):
    model.train()
    total_loss = 0.0
    for src, tgt in tqdm(train_loader, desc=f"[Epoch {epoch}] Training", leave=False):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        optimizer.zero_grad()
        out = model(src, tgt, teacher_forcing_ratio=0.7)
        loss = criterion(out.view(-1, tgt_size), tgt.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Validation: sequence-level accuracy
    model.eval()
    correct_seqs, total_seqs = 0, 0
    with torch.no_grad():
        for src, tgt in tqdm(val_loader, desc=f"[Epoch {epoch}] Validation", leave=False):
            src, tgt = src.to(DEVICE), tgt.to(DEVICE)
            out = model(src, tgt, teacher_forcing_ratio=0.0)
            preds = out.argmax(dim=2)
            for pred_seq, true_seq in zip(preds, tgt):
                # Remove <sos> and padding tokens for comparison
                pred_tokens = pred_seq[1:][true_seq[1:] != pad_idx]
                true_tokens = true_seq[1:][true_seq[1:] != pad_idx]
                if torch.equal(pred_tokens, true_tokens):
                    correct_seqs += 1
                total_seqs += 1

    val_acc = correct_seqs / total_seqs
    print(f"[Epoch {epoch}] Loss: {total_loss:.4f} | Val Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        # Optionally save model checkpoint here
    elif stopper.should_stop(val_acc):
        print("Early stopping triggered.")
        break

# Final Test Evaluation
model.eval()
correct_seqs, total_seqs = 0, 0
all_preds, all_trues = [], []

with torch.no_grad():
    for src, tgt in tqdm(test_loader, desc="Final Test Eval", leave=False):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        out = model(src, tgt, teacher_forcing_ratio=0.0)
        preds = out.argmax(dim=2)
        for pred_seq, true_seq in zip(preds, tgt):
            pred_tokens = pred_seq[1:][true_seq[1:] != pad_idx]
            true_tokens = true_seq[1:][true_seq[1:] != pad_idx]
            if torch.equal(pred_tokens, true_tokens):
                correct_seqs += 1
            total_seqs += 1

            all_preds.append(pred_tokens)
            all_trues.append(true_tokens)

test_acc = correct_seqs / total_seqs
print(f"\n Final Test Accuracy: {test_acc:.4f}")


# Sample Predictions in Telugu
print("\n Sample Predictions:")
sample_indices = random.sample(range(len(all_preds)), min(10, len(all_preds)))
for idx in sample_indices:
    pred_str = ''.join([IDX2CHAR_TGT[token.item()] for token in all_preds[idx]])
    true_str = ''.join([IDX2CHAR_TGT[token.item()] for token in all_trues[idx]])
    correctness = "true" if pred_str == true_str else "false"
    print(f"  True : {true_str}\n  Pred : {pred_str}  {correctness}\n")

                                                                     

[Epoch 1] Loss: 1802.5624 | Val Acc: 0.0947


                                                                     

[Epoch 2] Loss: 846.5610 | Val Acc: 0.2847


                                                                     

[Epoch 3] Loss: 610.7603 | Val Acc: 0.3718


                                                                     

[Epoch 4] Loss: 507.3234 | Val Acc: 0.4181


                                                                     

[Epoch 5] Loss: 449.8098 | Val Acc: 0.4552


                                                                     

[Epoch 6] Loss: 408.2083 | Val Acc: 0.4774


                                                                     

[Epoch 7] Loss: 379.0115 | Val Acc: 0.4894


                                                                     

[Epoch 8] Loss: 354.4138 | Val Acc: 0.5112


                                                                     

[Epoch 9] Loss: 339.5995 | Val Acc: 0.5036


                                                                      

[Epoch 10] Loss: 319.2981 | Val Acc: 0.5163


                                                                     


 Final Test Accuracy: 0.3598

 Sample Predictions:
  True : తైత్తిరీయ<eos>
  Pred : తైట్రీరీయా  false

  True : లీలావతి<eos>
  Pred : లేవతి<eos><eos><eos>  false

  True : కలదీ<eos>
  Pred : కలాది  false

  True : జిగురు<eos>
  Pred : జిగురు<eos>  true

  True : అన్నింటినీ<eos>
  Pred : అన్నింటిని<eos>  false

  True : ఛ<eos>
  Pred : <eos><eos>  false

  True : రామాయణంలో<eos>
  Pred : రామాయణంలో<eos>  true

  True : స్వరములు<eos>
  Pred : స్వరములు<eos>  true

  True : వ్యవహరించే<eos>
  Pred : వ్యవహరించే<eos>  true

  True : దిగబడిన<eos>
  Pred : దిగబడిన<eos>  true





In [13]:
# Define the sweep configuration
sweep_config = {
  'method': 'bayes',
  'metric':     {'name': 'val_accuracy', 'goal': 'maximize'},
  'early_terminate': {
      'type': 'hyperband', 'min_iter': 2, 'max_iter': 8, 's': 2
  },
  'parameters': {
    'embedding_dim': {'values': [16, 32, 64, 256]},
    'hidden_size':    {'values': [16, 32, 64, 256]},
    'encoder_layers': {'values': [1, 2, 3]},
    'decoder_layers': {'values': [1, 2, 3]},
    'cell_type':      {'values': ['RNN','GRU','LSTM']},
    'dropout_p':      {'values': [0.2, 0.3, 0.4]},
    'beam_width':     {'values': [1, 3, 5]},
    'teacher_forcing_ratio' : {'values': [0.0, 0.3, 0.5,0.7,1.0]}
  }
}

In [14]:
def train():
    with wandb.init():
        cfg = wandb.config
        run_name = (
            f"emb{cfg.embedding_dim}_hid{cfg.hidden_size}"
            f"_enc{cfg.encoder_layers}_dec{cfg.decoder_layers}"
            f"_{cfg.cell_type.lower()}_do{int(cfg.dropout_p*100)}"
            f"_beam{cfg.beam_width}_tf{int(cfg.teacher_forcing_ratio*100)}"
        )
        wandb.run.name = run_name  # Update name after init

        # Load data + build vocab
        train_loader, val_loader, test_loader, src_size, tgt_size, pad_idx, src_vocab, tgt_vocab = get_dataloaders(
            BASE_DIR, batch_size=64, build_vocab=True
        )
        IDX2CHAR_TGT = {idx: ch for ch, idx in tgt_vocab.items()}

        # Initialize model
        model = Seq2Seq(
            input_size=src_size,
            output_size=tgt_size,
            embedding_dim=cfg.embedding_dim,
            hidden_size=cfg.hidden_size,
            encoder_layers=cfg.encoder_layers,
            decoder_layers=cfg.decoder_layers,
            cell_type=cfg.cell_type,
            dropout_p=cfg.dropout_p,
            bidirectional_encoder=False
        ).to(DEVICE)

        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        criterion = nn.NLLLoss(ignore_index=pad_idx)
        stopper = EarlyStopper(patience=5)
        best_val_acc = 0.0

        for epoch in range(1, 11):
            model.train()
            total_loss = 0.0
            for src, tgt in tqdm(train_loader, desc=f"[Epoch {epoch}] Training", leave=False):
                src, tgt = src.to(DEVICE), tgt.to(DEVICE)
                optimizer.zero_grad()
                out = model(src, tgt, teacher_forcing_ratio=cfg.teacher_forcing_ratio)
                loss = criterion(out.view(-1, tgt_size), tgt.view(-1))
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            # Validation
            model.eval()
            correct_seqs, total_seqs = 0, 0
            with torch.no_grad():
                for src, tgt in tqdm(val_loader, desc=f"[Epoch {epoch}] Validation", leave=False):
                    src, tgt = src.to(DEVICE), tgt.to(DEVICE)
                    out = model(src, tgt, teacher_forcing_ratio=0.0)
                    preds = out.argmax(dim=2)
                    for pred_seq, true_seq in zip(preds, tgt):
                        pred_tokens = pred_seq[1:][true_seq[1:] != pad_idx]
                        true_tokens = true_seq[1:][true_seq[1:] != pad_idx]
                        if torch.equal(pred_tokens, true_tokens):
                            correct_seqs += 1
                        total_seqs += 1

            val_acc = correct_seqs / total_seqs
            wandb.log({'epoch': epoch, 'val_accuracy': val_acc, 'train_loss': total_loss})

            print(f"[Epoch {epoch}] Train Loss: {total_loss:.4f} | Val Acc: {val_acc:.4f}")

            if val_acc > best_val_acc:
                best_val_acc = val_acc
            elif stopper.should_stop(val_acc):
                print("Early stopping triggered.")
                break

        # Final test evaluation
        model.eval()
        correct_seqs, total_seqs = 0, 0
        with torch.no_grad():
            for src, tgt in tqdm(test_loader, desc="Final Test Eval", leave=False):
                src, tgt = src.to(DEVICE), tgt.to(DEVICE)
                out = model(src, tgt, teacher_forcing_ratio=0.0)
                preds = out.argmax(dim=2)
                for pred_seq, true_seq in zip(preds, tgt):
                    pred_tokens = pred_seq[1:][true_seq[1:] != pad_idx]
                    true_tokens = true_seq[1:][true_seq[1:] != pad_idx]
                    if torch.equal(pred_tokens, true_tokens):
                        correct_seqs += 1
                    total_seqs += 1

        test_acc = correct_seqs / total_seqs
        print(f"\nFinal Test Accuracy: {test_acc:.4f}")
        wandb.log({'final_test_accuracy': test_acc})

In [None]:
sweep_id = wandb.sweep(sweep_config, project='da6401_assignment3')
wandb.agent(sweep_id, function=train, count=100)

In [None]:
CHAR2IDX_SRC = {
    "<pad>": 0,
    "<sos>": 1,
    "<eos>": 2,
    "<unk>": 3,
    **{c: i + 4 for i, c in enumerate("abcdefghijklmnopqrstuvwxyz")}
}
IDX2CHAR_SRC = {i: c for c, i in CHAR2IDX_SRC.items()}

# Load data, build vocabs, and create reverse target-char map
train_loader, val_loader, test_loader, src_size, tgt_size, pad_idx, src_vocab, tgt_vocab = get_dataloaders(
    BASE_DIR, batch_size=64, build_vocab=True
)

IDX2CHAR_TGT = {idx: ch for ch, idx in tgt_vocab.items()}  # Map decoder indices back to Telugu chars

# Model, optimizer, loss, early stopping

# parameters of best model
best_model = Seq2Seq(
    input_size=src_size,
    output_size=tgt_size,
    embedding_dim=64,
    hidden_size=256,
    encoder_layers=3,
    decoder_layers=1,
    cell_type='LSTM',  # or 'GRU' or 'RNN'
    dropout_p=0.4,
    bidirectional_encoder=False
).to(DEVICE)

optimizer = torch.optim.Adam(best_model.parameters(), lr=1e-3)
criterion = nn.NLLLoss(ignore_index=pad_idx)
stopper = EarlyStopper(patience=5)
best_val_acc = 0.0

# Training Loop

for epoch in range(1, 11):
    best_model.train()
    total_loss = 0.0
    for src, tgt in tqdm(train_loader, desc=f"[Epoch {epoch}] Training", leave=False):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        optimizer.zero_grad()
        out = best_model(src, tgt, teacher_forcing_ratio=1.0)
        loss = criterion(out.view(-1, tgt_size), tgt.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Validation: sequence-level accuracy
    best_model.eval()
    correct_seqs, total_seqs = 0, 0
    with torch.no_grad():
        for src, tgt in tqdm(val_loader, desc=f"[Epoch {epoch}] Validation", leave=False):
            src, tgt = src.to(DEVICE), tgt.to(DEVICE)
            out = best_model(src, tgt, teacher_forcing_ratio=0.0)
            preds = out.argmax(dim=2)
            for pred_seq, true_seq in zip(preds, tgt):
                # Remove <sos> and padding tokens for comparison
                pred_tokens = pred_seq[1:][true_seq[1:] != pad_idx]
                true_tokens = true_seq[1:][true_seq[1:] != pad_idx]
                if torch.equal(pred_tokens, true_tokens):
                    correct_seqs += 1
                total_seqs += 1

    val_acc = correct_seqs / total_seqs
    print(f"[Epoch {epoch}] Loss: {total_loss:.4f} | Val Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        # Optionally save model checkpoint here
    elif stopper.should_stop(val_acc):
        print("Early stopping triggered.")
        break

# Final Test Evaluation
best_model.eval()
correct_seqs, total_seqs = 0, 0
all_preds, all_trues = [], []

with torch.no_grad():
    for src, tgt in tqdm(test_loader, desc="Final Test Eval", leave=False):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        out = best_model(src, tgt, teacher_forcing_ratio=0.0)
        preds = out.argmax(dim=2)
        for pred_seq, true_seq in zip(preds, tgt):
            pred_tokens = pred_seq[1:][true_seq[1:] != pad_idx]
            true_tokens = true_seq[1:][true_seq[1:] != pad_idx]
            if torch.equal(pred_tokens, true_tokens):
                correct_seqs += 1
            total_seqs += 1

            all_preds.append(pred_tokens)
            all_trues.append(true_tokens)

test_acc = correct_seqs / total_seqs
print(f"\n Final Test Accuracy (Exact Word match): {test_acc:.4f}")

In [16]:
romanized_test_words = []
with open(os.path.join(BASE_DIR, 'te.translit.sampled.test.tsv'), "r", encoding="utf-8") as f:
    for line in f:
        telugu, roman, _ = line.strip().split()
        romanized_test_words.append(roman)

best_model.eval()
samples = []

with torch.no_grad():
    for i, (src, tgt) in enumerate(test_loader):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        out = best_model(src, tgt, teacher_forcing_ratio=0.0)
        preds = out.argmax(dim=2)

        for b in range(src.shape[0]):
            pred = ''.join(
                IDX2CHAR_TGT[idx.item()] 
                for idx in preds[b][1:] 
                if idx.item() not in (pad_idx,)
            )
            true = ''.join(
                IDX2CHAR_TGT[idx.item()] 
                for idx in tgt[b][1:] 
                if idx.item() not in (pad_idx,)
            )

            romanized = romanized_test_words[i * src.shape[0] + b]

            samples.append({
                'Input': romanized,
                'True Telugu'             : true,
                'Predicted Telugu'        : pred
            })

# Sample and show table
subset = random.sample(samples, min(10, len(samples)))
df = pd.DataFrame(subset)
print(df.to_markdown(index=False))

| Input            | True Telugu   | Predicted Telugu   |
|:-----------------|:--------------|:-------------------|
| gorinka          | గోరింక<eos>      | గొరింక్                |
| palakadaaniki    | పలకడానికి<eos>   | పలకడానికి<eos>        |
| lagnastha        | లగ్నస్థ<eos>    | లగ్నస్థ<eos>         |
| jaanapadha       | జానపద<eos>     | జానపద్               |
| poenichchadu     | పోనిచ్చాడు<eos>    | పోనిచ్చాడు<eos>         |
| uttharayanam     | ఉత్తరాయణము<eos>  | ఉత్తరాయణం<eos><eos>   |
| kemerala         | కెమెరాల<eos>     | కేమలరా<eos><eos>     |
| sambhaashanalalo | సంభాషణలలో<eos>   | సంభాషణలలో<eos>        |
| caryalatoo       | చర్యలతో<eos>    | చర్యలతో<eos>         |
| thaati           | తాటి<eos>       | తాతి<eos>            |


In [17]:
def highlight_pred(row):
    """
    Returns a list of CSS styles, one per column, 
    coloring the 'Predicted Telugu' cell green if correct else red.
    """
    styles = [''] * len(row)
    # Find the index of the Predicted column
    pred_idx = list(row.index).index('Predicted Telugu')
    if row['Predicted Telugu'] == row['True Telugu']:
        styles[pred_idx] = 'background-color: #c8e6c9; font-weight: bold;'  # light green
    else:
        styles[pred_idx] = 'background-color: #f8d7da; font-weight: bold;'  # light red
    return styles

# Apply to a random sample of 10 rows
subset = df.sample(n=min(10, len(df))).reset_index(drop=True)

styled = (
    subset.style
          .apply(highlight_pred, axis=1)
          .set_table_styles([
              # Center all text
              {'selector': 'td, th',
               'props': [('text-align', 'center'), ('padding', '6px')]},
              # Header style
              {'selector': 'th',
               'props': [('background-color', '#4F81BD'),
                         ('color', 'white'),
                         ('font-weight', 'bold'),
                         ('padding', '8px')]}
          ])
          .set_caption(" Sample Transliteration Predictions (Green = Correct, Red = Wrong) ")
)

# Display in a Jupyter/HTML context
display(styled)

Unnamed: 0,Input,True Telugu,Predicted Telugu
0,sambhaashanalalo,సంభాషణలలో,సంభాషణలలో
1,poenichchadu,పోనిచ్చాడు,పోనిచ్చాడు
2,jaanapadha,జానపద,జానపద్
3,uttharayanam,ఉత్తరాయణము,ఉత్తరాయణం
4,gorinka,గోరింక,గొరింక్
5,thaati,తాటి,తాతి
6,kemerala,కెమెరాల,కేమలరా
7,lagnastha,లగ్నస్థ,లగ్నస్థ
8,palakadaaniki,పలకడానికి,పలకడానికి
9,caryalatoo,చర్యలతో,చర్యలతో


In [20]:
# Ensure the output folder exists
os.makedirs("predictions_attention", exist_ok=True)

df1 = pd.DataFrame(samples)
df1.to_csv("predictions_attention/predictions_attention.csv", index=False, encoding="utf-8-sig")

print(" Saved all predictions to predictions_attention/predictions_attention.csv")

 Saved all predictions to predictions_attention/predictions_attention.csv


In [46]:
wandb.init(project="da6401_assignment3")

In [53]:
from matplotlib import rcParams

# Set the font family to Noto Sans Telugu
rcParams['font.family'] = ['Noto Sans', 'Noto Sans Telugu', 'sans-serif']

In [54]:
# Ensure output directory exists
os.makedirs("predictions_vanilla/attention_maps", exist_ok=True)

# Put model in eval mode
best_model.eval()

wandb_images = []
num_samples = 12
samples_collected = 0

with torch.no_grad():
    for src, tgt in test_loader:
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        out, attn_weights = best_model(src, tgt, return_attention=True)

        for b in range(src.size(0)):
            # Decode source and target tokens
            src_tokens = [IDX2CHAR_SRC[i.item()] for i in src[b] if i.item() != pad_idx]
            tgt_tokens = [IDX2CHAR_TGT[i.item()] for i in tgt[b][1:] if i.item() != pad_idx]

            # Get predicted tokens
            pred_tokens = out.argmax(dim=2)[b][1:len(tgt_tokens)+1]
            pred_chars = [IDX2CHAR_TGT[i.item()] for i in pred_tokens]

            # Attention weights: [pred_len, src_len]
            attn = attn_weights[b][:len(pred_chars), :len(src_tokens)]

            # Plot heatmap
            fig, ax = plt.subplots(figsize=(6, 4))
            sns.heatmap(attn.cpu().numpy(), 
                        xticklabels=src_tokens, 
                        yticklabels=pred_chars, 
                        cmap='viridis', 
                        cbar=False, 
                        annot=False, 
                        linewidths=0.5, 
                        ax=ax)
            ax.set_xlabel("Input")
            ax.set_ylabel("Predicted Output (Telugu)")
            ax.set_title(f"Attention Heatmap {samples_collected + 1}")
            plt.tight_layout()

            # Save figure
            path = f"predictions_vanilla/attention_maps/sample_{samples_collected + 1}.png"
            fig.savefig(path)
            plt.close(fig)

            # Log to wandb
            wandb_images.append(wandb.Image(path, caption=f"Sample {samples_collected + 1}"))

            samples_collected += 1
            if samples_collected >= num_samples:
                break
        if samples_collected >= num_samples:
            break

# Optionally log all images to wandb
wandb.log({"attention_maps": wandb_images})


In [55]:
wandb.finish()

In [75]:
wandb.init(project="da6401_assignment3")

In [76]:
best_model.eval()
all_samples = []

# First, collect all model outputs
with torch.no_grad():
    for i, (src, tgt) in enumerate(test_loader):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        out, attn_weights = best_model(src, tgt, return_attention=True)
        batch_size = src.size(0)

        for b in range(batch_size):
            romanized = romanized_test_words[i * batch_size + b]
            src_tokens = list(romanized)
            tgt_tokens = [IDX2CHAR_TGT[idx.item()] for idx in tgt[b][1:] if idx.item() != pad_idx]
            pred_tokens = out.argmax(dim=2)[b][1:len(tgt_tokens)+1]
            pred_chars = [IDX2CHAR_TGT[idx.item()] for idx in pred_tokens]

            attn = attn_weights[b][:len(pred_chars), :len(src_tokens)].cpu().numpy().tolist()

            all_samples.append((src_tokens, pred_chars, attn))

# Select 10 random samples
random_samples = random.sample(all_samples, 10)

# Build HTML blocks
html_blocks = []
for sample_count, (src_tokens, pred_chars, attn) in enumerate(random_samples):
    input_tokens_js = json.dumps(src_tokens, ensure_ascii=False)
    output_tokens_js = json.dumps(pred_chars, ensure_ascii=False)
    attention_js = json.dumps(attn)

    html_block = f"""
    <div style="margin-bottom: 50px;">
      <h2>Sample {sample_count + 1}</h2>
      <div><strong>Input (English):</strong></div>
      <div id="input-tokens-{sample_count}"></div>
      <div><strong>Predicted Output (Telugu):</strong></div>
      <div id="output-tokens-{sample_count}"></div>
      <script>
        const inputTokens_{sample_count} = {input_tokens_js};
        const outputTokens_{sample_count} = {output_tokens_js};
        const attention_{sample_count} = {attention_js};

        const inputDiv_{sample_count} = d3.select("#input-tokens-{sample_count}");
        const outputDiv_{sample_count} = d3.select("#output-tokens-{sample_count}");

        inputTokens_{sample_count}.forEach((token, i) => {{
          inputDiv_{sample_count}.append("span")
            .attr("class", "token input")
            .attr("id", "input-{sample_count}-" + i)
            .text(token);
        }});

        outputTokens_{sample_count}.forEach((token, i) => {{
          outputDiv_{sample_count}.append("span")
            .attr("class", "token output")
            .text(token)
            .on("mouseover", () => {{
              d3.selectAll(".token.input").style("background-color", "#fff");
              attention_{sample_count}[i].forEach((score, j) => {{
                const color = d3.interpolateOranges(score);
                d3.select("#input-{sample_count}-" + j).style("background-color", color);
              }});
            }})
            .on("mouseout", () => {{
              d3.selectAll(".token.input").style("background-color", "#fff");
            }});
        }});
      </script>
    </div>
    """
    html_blocks.append(html_block)

# Full HTML document
full_html = f"""
<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8" />
  <title>Attention Visualizations</title>
  <script src="https://d3js.org/d3.v7.min.js"></script>
  <style>
    body {{ font-family: Arial, sans-serif; margin: 30px; }}
    .token {{
      display: inline-block;
      padding: 8px 12px;
      margin: 3px;
      border-radius: 5px;
      border: 1px solid #ccc;
      font-size: 20px;
      cursor: pointer;
      user-select: none;
      transition: background-color 0.3s;
    }}
  </style>
</head>
<body>
  <h1>Random Attention Visualizations (10 Samples)</h1>
  {''.join(html_blocks)}
</body>
</html>
"""

# Log to WandB
wandb.log({"attention_visualizations_random_10": wandb.Html(full_html)})

In [77]:
wandb.finish()