In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
import wandb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class InputEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, rnn_type='LSTM',
                 dropout=0.2, bidirectional=False):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.is_bidirectional = bidirectional
        self.rnn_type = rnn_type
        self.num_directions = 2 if bidirectional else 1
        self.hidden_dim = hidden_dim

        rnn_cls = {'RNN': nn.RNN, 'LSTM': nn.LSTM, 'GRU': nn.GRU}[rnn_type]
        self.rnn = rnn_cls(
            input_size=embed_dim,
            hidden_size=hidden_dim // self.num_directions,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
            bidirectional=bidirectional
        )

    def forward(self, x):
        embedded = self.embedding(x)
        outputs, hidden = self.rnn(embedded)
        return hidden


class OutputDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, rnn_type='LSTM',
                 dropout=0.2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.rnn_type = rnn_type

        rnn_cls = {'RNN': nn.RNN, 'LSTM': nn.LSTM, 'GRU': nn.GRU}[rnn_type]
        self.rnn = rnn_cls(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True
        )
        self.fc_out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_token, hidden_state):
        embedded = self.embedding(input_token.unsqueeze(1))  # (B, 1, E)
        rnn_output, hidden = self.rnn(embedded, hidden_state)
        logits = self.fc_out(rnn_output.squeeze(1))  # (B, V)
        return logits, hidden


class TransliterationModel(nn.Module):
    def __init__(self, input_vocab_size, output_vocab_size, embed_dim, hidden_dim,
                 enc_layers, dec_layers, rnn_type='LSTM', dropout=0.2, bidirectional=False):
        super().__init__()
        self.encoder = InputEncoder(input_vocab_size, embed_dim, hidden_dim,
                                    enc_layers, rnn_type, dropout, bidirectional)
        self.decoder = OutputDecoder(output_vocab_size, embed_dim, hidden_dim,
                                     dec_layers, rnn_type, dropout)
        self.rnn_type = rnn_type
        self.bidirectional = bidirectional
        self.hidden_dim = hidden_dim
        self.enc_layers = enc_layers
        self.dec_layers = dec_layers

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        batch_size, tgt_len = tgt.shape
        vocab_size = self.decoder.fc_out.out_features
        outputs = torch.zeros(batch_size, tgt_len, vocab_size, device=src.device)

        enc_hidden = self.encoder(src)

        def merge_bidir_states(state):
            return torch.cat([state[::2], state[1::2]], dim=2)

        def pad_layers(state, target_layers):
            if state.size(0) == target_layers:
                return state
            pad = torch.zeros(target_layers - state.size(0), *state.shape[1:], device=state.device)
            return torch.cat([state, pad], dim=0)

        if self.rnn_type == 'LSTM':
            h, c = enc_hidden
            if self.bidirectional:
                h, c = merge_bidir_states(h), merge_bidir_states(c)
            h, c = pad_layers(h, self.dec_layers), pad_layers(c, self.dec_layers)
            dec_hidden = (h, c)
        else:
            h = enc_hidden
            if self.bidirectional:
                h = merge_bidir_states(h)
            h = pad_layers(h, self.dec_layers)
            dec_hidden = h

        dec_input = tgt[:, 0]  # Start token
        for t in range(1, tgt_len):
            output, dec_hidden = self.decoder(dec_input, dec_hidden)
            outputs[:, t] = output
            top1 = output.argmax(1)
            teacher_force = random.random() < teacher_forcing_ratio
            dec_input = tgt[:, t] if teacher_force else top1

        return outputs

def read_pairs(file_path):
    with open(file_path, encoding='utf-8') as f:
        return [(line.split('\t')[1], line.split('\t')[0]) for line in f.read().strip().split('\n') if '\t' in line]

def build_vocab_and_prepare_batch(seqs, device):
    special_tokens = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
    
    # Build character sets
    unique_chars_latin = sorted(set(ch for seq in seqs for ch in seq[0]))
    unique_chars_dev = sorted(set(ch for seq in seqs for ch in seq[1]))

    # Build vocabularies
    src_vocab = {ch: idx + len(special_tokens) for idx, ch in enumerate(unique_chars_latin)}
    tgt_vocab = {ch: idx + len(special_tokens) for idx, ch in enumerate(unique_chars_dev)}
    src_vocab.update(special_tokens)
    tgt_vocab.update(special_tokens)

    idx2src = {idx: ch for ch, idx in src_vocab.items()}
    idx2tgt = {idx: ch for ch, idx in tgt_vocab.items()}

    def encode_text(seq, vocab):
        return [vocab.get(ch, vocab['<unk>']) for ch in seq]

    def create_batch(pairs):
        src = [torch.tensor(encode_text(x, src_vocab) + [src_vocab['<eos>']]) for x, _ in pairs]
        tgt = [torch.tensor([tgt_vocab['<sos>']] + encode_text(y, tgt_vocab) + [tgt_vocab['<eos>']]) for _, y in pairs]
        src = pad_sequence(src, batch_first=True, padding_value=src_vocab['<pad>'])
        tgt = pad_sequence(tgt, batch_first=True, padding_value=tgt_vocab['<pad>'])
        return src.to(device), tgt.to(device)

    return src_vocab, idx2src, tgt_vocab, idx2tgt, create_batch, unique_chars_latin, unique_chars_dev

def compute_word_level_accuracy(preds, targets, vocab):
    sos, eos, pad = vocab['<sos>'], vocab['<eos>'], vocab['<pad>']
    preds = preds.tolist()
    targets = targets.tolist()
    correct = 0
    for p, t in zip(preds, targets):
        p = [x for x in p if x != pad and x != eos]
        t = [x for x in t if x != pad and x != eos]
        if p == t:
            correct += 1
    return correct / len(preds) * 100

def run_training():
    # Initialize wandb config
    wandb.init()
    cfg = wandb.config
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load and prepare data
    train_path = "dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.train.tsv"
    dev_path = "dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.dev.tsv"
    train_set = read_pairs(train_path)
    dev_set = read_pairs(dev_path)

    src_vocab, idx2src, tgt_vocab, idx2tgt, create_batch, _, _ = build_vocab_and_prepare_batch(train_set, device)

    # Initialize model, optimizer, criterion
    model = TransliterationModel(
        len(src_vocab), len(tgt_vocab), cfg.embedding_size, cfg.hidden_size,
        cfg.enc_layers, cfg.dec_layers, cfg.rnn_type, cfg.dropout_rate,
        cfg.is_bidirectional
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate)
    criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab['<pad>'])

    # Training loop
    for epoch in range(cfg.epochs):
        model.train()
        total_loss, total_acc = 0, 0
        random.shuffle(train_set)

        for i in range(0, len(train_set), cfg.batch_size):
            batch = train_set[i:i+cfg.batch_size]
            src, tgt = create_batch(batch)

            optimizer.zero_grad()
            outputs = model(src, tgt, cfg.teacher_forcing_prob)

            loss = criterion(outputs[:, 1:].reshape(-1, outputs.size(-1)), tgt[:, 1:].reshape(-1))
            loss.backward()
            optimizer.step()

            preds = outputs.argmax(-1)
            acc = compute_word_level_accuracy(preds[:, 1:], tgt[:, 1:], tgt_vocab)

            total_loss += loss.item()
            total_acc += acc

        avg_train_loss = total_loss / (len(train_set) // cfg.batch_size)
        avg_train_acc = total_acc / (len(train_set) // cfg.batch_size)

        # Validation
        model.eval()
        dev_loss, dev_acc = 0, 0
        with torch.no_grad():
            for i in range(0, len(dev_set), cfg.batch_size):
                batch = dev_set[i:i+cfg.batch_size]
                src, tgt = create_batch(batch)
                outputs = model(src, tgt, 0)
                loss = criterion(outputs[:, 1:].reshape(-1, outputs.size(-1)), tgt[:, 1:].reshape(-1))

                preds = outputs.argmax(-1)
                acc = compute_word_level_accuracy(preds[:, 1:], tgt[:, 1:], tgt_vocab)

                dev_loss += loss.item()
                dev_acc += acc

        avg_dev_loss = dev_loss / (len(dev_set) // cfg.batch_size)
        avg_dev_acc = dev_acc / (len(dev_set) // cfg.batch_size)

        # Logging
        wandb.log({
            "Epoch": epoch + 1,
            "Train Loss": avg_train_loss,
            "Train Accuracy": avg_train_acc,
            "Validation Loss": avg_dev_loss,
            "Validation Accuracy": avg_dev_acc,
        })

        print(f"Epoch {epoch+1}/{cfg.epochs} | Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.2f}% | Val Loss: {avg_dev_loss:.4f}, Val Acc: {avg_dev_acc:.2f}%")

    wandb.finish()
    return model


In [None]:
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'Validation Accuracy', 'goal': 'maximize'},
    'parameters': {
        'embedding_size': {'values': [32, 64, 128, 256]},
        'hidden_size': {'values': [64, 128, 256]},
        'enc_layers': {'values': [1, 2, 3]},
        'dec_layers': {'values': [1, 2, 3]},
        'rnn_type': {'values': ['GRU', 'LSTM','RNN']},
        'dropout_rate': {'values': [0.2, 0.3]},
        'batch_size': {'values': [32, 64]},
        'epochs': {
            'values': [5, 10]},
        'is_bidirectional': {'values': [False, True]},
        'learning_rate': {'values': [0.001, 0.002]},
        'optimizer': {'values': ['adam', 'nadam']},
        'teacher_forcing_prob': {'values': [0.2, 0.5, 0.7]}
    }
}

sweep_id = wandb.sweep(sweep_config, project="dakshina_transliteration")
wandb.agent(sweep_id, function=run_training, count=50)


wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Create sweep with ID: q84vgxva
Sweep URL: https://wandb.ai/harshtrivs-indian-institute-of-technology-madras/dakshina_transliteration/sweeps/q84vgxva


wandb: Agent Starting Run: hzvyxt6c with config:
wandb: 	batch_size: 32
wandb: 	dec_layers: 2
wandb: 	dropout_rate: 0.2
wandb: 	embedding_size: 128
wandb: 	enc_layers: 1
wandb: 	epochs: 10
wandb: 	hidden_size: 64
wandb: 	is_bidirectional: False
wandb: 	learning_rate: 0.001
wandb: 	optimizer: nadam
wandb: 	rnn_type: GRU
wandb: 	teacher_forcing_prob: 0.7
wandb: Currently logged in as: harshtrivs (harshtrivs-indian-institute-of-technology-madras) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin




Epoch 1/10 | Train Loss: 2.6848, Train Acc: 0.06% | Val Loss: 2.6699, Val Acc: 0.30%
Epoch 2/10 | Train Loss: 1.8966, Train Acc: 1.65% | Val Loss: 2.2014, Val Acc: 4.24%
Epoch 3/10 | Train Loss: 1.5230, Train Acc: 5.39% | Val Loss: 1.9162, Val Acc: 8.95%
