In [None]:
pip install torch wandb pandas tqdm

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random
import wandb
import editdistance
import numpy as np
import os
from tqdm import tqdm
import pandas as pd

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
LANG = 'te'
data_path = f'/content/drive/MyDrive/dakshina_dataset_v1.0/{LANG}/lexicons/'

def read_data(filepath, max_len=40):
    pairs = []
    # Open the file with UTF-8 encoding to properly read Unicode characters
    with open(filepath, encoding='utf8') as f:
        for line in f:
            # Remove leading/trailing whitespace and split by tab
            parts = line.strip().split('\t')
            # Skip lines that don't contain both source and target text
            if len(parts) < 2:
                continue
            devanagiri, latin = parts[0], parts[1]

            # We are training a Latin → Devanagiri transliteration model,
            # so set Latin as the source and Devanagiri as the target
            source, target = latin, devanagiri

            # Only keep pairs where both source and target are within the allowed max length
            if len(source) <= max_len and len(target) <= max_len:
                pairs.append((source, target))

    # Return the list of filtered (source, target) pairs
    return pairs

def make_vocab(sequences):
    # Initialize the vocabulary with special tokens
    vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2}
    idx = 3  # Starting index for regular characters

    # Loop through all sequences to build the vocabulary
    for seq in sequences:
        for ch in seq:
            # Add each unique character to the vocabulary
            if ch not in vocab:
                vocab[ch] = idx
                idx += 1

    # Create reverse mapping from index to character
    idx2char = {i: c for c, i in vocab.items()}

    # Return both the character-to-index and index-to-character dictionaries
    return vocab, idx2char

def encode_word(word, vocab):
    # Convert a word into a list of indices using the vocabulary
    # Add <sos> token at the beginning and <eos> token at the end
    return [vocab['<sos>']] + [vocab[ch] for ch in word] + [vocab['<eos>']]

def pad_seq(seq, max_len, pad_idx=0):
    # Pad the sequence with <pad> tokens (default index 0) to reach max_len
    return seq + [pad_idx] * (max_len - len(seq))

class TransliterationDataset(Dataset):
    def __init__(self, pairs, source_vocab, target_vocab):
        # Save padding indices for both source and target vocabularies
        self.source_pad = source_vocab['<pad>']
        self.target_pad = target_vocab['<pad>']
        self.data = []

        # Convert each (source, target) word pair into sequences of token indices
        for source, target in pairs:
            source_t = encode_word(source, source_vocab)
            target_t = encode_word(target, target_vocab)
            self.data.append((source_t, target_t))

        # Determine the maximum lengths of source and target sequences
        self.source_max = max(len(x[0]) for x in self.data)
        self.target_max = max(len(x[1]) for x in self.data)

    def __len__(self):
        # Return total number of samples in the dataset
        return len(self.data)

    def __getitem__(self, idx):
        # Fetch a source-target pair and pad both to their respective max lengths
        source, target = self.data[idx]
        source = pad_seq(source, self.source_max, self.source_pad)
        target = pad_seq(target, self.target_max, self.target_pad)
        return torch.tensor(source), torch.tensor(target)

class translit_Encoder(nn.Module):
    def __init__(self, input_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell='lstm'):
        super().__init__()
        # Converts token indices into dense embeddings
        self.embedding = nn.Embedding(input_dimensions, emb_dimensions)

        # Choose the appropriate RNN variant (RNN, GRU, or LSTM)
        rnn_cls = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}[cell.lower()]
        self.rnn = rnn_cls(
            emb_dimensions,
            hid_dimensions,
            num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )

        self.cell = cell.lower()

    def forward(self, source):
        # Embed the input sequence
        embedded = self.embedding(source)

        # Forward pass through the RNN
        if self.cell == 'lstm':
            outputs, (hidden, cell) = self.rnn(embedded)
            return hidden, cell
        else:
            outputs, hidden = self.rnn(embedded)
            return hidden, None

class translit_Decoder(nn.Module):
    def __init__(self, output_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell='lstm'):
        super().__init__()
        # Embedding for decoder inputs
        self.embedding = nn.Embedding(output_dimensions, emb_dimensions)

        # Choose RNN type for the decoder
        rnn_cls = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}[cell.lower()]
        self.rnn = rnn_cls(
            emb_dimensions,
            hid_dimensions,
            num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )

        # Fully connected layer to project hidden states to vocabulary size
        self.fc_out = nn.Linear(hid_dimensions, output_dimensions)

        self.cell = cell.lower()

    def forward(self, input, hidden, cell=None):
        # Add time-step dimension (batch_size -> batch_size x 1)
        input = input.unsqueeze(1)

        # Embed input token
        embedded = self.embedding(input)

        # Forward pass through the RNN cell
        if self.cell == 'lstm':
            output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        else:
            output, hidden = self.rnn(embedded, hidden)
            cell = None

        # Convert final hidden state to vocabulary prediction
        prediction = self.fc_out(output.squeeze(1))
        return prediction, hidden, cell

class Attention(nn.Module):
    def __init__(self, hid_dim):
        super().__init__()
        # Linear layer to compute attention scores from hidden states
        self.attn = nn.Linear(hid_dim * 2, hid_dim)
        self.v = nn.Linear(hid_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs, mask=None):
        # Repeat decoder hidden state across source sequence length
        src_len = encoder_outputs.size(1)
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

        # Concatenate hidden state and encoder outputs, then pass through attention layers
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = self.v(energy).squeeze(2)

        # Optionally apply mask to ignore padding tokens
        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e10)

        # Normalize attention scores
        return torch.softmax(attention, dim=1)

class translit_Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        batch_size = source.size(0)
        target_len = target.size(1)
        output_dimensions = self.decoder.fc_out.out_features

        # Initialize tensor to store decoder outputs
        outputs = torch.zeros(batch_size, target_len, output_dimensions).to(self.device)

        # Encode the input sequence
        hidden, cell = self.encoder(source)

        # Start decoding with the <sos> token
        input = target[:, 0]

        for t in range(1, target_len):
            # Get output prediction and next hidden state
            output, hidden, cell = self.decoder(input, hidden, cell)
            outputs[:, t] = output

            # Decide whether to use teacher forcing
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)

            # Use actual next token (teacher forcing) or predicted token
            input = target[:, t] if teacher_force else top1

        return outputs



def strip_after_eos(seq, eos_idx):
    # Convert tensor to list if needed
    if isinstance(seq, torch.Tensor):
        seq = seq.cpu().numpy().tolist()
    # Trim the sequence at the first <eos> token
    if eos_idx in seq:
        return seq[:seq.index(eos_idx)]
    return seq

def calculate_word_accuracy(preds, targets, pad_idx=0, eos_idx=None):
    correct = 0
    for pred, target in zip(preds, targets):
        # Remove padding and stop at <eos> for fair comparison
        pred = strip_after_eos(pred, eos_idx) if eos_idx else pred
        target = strip_after_eos(target, eos_idx) if eos_idx else target
        pred = [p for p in pred if p != pad_idx]
        target = [t for t in target if t != pad_idx]
        # Count if full predicted word matches target
        correct += int(pred == target)
    return correct / max(len(preds), 1)


def calculate_cer(preds, targets, pad_idx=0, eos_idx=None):
    cer = 0
    total = 0
    for pred, target in zip(preds, targets):
        # Clean sequences by removing padding and trimming after <eos>
        pred = strip_after_eos(pred, eos_idx) if eos_idx else pred
        target = strip_after_eos(target, eos_idx) if eos_idx else target
        pred = [p for p in pred if p != pad_idx]
        target = [t for t in target if t != pad_idx]
        # Accumulate edit distance and total characters
        cer += editdistance.eval(pred, target)
        total += max(len(target), 1)
    return cer / total if total > 0 else float('inf')


def calculate_accuracy(preds, targets, pad_idx=0, eos_idx=None):
    correct = 0
    total = 0
    for pred, target in zip(preds, targets):
        # Convert tensors to lists if necessary
        if isinstance(pred, torch.Tensor):
            pred = pred.cpu().tolist()
        if isinstance(target, torch.Tensor):
            target = target.cpu().tolist()
        # Strip <eos> tokens if specified
        if eos_idx is not None:
            pred = strip_after_eos(pred, eos_idx)
            target = strip_after_eos(target, eos_idx)
        # Compare tokens one by one, ignoring padding
        for p_token, t_token in zip(pred, target):
            if t_token == pad_idx:
                continue
            if p_token == t_token:
                correct += 1
            total += 1
    return correct / total if total > 0 else 0.0

In [None]:
wandb.login()

In [None]:
run = wandb.init(project="dakshina-seq2seq-3", entity="sai-sakunthala-indian-institute-of-technology-madras", name="evaluate_test")
artifact = run.use_artifact('best_model:v5', type='model')
artifact_dir = artifact.download()

# Read data and create vocabularies
test_pairs = read_data(data_path + f"{LANG}.translit.sampled.test.tsv", max_len=30)
train_pairs = read_data(data_path + f"{LANG}.translit.sampled.train.tsv", max_len=30)
source_vocab, idx2char_src = make_vocab([x[0] for x in train_pairs])
target_vocab, idx2char_tgt = make_vocab([x[1] for x in train_pairs])

# Initialize model
encoder = translit_Encoder(len(source_vocab), 128, 128*2, 2, 0.2, 'lstm').to(device)
decoder = translit_Decoder(len(target_vocab), 128, 128*2, 2, 0.2, 'lstm').to(device)
model = translit_Seq2Seq(encoder, decoder, device).to(device)

# Load model weights
state_dict = torch.load(f"{artifact_dir}/best_model.pt", map_location=device)
model.load_state_dict(state_dict)
model.eval()

# Create test dataset and loader
test_translit = TransliterationDataset(test_pairs, source_vocab, target_vocab)
test_loader = DataLoader(test_translit, batch_size=64, shuffle=False, drop_last=True)

all_src, all_preds, all_tgts = [], [], []
correct = 0
total = 0

def predict(model, src, max_len=30):
    """Greedy decoding implementation"""
    encoder_hidden, encoder_cell = model.encoder(src)

    # First input is SOS token
    input = torch.tensor([target_vocab['<sos>']] * src.size(0)).to(device)
    outputs = []

    for t in range(max_len):
        output, encoder_hidden, encoder_cell = model.decoder(input, encoder_hidden, encoder_cell)
        input = output.argmax(1)
        outputs.append(input)

        # Stop if all sequences predicted EOS
        if (input == target_vocab['<eos>']).all():
            break

    return torch.stack(outputs, dim=1)

with torch.no_grad():
    for src, tgt in tqdm(test_loader):
        src, tgt = src.to(device), tgt.to(device)
        preds = predict(model, src)

        # Convert to numpy arrays for processing
        src_np = src.cpu().numpy()
        preds_np = preds.cpu().numpy()
        tgt_np = tgt.cpu().numpy()

        for i in range(len(src_np)):
            # Get source, prediction and target sequences
            s = src_np[i]
            p = preds_np[i]
            t = tgt_np[i]

            # Store original sequences
            all_src.append(s)
            all_preds.append(p)
            all_tgts.append(t)

            # remove padding and everything after EOS
            p_processed = []
            for token in p:
                if token == target_vocab['<eos>']:
                    break
                if token not in [target_vocab['<pad>'], target_vocab['<sos>']]:
                    p_processed.append(token)

            # remove padding and everything after EOS
            t_processed = []
            for token in t:
                if token == target_vocab['<eos>']:
                    break
                if token not in [target_vocab['<pad>'], target_vocab['<sos>']]:
                    t_processed.append(token)

            # Compare the processed sequences
            if p_processed == t_processed:
                correct += 1
            total += 1

accuracy = correct / total if total > 0 else 0
print(f"Test Accuracy: {accuracy:.4f}")
print(f"Correct: {correct}, Total: {total}")
wandb.log({"Test Accuracy": accuracy})

def log_sample_predictions_table_wandb(sources, preds, targets, idx2char_src, idx2char_tgt, num_samples=10):
    table = wandb.Table(columns=["Source", "Prediction", "Reference", "status"])

    # Pick random indices without replacement
    sample_indices = random.sample(range(len(sources)), min(num_samples, len(sources)))

    for i in sample_indices:
        src_word = ''.join([idx2char_src[idx] for idx in sources[i] if idx not in [source_vocab['<pad>'], source_vocab['<sos>'], source_vocab['<eos>']]])
        pred_word = ''.join([idx2char_tgt[idx] for idx in preds[i] if idx not in [target_vocab['<pad>'], target_vocab['<sos>'], target_vocab['<eos>']]])
        ref_word = ''.join([idx2char_tgt[idx] for idx in targets[i] if idx not in [target_vocab['<pad>'], target_vocab['<sos>'], target_vocab['<eos>']]])
        # Determine correctness
        is_correct = (pred_word == ref_word)
        status = "🟩 **Correct**" if is_correct else "🟥 **Incorrect**"

        table.add_data(src_word, pred_word, ref_word, status)

    wandb.log({"Test Sample Predictions (Color-Coded)": table})

log_sample_predictions_table_wandb(all_src, all_preds, all_tgts, idx2char_src, idx2char_tgt)

output_dir = "predictions_vanilla"
os.makedirs(output_dir, exist_ok=True)

with open(os.path.join(output_dir, "test_predictions.txt"), "w", encoding="utf-8") as f:
    for s, p, t in zip(all_src, all_preds, all_tgts):
        src_word = ''.join([idx2char_src[idx] for idx in s if idx not in [source_vocab['<pad>'], source_vocab['<sos>'], source_vocab['<eos>']]])
        pred_word = ''.join([idx2char_tgt[idx] for idx in p if idx not in [target_vocab['<pad>'], target_vocab['<sos>'], target_vocab['<eos>']]])
        ref_word = ''.join([idx2char_tgt[idx] for idx in t if idx not in [target_vocab['<pad>'], target_vocab['<sos>'], target_vocab['<eos>']]])
        f.write(f"{src_word}\t{pred_word}\t{ref_word}\n")

print(f"Saved full predictions to: {output_dir}/test_predictions.txt")
wandb.save(os.path.join(output_dir, "test_predictions.txt"))
wandb.finish()