In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random
import os
import wandb
from tqdm import tqdm
import re
import matplotlib.pyplot as plt
import pandas as pd

In [4]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_key")

os.environ['WANDB_API_KEY'] = secret_value_0

In [5]:
# DEVICE
DEVICE   = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BASE_DIR = '/kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons' # Change 'te' → desired language

In [6]:
class CharacterEmbedding(nn.Module):
    # Creating an embedding layer that maps input character indices to embedding vectors.
    # input_size: number of unique characters (vocabulary size)
    # embedding_dim: size of each embedding vector
    def __init__(self, input_size, embedding_dim):
        super(CharacterEmbedding, self).__init__()
        self.embedding = nn.Embedding(input_size, embedding_dim)

    # Returns corresponding embedding vectors of shape (batch_size, seq_length, embedding_dim)
    def forward(self, input_seq):
        # input_seq: a tensor of character indices, typically of shape (batch_size, seq_length)
        return self.embedding(input_seq)

In [7]:
# EncoderRNN transforms sequences of token IDs into contextual hidden states
# Supports GRU, LSTM, or vanilla RNN cells
# input_size: number of unique tokens
# hidden_size: size of the RNN hidden state
# embedding_dim: size of token embedding vectors
# num_layers: number of stacked recurrent layers
# cell_type: 'GRU', 'LSTM', or 'RNN'
# dropout_p: dropout probability between RNN layers (only if num_layers > 1)
# bidirectional: whether to run the RNN in both forward and backward directions
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, embedding_dim, num_layers=1,cell_type='GRU', dropout_p=0.1, bidirectional=False):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.cell_type = cell_type
        self.bidirectional = bidirectional
        self.directions = 2 if bidirectional else 1
        
        # Embedding layer
        self.embedding = nn.Embedding(input_size, embedding_dim)
        
        # Dropout before the RNN (applied to embeddings)
        self.dropout = nn.Dropout(dropout_p)
        dropout_p = dropout_p if num_layers > 1 else 0
        
        # RNN layer
        if cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_size, num_layers,dropout=dropout_p,bidirectional=bidirectional, batch_first=True)
        elif cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_size, num_layers,dropout=dropout_p,bidirectional=bidirectional, batch_first=True)
        else:  # Default to RNN
            self.rnn = nn.RNN(embedding_dim, hidden_size, num_layers,dropout=dropout_p,bidirectional=bidirectional, nonlinearity='tanh', batch_first=True)

    # Forward pass through the encoder
    def forward(self, input_seq):
        # Input shape: [batch_size, seq_len]
        batch_size = input_seq.size(0)
        
        # Convert indices to embeddings and apply dropout to embeddings
        embedded = self.embedding(input_seq)  # [batch_size, seq_len, embedding_dim]
        embedded = self.dropout(embedded)
        
        # Pass through RNN
        outputs, hidden = self.rnn(embedded)
        
        return outputs, hidden

In [8]:
# DecoderRNN generates target sequences one token at a time
class DecoderRNN(nn.Module):
    def __init__(self, output_size, hidden_size, embedding_dim, num_layers=1, 
                 cell_type='GRU', dropout_p=0.1):
        
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.cell_type = cell_type
        
        # Embedding layer for target characters
        self.embedding = nn.Embedding(output_size, embedding_dim)
        
        # Dropout applied before RNN
        self.dropout = nn.Dropout(dropout_p)

        dropout_p = dropout_p if num_layers > 1 else 0
        
        # RNN layer
        if cell_type == 'GRU':
            self.rnn = nn.GRU(embedding_dim, hidden_size, num_layers,dropout=dropout_p,batch_first=True)
        elif cell_type == 'LSTM':
            self.rnn = nn.LSTM(embedding_dim, hidden_size, num_layers,dropout=dropout_p,batch_first=True)
        else:  # Default to RNN
            self.rnn = nn.RNN(embedding_dim, hidden_size, num_layers,dropout=dropout_p,nonlinearity='tanh', batch_first=True)
            
        # Output projection layer with dropout
        self.out_dropout = nn.Dropout(dropout_p)
        self.out = nn.Linear(hidden_size, output_size)

    # Forward pass for a single decoding step
    def forward(self, input_char, hidden):
        # Convert input to embeddings and apply dropout
        embedded = self.embedding(input_char)  # [batch_size, 1, embedding_dim]
        embedded = self.dropout(embedded)
        
        # Pass through RNN
        output, hidden = self.rnn(embedded, hidden)
        
        # Apply dropout before prediction layer
        output = self.out_dropout(output)
        output = self.out(output[:, 0, :])
        
        return F.log_softmax(output, dim=1), hidden

In [9]:
#  Perform beam search decoding with the trained seq2seq model.
def beam_search_decode(model, src, sos_idx, eos_idx, max_len=30, beam_width=3, device='gpu'):
    model.eval()
    with torch.no_grad():
        # Encode input
        encoder_outputs, encoder_hidden = model.encoder(src)

        # Prepare initial decoder hidden state
        if model.bidirectional:
            if model.cell_type == 'LSTM':
                h_n, c_n = encoder_hidden
                h_dec = torch.zeros(model.decoder.num_layers, 1, model.decoder.hidden_size).to(device)
                c_dec = torch.zeros(model.decoder.num_layers, 1, model.decoder.hidden_size).to(device)
                for layer in range(model.encoder.num_layers):
                    h_combined = torch.cat((h_n[2*layer], h_n[2*layer+1]), dim=1)
                    c_combined = torch.cat((c_n[2*layer], c_n[2*layer+1]), dim=1)
                    h_dec[layer] = model.hidden_transform(h_combined)
                    c_dec[layer] = model.hidden_transform(c_combined)
                decoder_hidden = (h_dec, c_dec)
            else:
                decoder_hidden = torch.zeros(model.decoder.num_layers, 1, model.decoder.hidden_size).to(device)
                for layer in range(model.encoder.num_layers):
                    h_combined = torch.cat((encoder_hidden[2*layer], encoder_hidden[2*layer+1]), dim=1)
                    decoder_hidden[layer] = model.hidden_transform(h_combined)
        else:
            decoder_hidden = encoder_hidden

        # Beam search initialization
        beams = [([sos_idx], 0.0, decoder_hidden)]  # (sequence, cumulative log-prob, hidden)
        completed = []

        for _ in range(max_len):
            new_beams = []
            for seq, score, hidden in beams:
                if seq[-1] == eos_idx:
                    completed.append((seq, score))
                    continue
                input_char = torch.tensor([[seq[-1]]], device=device)
                output, hidden_new = model.decoder(input_char, hidden)
                log_probs = output.squeeze(0)  # [output_size]
                topk_log_probs, topk_indices = torch.topk(log_probs, beam_width)
                for k in range(beam_width):
                    next_seq = seq + [topk_indices[k].item()]
                    next_score = score + topk_log_probs[k].item()
                    new_beams.append((next_seq, next_score, hidden_new))
            # Keep top beam_width beams
            beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
            if not beams:
                break

        # Add any remaining beams
        completed += [(seq, score) for seq, score, _ in beams if seq[-1] == eos_idx]
        # If none ended with <eos>, just take the best
        if not completed:
            completed = beams

        # Sort by score
        completed = sorted(completed, key=lambda x: x[1], reverse=True)
        return completed

In [10]:
# Seq2Seq implements an Encoder and Decoder for end-to-end sequence-to-sequence modeling
# input_size: size of source vocabulary
# output_size: size of target vocabulary
# embedding_dim: dimension of embeddings in both encoder and decoder
# hidden_size: size of hidden states in encoder and decoder (must match for vanilla seq2seq)
# encoder_layers / decoder_layers: number of stacked RNN layers
# cell_type: 'GRU', 'LSTM', or 'RNN'
# dropout_p: dropout probability for embeddings and RNN layers
# bidirectional_encoder: if True, runs encoder bidirectionally and transforms hidden state

class Seq2Seq(nn.Module):
    def __init__(self, input_size, output_size, embedding_dim=256, hidden_size=256,
                 encoder_layers=1, decoder_layers=1, cell_type='GRU', dropout_p=0.2,
                 bidirectional_encoder=False):
        super(Seq2Seq, self).__init__()
        
        # Create encoder RNN
        self.encoder = EncoderRNN(input_size, hidden_size, embedding_dim,
                                  num_layers=encoder_layers, cell_type=cell_type,
                                  dropout_p=dropout_p, bidirectional=bidirectional_encoder)
        
        self.bidirectional = bidirectional_encoder
        directions = 2 if bidirectional_encoder else 1
        
        # If bidirectional encoder, need a linear layer to transform hidden state
        if bidirectional_encoder:
            self.hidden_transform = nn.Linear(hidden_size * directions, hidden_size)
        
        # Create decoder RNN
        self.decoder = DecoderRNN(output_size, hidden_size, embedding_dim,
                                 num_layers=decoder_layers, cell_type=cell_type,
                                 dropout_p=dropout_p)
        
        self.cell_type = cell_type

    def _match_decoder_layers(self, hidden, batch_size):
        """Ensures hidden state matches decoder layers by trimming or padding."""
        if hidden.size(0) > self.decoder.num_layers:
            return hidden[:self.decoder.num_layers]
        elif hidden.size(0) < self.decoder.num_layers:
            pad = torch.zeros(self.decoder.num_layers - hidden.size(0),
                              batch_size, self.decoder.hidden_size,
                              device=hidden.device)
            return torch.cat([hidden, pad], dim=0)
        else:
            return hidden
        
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.size(0)
        trg_len = trg.size(1)
        output_size = self.decoder.output_size
        
        # Tensor to store decoder outputs (logits or log-probs)
        outputs = torch.zeros(batch_size, trg_len, output_size).to(src.device)
        
        # Encode the source sequence
        encoder_outputs, encoder_hidden = self.encoder(src)
        decoder_hidden = None 
        
        # Prepare initial hidden state for decoder
        if self.bidirectional:
            # Bidirectional encoder returns hidden states with doubled layers * 2 directions
            
            if self.cell_type == 'LSTM':
                # For LSTM, hidden state is a tuple (h_n, c_n)
                h_n, c_n = encoder_hidden
                
                # Initialize decoder hidden states
                h_dec = torch.zeros(self.decoder.num_layers, batch_size, self.decoder.hidden_size).to(src.device)
                c_dec = torch.zeros(self.decoder.num_layers, batch_size, self.decoder.hidden_size).to(src.device)
                
                # For each decoder layer, combine corresponding forward and backward encoder layers
                for layer in range(self.decoder.num_layers):
                    # Clamp to max encoder layers index to avoid index errors if decoder has more layers
                    enc_layer = min(layer, self.encoder.num_layers - 1)
                    
                    # Concatenate forward and backward hidden states from encoder for this layer
                    h_combined = torch.cat((h_n[2 * enc_layer], h_n[2 * enc_layer + 1]), dim=1)
                    c_combined = torch.cat((c_n[2 * enc_layer], c_n[2 * enc_layer + 1]), dim=1)
                    
                    # Transform concatenated states to decoder hidden size
                    h_dec[layer] = self.hidden_transform(h_combined)
                    c_dec[layer] = self.hidden_transform(c_combined)
                
                decoder_hidden = (h_dec, c_dec)
            
            else:
                # For GRU or vanilla RNN (hidden state is a single tensor)
                h_n = encoder_hidden
                h_dec = torch.zeros(self.decoder.num_layers, batch_size, self.decoder.hidden_size).to(src.device)
                
                for layer in range(self.decoder.num_layers):
                    enc_layer = min(layer, self.encoder.num_layers - 1)
                    h_combined = torch.cat((h_n[2 * enc_layer], h_n[2 * enc_layer + 1]), dim=1)
                    h_dec[layer] = self.hidden_transform(h_combined)
                
                decoder_hidden = h_dec
        
        else:
            if self.cell_type == "LSTM":
                h, c = encoder_hidden
                decoder_hidden = (
                    self._match_decoder_layers(h, batch_size),
                    self._match_decoder_layers(c, batch_size)
                )
            else:
                decoder_hidden = self._match_decoder_layers(encoder_hidden, batch_size)
        
        # First input to decoder is <sos> token from target
        input_char = trg[:, 0].unsqueeze(1)  # shape: (batch_size, 1)
        
        # Decode one token at a time
        for t in range(1, trg_len):
            output, decoder_hidden = self.decoder(input_char, decoder_hidden)
            outputs[:, t, :] = output
            
            # Decide if teacher forcing should be used
            teacher_force = random.random() < teacher_forcing_ratio
            
            # Get highest probability token from output
            top1 = output.argmax(1).unsqueeze(1)
            
            # Next input is either true target token or predicted token
            input_char = trg[:, t].unsqueeze(1) if teacher_force else top1
        
        return outputs

In [11]:
class LexiconDataset(Dataset):
    def __init__(self, path, src_vocab=None, tgt_vocab=None, build_vocab=False):
        self.pairs = []
        with open(path, encoding='utf-8') as f:
            for line in f:
                cols = line.strip().split('\t')
                if len(cols) < 2:
                    continue
                tgt, src = cols[0], cols[1]  # Telugu is first column, romanized second
                self.pairs.append((src, tgt)) # src = romanized, tgt = telugu

        if build_vocab:
            self.src_vocab = {'<pad>':0, '<sos>':1, '<eos>':2, '<unk>':3}
            self.tgt_vocab = {'<pad>':0, '<sos>':1, '<eos>':2, '<unk>':3}
            for rom, dev in self.pairs:
                for c in rom:
                    self.src_vocab.setdefault(c, len(self.src_vocab))
                for c in dev:
                    self.tgt_vocab.setdefault(c, len(self.tgt_vocab))
        else:
            assert src_vocab and tgt_vocab, "Must provide vocabs if not building."
            self.src_vocab, self.tgt_vocab = src_vocab, tgt_vocab

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        rom, dev = self.pairs[idx]
        src_idxs = [self.src_vocab.get(c, self.src_vocab['<unk>']) for c in rom]
        tgt_idxs = [self.tgt_vocab['<sos>']] + [self.tgt_vocab.get(c, self.tgt_vocab['<unk>']) for c in dev] + [self.tgt_vocab['<eos>']]
        return torch.tensor(src_idxs, dtype=torch.long), torch.tensor(tgt_idxs, dtype=torch.long)

def collate_fn(batch):
    """
    Pads all src/tgt sequences in the batch to the max length.
    Returns:
      padded_src: (batch_size, max_src_len)
      padded_tgt: (batch_size, max_tgt_len)
    """
    srcs, tgts = zip(*batch)
    max_src = max(len(s) for s in srcs)
    max_tgt = max(len(t) for t in tgts)

    padded_src = torch.full((len(batch), max_src), 0, dtype=torch.long)
    padded_tgt = torch.full((len(batch), max_tgt), 0, dtype=torch.long)
    for i, (s, t) in enumerate(zip(srcs, tgts)):
        padded_src[i, :len(s)] = s
        padded_tgt[i, :len(t)] = t

    return padded_src, padded_tgt

def get_dataloaders(base_dir, batch_size, build_vocab=False):
    """
    Returns:
      train_loader, val_loader, test_loader,
      src_vocab_size, tgt_vocab_size, pad_index, src_vocab, tgt_vocab
    """
    train_p = os.path.join(base_dir, 'te.translit.sampled.train.tsv')
    dev_p   = os.path.join(base_dir, 'te.translit.sampled.dev.tsv')
    test_p  = os.path.join(base_dir, 'te.translit.sampled.test.tsv')

    train_ds = LexiconDataset(train_p, build_vocab=build_vocab)
    src_vocab, tgt_vocab = train_ds.src_vocab, train_ds.tgt_vocab
    val_ds   = LexiconDataset(dev_p,  src_vocab, tgt_vocab)
    test_ds  = LexiconDataset(test_p, src_vocab, tgt_vocab)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  collate_fn=collate_fn)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    test_loader  = DataLoader(test_ds,  batch_size=1,            shuffle=False, collate_fn=collate_fn)

    return (train_loader, val_loader, test_loader,
            len(src_vocab), len(tgt_vocab), src_vocab['<pad>'],
            src_vocab, tgt_vocab)

In [12]:
class EarlyStopper:
    """Stops a run if the monitored metric doesn’t improve for `patience` steps."""
    def __init__(self, patience=5, min_delta=1e-4):
        self.patience, self.min_delta = patience, min_delta
        self.counter, self.best = 0, None

    def should_stop(self, current):
        if self.best is None or current > self.best + self.min_delta:
            self.best, self.counter = current, 0
        else:
            self.counter += 1
        return self.counter >= self.patience

In [11]:
CHAR2IDX_SRC = {
    "<pad>": 0,
    "<sos>": 1,
    "<eos>": 2,
    "<unk>": 3,
    **{c: i + 4 for i, c in enumerate("abcdefghijklmnopqrstuvwxyz")}
}
IDX2CHAR_SRC = {i: c for c, i in CHAR2IDX_SRC.items()}

# Load data, build vocabs, and create reverse target-char map
train_loader, val_loader, test_loader, src_size, tgt_size, pad_idx, src_vocab, tgt_vocab = get_dataloaders(
    BASE_DIR, batch_size=64, build_vocab=True
)

IDX2CHAR_TGT = {idx: ch for ch, idx in tgt_vocab.items()}  # Map decoder indices back to Telugu chars

# Model, optimizer, loss, early stopping

model = Seq2Seq(
    input_size=src_size,
    output_size=tgt_size,
    embedding_dim=64,
    hidden_size=128,
    encoder_layers=1,
    decoder_layers=3,
    cell_type='GRU',  # or 'GRU' or 'RNN'
    dropout_p=0.3,
    bidirectional_encoder=False
).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.NLLLoss(ignore_index=pad_idx)
stopper = EarlyStopper(patience=5)
best_val_acc = 0.0

# Training Loop

for epoch in range(1, 11):
    model.train()
    total_loss = 0.0
    for src, tgt in tqdm(train_loader, desc=f"[Epoch {epoch}] Training", leave=False):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        optimizer.zero_grad()
        out = model(src, tgt, teacher_forcing_ratio=0.7)
        loss = criterion(out.view(-1, tgt_size), tgt.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Validation: sequence-level accuracy
    model.eval()
    correct_seqs, total_seqs = 0, 0
    with torch.no_grad():
        for src, tgt in tqdm(val_loader, desc=f"[Epoch {epoch}] Validation", leave=False):
            src, tgt = src.to(DEVICE), tgt.to(DEVICE)
            out = model(src, tgt, teacher_forcing_ratio=0.0)
            preds = out.argmax(dim=2)
            for pred_seq, true_seq in zip(preds, tgt):
                # Remove <sos> and padding tokens for comparison
                pred_tokens = pred_seq[1:][true_seq[1:] != pad_idx]
                true_tokens = true_seq[1:][true_seq[1:] != pad_idx]
                if torch.equal(pred_tokens, true_tokens):
                    correct_seqs += 1
                total_seqs += 1

    val_acc = correct_seqs / total_seqs
    print(f"[Epoch {epoch}] Loss: {total_loss:.4f} | Val Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        # Optionally save model checkpoint here
    elif stopper.should_stop(val_acc):
        print("Early stopping triggered.")
        break

# Final Test Evaluation
model.eval()
correct_seqs, total_seqs = 0, 0
all_preds, all_trues = [], []

with torch.no_grad():
    for src, tgt in tqdm(test_loader, desc="Final Test Eval", leave=False):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        out = model(src, tgt, teacher_forcing_ratio=0.0)
        preds = out.argmax(dim=2)
        for pred_seq, true_seq in zip(preds, tgt):
            pred_tokens = pred_seq[1:][true_seq[1:] != pad_idx]
            true_tokens = true_seq[1:][true_seq[1:] != pad_idx]
            if torch.equal(pred_tokens, true_tokens):
                correct_seqs += 1
            total_seqs += 1

            all_preds.append(pred_tokens)
            all_trues.append(true_tokens)

test_acc = correct_seqs / total_seqs
print(f"\n Final Test Accuracy: {test_acc:.4f}")


# Sample Predictions in Telugu
print("\n Sample Predictions:")
sample_indices = random.sample(range(len(all_preds)), min(10, len(all_preds)))
for idx in sample_indices:
    pred_str = ''.join([IDX2CHAR_TGT[token.item()] for token in all_preds[idx]])
    true_str = ''.join([IDX2CHAR_TGT[token.item()] for token in all_trues[idx]])
    correctness = "true" if pred_str == true_str else "false"
    print(f"  True : {true_str}\n  Pred : {pred_str}  {correctness}\n")

                                                                     

[Epoch 1] Loss: 2245.6370 | Val Acc: 0.0002


                                                                     

[Epoch 2] Loss: 1503.1435 | Val Acc: 0.0707


                                                                     

[Epoch 3] Loss: 1035.8125 | Val Acc: 0.1863


                                                                     

[Epoch 4] Loss: 817.9509 | Val Acc: 0.2708


                                                                     

[Epoch 5] Loss: 702.1377 | Val Acc: 0.3113


                                                                     

[Epoch 6] Loss: 626.4294 | Val Acc: 0.3532


                                                                     

[Epoch 7] Loss: 577.1637 | Val Acc: 0.3774


                                                                     

[Epoch 8] Loss: 536.8485 | Val Acc: 0.4065


                                                                     

[Epoch 9] Loss: 505.8226 | Val Acc: 0.4186


                                                                      

[Epoch 10] Loss: 481.2055 | Val Acc: 0.4209


                                                                     


 Final Test Accuracy: 0.3108

 Sample Predictions:
  True : అవసరము<eos>
  Pred : అవసరము<eos>  true

  True : తెలంగాణకు<eos>
  Pred : తెలంగానకు<eos>  false

  True : పోల్<eos>
  Pred : పోలు<eos>  false

  True : బలహీనత<eos>
  Pred : బలహినత<eos>  false

  True : ఆస్ట్రేలియా<eos>
  Pred : అస్టర్లియా<eos><eos>  false

  True : ఓటర్ల<eos>
  Pred : ఓటర్ల<eos>  true

  True : మీరూ<eos>
  Pred : మీరుల  false

  True : జాతీయోద్యమానికి<eos>
  Pred : జతీయోయాయమిని<eos><eos><eos><eos>  false

  True : కధలుగా<eos>
  Pred : కథలుగా<eos>  false

  True : చెందినదే<eos>
  Pred : చెందినదే<eos>  true





In [13]:
# Define the sweep configuration
sweep_config = {
  'method': 'bayes',
  'metric':     {'name': 'val_accuracy', 'goal': 'maximize'},
  'early_terminate': {
      'type': 'hyperband', 'min_iter': 2, 'max_iter': 8, 's': 2
  },
  'parameters': {
    'embedding_dim': {'values': [16, 32, 64, 256]},
    'hidden_size':    {'values': [16, 32, 64, 256]},
    'encoder_layers': {'values': [1, 2, 3]},
    'decoder_layers': {'values': [1, 2, 3]},
    'cell_type':      {'values': ['RNN','GRU','LSTM']},
    'dropout_p':      {'values': [0.2, 0.3, 0.4]},
    'beam_width':     {'values': [1, 3, 5]},
    'teacher_forcing_ratio' : {'values': [0.0, 0.3, 0.5,0.7,1.0]}
  }
}

In [14]:
def train():
    with wandb.init():
        cfg = wandb.config
        run_name = (
            f"emb{cfg.embedding_dim}_hid{cfg.hidden_size}"
            f"_enc{cfg.encoder_layers}_dec{cfg.decoder_layers}"
            f"_{cfg.cell_type.lower()}_do{int(cfg.dropout_p*100)}"
            f"_beam{cfg.beam_width}_tf{int(cfg.teacher_forcing_ratio*100)}"
        )
        wandb.run.name = run_name  # Update name after init

        # Load data + build vocab
        train_loader, val_loader, test_loader, src_size, tgt_size, pad_idx, src_vocab, tgt_vocab = get_dataloaders(
            BASE_DIR, batch_size=64, build_vocab=True
        )
        IDX2CHAR_TGT = {idx: ch for ch, idx in tgt_vocab.items()}

        # Initialize model
        model = Seq2Seq(
            input_size=src_size,
            output_size=tgt_size,
            embedding_dim=cfg.embedding_dim,
            hidden_size=cfg.hidden_size,
            encoder_layers=cfg.encoder_layers,
            decoder_layers=cfg.decoder_layers,
            cell_type=cfg.cell_type,
            dropout_p=cfg.dropout_p,
            bidirectional_encoder=False
        ).to(DEVICE)

        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        criterion = nn.NLLLoss(ignore_index=pad_idx)
        stopper = EarlyStopper(patience=5)
        best_val_acc = 0.0

        for epoch in range(1, 11):
            model.train()
            total_loss = 0.0
            for src, tgt in tqdm(train_loader, desc=f"[Epoch {epoch}] Training", leave=False):
                src, tgt = src.to(DEVICE), tgt.to(DEVICE)
                optimizer.zero_grad()
                out = model(src, tgt, teacher_forcing_ratio=cfg.teacher_forcing_ratio)
                loss = criterion(out.view(-1, tgt_size), tgt.view(-1))
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            # Validation
            model.eval()
            correct_seqs, total_seqs = 0, 0
            with torch.no_grad():
                for src, tgt in tqdm(val_loader, desc=f"[Epoch {epoch}] Validation", leave=False):
                    src, tgt = src.to(DEVICE), tgt.to(DEVICE)
                    out = model(src, tgt, teacher_forcing_ratio=0.0)
                    preds = out.argmax(dim=2)
                    for pred_seq, true_seq in zip(preds, tgt):
                        pred_tokens = pred_seq[1:][true_seq[1:] != pad_idx]
                        true_tokens = true_seq[1:][true_seq[1:] != pad_idx]
                        if torch.equal(pred_tokens, true_tokens):
                            correct_seqs += 1
                        total_seqs += 1

            val_acc = correct_seqs / total_seqs
            wandb.log({'epoch': epoch, 'val_accuracy': val_acc, 'train_loss': total_loss})

            print(f"[Epoch {epoch}] Train Loss: {total_loss:.4f} | Val Acc: {val_acc:.4f}")

            if val_acc > best_val_acc:
                best_val_acc = val_acc
            elif stopper.should_stop(val_acc):
                print("Early stopping triggered.")
                break

        # Final test evaluation
        model.eval()
        correct_seqs, total_seqs = 0, 0
        with torch.no_grad():
            for src, tgt in tqdm(test_loader, desc="Final Test Eval", leave=False):
                src, tgt = src.to(DEVICE), tgt.to(DEVICE)
                out = model(src, tgt, teacher_forcing_ratio=0.0)
                preds = out.argmax(dim=2)
                for pred_seq, true_seq in zip(preds, tgt):
                    pred_tokens = pred_seq[1:][true_seq[1:] != pad_idx]
                    true_tokens = true_seq[1:][true_seq[1:] != pad_idx]
                    if torch.equal(pred_tokens, true_tokens):
                        correct_seqs += 1
                    total_seqs += 1

        test_acc = correct_seqs / total_seqs
        print(f"\nFinal Test Accuracy: {test_acc:.4f}")
        wandb.log({'final_test_accuracy': test_acc})

In [None]:
sweep_id = wandb.sweep(sweep_config, project='da6401_assignment3')
wandb.agent(sweep_id, function=train, count=100)

# Best Model 

In [None]:
CHAR2IDX_SRC = {
    "<pad>": 0,
    "<sos>": 1,
    "<eos>": 2,
    "<unk>": 3,
    **{c: i + 4 for i, c in enumerate("abcdefghijklmnopqrstuvwxyz")}
}
IDX2CHAR_SRC = {i: c for c, i in CHAR2IDX_SRC.items()}

# Load data, build vocabs, and create reverse target-char map
train_loader, val_loader, test_loader, src_size, tgt_size, pad_idx, src_vocab, tgt_vocab = get_dataloaders(
    BASE_DIR, batch_size=64, build_vocab=True
)

IDX2CHAR_TGT = {idx: ch for ch, idx in tgt_vocab.items()}  # Map decoder indices back to Telugu chars

# Model, optimizer, loss, early stopping

# parameters of best model
best_model = Seq2Seq(
    input_size=src_size,
    output_size=tgt_size,
    embedding_dim=64,
    hidden_size=256,
    encoder_layers=3,
    decoder_layers=3,
    cell_type='LSTM',  # or 'GRU' or 'RNN'
    dropout_p=0.2,
    bidirectional_encoder=False
).to(DEVICE)

optimizer = torch.optim.Adam(best_model.parameters(), lr=1e-3)
criterion = nn.NLLLoss(ignore_index=pad_idx)
stopper = EarlyStopper(patience=5)
best_val_acc = 0.0

# Training Loop

for epoch in range(1, 11):
    best_model.train()
    total_loss = 0.0
    for src, tgt in tqdm(train_loader, desc=f"[Epoch {epoch}] Training", leave=False):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        optimizer.zero_grad()
        out = best_model(src, tgt, teacher_forcing_ratio=1.0)
        loss = criterion(out.view(-1, tgt_size), tgt.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Validation: sequence-level accuracy
    best_model.eval()
    correct_seqs, total_seqs = 0, 0
    with torch.no_grad():
        for src, tgt in tqdm(val_loader, desc=f"[Epoch {epoch}] Validation", leave=False):
            src, tgt = src.to(DEVICE), tgt.to(DEVICE)
            out = best_model(src, tgt, teacher_forcing_ratio=0.0)
            preds = out.argmax(dim=2)
            for pred_seq, true_seq in zip(preds, tgt):
                # Remove <sos> and padding tokens for comparison
                pred_tokens = pred_seq[1:][true_seq[1:] != pad_idx]
                true_tokens = true_seq[1:][true_seq[1:] != pad_idx]
                if torch.equal(pred_tokens, true_tokens):
                    correct_seqs += 1
                total_seqs += 1

    val_acc = correct_seqs / total_seqs
    print(f"[Epoch {epoch}] Loss: {total_loss:.4f} | Val Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        # Optionally save model checkpoint here
    elif stopper.should_stop(val_acc):
        print("Early stopping triggered.")
        break

# Final Test Evaluation
best_model.eval()
correct_seqs, total_seqs = 0, 0
all_preds, all_trues = [], []

with torch.no_grad():
    for src, tgt in tqdm(test_loader, desc="Final Test Eval", leave=False):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        out = best_model(src, tgt, teacher_forcing_ratio=0.0)
        preds = out.argmax(dim=2)
        for pred_seq, true_seq in zip(preds, tgt):
            pred_tokens = pred_seq[1:][true_seq[1:] != pad_idx]
            true_tokens = true_seq[1:][true_seq[1:] != pad_idx]
            if torch.equal(pred_tokens, true_tokens):
                correct_seqs += 1
            total_seqs += 1

            all_preds.append(pred_tokens)
            all_trues.append(true_tokens)

test_acc = correct_seqs / total_seqs
print(f"\n Final Test Accuracy (Exact Word match): {test_acc:.4f}")


In [16]:
romanized_test_words = []
with open(os.path.join(BASE_DIR, 'te.translit.sampled.test.tsv'), "r", encoding="utf-8") as f:
    for line in f:
        telugu, roman, _ = line.strip().split()
        romanized_test_words.append(roman)

best_model.eval()
samples = []

with torch.no_grad():
    for i, (src, tgt) in enumerate(test_loader):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        out = best_model(src, tgt, teacher_forcing_ratio=0.0)
        preds = out.argmax(dim=2)

        for b in range(src.shape[0]):
            pred = ''.join(
                IDX2CHAR_TGT[idx.item()] 
                for idx in preds[b][1:] 
                if idx.item() not in (pad_idx,)
            )
            true = ''.join(
                IDX2CHAR_TGT[idx.item()] 
                for idx in tgt[b][1:] 
                if idx.item() not in (pad_idx,)
            )

            romanized = romanized_test_words[i * src.shape[0] + b]

            samples.append({
                'Input': romanized,
                'True Telugu'             : true,
                'Predicted Telugu'        : pred
            })

# Sample and show table
subset = random.sample(samples, min(10, len(samples)))
df = pd.DataFrame(subset)
print(df.to_markdown(index=False))

| Input            | True Telugu   | Predicted Telugu   |
|:-----------------|:--------------|:-------------------|
| sikshnhaa        | శిక్షణా<eos>     | సిక్షణా<eos>          |
| rahasyamuga      | రహస్యముగా<eos>   | రాహస్యముగా             |
| kuuragaayalu     | కూరగాయలు<eos>    | కూరగాయలు<eos>         |
| varthapathrikalu | వార్తాపత్రికలు<eos> | వర్తపాత్రికలు<eos>ు      |
| mandalamulo      | మండలములో<eos>    | మండలములో<eos>         |
| akramaalu        | అక్రమాలు<eos>    | ఆక్రమాలు<eos>         |
| vikalangula      | వికలాంగుల<eos>    | వికలంగుల<eos>ు         |
| daariloone       | దారిలోనే<eos>     | దారిలోనే<eos>          |
| bhawanaalalo     | భవనాలలో<eos>    | భావనాలలో              |
| gramddhaalayam   | గ్రంథాలయం<eos>    | గ్రంథాలయాం              |


In [18]:
def highlight_pred(row):
    """
    Returns a list of CSS styles, one per column, 
    coloring the 'Predicted Telugu' cell green if correct else red.
    """
    styles = [''] * len(row)
    # Find the index of the Predicted column
    pred_idx = list(row.index).index('Predicted Telugu')
    if row['Predicted Telugu'] == row['True Telugu']:
        styles[pred_idx] = 'background-color: #c8e6c9; font-weight: bold;'  # light green
    else:
        styles[pred_idx] = 'background-color: #f8d7da; font-weight: bold;'  # light red
    return styles

# Apply to a random sample of 10 rows
subset = df.sample(n=min(10, len(df))).reset_index(drop=True)

styled = (
    subset.style
          .apply(highlight_pred, axis=1)
          .set_table_styles([
              # Center all text
              {'selector': 'td, th',
               'props': [('text-align', 'center'), ('padding', '6px')]},
              # Header style
              {'selector': 'th',
               'props': [('background-color', '#4F81BD'),
                         ('color', 'white'),
                         ('font-weight', 'bold'),
                         ('padding', '8px')]}
          ])
          .set_caption(" Sample Transliteration Predictions (Green = Correct, Red = Wrong) ")
)

# Display in a Jupyter/HTML context
display(styled)

Unnamed: 0,Input,True Telugu,Predicted Telugu
0,bhawanaalalo,భవనాలలో,భావనాలలో
1,akramaalu,అక్రమాలు,ఆక్రమాలు
2,gramddhaalayam,గ్రంథాలయం,గ్రంథాలయాం
3,sikshnhaa,శిక్షణా,సిక్షణా
4,vikalangula,వికలాంగుల,వికలంగులు
5,varthapathrikalu,వార్తాపత్రికలు,వర్తపాత్రికలుు
6,daariloone,దారిలోనే,దారిలోనే
7,mandalamulo,మండలములో,మండలములో
8,kuuragaayalu,కూరగాయలు,కూరగాయలు
9,rahasyamuga,రహస్యముగా,రాహస్యముగా


In [19]:
# Ensure the output folder exists
os.makedirs("predictions_vanilla", exist_ok=True)

df1 = pd.DataFrame(samples)
df1.to_csv("predictions_vanilla/predictions.csv", index=False, encoding="utf-8-sig")

print(" Saved all predictions to predictions_vanilla/predictions.csv")

 Saved all predictions to predictions_vanilla/predictions.csv
