In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import itertools
import os
from pathlib import Path
import wandb
from tabulate import tabulate
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Character-level dataset class
class TransliterationDataset(Dataset):
    def __init__(self, data_file, max_len=500):
        self.df = pd.read_csv(data_file, sep='\t', header=None)
        self.src_texts = self.df[1].astype(str).tolist()
        self.tgt_texts = self.df[0].astype(str).tolist()
        self.max_len = max_len
        self.src_vocab = self.build_vocab(self.src_texts)
        self.tgt_vocab = self.build_vocab(self.tgt_texts)
        self.src_vocab_size = len(self.src_vocab)
        self.tgt_vocab_size = len(self.tgt_vocab)
        self.src_pad_idx = self.src_vocab['<pad>']
        self.tgt_pad_idx = self.tgt_vocab['<pad>']
        self.tgt_sos_idx = self.tgt_vocab['<sos>']
        self.tgt_eos_idx = self.tgt_vocab['<eos>']
        self.tgt_inv_vocab = {i: char for char, i in self.tgt_vocab.items()}
        self.src_inv_vocab = {i: char for char, i in self.src_vocab.items()}
    
    def build_vocab(self, texts):
        counter = Counter()
        for text in texts:
            counter.update(list(text))
        vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2}
        vocab.update({char: i+3 for i, char in enumerate(sorted(counter.keys()))})
        return vocab
    
    def __len__(self):
        return len(self.src_texts)
    
    def __getitem__(self, idx):
        src = self.src_texts[idx]
        tgt = self.tgt_texts[idx]
        src_indices = [self.src_vocab[char] for char in src if char in self.src_vocab]
        tgt_indices = [self.tgt_sos_idx] + [self.tgt_vocab[char] for char in tgt if char in self.tgt_vocab] + [self.tgt_eos_idx]
        return torch.tensor(src_indices), torch.tensor(tgt_indices), src, tgt

# # Attention layer (Bahdanau)
# class Attention(nn.Module):
#     def __init__(self, hid_dim):
#         super().__init__()
#         self.attn = nn.Linear(hid_dim * 2, hid_dim)
#         self.v = nn.Parameter(torch.rand(hid_dim))
#         nn.init.normal_(self.v, 0, 0.1)
#         nn.init.xavier_uniform_(self.attn.weight)
    
#     def forward(self, hidden, encoder_outputs):
#         src_len = encoder_outputs.shape[0]
#         hidden = hidden.unsqueeze(0).repeat(src_len, 1, 1)
#         energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
#         energy = energy.permute(1, 0, 2)
#         v = self.v.repeat(encoder_outputs.size(1), 1).unsqueeze(1)
#         attention = torch.bmm(energy, v).squeeze(2)
#         return torch.softmax(attention, dim=1)
class Attention(nn.Module):
    def __init__(self, hid_dim):
        super().__init__()
        self.attn = nn.Linear(hid_dim * 2, hid_dim)
        self.v = nn.Parameter(torch.rand(hid_dim))
        nn.init.normal_(self.v, 0, 0.1)
        nn.init.xavier_uniform_(self.attn.weight)
    
    def forward(self, hidden, encoder_outputs):
        src_len = encoder_outputs.shape[0]
        batch_size = encoder_outputs.shape[1]
        hidden = hidden.unsqueeze(0).repeat(src_len, 1, 1)  # [src_len, batch_size, hid_dim]
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))  # [src_len, batch_size, hid_dim]
        energy = energy.permute(1, 0, 2)  # [batch_size, src_len, hid_dim]
        v = self.v.repeat(batch_size, 1).unsqueeze(2)  # [batch_size, hid_dim, 1]
        attention = torch.bmm(energy, v).squeeze(2)  # [batch_size, src_len]
        return torch.softmax(attention, dim=1)
# Encoder
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, cell_type, dropout, bidirectional):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.cell_type = cell_type
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        rnn_class = {'RNN': nn.RNN, 'LSTM': nn.LSTM, 'GRU': nn.GRU}[cell_type]
        self.rnn = rnn_class(emb_dim, hid_dim, num_layers=n_layers, dropout=dropout if n_layers > 1 else 0, bidirectional=bidirectional)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        if self.cell_type == 'LSTM':
            outputs, (hidden, cell) = self.rnn(embedded)
            return outputs, hidden, cell
        else:
            outputs, hidden = self.rnn(embedded)
            return outputs, hidden

# Decoder with Attention
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, cell_type, dropout, attention):
        super().__init__()
        self.output_dim = output_dim
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.cell_type = cell_type
        self.hid_dim = hid_dim
        self.attention = attention
        rnn_class = {'RNN': nn.RNN, 'LSTM': nn.LSTM, 'GRU': nn.GRU}[cell_type]
        self.rnn = rnn_class(hid_dim + emb_dim, hid_dim, num_layers=n_layers, dropout=dropout if n_layers > 1 else 0)
        self.fc_out = nn.Linear(hid_dim * 2 + emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input, hidden, cell, encoder_outputs):
        input = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input))
        a = self.attention(hidden[-1], encoder_outputs)
        a = a.unsqueeze(1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        context = torch.bmm(a, encoder_outputs)
        context = context.permute(1, 0, 2)
        rnn_input = torch.cat((embedded, context), dim=2)
        if self.cell_type == 'LSTM':
            output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
        else:
            output, hidden = self.rnn(rnn_input, hidden)
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        context = context.squeeze(0)
        prediction = self.fc_out(torch.cat((output, context, embedded), dim=1))
        return prediction, hidden, cell, a.squeeze(1)

# Seq2Seq with Attention
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
    
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        attention_weights = torch.zeros(trg_len, batch_size, src.shape[0]).to(self.device)
        if self.encoder.cell_type == 'LSTM':
            encoder_outputs, hidden, cell = self.encoder(src)
        else:
            encoder_outputs, hidden = self.encoder(src)
            cell = None
        input = trg[0, :]
        for t in range(1, trg_len):
            output, hidden, cell, attn = self.decoder(input, hidden, cell, encoder_outputs)
            outputs[t] = output
            attention_weights[t] = attn
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[t] if teacher_force else top1
        return outputs, attention_weights
    
    def predict(self, src, max_len=50, sos_idx=1, eos_idx=2):
        self.eval()
        src = src.to(self.device)
        batch_size = src.shape[1] if len(src.shape) > 1 else 1
        outputs = torch.zeros(max_len, batch_size).long().to(self.device)
        attention_weights = torch.zeros(max_len, batch_size, src.shape[0]).to(self.device)
        outputs[0] = sos_idx
        if self.encoder.cell_type == 'LSTM':
            encoder_outputs, hidden, cell = self.encoder(src)
        else:
            encoder_outputs, hidden = self.encoder(src)
            cell = None
        input = torch.LongTensor([sos_idx] * batch_size).to(self.device)
        for t in range(1, max_len):
            output, hidden, cell, attn = self.decoder(input, hidden, cell, encoder_outputs)
            attention_weights[t] = attn
            top1 = output.argmax(1)
            outputs[t] = top1
            input = top1
            if all(top1 == eos_idx):
                outputs = outputs[:t+1]
                attention_weights = attention_weights[:t+1]
                break
        return outputs, attention_weights

# Accuracy metrics
def calculate_accuracies(model, iterator, device, dataset, max_len=50):
    model.eval()
    char_correct = 0
    char_total = 0
    word_correct = 0
    word_total = 0
    with torch.no_grad():
        for src, trg, src_text, tgt_text in iterator:
            src, trg = src.to(device), trg.to(device)
            output, _ = model.predict(src, max_len, dataset.tgt_sos_idx, dataset.tgt_eos_idx)
            for i in range(src.shape[1]):
                pred_seq = output[:, i].cpu().numpy()
                trg_seq = trg[:, i].cpu().numpy()
                pred_chars = [dataset.tgt_inv_vocab.get(idx, '<unk>') for idx in pred_seq 
                              if idx not in [dataset.tgt_eos_idx, dataset.tgt_pad_idx, dataset.tgt_sos_idx]]
                trg_chars = [dataset.tgt_inv_vocab.get(idx, '<unk>') for idx in trg_seq 
                             if idx not in [dataset.tgt_eos_idx, dataset.tgt_pad_idx, dataset.tgt_sos_idx]]
                char_correct += sum(p == t for p, t in zip(pred_chars, trg_chars))
                char_total += len(trg_chars)
                pred_word = ''.join(pred_chars)
                trg_word = ''.join(trg_chars)
                word_correct += pred_word == trg_word
                word_total += 1
    char_acc = char_correct / char_total if char_total > 0 else 0
    word_acc = word_correct / word_total if word_total > 0 else 0
    return char_acc, word_acc

# Training function
def train(model, iterator, optimizer, criterion, clip, device, dataset):
    model.train()
    epoch_loss = 0
    for src, trg, _, _ in iterator:
        src, trg = src.to(device), trg.to(device)
        optimizer.zero_grad()
        output, _ = model(src, trg)
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
    char_acc, word_acc = calculate_accuracies(model, iterator, device, dataset)
    return epoch_loss / len(iterator), char_acc, word_acc

# Evaluation function
def evaluate(model, iterator, criterion, device, dataset):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for src, trg, _, _ in iterator:
            src, trg = src.to(device), trg.to(device)
            output, _ = model(src, trg, 0)
            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    char_acc, word_acc = calculate_accuracies(model, iterator, device, dataset)
    return epoch_loss / len(iterator), char_acc, word_acc

# Custom collate function
def collate_fn(batch):
    src_tensors, tgt_tensors, src_texts, tgt_texts = zip(*batch)
    src_padded = torch.nn.utils.rnn.pad_sequence(src_tensors, batch_first=False, padding_value=0)
    tgt_padded = torch.nn.utils.rnn.pad_sequence(tgt_tensors, batch_first=False, padding_value=0)
    return src_padded, tgt_padded, list(src_texts), list(tgt_texts)

# Generate attention heatmaps
def plot_attention_heatmaps(samples, dataset, save_dir='predictions_attention'):
    os.makedirs(save_dir, exist_ok=True)
    fig, axes = plt.subplots(3, 3, figsize=(15, 15))
    axes = axes.flatten()
    for i, sample in enumerate(samples[:9]):
        src_text = sample['Latin Input']
        tgt_text = sample['Devanagari Target']
        pred_text = sample['Devanagari Predicted']
        attn = sample['Attention Weights'].cpu().numpy()
        src_chars = [dataset.src_inv_vocab.get(idx, '<unk>') for idx in sample['Src Indices'] 
                     if idx not in [dataset.src_eos_idx, dataset.src_pad_idx]]
        tgt_chars = [dataset.tgt_inv_vocab.get(idx, '<unk>') for idx in sample['Pred Indices'] 
                     if idx not in [dataset.tgt_eos_idx, dataset.tgt_pad_idx, dataset.tgt_sos_idx]]
        sns.heatmap(attn[:len(tgt_chars), :len(src_chars)], ax=axes[i], cmap='viridis', 
                    xticklabels=src_chars, yticklabels=tgt_chars, cbar=False)
        axes[i].set_title(f'Input: {src_text}\nPred: {pred_text}\nCorrect: {"✅" if sample["Correct"] else "❌"}')
        axes[i].set_xlabel('Source (Latin)')
        axes[i].set_ylabel('Target (Devanagari)')
    for j in range(len(samples), 9):
        axes[j].axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'attention_heatmaps.png'))
    plt.close()
    wandb.log({'attention_heatmaps': wandb.Image(os.path.join(save_dir, 'attention_heatmaps.png'))})

# Training function for WandB sweep
def train_sweep():
    wandb.init()
    hparams = wandb.config
    
    base_path = '/home/user/Downloads/dakshina_dataset_v1.0/hi/lexicons'
    train_path = f'{base_path}/hi.translit.sampled.train.tsv'
    dev_path = f'{base_path}/hi.translit.sampled.dev.tsv'
    test_path = f'{base_path}/hi.translit.sampled.test.tsv'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    train_dataset = TransliterationDataset(train_path)
    dev_dataset = TransliterationDataset(dev_path)
    test_dataset = TransliterationDataset(test_path)
    
    train_loader = DataLoader(train_dataset, batch_size=hparams.batch_size, shuffle=True, 
                             collate_fn=collate_fn)
    dev_loader = DataLoader(dev_dataset, batch_size=hparams.batch_size, 
                           collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=hparams.batch_size, 
                            collate_fn=collate_fn)
    
    # Debug: Print sample batch
    for src, trg, src_text, tgt_text in train_loader:
        print(f"Sample batch: src shape={src.shape}, trg shape={trg.shape}, src_text={src_text[:2]}, tgt_text={tgt_text[:2]}")
        break
    
    attention = Attention(hparams.hid_dim)
    encoder = Encoder(
        input_dim=train_dataset.src_vocab_size,
        emb_dim=hparams.emb_dim,
        hid_dim=hparams.hid_dim,
        n_layers=hparams.enc_layers,
        cell_type=hparams.cell_type,
        dropout=hparams.dropout,
        bidirectional=hparams.bidirectional
    )
    decoder = Decoder(
        output_dim=train_dataset.tgt_vocab_size,
        emb_dim=hparams.emb_dim,
        hid_dim=hparams.hid_dim,
        n_layers=hparams.dec_layers,
        cell_type=hparams.cell_type,
        dropout=hparams.dropout,
        attention=attention
    )
    model = Seq2Seq(encoder, decoder, device).to(device)
    
    optimizer_class = {'Adam': optim.Adam, 'RMSprop': optim.RMSprop, 'AdamW': optim.AdamW}[hparams.optimizer]
    optimizer = optimizer_class(model.parameters(), lr=hparams.learning_rate, weight_decay=hparams.weight_decay)
    criterion = nn.CrossEntropyLoss(ignore_index=train_dataset.tgt_pad_idx)
    
    n_epochs = 25
    best_valid_word_acc = 0
    
    try:
        for epoch in range(n_epochs):
            train_loss, train_char_acc, train_word_acc = train(model, train_loader, optimizer, 
                                                              criterion, hparams.grad_clip, device, train_dataset)
            valid_loss, valid_char_acc, valid_word_acc = evaluate(model, dev_loader, criterion, device, dev_dataset)
            wandb.log({
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'train_char_acc': train_char_acc,
                'train_word_acc': train_word_acc,
                'valid_loss': valid_loss,
                'valid_char_acc': valid_char_acc,
                'valid_word_acc': valid_word_acc
            })
            print(f'Epoch: {epoch+1:02}')
            print(f'\tTrain Loss: {train_loss:.3f} | Char Acc: {train_char_acc:.3f} | Word Acc: {train_word_acc:.3f}')
            print(f'\tVal. Loss: {valid_loss:.3f} | Char Acc: {valid_char_acc:.3f} | Word Acc: {valid_word_acc:.3f}')
            if valid_word_acc > best_valid_word_acc:
                best_valid_word_acc = valid_word_acc
                torch.save(model.state_dict(), f'best_model_{wandb.run.id}.pt')
                artifact = wandb.Artifact(f'model_{wandb.run.id}', type='model')
                artifact.add_file(f'best_model_{wandb.run.id}.pt')
                wandb.log_artifact(artifact)
        
        model.load_state_dict(torch.load(f'best_model_{wandb.run.id}.pt'))
        test_loss, test_char_acc, test_word_acc = evaluate(model, test_loader, criterion, device, test_dataset)
        wandb.log({
            'test_loss': test_loss,
            'test_char_acc': test_char_acc,
            'test_word_acc': test_word_acc
        })
        print(f'\nTest Results:')
        print(f'\tTest Loss: {test_loss:.3f}')
        print(f'\tTest Char Accuracy: {test_char_acc:.3f}')
        print(f'\tTest Word Accuracy: {test_word_acc:.3f}')
        
        model.eval()
        all_predictions = []
        heatmap_samples = []
        with torch.no_grad():
            for src, trg, src_text, tgt_text in test_loader:
                src = src.to(device)
                output, attn_weights = model.predict(src, max_len=50, 
                                                    sos_idx=test_dataset.tgt_sos_idx, 
                                                    eos_idx=test_dataset.tgt_eos_idx)
                for i in range(src.shape[1]):
                    pred_seq = output[:, i].cpu().numpy()
                    pred_word = ''.join([test_dataset.tgt_inv_vocab.get(idx, '<unk>') for idx in pred_seq 
                                        if idx not in [test_dataset.tgt_eos_idx, test_dataset.tgt_pad_idx, test_dataset.tgt_sos_idx]])
                    all_predictions.append({
                        'Latin Input': src_text[i],
                        'Devanagari Target': tgt_text[i],
                        'Devanagari Predicted': pred_word,
                        'Correct': pred_word == tgt_text[i]
                    })
                    if len(heatmap_samples) < 10:
                        heatmap_samples.append({
                            'Latin Input': src_text[i],
                            'Devanagari Target': tgt_text[i],
                            'Devanagari Predicted': pred_word,
                            'Correct': pred_word == tgt_text[i],
                            'Src Indices': src[:, i].cpu().numpy(),
                            'Pred Indices': pred_seq,
                            'Attention Weights': attn_weights[:, i, :]
                        })
        
        os.makedirs('predictions_attention', exist_ok=True)
        predictions_df = pd.DataFrame(all_predictions)
        predictions_df.to_csv('predictions_attention/test_predictions.csv', index=False)
        
        plot_attention_heatmaps(heatmap_samples, test_dataset)
        
        headers = ['#', 'Latin Input', 'Devanagari Target', 'Devanagari Predicted', 'Correct']
        table_data = [[i+1, s['Latin Input'], s['Devanagari Target'], s['Devanagari Predicted'], 
                       '✅' if s['Correct'] else '❌'] for i, s in enumerate(all_predictions[:5])]
        wandb.log({'sample_predictions': wandb.Table(columns=headers, data=table_data)})
        
    except KeyboardInterrupt:
        print(f"Training interrupted. Saving current model state...")
        torch.save(model.state_dict(), f'last_model_{wandb.run.id}.pt')
        artifact = wandb.Artifact(f'last_model_{wandb.run.id}', type='model')
        artifact.add_file(f'last_model_{wandb.run.id}.pt')
        wandb.log_artifact(artifact)
        wandb.finish()
        exit(0)
    
    wandb.finish()

# Sweep configuration
sweep_config = {
    'method': 'grid',
    'metric': {'name': 'valid_word_acc', 'goal': 'maximize'},
    'parameters': {
        'emb_dim': {'values': [128]},
        'hid_dim': {'values': [128,256]},
        'enc_layers': {'values': [1]},
        'dec_layers': {'values': [1]},
        'cell_type': {'values': ['RNN', 'GRU', 'LSTM']},
        'dropout': {'values': [0.2, 0.3]},
        'beam_size': {'values': [1, 3, 5]},
        'learning_rate': {'values': [5e-4, 1e-3, 5e-3]},
        'batch_size': {'values': [8, 16]},
        'teacher_forcing': {'values': [0.3, 0.5, 0.7]},
        'optimizer': {'values': ['Adam', 'RMSprop', 'AdamW']},
        'grad_clip': {'values': [5]},
        'weight_decay': {'values': [1e-6]},
        'bidirectional': {'values': [False]}
    }
}

# Main execution
def main():
    sweep_id = wandb.sweep(sweep_config, project="transliteration-seq2seq-attention")
    wandb.agent(sweep_id, function=train_sweep, count=50)

if __name__ == '__main__':
    main()