In [7]:
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import random
from torch.nn.utils.rnn import pad_sequence

# Log in to W&B
wandb.login(key='b5d1fbca9d5170f54415e9c5a70ef09cee7a0aec')

# ---------- Model Components ----------
class InputEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, layers, rnn_type='LSTM', dropout_rate=0.2, is_bidirectional=False):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
        rnn_class = {'RNN': nn.RNN, 'LSTM': nn.LSTM, 'GRU': nn.GRU}[rnn_type]
        self.rnn = rnn_class(embedding_size, hidden_size, layers, dropout=dropout_rate, batch_first=True, bidirectional=is_bidirectional)
        self.rnn_type = rnn_type
        self.is_bidirectional = is_bidirectional

    def forward(self, x):
        embedded = self.embedding(x)
        output, hidden = self.rnn(embedded)
        return hidden

class OutputDecoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, layers, rnn_type='LSTM', dropout_rate=0.2, is_bidirectional=False):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
        rnn_class = {'RNN': nn.RNN, 'LSTM': nn.LSTM, 'GRU': nn.GRU}[rnn_type]
        self.rnn = rnn_class(embedding_size, hidden_size, layers, dropout=dropout_rate, batch_first=True, bidirectional=is_bidirectional)
        self.output_layer = nn.Linear(hidden_size * (2 if is_bidirectional else 1), vocab_size)
        self.rnn_type = rnn_type
        self.is_bidirectional = is_bidirectional

    def forward(self, token, hidden):
        token = token.unsqueeze(1)
        embedded = self.embedding(token)
        output, hidden = self.rnn(embedded, hidden)
        output = self.output_layer(output.squeeze(1))
        return output, hidden

class TransliterationModel(nn.Module):
    def __init__(self, input_vocab_size, output_vocab_size, embedding_size, hidden_size, enc_layers, dec_layers,
                 rnn_type='LSTM', dropout_rate=0.2, is_bidirectional=False):
        super().__init__()
        self.encoder = InputEncoder(input_vocab_size, embedding_size, hidden_size, enc_layers, rnn_type, dropout_rate, is_bidirectional)
        self.decoder = OutputDecoder(output_vocab_size, embedding_size, hidden_size, dec_layers, rnn_type, dropout_rate, is_bidirectional)
        self.rnn_type = rnn_type

    def forward(self, source, target, teacher_forcing_prob=0.5):
        batch_size, target_len = target.size()
        output_vocab_size = self.decoder.output_layer.out_features
        predictions = torch.zeros(batch_size, target_len, output_vocab_size, device=source.device)
        
        encoder_hidden = self.encoder(source)

        def merge_bidirectional(state):
            num_layers = self.decoder.rnn.num_layers
            return torch.cat([state[i*2:(i+1)*2] for i in range(num_layers)], dim=2)

        if self.rnn_type == 'LSTM':
            h, c = encoder_hidden
            h = merge_bidirectional(h) if self.encoder.is_bidirectional else h[:self.decoder.rnn.num_layers]
            c = merge_bidirectional(c) if self.encoder.is_bidirectional else c[:self.decoder.rnn.num_layers]
            decoder_hidden = (h, c)
        else:
            h = encoder_hidden
            h = merge_bidirectional(h) if self.encoder.is_bidirectional else h[:self.decoder.rnn.num_layers]
            decoder_hidden = h

        decoder_input = target[:, 0]
        for t in range(1, target_len):
            output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            predictions[:, t] = output
            top1 = output.argmax(1)
            decoder_input = target[:, t] if random.random() < teacher_forcing_prob else top1

        return predictions

def build_vocab_and_prepare_batch(seqs, device):
    special_tokens = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
    
    # Build vocab
    unique_chars = sorted(set(ch for seq in seqs for ch in seq))
    src_vocab = {ch: idx+len(special_tokens) for idx, ch in enumerate(unique_chars)}
    src_vocab.update(special_tokens)
    tgt_vocab = {ch: idx+len(special_tokens) for idx, ch in enumerate(unique_chars)}
    tgt_vocab.update(special_tokens)
    
    idx2src = {idx: ch for ch, idx in src_vocab.items()}
    idx2tgt = {idx: ch for ch, idx in tgt_vocab.items()}
    
    # Encode function
    def encode_text(seq, vocab):
        return [vocab.get(ch, vocab['<unk>']) for ch in seq]

    # Prepare batches
    def create_batch(pairs):
        src = [torch.tensor(encode_text(x, src_vocab) + [src_vocab['<eos>']]) for x, _ in pairs]
        tgt = [torch.tensor([tgt_vocab['<sos>']] + encode_text(y, tgt_vocab) + [tgt_vocab['<eos>']]) for _, y in pairs]
        src = pad_sequence(src, batch_first=True, padding_value=src_vocab['<pad>'])
        tgt = pad_sequence(tgt, batch_first=True, padding_value=tgt_vocab['<pad>'])
        return src.to(device), tgt.to(device)
    return (src_vocab, idx2src), (tgt_vocab, idx2tgt), create_batch

def read_pairs(file_path):
    with open(file_path, encoding='utf-8') as f:
        return [(line.split('\t')[1], line.split('\t')[0]) for line in f.read().strip().split('\n') if '\t' in line]

# ---------- Training and Evaluation ----------
def run_training():
    wandb.init(config={
        "embedding_size": 128,
        "hidden_size": 256,
        "enc_layers": 2,
        "dec_layers": 2,
        "rnn_type": "LSTM",
        "dropout_rate": 0.2,
        "epochs": 10,
        "batch_size": 64,
        "is_bidirectional": False,
        "learning_rate": 0.001,
        "optimizer": "adam",
        "teacher_forcing_prob": 0.5
    })
    cfg = wandb.config
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_path = "/kaggle/input/dakshina-data/dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.train.tsv"
    dev_path = "/kaggle/input/dakshina-data/dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.dev.tsv"
    train_set = read_pairs(train_path)
    dev_set = read_pairs(dev_path)

    #(src_vocab, _), (tgt_vocab, tgt_itos) = build_vocab_and_prepare_batch(train_set, device)
    (src_vocab, idx2src), (tgt_vocab, idx2tgt), create_batch = build_vocab_and_prepare_batch(train_set, device)

    model = TransliterationModel(len(src_vocab), len(tgt_vocab), cfg.embedding_size, cfg.hidden_size,
                                 cfg.enc_layers, cfg.dec_layers, cfg.rnn_type, cfg.dropout_rate, cfg.is_bidirectional).to(device)

    optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate)
    criterion = nn.CrossEntropyLoss(ignore_index=src_vocab['<pad>'])

    for epoch in range(cfg.epochs):
        model.train()
        total_loss, total_acc = 0, 0
        random.shuffle(train_set)

        for i in range(0, len(train_set), cfg.batch_size):
            batch = train_set[i:i+cfg.batch_size]
            src, tgt = create_batch(batch)

            optimizer.zero_grad()
            outputs = model(src, tgt, cfg.teacher_forcing_prob)
            loss = criterion(outputs[:, 1:].reshape(-1, outputs.size(-1)), tgt[:, 1:].reshape(-1))

            # Calculate accuracy
            preds = outputs.argmax(-1)
            mask = tgt != src_vocab['<pad>']
            correct = (preds == tgt) & mask
            acc = (correct.sum().item() / mask.sum().item()) * 100

            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            total_acc += acc

        avg_train_loss = total_loss / len(train_set)
        avg_train_acc = total_acc / (len(train_set) // cfg.batch_size)

        # Evaluation in the same loop
        model.eval()
        dev_loss, dev_acc = 0, 0
        with torch.no_grad():
            for i in range(0, len(dev_set), cfg.batch_size):
                batch = dev_set[i:i+cfg.batch_size]
                src, tgt = create_batch(batch)
                outputs = model(src, tgt, cfg.teacher_forcing_prob)
                loss = criterion(outputs[:, 1:].reshape(-1, outputs.size(-1)), tgt[:, 1:].reshape(-1))

                # Calculate accuracy
                preds = outputs.argmax(-1)
                mask = tgt != src_vocab['<pad>']
                correct = (preds == tgt) & mask
                acc = (correct.sum().item() / mask.sum().item()) * 100

                dev_loss += loss.item()
                dev_acc += acc

    avg_dev_loss = dev_loss / len(dev_set)
    avg_dev_acc = dev_acc / (len(dev_set) // cfg.batch_size)

    wandb.log({
        "Train Loss": avg_train_loss,
        "Train Accuracy": avg_train_acc,
        "Validation Loss": avg_dev_loss,
        "Validation Accuracy": avg_dev_acc,
        "Epoch": epoch + 1
    })
    print(f"Epoch {epoch + 1}/{cfg.epochs} | Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.2f}% | Val Loss: {avg_dev_loss:.4f}, Val Acc: {avg_dev_acc:.2f}%")

wandb.finish()



