In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from collections import defaultdict
import os

class Seq2Seq(nn.Module):
    def __init__(self, input_vocab_size, target_vocab_size,
                 embedding_dim=64, hidden_dim=128, rnn_type='LSTM',
                 num_layers=1, dropout=0.2):
        super(Seq2Seq, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.rnn_type = rnn_type
        self.num_layers = num_layers

        self.encoder_embedding = nn.Embedding(input_vocab_size, embedding_dim)
        self.decoder_embedding = nn.Embedding(target_vocab_size, embedding_dim)

        RNN = {'LSTM': nn.LSTM, 'GRU': nn.GRU, 'RNN': nn.RNN}[rnn_type]

        self.encoder_rnn = RNN(embedding_dim, hidden_dim, num_layers,
                               batch_first=True, dropout=dropout if num_layers > 1 else 0)
        self.decoder_rnn = RNN(embedding_dim, hidden_dim, num_layers,
                               batch_first=True, dropout=dropout if num_layers > 1 else 0)

        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, target_vocab_size)

    def forward(self, src, tgt):
        embedded_src = self.dropout(self.encoder_embedding(src))
        encoder_outputs, hidden = self.encoder_rnn(embedded_src)

        embedded_tgt = self.dropout(self.decoder_embedding(tgt))
        decoder_outputs, _ = self.decoder_rnn(embedded_tgt, hidden)

        output = self.fc(decoder_outputs)
        return output


class TransliterationDataset(Dataset):
    def __init__(self, file_path, input_char2idx, target_char2idx, max_len=30):
        self.pairs = []
        with open(file_path, encoding="utf-8") as f:
            for line in f:
                parts = line.strip().split('\t')
                x, y = parts[1], parts[0]
                x_idx = [input_char2idx[c] for c in x if c in input_char2idx]
                y_idx = [target_char2idx['<s>']] + [target_char2idx[c] for c in y if c in target_char2idx] + [target_char2idx['</s>']]
                self.pairs.append((x_idx[:max_len], y_idx[:max_len]))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        x, y = self.pairs[idx]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
    
def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_padded = nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=0)
    tgt_padded = nn.utils.rnn.pad_sequence(tgt_batch, batch_first=True, padding_value=0)
    return src_padded, tgt_padded


def build_vocab(data_path):
    input_chars = set()
    target_chars = {'<pad>', '<s>', '</s>'}
    with open(data_path, encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split('\t')
            x, y = parts[1], parts[0]
            input_chars.update(list(x))
            target_chars.update(list(y))

    input_char2idx = {c: i+1 for i, c in enumerate(sorted(input_chars))}
    input_char2idx['<pad>'] = 0

    target_char2idx = {c: i for i, c in enumerate(sorted(target_chars))}
    target_idx2char = {i: c for c, i in target_char2idx.items()}
    return input_char2idx, target_char2idx, target_idx2char


def calculate_char_accuracy(preds, targets):
    preds = preds.argmax(dim=-1)
    match = (preds == targets).float()
    mask = (targets != 0).float()
    return (match * mask).sum() / mask.sum()

def calculate_word_accuracy(preds, targets,target_idx2char, pad_idx=0, eos_idx=None):
    """
    Strict word-level accuracy where entire sequence must match exactly.
    Counts as correct ONLY if all non-padding tokens match exactly.
    """
    preds = preds.argmax(dim=-1)
    mask = targets != pad_idx
    
    if eos_idx is not None:
        eos_mask = (targets == eos_idx).cumsum(dim=1) <= 1
        mask = mask & eos_mask
    
    correct_words = ((preds == targets) | ~mask).all(dim=1)
    # Print pairs of correct words (predicted and target) for strict word-level accuracy
    for i, is_correct in enumerate(correct_words):
        if is_correct:
            pred_word = ''.join([target_idx2char[idx.item()] for idx in preds[i] if idx.item() != pad_idx and (eos_idx is None or idx.item() != eos_idx)])
            tgt_word = ''.join([target_idx2char[idx.item()] for idx in targets[i] if idx.item() != pad_idx and (eos_idx is None or idx.item() != eos_idx)])
            print(f"Correct: pred='{pred_word}' tgt='{tgt_word}'")

    return correct_words.float().mean().item()



def predict(model, word, input_char2idx, target_idx2char, target_char2idx, device, max_len=30):
    model.eval()
    with torch.no_grad():
        # Convert input word to indices and pad
        src = [input_char2idx.get(c, 0) for c in word]
        src = src[:max_len] + [input_char2idx['<pad>']] * (max_len - len(src))
        src_tensor = torch.tensor([src], dtype=torch.long).to(device)

        # Initial decoder input with <s>
        tgt_input = torch.tensor([[target_char2idx['<s>']]], dtype=torch.long).to(device)

        output_seq = []

        # Encode input
        embedded_src = model.dropout(model.encoder_embedding(src_tensor))
        encoder_outputs, hidden = model.encoder_rnn(embedded_src)

        # Decode step-by-step
        for _ in range(max_len):
            embedded_tgt = model.dropout(model.decoder_embedding(tgt_input))
            decoder_outputs, hidden = model.decoder_rnn(embedded_tgt, hidden)
            output_logits = model.fc(decoder_outputs[:, -1, :])
            pred_token = output_logits.argmax(dim=-1).item()

            if target_idx2char[pred_token] == '</s>':
                break
            output_seq.append(target_idx2char[pred_token])

            # Prepare input for next step
            tgt_input = torch.tensor([[pred_token]], dtype=torch.long).to(device)

        return ''.join(output_seq)



In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path to your dataset
train_path = "dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.train.tsv"

# Build vocabulary
input_char2idx, target_char2idx, target_idx2char = build_vocab(train_path)

# Create datasets and dataloaders
train_dataset = TransliterationDataset(train_path, input_char2idx, target_char2idx)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)

# Initialize model
model = Seq2Seq(
    input_vocab_size=len(input_char2idx),
    target_vocab_size=len(target_char2idx),
    embedding_dim=128,
    hidden_dim=256,
    rnn_type='LSTM',
    num_layers=2,
    dropout=0.3
).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=0)  # ignore padding
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5, verbose=True)

# Training loop
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    total_char_acc = 0
    total_word_acc = 0
    batch_count = 0
    
    for src, tgt in train_loader:
        src, tgt = src.to(device), tgt.to(device)
        
        # Forward pass with teacher forcing
        outputs = model(src, tgt[:, :-1])
        loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt[:, 1:].reshape(-1))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # Calculate metrics
        char_acc = calculate_char_accuracy(outputs, tgt[:, 1:])
        word_acc = calculate_word_accuracy(outputs, tgt[:, 1:], 
                                          eos_idx=target_char2idx.get('</s>', None),target_idx2char=target_idx2char)
        
        total_loss += loss.item()
        total_char_acc += char_acc
        total_word_acc += word_acc
        batch_count += 1
        
    # Log metrics
    avg_loss = total_loss / batch_count
    avg_char_acc = total_char_acc / batch_count
    avg_word_acc = total_word_acc / batch_count
    
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {avg_loss:.4f} | Char Acc: {avg_char_acc:.4f} | Word Acc: {avg_word_acc:.4f}")

# Testing
test_words = ["ankganit", "ankur", "angbhang", "anguliyon", "anguthe"]
print("\nTesting the model:")
for word in test_words:
    pred = predict(model, word, input_char2idx, target_idx2char, target_char2idx, device)
    print(f"{word} -> {pred}")

Epoch 1/5
Train Loss: 2.0101 | Char Acc: 0.4389 | Word Acc: 0.0438
Epoch 2/5
Train Loss: 0.7791 | Char Acc: 0.7538 | Word Acc: 0.2270
Epoch 3/5
Train Loss: 0.5643 | Char Acc: 0.8187 | Word Acc: 0.3322
Epoch 4/5
Train Loss: 0.4618 | Char Acc: 0.8506 | Word Acc: 0.3983
Epoch 5/5
Train Loss: 0.3975 | Char Acc: 0.8709 | Word Acc: 0.4491

Testing the model:
ankganit -> अंकनगिताएंगायोंत्रीयामीवार्णीत
ankur -> अंकुर्करीनोंडीयरूईएंटीआईवीजीआर
angbhang -> अंगभंग्वांतोंगीयारीपोंड़ाईटीआई
anguliyon -> अंगुलियोंकोंशीयरोंचीयरुपीयरीवा
anguthe -> अंगूठेंटेयाईवीज़ीयोंपीयारीवींत


In [None]:
def train(config=None):
    with wandb.init(config=config):
        config = wandb.config
        wandb.run.name = (
            f"embedding_dim_{config.embedding_dim}_"
            f"hidden_dim_{config.hidden_dim}_"
            f"rnn_type_{config.rnn_type}_"
            f"num_layers_{config.num_layers}_"
            f"dropout_{config.dropout}"

        ) # Set the run name based on hyperparameters

        train_path = "dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.train.tsv"
        dev_path = "dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.dev.tsv"

        input_char2idx, target_char2idx, target_idx2char = build_vocab(train_path)
        train_dataset = TransliterationDataset(train_path, input_char2idx, target_char2idx)
        dev_dataset = TransliterationDataset(dev_path, input_char2idx, target_char2idx)

        train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
        dev_loader = DataLoader(dev_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)

        model = Seq2Seq(
            input_vocab_size=len(input_char2idx),
            target_vocab_size=len(target_char2idx),
            embedding_dim=config.embedding_dim,
            hidden_dim=config.hidden_dim,
            rnn_type=config.rnn_type,
            num_layers=config.num_layers,
            dropout=config.dropout
        ).to("cuda" if torch.cuda.is_available() else "cpu")

        device = next(model.parameters()).device
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss(ignore_index=0)

        for epoch in range(5):
            model.train()
            total_loss = 0
            for src, tgt in train_loader:
                src, tgt = src.to(device).long(), tgt.to(device).long()
                tgt_input = tgt[:, :-1]
                tgt_output = tgt[:, 1:]

                logits = model(src, tgt_input)
                loss = criterion(logits.reshape(-1, logits.shape[-1]), tgt_output.reshape(-1))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            # Evaluate
            model.eval()
            with torch.no_grad():
                char_accs = []
                word_accs = []
                for src, tgt in dev_loader:
                    src, tgt = src.to(device).long(), tgt.to(device).long()
                    logits = model(src, tgt[:, :-1])
                    char_acc = calculate_char_accuracy(logits, tgt[:, 1:])
                    word_acc = calculate_word_accuracy(logits, tgt[:, 1:])
                    char_accs.append(char_acc.item())
                    word_accs.append(word_acc)

            val_char_acc = sum(char_accs) / len(char_accs)
            val_word_acc = sum(word_accs) / len(word_accs)
            wandb.log({
                'epoch': epoch+1, 
                'train_loss': total_loss / len(train_loader), 
                'val_char_accuracy': val_char_acc, 
                'val_word_accuracy': val_word_acc
            })

sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'val_word_accuracy', 'goal': 'maximize'},
    'parameters': {
        'embedding_dim': {'values': [32, 64, 128]},
        'hidden_dim': {'values': [64, 128, 256]},
        'rnn_type': {'values': ['LSTM', 'GRU','RNN']},
        'num_layers': {'values': [1, 2]},
        'dropout': {'values': [0.2, 0.3]}
    }
}

# Initialize wandb and run the sweep
sweep_id = wandb.sweep(sweep_config, project="dakshina-transliteration-DA6401")
wandb.agent(sweep_id, function=train, count=50)