In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import itertools
import os
from pathlib import Path
import wandb
from tabulate import tabulate
import uuid

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Character-level dataset class
class TransliterationDataset(Dataset):
    def __init__(self, data_file, max_len=500, vocab=None):
        self.df = pd.read_csv(data_file, sep='\t', header=None)
        self.src_texts = self.df[1].astype(str).tolist()
        self.tgt_texts = self.df[0].astype(str).tolist()
        self.max_len = max_len
        if vocab is None:
            self.src_vocab = self.build_vocab(self.src_texts)
            self.tgt_vocab = self.build_vocab(self.tgt_texts)
        else:
            self.src_vocab, self.tgt_vocab = vocab
        self.src_vocab_size = len(self.src_vocab)
        self.tgt_vocab_size = len(self.tgt_vocab)
        self.src_pad_idx = self.src_vocab['<pad>']
        self.tgt_pad_idx = self.tgt_vocab['<pad>']
        self.tgt_sos_idx = self.tgt_vocab['<sos>']
        self.tgt_eos_idx = self.tgt_vocab['<eos>']
        self.tgt_inv_vocab = {i: char for char, i in self.tgt_vocab.items()}
    
    def build_vocab(self, texts):
        counter = Counter()
        for text in texts:
            counter.update(list(text))
        vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2}
        vocab.update({char: i+3 for i, char in enumerate(sorted(counter.keys()))})
        return vocab
    
    def __len__(self):
        return len(self.src_texts)
    
    def __getitem__(self, idx):
        src = self.src_texts[idx]
        tgt = self.tgt_texts[idx]
        src_indices = [self.src_vocab.get(char, self.src_vocab['<pad>']) for char in src]
        tgt_indices = [self.tgt_sos_idx] + [self.tgt_vocab.get(char, self.tgt_vocab['<pad>']) for char in tgt] + [self.tgt_eos_idx]
        return torch.tensor(src_indices), torch.tensor(tgt_indices), src, tgt

# Encoder class
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, cell_type, dropout, bidirectional=False):
        super().__init__()
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.cell_type = cell_type
        self.bidirectional = bidirectional
        self.embedding = nn.Embedding(input_dim, emb_dim)
        rnn_class = {'RNN': nn.RNN, 'LSTM': nn.LSTM, 'GRU': nn.GRU}[cell_type]
        self.rnn = rnn_class(emb_dim, hid_dim, num_layers=n_layers, dropout=dropout if n_layers > 1 else 0, bidirectional=bidirectional)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        if self.cell_type == 'LSTM':
            outputs, (hidden, cell) = self.rnn(embedded)
            if self.bidirectional:
                # Reshape hidden and cell: [n_layers * 2, batch_size, hid_dim] -> [n_layers, batch_size, hid_dim * 2]
                hidden = hidden.view(self.n_layers, 2, -1, self.hid_dim)
                cell = cell.view(self.n_layers, 2, -1, self.hid_dim)
                hidden = torch.cat((hidden[:, 0, :, :], hidden[:, 1, :, :]), dim=-1)
                cell = torch.cat((cell[:, 0, :, :], cell[:, 1, :, :]), dim=-1)
            return hidden, cell
        else:
            outputs, hidden = self.rnn(embedded)
            if self.bidirectional:
                hidden = hidden.view(self.n_layers, 2, -1, self.hid_dim)
                hidden = torch.cat((hidden[:, 0, :, :], hidden[:, 1, :, :]), dim=-1)
            return hidden

# Decoder class
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, cell_type, dropout, bidirectional=False):
        super().__init__()
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.cell_type = cell_type
        self.embedding = nn.Embedding(output_dim, emb_dim)
        rnn_class = {'RNN': nn.RNN, 'LSTM': nn.LSTM, 'GRU': nn.GRU}[cell_type]
        input_dim = emb_dim
        self.rnn = rnn_class(input_dim, hid_dim, num_layers=n_layers, dropout=dropout if n_layers > 1 else 0)
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input, hidden, cell=None):
        input = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input))
        if self.cell_type == 'LSTM':
            output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        else:
            output, hidden = self.rnn(embedded, hidden)
        prediction = self.fc_out(output.squeeze(0))
        return prediction, hidden, cell

# Seq2Seq class
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.bidirectional = encoder.bidirectional
        if self.bidirectional:
            self.fc_hidden = nn.Linear(encoder.hid_dim * 2, decoder.hid_dim)
            if encoder.cell_type == 'LSTM':
                self.fc_cell = nn.Linear(encoder.hid_dim * 2, decoder.hid_dim)
    
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        if self.encoder.cell_type == 'LSTM':
            hidden, cell = self.encoder(src)
            if self.bidirectional:
                hidden = self.fc_hidden(hidden)
                cell = self.fc_cell(cell)
        else:
            hidden = self.encoder(src)
            if self.bidirectional:
                hidden = self.fc_hidden(hidden)
            cell = None
        input = trg[0, :]
        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input, hidden, cell)
            outputs[t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[t] if teacher_force else top1
        return outputs
    
    def predict(self, src, max_len=50, sos_idx=1, eos_idx=2):
        self.eval()
        src = src.to(self.device)
        batch_size = src.shape[1] if len(src.shape) > 1 else 1
        outputs = torch.zeros(max_len, batch_size).long().to(self.device)
        outputs[0] = sos_idx
        if self.encoder.cell_type == 'LSTM':
            hidden, cell = self.encoder(src)
            if self.bidirectional:
                hidden = self.fc_hidden(hidden)
                cell = self.fc_cell(cell)
        else:
            hidden = self.encoder(src)
            if self.bidirectional:
                hidden = self.fc_hidden(hidden)
            cell = None
        input = torch.LongTensor([sos_idx] * batch_size).to(self.device)
        for t in range(1, max_len):
            output, hidden, cell = self.decoder(input, hidden, cell)
            top1 = output.argmax(1)
            outputs[t] = top1
            input = top1
            if all(top1 == eos_idx):
                outputs = outputs[:t+1]
                break
        return outputs

# Accuracy metrics
def calculate_accuracies(model, iterator, device, dataset):
    model.eval()
    char_correct = 0
    char_total = 0
    word_correct = 0
    word_total = 0
    with torch.no_grad():
        for src, trg, src_text, tgt_text in iterator:
            src, trg = src.to(device), trg.to(device)
            output = model.predict(src, max_len=dataset.max_len, 
                                 sos_idx=dataset.tgt_sos_idx, eos_idx=dataset.tgt_eos_idx)
            for i in range(src.shape[1] if len(src.shape) > 1 else 1):
                pred_seq = output[:, i].cpu().numpy()
                trg_seq = trg[:, i].cpu().numpy() if len(src.shape) > 1 else trg.cpu().numpy()
                pred_chars = [dataset.tgt_inv_vocab.get(idx, '<unk>') for idx in pred_seq if idx != dataset.tgt_eos_idx and idx != dataset.tgt_sos_idx]
                trg_chars = [dataset.tgt_inv_vocab.get(idx, '<unk>') for idx in trg_seq if idx != dataset.tgt_eos_idx and idx != dataset.tgt_sos_idx]
                char_correct += sum(p == t for p, t in zip(pred_chars, trg_chars))
                char_total += len(trg_chars)
                pred_word = ''.join(pred_chars)
                trg_word = ''.join(trg_chars)
                word_correct += pred_word == trg_word
                word_total += 1
    char_acc = char_correct / char_total if char_total > 0 else 0
    word_acc = word_correct / word_total if word_total > 0 else 0
    return char_acc, word_acc

# Training function
def train(model, iterator, optimizer, criterion, clip, device, dataset, epoch, max_epochs=20):
    model.train()
    epoch_loss = 0
    teacher_forcing_ratio = max(0.5, 1.0 - (epoch / max_epochs) * 0.5)
    for src, trg, _, _ in iterator:
        src, trg = src.to(device), trg.to(device)
        optimizer.zero_grad()
        output = model(src, trg, teacher_forcing_ratio)
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
    char_acc, word_acc = calculate_accuracies(model, iterator, device, dataset)
    return epoch_loss / len(iterator), char_acc, word_acc

# Evaluation function
def evaluate(model, iterator, criterion, device, dataset):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for src, trg, _, _ in iterator:
            src, trg = src.to(device), trg.to(device)
            output = model(src, trg, 0)
            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    char_acc, word_acc = calculate_accuracies(model, iterator, device, dataset)
    return epoch_loss / len(iterator), char_acc, word_acc

# Custom collate function
def collate_fn(batch):
    src_tensors, tgt_tensors, src_texts, tgt_texts = zip(*batch)
    src_padded = torch.nn.utils.rnn.pad_sequence(src_tensors, batch_first=False, padding_value=0)
    tgt_padded = torch.nn.utils.rnn.pad_sequence(tgt_tensors, batch_first=False, padding_value=0)
    return src_padded, tgt_padded, list(src_texts), list(tgt_texts)

# Training function for WandB sweep
def train_sweep():
    wandb.init()
    hparams = wandb.config
    
    base_path = '/home/user/Downloads/dakshina_dataset_v1.0/hi/lexicons'
    train_path = f'{base_path}/hi.translit.sampled.train.tsv'
    dev_path = f'{base_path}/hi.translit.sampled.dev.tsv'
    test_path = f'{base_path}/hi.translit.sampled.test.tsv'
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Build shared vocabulary
    train_df = pd.read_csv(train_path, sep='\t', header=None)
    dev_df = pd.read_csv(dev_path, sep='\t', header=None)
    test_df = pd.read_csv(test_path, sep='\t', header=None)
    all_src_texts = train_df[1].astype(str).tolist() + dev_df[1].astype(str).tolist() + test_df[1].astype(str).tolist()
    all_tgt_texts = train_df[0].astype(str).tolist() + dev_df[0].astype(str).tolist() + test_df[0].astype(str).tolist()
    
    temp_dataset = TransliterationDataset(train_path)
    src_vocab = temp_dataset.build_vocab(all_src_texts)
    tgt_vocab = temp_dataset.build_vocab(all_tgt_texts)
    vocab = (src_vocab, tgt_vocab)
    
    # Initialize datasets
    train_dataset = TransliterationDataset(train_path, vocab=vocab)
    dev_dataset = TransliterationDataset(dev_path, vocab=vocab)
    test_dataset = TransliterationDataset(test_path, vocab=vocab)
    
    train_loader = DataLoader(train_dataset, batch_size=hparams.batch_size, shuffle=True, 
                             collate_fn=collate_fn)
    dev_loader = DataLoader(dev_dataset, batch_size=hparams.batch_size, 
                           collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=hparams.batch_size, 
                            collate_fn=collate_fn)
    
    encoder = Encoder(
        input_dim=train_dataset.src_vocab_size,
        emb_dim=hparams.emb_dim,
        hid_dim=hparams.hid_dim,
        n_layers=hparams.enc_layers,
        cell_type=hparams.cell_type,
        dropout=hparams.dropout,
        bidirectional=False
    )
    decoder = Decoder(
        output_dim=train_dataset.tgt_vocab_size,
        emb_dim=hparams.emb_dim,
        hid_dim=hparams.hid_dim,
        n_layers=hparams.dec_layers,
        cell_type=hparams.cell_type,
        dropout=hparams.dropout,
        bidirectional=False
    )
    model = Seq2Seq(encoder, decoder, device).to(device)
    
    optimizer_class = {'Adam': optim.Adam, 'RMSprop': optim.RMSprop, 'AdamW': optim.AdamW}[hparams.optimizer]
    optimizer = optimizer_class(model.parameters(), lr=hparams.learning_rate, weight_decay=hparams.weight_decay)
    criterion = nn.CrossEntropyLoss(ignore_index=train_dataset.tgt_pad_idx)
    
    n_epochs = 50
    best_valid_word_acc = 0
    
    try:
        for epoch in range(n_epochs):
            train_loss, train_char_acc, train_word_acc = train(model, train_loader, optimizer, 
                                                              criterion, hparams.grad_clip, device, train_dataset, epoch, n_epochs)
            valid_loss, valid_char_acc, valid_word_acc = evaluate(model, dev_loader, criterion, device, dev_dataset)
            wandb.log({
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'train_char_acc': train_char_acc,
                'train_word_acc': train_word_acc,
                'valid_loss': valid_loss,
                'valid_char_acc': valid_char_acc,
                'valid_word_acc': valid_word_acc
            })
            print(f'Epoch: {epoch+1:02}')
            print(f'\tTrain Loss: {train_loss:.3f} | Char Acc: {train_char_acc:.3f} | Word Acc: {train_word_acc:.3f}')
            print(f'\tVal. Loss: {valid_loss:.3f} | Char Acc: {valid_char_acc:.3f} | Word Acc: {valid_word_acc:.3f}')
            if valid_word_acc > best_valid_word_acc:
                best_valid_word_acc = valid_word_acc
                model_path = f'best_model_{wandb.run.id}.pt'
                torch.save(model.state_dict(), model_path)
                artifact = wandb.Artifact(f'model_{wandb.run.id}', type='model')
                artifact.add_file(model_path)
                wandb.log_artifact(artifact)
        
        model.load_state_dict(torch.load(f'best_model_{wandb.run.id}.pt'))
        test_loss, test_char_acc, test_word_acc = evaluate(model, test_loader, criterion, device, test_dataset)
        wandb.log({
            'test_loss': test_loss,
            'test_char_acc': test_char_acc,
            'test_word_acc': test_word_acc
        })
        
        model.eval()
        samples = []
        with torch.no_grad():
            for src, trg, src_text, tgt_text in itertools.islice(test_loader, 5):
                src = src.to(device)
                output = model.predict(src, max_len=50, sos_idx=test_dataset.tgt_sos_idx, 
                                     eos_idx=test_dataset.tgt_eos_idx)
                for i in range(src.shape[1] if len(src.shape) > 1 else 1):
                    pred_seq = output[:, i].cpu().numpy()
                    pred_word = ''.join([test_dataset.tgt_inv_vocab.get(idx, '<unk>') for idx in pred_seq 
                                        if idx != test_dataset.tgt_eos_idx and idx != test_dataset.tgt_sos_idx])
                    tgt_word = tgt_text[i] if len(src.shape) > 1 else tgt_text
                    samples.append({
                        'Latin Input': src_text[i] if len(src.shape) > 1 else src_text,
                        'Devanagari Target': tgt_word,
                        'Devanagari Predicted': pred_word,
                        'Correct': pred_word == tgt_word
                    })
        
        headers = ['#', 'Latin Input', 'Devanagari Target', 'Devanagari Predicted', 'Correct']
        table_data = [[i+1, s['Latin Input'], s['Devanagari Target'], s['Devanagari Predicted'], 
                       '✅' if s['Correct'] else '❌'] for i, s in enumerate(samples)]
        wandb.log({'sample_predictions': wandb.Table(columns=headers, data=table_data)})
        
    except KeyboardInterrupt:
        print(f"Training interrupted. Saving current model state...")
        model_path = f'last_model_{wandb.run.id}.pt'
        torch.save(model.state_dict(), model_path)
        artifact = wandb.Artifact(f'last_model_{wandb.run.id}', type='model')
        artifact.add_file(model_path)
        wandb.log_artifact(artifact)
        wandb.finish()
        exit(0)
    
    wandb.finish()
# Sweep configuration
sweep_config = {
    'method': 'grid',
    'metric': {'name': 'valid_word_acc', 'goal': 'maximize'},
    'parameters': {
        'emb_dim': {'values': [128]},
        'hid_dim': {'values': [256]},
        'enc_layers': {'values': [1]},
        'dec_layers': {'values': [1]},
        'cell_type': {'values': ['GRU', 'LSTM','RNN']},
        'dropout': {'values': [0.2, 0.3]},
        'learning_rate': {'values': [1e-3, 5e-4]},
        'batch_size': {'values': [16, 32]},
        'optimizer': {'values': ['Adam', 'AdamW']},
        'grad_clip': {'values': [5.0]},
        'weight_decay': {'values': [ 1e-6]},
        'bidirectional': {'values': [True]}
    }
}

# Main execution
def main():
    sweep_id = wandb.sweep(sweep_config, project="transliteration-seq2seq")
    wandb.agent(sweep_id, function=train_sweep, count=100)

if __name__ == '__main__':
    main()