<a href="https://colab.research.google.com/github/Sai-sakunthala/Assignment-3/blob/main/attention_latin_to_telugu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch wandb pandas tqdm

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random
import wandb
import editdistance
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
LANG = 'te'
data_path = f'/content/drive/MyDrive/dakshina_dataset_v1.0/{LANG}/lexicons/'

def read_data(filepath, max_len=40):
    pairs = []
    with open(filepath, encoding='utf8') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) < 2:
                continue
            devanagiri, latin = parts[0], parts[1]  # Current order
            # Swap to make Latin → Devanagiri
            source, target = latin, devanagiri  # Now Latin is source, Devanagiri is target
            if len(source) <= max_len and len(target) <= max_len:
                pairs.append((source, target))
    return pairs

def make_vocab(sequences):
    vocab = {'<pad>':0, '<sos>':1, '<eos>':2}
    idx = 3
    for seq in sequences:
        for ch in seq:
            if ch not in vocab:
                vocab[ch] = idx
                idx += 1
    idx2char = {i:c for c,i in vocab.items()}
    return vocab, idx2char

def encode_word(word, vocab):
    return [vocab['<sos>']] + [vocab[ch] for ch in word] + [vocab['<eos>']]

def pad_seq(seq, max_len, pad_idx=0):
    return seq + [pad_idx] * (max_len - len(seq))

class TransliterationDataset(Dataset):
    def __init__(self, pairs, source_vocab, target_vocab):
        self.source_pad = source_vocab['<pad>']
        self.target_pad = target_vocab['<pad>']
        self.data = []
        for source, target in pairs:
            source_t = encode_word(source, source_vocab)
            target_t = encode_word(target, target_vocab)
            self.data.append((source_t, target_t))
        self.source_max = max(len(x[0]) for x in self.data)
        self.target_max = max(len(x[1]) for x in self.data)

    def __len__(self): return len(self.data)

    def __getitem__(self, idx):
        source, target = self.data[idx]
        source = pad_seq(source, self.source_max, self.source_pad)
        target = pad_seq(target, self.target_max, self.target_pad)
        return torch.tensor(source), torch.tensor(target)

import torch
import torch.nn as nn
import random

# Updated Attention class with robust shape handling
class Attention(nn.Module):
    def __init__(self, hid_dimensions):
        super().__init__()
        self.attn = nn.Linear(hid_dimensions * 2, hid_dimensions)
        self.v = nn.Parameter(torch.rand(hid_dimensions))
        stdv = 1. / (hid_dimensions ** 0.5)
        self.v.data.uniform_(-stdv, stdv)
        self.hid_dimensions = hid_dimensions

    def forward(self, hidden, encoder_outputs):
        # hidden: (batch_size, hid_dimensions) or (num_layers, batch_size, hid_dimensions)
        # encoder_outputs: (batch_size, src_len, hid_dimensions)
        batch_size = encoder_outputs.size(0)
        src_len = encoder_outputs.size(1)

        # Ensure hidden is 2D (batch_size, hid_dimensions)
        if hidden.dim() == 3:  # (num_layers, batch_size, hid_dimensions)
            hidden = hidden[-1]  # Take last layer: (batch_size, hid_dimensions)
        elif hidden.dim() != 2:
            raise ValueError(f"Expected hidden to be 2D or 3D, got shape {hidden.shape}")

        # Repeat hidden state src_len times
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)  # (batch_size, src_len, hid_dimensions)

        # Compute energy
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))  # (batch_size, src_len, hid_dimensions)
        energy = energy @ self.v  # (batch_size, src_len)

        # Compute attention weights
        attn_weights = torch.softmax(energy, dim=1).unsqueeze(2)  # (batch_size, src_len, 1)

        # Compute context vector
        context = torch.sum(attn_weights * encoder_outputs, dim=1)  # (batch_size, hid_dimensions)

        return context

# Updated translit_Encoder
class translit_Encoder(nn.Module):
    def __init__(self, input_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell='lstm'):
        super().__init__()
        self.embedding = nn.Embedding(input_dimensions, emb_dimensions)
        rnn_cls = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}[cell.lower()]
        self.rnn = rnn_cls(emb_dimensions, hid_dimensions, num_layers, dropout=dropout if num_layers > 1 else 0, batch_first=True)
        self.attention = Attention(hid_dimensions)
        self.cell = cell.lower()
        self.dropout = nn.Dropout(dropout)

    def forward(self, source):
        # source: (batch_size, src_len)
        embedded = self.dropout(self.embedding(source))  # (batch_size, src_len, emb_dimensions)

        # RNN
        if self.cell == 'lstm':
            outputs, (hidden, cell) = self.rnn(embedded)
        else:
            outputs, hidden = self.rnn(embedded)
            cell = None

        # Apply attention using the last hidden state
        context = self.attention(hidden, outputs)  # (batch_size, hid_dimensions)

        return outputs, hidden, cell

# Updated translit_Decoder
class translit_Decoder(nn.Module):
    def __init__(self, output_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell='lstm'):
        super().__init__()
        self.embedding = nn.Embedding(output_dimensions, emb_dimensions)
        self.attention = Attention(hid_dimensions)
        rnn_cls = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}[cell.lower()]
        self.rnn = rnn_cls(emb_dimensions, hid_dimensions, num_layers, dropout=dropout if num_layers > 1 else 0, batch_first=True)
        self.fc_out = nn.Linear(hid_dimensions * 2, output_dimensions)  # *2 for RNN output + context
        self.cell = cell.lower()
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell, encoder_outputs):
        # input: (batch_size)
        # hidden: (num_layers, batch_size, hid_dimensions)
        # cell: (num_layers, batch_size, hid_dimensions) or None
        # encoder_outputs: (batch_size, src_len, hid_dimensions)

        input = input.unsqueeze(1)  # (batch_size, 1)
        embedded = self.dropout(self.embedding(input))  # (batch_size, 1, emb_dimensions)

        # RNN
        if self.cell == 'lstm':
            output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        else:
            output, hidden = self.rnn(embedded, hidden)
            cell = None

        # Compute attention context
        context = self.attention(hidden, encoder_outputs)  # (batch_size, hid_dimensions)

        # Combine RNN output and context
        rnn_output = output.squeeze(1)  # (batch_size, hid_dimensions)
        combined = torch.cat((rnn_output, context), dim=1)  # (batch_size, hid_dimensions * 2)

        # Prediction
        prediction = self.fc_out(combined)  # (batch_size, output_dimensions)

        return prediction, hidden, cell

# Updated translit_Seq2Seq
class translit_Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        batch_size = source.size(0)
        target_len = target.size(1)
        output_dimensions = self.decoder.fc_out.out_features

        outputs = torch.zeros(batch_size, target_len, output_dimensions).to(self.device)

        # Encoder
        encoder_outputs, hidden, cell = self.encoder(source)

        # First input
        input = target[:, 0]

        # Decoder loop
        for t in range(1, target_len):
            output, hidden, cell = self.decoder(input, hidden, cell, encoder_outputs)
            outputs[:, t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = target[:, t] if teacher_force else top1

        return outputs

def strip_after_eos(seq, eos_idx):
    if isinstance(seq, torch.Tensor):  # Handle tensors
        seq = seq.cpu().numpy().tolist()
    if eos_idx in seq:
        return seq[:seq.index(eos_idx)]  # Exclude EOS for fair comparison
    return seq

def calculate_word_accuracy(preds, targets, pad_idx=0, eos_idx=None):
    correct = 0
    for pred, target in zip(preds, targets):
        pred = strip_after_eos(pred, eos_idx) if eos_idx else pred
        target = strip_after_eos(target, eos_idx) if eos_idx else target
        pred = [p for p in pred if p != pad_idx]
        target = [t for t in target if t != pad_idx]
        correct += int(pred == target)
    return correct / max(len(preds), 1)

def calculate_cer(preds, targets, pad_idx=0, eos_idx=None):
    cer, total = 0, 0
    for pred, target in zip(preds, targets):
        pred = strip_after_eos(pred, eos_idx) if eos_idx else pred
        target = strip_after_eos(target, eos_idx) if eos_idx else target
        pred = [p for p in pred if p != pad_idx]
        target = [t for t in target if t != pad_idx]
        cer += editdistance.eval(pred, target)
        total += max(len(target), 1)
    return cer / total if total > 0 else float('inf')

def calculate_accuracy(preds, targets, pad_idx=0, eos_idx=None):
    correct = 0
    total = 0
    for pred, target in zip(preds, targets):
        if isinstance(pred, torch.Tensor):
            pred = pred.cpu().tolist()
        if isinstance(target, torch.Tensor):
            target = target.cpu().tolist()
        if eos_idx is not None:
            pred = strip_after_eos(pred, eos_idx)
            target = strip_after_eos(target, eos_idx)
        for p_token, t_token in zip(pred, target):
            if t_token == pad_idx:
                continue
            if p_token == t_token:
                correct += 1
            total += 1
    return correct / total if total > 0 else 0.0

def run(config=None):
    with wandb.init(config=config):
        cfg = wandb.config
        cfg.hidden_dim = 2 * cfg.embed_dim if cfg.hidden_dim_config == 'double' else cfg.embed_dim
        sweep_name = f"{cfg.cell_type}_{cfg.embed_dim}e_{cfg.hidden_dim_config}h_{cfg.layers}l_" \
             f"{int(cfg.dropout*100)}d_{int(cfg.teacher_forcing*10)}tf_" \
             f"{str(cfg.lr).replace('.', '')}lr_attention"

        wandb.run.name = sweep_name

        max_len = 30

        # Load data (ensure paths match your dataset—full dataset naming assumed)
        train_pairs = read_data(data_path + f"{LANG}.translit.sampled.train.tsv", max_len=max_len)
        val_pairs   = read_data(data_path + f"{LANG}.translit.sampled.dev.tsv",   max_len=max_len)
        test_pairs  = read_data(data_path + f"{LANG}.translit.sampled.test.tsv",  max_len=max_len)

        source_vocab, _ = make_vocab([x[0] for x in train_pairs])
        target_vocab, _ = make_vocab([x[1] for x in train_pairs])

        # (Add assertions to ensure special tokens are consistent)
        assert source_vocab['<pad>'] == 0 and target_vocab['<pad>'] == 0, "Pad token must be index 0 in both vocabs."

        train_translit = TransliterationDataset(train_pairs, source_vocab, target_vocab)
        val_translit   = TransliterationDataset(val_pairs,   source_vocab, target_vocab)
        test_translit  = TransliterationDataset(test_pairs,  source_vocab, target_vocab)

        # Use drop_last=True to ensure consistent batch sizes.
        train_drop_last = DataLoader(train_translit, batch_size=cfg.batch_size, shuffle=True, drop_last=True)
        val_drop_last   = DataLoader(val_translit,   batch_size=cfg.batch_size, drop_last=True)
        test_drop_last  = DataLoader(test_translit,  batch_size=cfg.batch_size, drop_last=True)

        encoder = translit_Encoder(len(source_vocab), cfg.embed_dim, cfg.hidden_dim, cfg.layers, cfg.dropout, cfg.cell_type).to(device)
        decoder = translit_Decoder(len(target_vocab), cfg.embed_dim, cfg.hidden_dim, cfg.layers, cfg.dropout, cfg.cell_type).to(device)
        model = translit_Seq2Seq(encoder, decoder, device).to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
        criterion = nn.CrossEntropyLoss(ignore_index=target_vocab['<pad>'])
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)

        best_val_loss = float('inf')
        patience = 15
        wait = 0

        for epoch in range(cfg.epochs):
            model.train()
            total_loss = 0
            total_acc  = 0
            total_char_acc = 0
            for source, target in train_drop_last:
                source, target = source.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(source, target, cfg.teacher_forcing)
                out_dimensions = output.shape[-1]
                loss = criterion(output[:, 1:].reshape(-1, out_dimensions), target[:, 1:].reshape(-1))
                eos_idx = target_vocab['<eos>']

                raw_preds = output.argmax(2)[:, 1:].tolist()
                raw_targets = target[:, 1:].tolist()
                preds = [strip_after_eos(p, eos_idx) for p in raw_preds]
                targets = [strip_after_eos(t, eos_idx) for t in raw_targets]
                acc = calculate_word_accuracy(preds, targets, pad_idx=target_vocab['<pad>'])
                char_acc = calculate_accuracy(preds, targets, pad_idx=target_vocab['<pad>'])
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                total_loss += loss.item()
                total_acc  += acc
                total_char_acc += char_acc

            avg_train_loss = total_loss / len(train_drop_last)
            avg_train_acc  = total_acc / len(train_drop_last)
            avg_train_char_acc = total_char_acc / len(train_drop_last)

            model.eval()
            val_loss = 0
            val_acc  = 0
            val_cer  = 0
            total_char_acc = 0
            with torch.no_grad():
                for source, target in val_drop_last:
                    source, target = source.to(device), target.to(device)
                    output = model(source, target, teacher_forcing_ratio=0)
                    out_dimensions = output.shape[-1]
                    loss = criterion(output[:, 1:].reshape(-1, out_dimensions), target[:, 1:].reshape(-1))
                    eos_idx = target_vocab['<eos>']

                    raw_preds = output.argmax(2)[:, 1:].tolist()
                    raw_targets = target[:, 1:].tolist()
                    preds = [strip_after_eos(p, eos_idx) for p in raw_preds]
                    targets = [strip_after_eos(t, eos_idx) for t in raw_targets]
                    acc = calculate_word_accuracy(preds, targets, pad_idx=target_vocab['<pad>'])
                    char_acc = calculate_accuracy(preds, targets, pad_idx=target_vocab['<pad>'])
                    val_loss += loss.item()
                    val_acc  += acc
                    preds = output.argmax(2).tolist()
                    targets = target.tolist()
                    val_cer += calculate_cer(preds, targets, pad_idx=target_vocab['<pad>'])
                    total_char_acc += char_acc

            avg_val_loss = val_loss / len(val_drop_last)
            avg_val_acc  = val_acc / len(val_drop_last)
            avg_val_cer  = val_cer / len(val_drop_last)
            avg_val_char_acc = total_char_acc / len(val_drop_last)

            scheduler.step(avg_val_loss)

            wandb.log({
                'train_loss': avg_train_loss,
                'train_accuracy': avg_train_acc,
                'train_char_accuracy': avg_train_char_acc,
                'val_loss': avg_val_loss,
                'val_accuracy': avg_val_acc,
                'val_cer': avg_val_cer,
                'val_char_accuracy': avg_val_char_acc,
                'epoch': epoch + 1
            })
            print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.3f} Acc: {avg_train_acc:.3f} | "
                  f"Val Loss: {avg_val_loss:.3f} Acc: {avg_val_acc:.3f} CER: {avg_val_cer:.3f}")

            # Early Stopping Check
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                wait = 0
                torch.save(model.state_dict(), 'best_model.pt')
            else:
                wait += 1
                if wait >= patience:
                    print("Early stopping triggered.")
                    break

        # After training, load the best model and evaluate on the test set.
        artifact = wandb.Artifact('best_model', type='model')
        artifact.add_file('best_model.pt')
        wandb.log_artifact(artifact)
        model.load_state_dict(torch.load('best_model.pt'))

sweep_config = {
    'method': 'bayes',
    'metric': {
        'name': 'val_char_accuracy',
        'goal': 'maximize'
    },
    'parameters': {
        'embed_dim': {'values': [256]},
        'hidden_dim_config': {'values': ['double']},
        'layers': {'values': [2]},
        'dropout': {'values': [0.3]},
        'lr': {'values': [0.001]},
        'cell_type': {'values': ['lstm']},
        'teacher_forcing': {'values': [0.5]},
        'batch_size': {'value': 64},
        'epochs': {'value': 20}
    }
}

cuda


In [None]:
import wandb
wandb.login()
try:
    sweep_id = wandb.sweep(sweep_config, project="dakshina-seq2seq-3")
    wandb.agent(sweep_id, function=run, count=1)
except:
    wandb.finish()

Create sweep with ID: attugua7
Sweep URL: https://wandb.ai/sai-sakunthala-indian-institute-of-technology-madras/dakshina-seq2seq-3/sweeps/attugua7


[34m[1mwandb[0m: Agent Starting Run: q4uu2h8m with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	cell_type: lstm
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	embed_dim: 256
[34m[1mwandb[0m: 	epochs: 20
[34m[1mwandb[0m: 	hidden_dim_config: double
[34m[1mwandb[0m: 	layers: 2
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	teacher_forcing: 0.5


Epoch 1 | Train Loss: 0.840 Acc: 0.333 | Val Loss: 0.717 Acc: 0.481 CER: 1.451
Epoch 2 | Train Loss: 0.329 Acc: 0.560 | Val Loss: 0.622 Acc: 0.535 CER: 1.434
Epoch 3 | Train Loss: 0.260 Acc: 0.636 | Val Loss: 0.612 Acc: 0.562 CER: 1.427
Epoch 4 | Train Loss: 0.218 Acc: 0.688 | Val Loss: 0.629 Acc: 0.578 CER: 1.426
Epoch 5 | Train Loss: 0.186 Acc: 0.729 | Val Loss: 0.636 Acc: 0.586 CER: 1.423
Epoch 6 | Train Loss: 0.160 Acc: 0.760 | Val Loss: 0.659 Acc: 0.581 CER: 1.422
Epoch 7 | Train Loss: 0.106 Acc: 0.833 | Val Loss: 0.669 Acc: 0.610 CER: 1.416
Epoch 8 | Train Loss: 0.086 Acc: 0.862 | Val Loss: 0.652 Acc: 0.615 CER: 1.418
Epoch 9 | Train Loss: 0.076 Acc: 0.877 | Val Loss: 0.694 Acc: 0.610 CER: 1.415
Epoch 10 | Train Loss: 0.055 Acc: 0.908 | Val Loss: 0.690 Acc: 0.614 CER: 1.416
Epoch 11 | Train Loss: 0.046 Acc: 0.921 | Val Loss: 0.715 Acc: 0.613 CER: 1.416
Epoch 12 | Train Loss: 0.042 Acc: 0.927 | Val Loss: 0.748 Acc: 0.609 CER: 1.417
Epoch 13 | Train Loss: 0.033 Acc: 0.943 | Val Los

0,1
epoch,▁▁▂▂▃▃▃▄▄▅▅▆▆▆▇▇██
train_accuracy,▁▄▄▅▅▆▇▇▇▇████████
train_char_accuracy,▁▅▆▆▇▇▇▇██████████
train_loss,█▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▄▅▆▆▆████████████
val_cer,█▅▃▃▃▂▁▂▁▁▁▁▁▁▁▁▁▁
val_char_accuracy,▁▄▆▆▆▆▇█▇██▇██████
val_loss,▅▁▁▂▂▃▃▂▄▄▅▆▇▆▇█▇█

0,1
epoch,18.0
train_accuracy,0.95947
train_char_accuracy,0.99425
train_loss,0.022
val_accuracy,0.61523
val_cer,1.41637
val_char_accuracy,0.87985
val_loss,0.80257
