<a href="https://colab.research.google.com/github/Sai-sakunthala/Assignment-3/blob/main/Vanilla_latin_to_telugu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch wandb pandas tqdm

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random
import wandb
import editdistance
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
LANG = 'te'
data_path = f'/content/drive/MyDrive/dakshina_dataset_v1.0/{LANG}/lexicons/'

def read_data(filepath, max_len=40):
    pairs = []
    with open(filepath, encoding='utf8') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) < 2:
                continue
            devanagiri, latin = parts[0], parts[1]  # Current order
            # Swap to make Latin → Devanagiri
            source, target = latin, devanagiri  # Now Latin is source, Devanagiri is target
            if len(source) <= max_len and len(target) <= max_len:
                pairs.append((source, target))
    return pairs

def make_vocab(sequences):
    vocab = {'<pad>':0, '<sos>':1, '<eos>':2}
    idx = 3
    for seq in sequences:
        for ch in seq:
            if ch not in vocab:
                vocab[ch] = idx
                idx += 1
    idx2char = {i:c for c,i in vocab.items()}
    return vocab, idx2char

def encode_word(word, vocab):
    return [vocab['<sos>']] + [vocab[ch] for ch in word] + [vocab['<eos>']]

def pad_seq(seq, max_len, pad_idx=0):
    return seq + [pad_idx] * (max_len - len(seq))

class TransliterationDataset(Dataset):
    def __init__(self, pairs, source_vocab, target_vocab):
        self.source_pad = source_vocab['<pad>']
        self.target_pad = target_vocab['<pad>']
        self.data = []
        for source, target in pairs:
            source_t = encode_word(source, source_vocab)
            target_t = encode_word(target, target_vocab)
            self.data.append((source_t, target_t))
        self.source_max = max(len(x[0]) for x in self.data)
        self.target_max = max(len(x[1]) for x in self.data)

    def __len__(self): return len(self.data)

    def __getitem__(self, idx):
        source, target = self.data[idx]
        source = pad_seq(source, self.source_max, self.source_pad)
        target = pad_seq(target, self.target_max, self.target_pad)
        return torch.tensor(source), torch.tensor(target)

class translit_Encoder(nn.Module):
    def __init__(self, input_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell='lstm'):
        super().__init__()
        self.embedding = nn.Embedding(input_dimensions, emb_dimensions)
        rnn_cls = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}[cell.lower()]
        self.rnn = rnn_cls(emb_dimensions, hid_dimensions, num_layers, dropout=dropout if num_layers > 1 else 0, batch_first=True)
        self.cell = cell.lower()

    def forward(self, source):
        embedded = self.embedding(source)
        if self.cell == 'lstm':
            outputs, (hidden, cell) = self.rnn(embedded)
            return hidden, cell
        else:
            outputs, hidden = self.rnn(embedded)
            return hidden, None

class translit_Decoder(nn.Module):
    def __init__(self, output_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell='lstm'):
        super().__init__()
        self.embedding = nn.Embedding(output_dimensions, emb_dimensions)
        rnn_cls = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}[cell.lower()]
        self.rnn = rnn_cls(emb_dimensions, hid_dimensions, num_layers, dropout=dropout if num_layers > 1 else 0, batch_first=True)
        self.fc_out = nn.Linear(hid_dimensions, output_dimensions)
        self.cell = cell.lower()

    def forward(self, input, hidden, cell=None):
        input = input.unsqueeze(1)
        embedded = self.embedding(input)
        if self.cell == 'lstm':
            output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        else:
            output, hidden = self.rnn(embedded, hidden)
            cell = None
        prediction = self.fc_out(output.squeeze(1))
        return prediction, hidden, cell

class Attention(nn.Module):
    def __init__(self, hid_dim):
        super().__init__()
        self.attn = nn.Linear(hid_dim * 2, hid_dim)
        self.v = nn.Linear(hid_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs, mask=None):
        # hidden: [batch, hid_dim], encoder_outputs: [batch, src_len, hid_dim]
        src_len = encoder_outputs.size(1)
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)  # [batch, src_len, hid_dim]
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))  # [batch, src_len, hid_dim]
        attention = self.v(energy).squeeze(2)  # [batch, src_len]
        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e10)
        return torch.softmax(attention, dim=1)

class translit_Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        batch_size = source.size(0)
        target_len = target.size(1)
        output_dimensions = self.decoder.fc_out.out_features

        outputs = torch.zeros(batch_size, target_len, output_dimensions).to(self.device)
        hidden, cell = self.encoder(source)
        input = target[:, 0]

        for t in range(1, target_len):
            output, hidden, cell = self.decoder(input, hidden, cell)
            outputs[:, t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = target[:, t] if teacher_force else top1
        return outputs

def strip_after_eos(seq, eos_idx):
    if isinstance(seq, torch.Tensor):  # Handle tensors
        seq = seq.cpu().numpy().tolist()
    if eos_idx in seq:
        return seq[:seq.index(eos_idx)]  # Exclude EOS for fair comparison
    return seq

def calculate_word_accuracy(preds, targets, pad_idx=0, eos_idx=None):
    correct = 0
    for pred, target in zip(preds, targets):
        pred = strip_after_eos(pred, eos_idx) if eos_idx else pred
        target = strip_after_eos(target, eos_idx) if eos_idx else target
        pred = [p for p in pred if p != pad_idx]
        target = [t for t in target if t != pad_idx]
        correct += int(pred == target)
    return correct / max(len(preds), 1)

def calculate_cer(preds, targets, pad_idx=0, eos_idx=None):
    cer, total = 0, 0
    for pred, target in zip(preds, targets):
        pred = strip_after_eos(pred, eos_idx) if eos_idx else pred
        target = strip_after_eos(target, eos_idx) if eos_idx else target
        pred = [p for p in pred if p != pad_idx]
        target = [t for t in target if t != pad_idx]
        cer += editdistance.eval(pred, target)
        total += max(len(target), 1)
    return cer / total if total > 0 else float('inf')

def calculate_accuracy(preds, targets, pad_idx=0, eos_idx=None):
    correct = 0
    total = 0
    for pred, target in zip(preds, targets):
        if isinstance(pred, torch.Tensor):
            pred = pred.cpu().tolist()
        if isinstance(target, torch.Tensor):
            target = target.cpu().tolist()
        if eos_idx is not None:
            pred = strip_after_eos(pred, eos_idx)
            target = strip_after_eos(target, eos_idx)
        for p_token, t_token in zip(pred, target):
            if t_token == pad_idx:
                continue
            if p_token == t_token:
                correct += 1
            total += 1
    return correct / total if total > 0 else 0.0

def run(config=None):
    with wandb.init(config=config):
        cfg = wandb.config
        cfg.hidden_dim = 2 * cfg.embed_dim if cfg.hidden_dim_config == 'double' else cfg.embed_dim
        sweep_name = f"{cfg.cell_type}_{cfg.embed_dim}e_{cfg.hidden_dim_config}h_{cfg.layers}l_" \
             f"{int(cfg.dropout*100)}d_{int(cfg.teacher_forcing*10)}tf_" \
             f"{str(cfg.lr).replace('.', '')}lr"

        wandb.run.name = sweep_name

        max_len = 30

        # Load data (ensure paths match your dataset—full dataset naming assumed)
        train_pairs = read_data(data_path + f"{LANG}.translit.sampled.train.tsv", max_len=max_len)
        val_pairs   = read_data(data_path + f"{LANG}.translit.sampled.dev.tsv",   max_len=max_len)
        test_pairs  = read_data(data_path + f"{LANG}.translit.sampled.test.tsv",  max_len=max_len)

        source_vocab, _ = make_vocab([x[0] for x in train_pairs])
        target_vocab, _ = make_vocab([x[1] for x in train_pairs])

        # (Add assertions to ensure special tokens are consistent)
        assert source_vocab['<pad>'] == 0 and target_vocab['<pad>'] == 0, "Pad token must be index 0 in both vocabs."

        train_translit = TransliterationDataset(train_pairs, source_vocab, target_vocab)
        val_translit   = TransliterationDataset(val_pairs,   source_vocab, target_vocab)
        test_translit  = TransliterationDataset(test_pairs,  source_vocab, target_vocab)

        # Use drop_last=True to ensure consistent batch sizes.
        train_drop_last = DataLoader(train_translit, batch_size=cfg.batch_size, shuffle=True, drop_last=True)
        val_drop_last   = DataLoader(val_translit,   batch_size=cfg.batch_size, drop_last=True)
        test_drop_last  = DataLoader(test_translit,  batch_size=cfg.batch_size, drop_last=True)

        encoder = translit_Encoder(len(source_vocab), cfg.embed_dim, cfg.hidden_dim, cfg.layers, cfg.dropout, cfg.cell_type).to(device)
        decoder = translit_Decoder(len(target_vocab), cfg.embed_dim, cfg.hidden_dim, cfg.layers, cfg.dropout, cfg.cell_type).to(device)
        model = translit_Seq2Seq(encoder, decoder, device).to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
        criterion = nn.CrossEntropyLoss(ignore_index=target_vocab['<pad>'])
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)

        best_val_loss = float('inf')
        patience = 10
        wait = 0

        for epoch in range(cfg.epochs):
            model.train()
            total_loss = 0
            total_acc  = 0
            total_char_acc = 0
            for source, target in train_drop_last:
                source, target = source.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(source, target, cfg.teacher_forcing)
                out_dimensions = output.shape[-1]
                loss = criterion(output[:, 1:].reshape(-1, out_dimensions), target[:, 1:].reshape(-1))
                eos_idx = target_vocab['<eos>']

                raw_preds = output.argmax(2)[:, 1:].tolist()
                raw_targets = target[:, 1:].tolist()
                preds = [strip_after_eos(p, eos_idx) for p in raw_preds]
                targets = [strip_after_eos(t, eos_idx) for t in raw_targets]
                acc = calculate_word_accuracy(preds, targets, pad_idx=target_vocab['<pad>'])
                char_acc = calculate_accuracy(preds, targets, pad_idx=target_vocab['<pad>'])
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                total_loss += loss.item()
                total_acc  += acc
                total_char_acc += char_acc

            avg_train_loss = total_loss / len(train_drop_last)
            avg_train_acc  = total_acc / len(train_drop_last)
            avg_train_char_acc = total_char_acc / len(train_drop_last)

            model.eval()
            val_loss = 0
            val_acc  = 0
            val_cer  = 0
            total_char_acc = 0
            with torch.no_grad():
                for source, target in val_drop_last:
                    source, target = source.to(device), target.to(device)
                    output = model(source, target, teacher_forcing_ratio=0)
                    out_dimensions = output.shape[-1]
                    loss = criterion(output[:, 1:].reshape(-1, out_dimensions), target[:, 1:].reshape(-1))
                    eos_idx = target_vocab['<eos>']

                    raw_preds = output.argmax(2)[:, 1:].tolist()
                    raw_targets = target[:, 1:].tolist()
                    preds = [strip_after_eos(p, eos_idx) for p in raw_preds]
                    targets = [strip_after_eos(t, eos_idx) for t in raw_targets]
                    acc = calculate_word_accuracy(preds, targets, pad_idx=target_vocab['<pad>'])
                    char_acc = calculate_accuracy(preds, targets, pad_idx=target_vocab['<pad>'])
                    val_loss += loss.item()
                    val_acc  += acc
                    preds = output.argmax(2).tolist()
                    targets = target.tolist()
                    val_cer += calculate_cer(preds, targets, pad_idx=target_vocab['<pad>'])
                    total_char_acc += char_acc

            avg_val_loss = val_loss / len(val_drop_last)
            avg_val_acc  = val_acc / len(val_drop_last)
            avg_val_cer  = val_cer / len(val_drop_last)
            avg_val_char_acc = total_char_acc / len(val_drop_last)

            scheduler.step(avg_val_loss)

            wandb.log({
                'train_loss': avg_train_loss,
                'train_accuracy': avg_train_acc,
                'train_char_accuracy': avg_train_char_acc,
                'val_loss': avg_val_loss,
                'val_accuracy': avg_val_acc,
                'val_cer': avg_val_cer,
                'val_char_accuracy': avg_val_char_acc,
                'epoch': epoch + 1
            })
            print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.3f} Acc: {avg_train_acc:.3f} | "
                  f"Val Loss: {avg_val_loss:.3f} Acc: {avg_val_acc:.3f} CER: {avg_val_cer:.3f}")

            # Early Stopping Check
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                wait = 0
                torch.save(model.state_dict(), 'best_model.pt')
            else:
                wait += 1
                if wait >= patience:
                    print("Early stopping triggered.")
                    break

        # After training, load the best model and evaluate on the test set.
        artifact = wandb.Artifact('best_model', type='model')
        artifact.add_file('best_model.pt')
        wandb.log_artifact(artifact)
        model.load_state_dict(torch.load('best_model.pt'))

sweep_config = {
    'method': 'bayes',
    'metric': {
        'name': 'val_char_accuracy',
        'goal': 'maximize'
    },
    'parameters': {
        'embed_dim': {'values': [64, 128, 256]},
        'hidden_dim_config': {'values': ['same', 'double']},
        'layers': {'values': [1, 2]},
        'dropout': {'values': [0.2, 0.3]},
        'lr': {'values': [0.001, 0.0005]},
        'cell_type': {'values': ['rnn', 'gru','lstm']},
        'teacher_forcing': {'values': [0.5, 0.7]},
        'batch_size': {'value': 64},
        'epochs': {'value': 1}
    }
}

cpu


In [None]:
train_pairs = read_data(data_path + f"{LANG}.translit.sampled.train.tsv", max_len=30)
print(train_pairs)

[('amkita', 'అంకిత'), ('ankita', 'అంకిత'), ('ankitha', 'అంకిత'), ('ankitam', 'అంకితం'), ('ankitham', 'అంకితం'), ('ankitabaavam', 'అంకితభావం'), ('ankithabhavam', 'అంకితభావం'), ('ankatamicchaadu', 'అంకితమిచ్చాడు'), ('ankitamicchadu', 'అంకితమిచ్చాడు'), ('ankitamichhaadu', 'అంకితమిచ్చాడు'), ('ankithamicchaadu', 'అంకితమిచ్చాడు'), ('ankithamichhaadu', 'అంకితమిచ్చాడు'), ('amkithamichaaru', 'అంకితమిచ్చారు'), ('ankithamichaaru', 'అంకితమిచ్చారు'), ('ankithamicharu', 'అంకితమిచ్చారు'), ('ankithamaina', 'అంకితమైన'), ('ankusham', 'అంకుశం'), ('amkelu', 'అంకెలు'), ('amkeylu', 'అంకెలు'), ('ankelu', 'అంకెలు'), ('anga', 'అంగ'), ('angam', 'అంగం'), ('amgadi', 'అంగడి'), ('angadi', 'అంగడి'), ('angadhudu', 'అంగదుడు'), ('angadudu', 'అంగదుడు'), ('angaranga', 'అంగరంగ'), ('amgarakshakulu', 'అంగరక్షకులు'), ('angarakshakulu', 'అంగరక్షకులు'), ('angavaikalyam', 'అంగవైకల్యం'), ('angaaraka', 'అంగారక'), ('angaraka', 'అంగారక'), ('angaala', 'అంగాల'), ('angeekarinchadu', 'అంగీకరించదు'), ('angikarinchadu', 'అంగీకరించదు'), (

In [None]:
import wandb
wandb.login()
try:
    sweep_id = wandb.sweep(sweep_config, project="dakshina-seq2seq")
    wandb.agent(sweep_id, function=run, count=1)
except:
    wandb.finish()

Create sweep with ID: 51uw9lc1
Sweep URL: https://wandb.ai/sai-sakunthala-indian-institute-of-technology-madras/dakshina-seq2seq/sweeps/51uw9lc1


[34m[1mwandb[0m: Agent Starting Run: ql29byye with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	cell_type: lstm
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	embed_dim: 128
[34m[1mwandb[0m: 	epochs: 1
[34m[1mwandb[0m: 	hidden_dim_config: same
[34m[1mwandb[0m: 	layers: 2
[34m[1mwandb[0m: 	lr: 0.0005
[34m[1mwandb[0m: 	teacher_forcing: 0.5


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


In [None]:
wandb.finish()