<a href="https://colab.research.google.com/github/Sai-sakunthala/Assignment-3/blob/main/Assignment_3_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install torch wandb pandas tqdm

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random
import wandb
import editdistance
import numpy as np
import os
from tqdm import tqdm
import pandas as pd

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
LANG = 'te'
data_path = f'/content/drive/MyDrive/dakshina_dataset_v1.0/{LANG}/lexicons/'

def read_data(filepath, max_len=40):
    pairs = []
    with open(filepath, encoding='utf8') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) < 2:
                continue
            source, target = parts[0], parts[1]
            if len(source) <= max_len and len(target) <= max_len:
                pairs.append((source, target))
    return pairs

def make_vocab(sequences):
    vocab = {'<pad>':0, '<sos>':1, '<eos>':2}
    idx = 3
    for seq in sequences:
        for ch in seq:
            if ch not in vocab:
                vocab[ch] = idx
                idx += 1
    idx2char = {i:c for c,i in vocab.items()}
    return vocab, idx2char

def encode_word(word, vocab):
    return [vocab['<sos>']] + [vocab[ch] for ch in word] + [vocab['<eos>']]

def pad_seq(seq, max_len, pad_idx=0):
    return seq + [pad_idx] * (max_len - len(seq))

class TransliterationDataset(Dataset):
    def __init__(self, pairs, source_vocab, target_vocab):
        self.source_pad = source_vocab['<pad>']
        self.target_pad = target_vocab['<pad>']
        self.data = []
        for source, target in pairs:
            source_t = encode_word(source, source_vocab)
            target_t = encode_word(target, target_vocab)
            self.data.append((source_t, target_t))
        self.source_max = max(len(x[0]) for x in self.data)
        self.target_max = max(len(x[1]) for x in self.data)

    def __len__(self): return len(self.data)

    def __getitem__(self, idx):
        source, target = self.data[idx]
        source = pad_seq(source, self.source_max, self.source_pad)
        target = pad_seq(target, self.target_max, self.target_pad)
        return torch.tensor(source), torch.tensor(target)

class translit_Encoder(nn.Module):
    def __init__(self, input_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell='lstm'):
        super().__init__()
        self.embedding = nn.Embedding(input_dimensions, emb_dimensions)
        rnn_cls = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}[cell.lower()]
        self.rnn = rnn_cls(emb_dimensions, hid_dimensions, num_layers, dropout=dropout if num_layers > 1 else 0, batch_first=True)
        self.cell = cell.lower()

    def forward(self, source):
        embedded = self.embedding(source)
        if self.cell == 'lstm':
            outputs, (hidden, cell) = self.rnn(embedded)
            return hidden, cell
        else:
            outputs, hidden = self.rnn(embedded)
            return hidden, None

class translit_Decoder(nn.Module):
    def __init__(self, output_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell='lstm'):
        super().__init__()
        self.embedding = nn.Embedding(output_dimensions, emb_dimensions)
        rnn_cls = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}[cell.lower()]
        self.rnn = rnn_cls(emb_dimensions, hid_dimensions, num_layers, dropout=dropout if num_layers > 1 else 0, batch_first=True)
        self.fc_out = nn.Linear(hid_dimensions, output_dimensions)
        self.cell = cell.lower()

    def forward(self, input, hidden, cell=None):
        input = input.unsqueeze(1)
        embedded = self.embedding(input)
        if self.cell == 'lstm':
            output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        else:
            output, hidden = self.rnn(embedded, hidden)
            cell = None
        prediction = self.fc_out(output.squeeze(1))
        return prediction, hidden, cell

class translit_Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        batch_size = source.size(0)
        target_len = target.size(1)
        output_dimensions = self.decoder.fc_out.out_features

        outputs = torch.zeros(batch_size, target_len, output_dimensions).to(self.device)
        hidden, cell = self.encoder(source)
        input = target[:, 0]

        for t in range(1, target_len):
            output, hidden, cell = self.decoder(input, hidden, cell)
            outputs[:, t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = target[:, t] if teacher_force else top1
        return outputs

def strip_after_eos(seq, eos_idx):
    """Strip sequence after EOS token, handling both lists and numpy arrays"""
    if isinstance(seq, np.ndarray):
        eos_positions = np.where(seq == eos_idx)[0]
        if len(eos_positions) > 0:
            return seq[:eos_positions[0] + 1]
        return seq
    else:  # handle lists
        if eos_idx in seq:
            return seq[:seq.index(eos_idx) + 1]
        return seq

def calculate_word_accuracy(preds, targets, pad_idx=0, eos_idx=None):
    correct = 0
    total = 0
    for pred, target in zip(preds, targets):
        if eos_idx is not None:
            pred = strip_after_eos(pred, eos_idx)
            target = strip_after_eos(target, eos_idx)
        pred = [p for p in pred if p != pad_idx]
        target = [t for t in target if t != pad_idx]
        if pred == target:
            correct += 1
        total += 1
    return correct / total if total > 0 else 0

def calculate_accuracy(preds, targets, pad_idx=0):
    total = 0
    correct = 0
    for p, t in zip(preds, targets):
        for pi, ti in zip(p, t):
            if ti == pad_idx:
                continue
            if pi == ti:
                correct += 1
            total += 1
    return correct / total if total > 0 else 0

def calculate_cer(preds, targets, pad_idx=0):
    cer, total = 0, 0
    for pred, target in zip(preds, targets):
        # Remove PAD tokens for evaluation.
        pred = [p for p in pred if p != pad_idx]
        target = [t for t in target if t != pad_idx]
        cer += editdistance.eval(pred, target)
        total += len(target)
    return cer / total if total > 0 else 0

cpu


In [3]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msai-sakunthala[0m ([33msai-sakunthala-indian-institute-of-technology-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

##multi reference

In [9]:
# Add this helper first to load all references
from collections import defaultdict
import unicodedata

run = wandb.init(project="dakshina-seq2seq_2", entity="sai-sakunthala-indian-institute-of-technology-madras", name="evaluate_test_multi")
artifact = run.use_artifact('best_model:v6', type='model')
artifact_dir = artifact.download()

# Read data and create vocabularies
test_pairs = read_data(data_path + f"{LANG}.translit.sampled.test.tsv", max_len=30)
train_pairs = read_data(data_path + f"{LANG}.translit.sampled.train.tsv", max_len=30)
source_vocab, idx2char_src = make_vocab([x[0] for x in train_pairs])
target_vocab, idx2char_tgt = make_vocab([x[1] for x in train_pairs])

# Initialize model
encoder = translit_Encoder(len(source_vocab), 256, 256*2, 2, 0.3, 'lstm').to(device)
decoder = translit_Decoder(len(target_vocab), 256, 256*2, 2, 0.3, 'lstm').to(device)
model = translit_Seq2Seq(encoder, decoder, device).to(device)

# Load model weights
state_dict = torch.load(f"{artifact_dir}/best_model.pt", map_location=device)
model.load_state_dict(state_dict)
model.eval()

# Create test dataset and loader
test_translit = TransliterationDataset(test_pairs, source_vocab, target_vocab)
test_loader = DataLoader(test_translit, batch_size=64, shuffle=False, drop_last=True)

all_src, all_preds, all_tgts = [], [], []
correct = 0
total = 0

def predict(model, src, max_len=30):
    """Greedy decoding implementation"""
    encoder_hidden, encoder_cell = model.encoder(src)

    # First input is SOS token
    input = torch.tensor([target_vocab['<sos>']] * src.size(0)).to(device)
    outputs = []

    for t in range(max_len):
        output, encoder_hidden, encoder_cell = model.decoder(input, encoder_hidden, encoder_cell)
        input = output.argmax(1)  # Greedy decoding
        outputs.append(input)

        # Stop if all sequences predicted EOS
        if (input == target_vocab['<eos>']).all():
            break

    return torch.stack(outputs, dim=1)

def normalize(text):
    return unicodedata.normalize('NFC', text)

def read_test_refs(filepath, max_len=30):
    ref_dict = defaultdict(list)
    with open(filepath, encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split('\t')
            # Accept lines with 2 or 3 parts (ignore the third if present)
            if len(parts) < 2:
                continue
            src, tgt = parts[0], parts[1]
            if len(src) <= max_len and len(tgt) <= max_len:
                ref_dict[normalize(src)].append(normalize(tgt))
    return ref_dict

# Use the reference loader to get valid outputs
ref_dict = read_test_refs(data_path + f"{LANG}.translit.sampled.test.tsv", max_len=30)

# Create a function to convert int sequences to string

def tokens_to_string(token_seq, idx2char, vocab):
    tokens = []
    for tok in token_seq:
        if tok in [vocab['<pad>'], vocab['<sos>'], vocab['<eos>']]:
            continue
        tokens.append(idx2char[tok])
    return ''.join(tokens)

# Replace the accuracy calculation loop
correct = 0
all_src, all_preds, all_refs = [], [], []

with torch.no_grad():
    for src, tgt in tqdm(test_loader):
        src, tgt = src.to(device), tgt.to(device)
        preds = predict(model, src)

        src_np = src.cpu().numpy()
        preds_np = preds.cpu().numpy()

        for i in range(len(src_np)):
            src_word = normalize(tokens_to_string(src_np[i], idx2char_src, source_vocab))
            pred_word = normalize(tokens_to_string(preds_np[i], idx2char_tgt, target_vocab))
            valid_refs = ref_dict.get(src_word, [])

            all_src.append(src_np[i])
            all_preds.append(preds_np[i])
            all_refs.append(valid_refs)
            if pred_word in valid_refs:
                correct += 1

accuracy = correct / len(all_src)
print(f"Test Accuracy (Multi-Ref): {accuracy:.4f}")
wandb.log({"Test Accuracy (Multi-Ref)": accuracy})

# Update table logging to handle multiple references

def log_sample_predictions_table_wandb(sources, preds, targets, idx2char_src, idx2char_tgt, num_samples=10):
    table = wandb.Table(columns=["Source", "Prediction", "Valid Reference(s)", "Correct?"])
    sample_indices = random.sample(range(len(sources)), min(num_samples, len(sources)))

    for i in sample_indices:
        src_word = normalize(tokens_to_string(sources[i], idx2char_src, source_vocab))
        pred_word = normalize(tokens_to_string(preds[i], idx2char_tgt, target_vocab))
        valid_refs = ref_dict.get(src_word, [])
        ref_display = ', '.join(valid_refs)

        # Determine correctness
        is_correct = pred_word in valid_refs
        status = "🟩 **Correct**" if is_correct else "🟥 **Incorrect**"

        table.add_data(src_word, pred_word, ref_display, status)

    wandb.log({"Test Sample Predictions (Color-Coded)": table})

log_sample_predictions_table_wandb(all_src, all_preds, ref_dict, idx2char_src, idx2char_tgt)

output_dir = "predictions_vanilla"
os.makedirs(output_dir, exist_ok=True)

# Save updated predictions
with open(os.path.join(output_dir, "test_predictions.txt"), "w", encoding="utf-8") as f:
    for s, p in zip(all_src, all_preds):
        src_word = normalize(tokens_to_string(s, idx2char_src, source_vocab))
        pred_word = normalize(tokens_to_string(p, idx2char_tgt, target_vocab))
        refs = ', '.join(ref_dict.get(src_word, []))
        f.write(f"{src_word}\t{pred_word}\t{refs}\n")
wandb.save(os.path.join(output_dir, "test_predictions.txt"))
wandb.finish()

[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 89/89 [00:45<00:00,  1.96it/s]

Test Accuracy (Multi-Ref): 0.7049





0,1
Test Accuracy (Multi-Ref),▁

0,1
Test Accuracy (Multi-Ref),0.70488


0,1
Test Accuracy (Multi-Ref),▁

0,1
Test Accuracy (Multi-Ref),0.70488


## no multi reference

In [8]:
run = wandb.init(project="dakshina-seq2seq_2", entity="sai-sakunthala-indian-institute-of-technology-madras", name="evaluate_test")
artifact = run.use_artifact('best_model:v6', type='model')
artifact_dir = artifact.download()

# Read data and create vocabularies
test_pairs = read_data(data_path + f"{LANG}.translit.sampled.test.tsv", max_len=30)
train_pairs = read_data(data_path + f"{LANG}.translit.sampled.train.tsv", max_len=30)
source_vocab, idx2char_src = make_vocab([x[0] for x in train_pairs])
target_vocab, idx2char_tgt = make_vocab([x[1] for x in train_pairs])

# Initialize model
encoder = translit_Encoder(len(source_vocab), 256, 256*2, 2, 0.3, 'lstm').to(device)
decoder = translit_Decoder(len(target_vocab), 256, 256*2, 2, 0.3, 'lstm').to(device)
model = translit_Seq2Seq(encoder, decoder, device).to(device)

# Load model weights
state_dict = torch.load(f"{artifact_dir}/best_model.pt", map_location=device)
model.load_state_dict(state_dict)
model.eval()

# Create test dataset and loader
test_translit = TransliterationDataset(test_pairs, source_vocab, target_vocab)
test_loader = DataLoader(test_translit, batch_size=64, shuffle=False, drop_last=True)

all_src, all_preds, all_tgts = [], [], []
correct = 0
total = 0

def predict(model, src, max_len=30):
    """Greedy decoding implementation"""
    encoder_hidden, encoder_cell = model.encoder(src)

    # First input is SOS token
    input = torch.tensor([target_vocab['<sos>']] * src.size(0)).to(device)
    outputs = []

    for t in range(max_len):
        output, encoder_hidden, encoder_cell = model.decoder(input, encoder_hidden, encoder_cell)
        input = output.argmax(1)  # Greedy decoding
        outputs.append(input)

        # Stop if all sequences predicted EOS
        if (input == target_vocab['<eos>']).all():
            break

    return torch.stack(outputs, dim=1)

with torch.no_grad():
    for src, tgt in tqdm(test_loader):
        src, tgt = src.to(device), tgt.to(device)
        preds = predict(model, src)

        # Convert to numpy arrays for processing
        src_np = src.cpu().numpy()
        preds_np = preds.cpu().numpy()
        tgt_np = tgt.cpu().numpy()

        for i in range(len(src_np)):
            # Get source, prediction and target sequences
            s = src_np[i]
            p = preds_np[i]
            t = tgt_np[i]

            # Store original sequences
            all_src.append(s)
            all_preds.append(p)
            all_tgts.append(t)

            # Process prediction: remove padding and everything after EOS
            p_processed = []
            for token in p:
                if token == target_vocab['<eos>']:
                    break
                if token not in [target_vocab['<pad>'], target_vocab['<sos>']]:
                    p_processed.append(token)

            # Process target: remove padding and everything after EOS
            t_processed = []
            for token in t:
                if token == target_vocab['<eos>']:
                    break
                if token not in [target_vocab['<pad>'], target_vocab['<sos>']]:
                    t_processed.append(token)

            # Compare the processed sequences
            if p_processed == t_processed:
                correct += 1
            total += 1

accuracy = correct / total if total > 0 else 0
print(f"Test Accuracy: {accuracy:.4f}")
print(f"Correct: {correct}, Total: {total}")
wandb.log({"Test Accuracy": accuracy})

def log_sample_predictions_table_wandb(sources, preds, targets, idx2char_src, idx2char_tgt, num_samples=10):
    table = wandb.Table(columns=["Source", "Prediction", "Reference"])

    # Pick random indices without replacement
    sample_indices = random.sample(range(len(sources)), min(num_samples, len(sources)))

    for i in sample_indices:
        src_word = ''.join([idx2char_src[idx] for idx in sources[i] if idx not in [source_vocab['<pad>'], source_vocab['<sos>'], source_vocab['<eos>']]])
        pred_word = ''.join([idx2char_tgt[idx] for idx in preds[i] if idx not in [target_vocab['<pad>'], target_vocab['<sos>'], target_vocab['<eos>']]])
        ref_word = ''.join([idx2char_tgt[idx] for idx in targets[i] if idx not in [target_vocab['<pad>'], target_vocab['<sos>'], target_vocab['<eos>']]])
        table.add_data(src_word, pred_word, ref_word)

    wandb.log({"Test Sample Predictions Table": table})

log_sample_predictions_table_wandb(all_src, all_preds, all_tgts, idx2char_src, idx2char_tgt)

output_dir = "predictions_vanilla"
os.makedirs(output_dir, exist_ok=True)

with open(os.path.join(output_dir, "test_predictions.txt"), "w", encoding="utf-8") as f:
    for s, p, t in zip(all_src, all_preds, all_tgts):
        src_word = ''.join([idx2char_src[idx] for idx in s if idx not in [source_vocab['<pad>'], source_vocab['<sos>'], source_vocab['<eos>']]])
        pred_word = ''.join([idx2char_tgt[idx] for idx in p if idx not in [target_vocab['<pad>'], target_vocab['<sos>'], target_vocab['<eos>']]])
        ref_word = ''.join([idx2char_tgt[idx] for idx in t if idx not in [target_vocab['<pad>'], target_vocab['<sos>'], target_vocab['<eos>']]])
        f.write(f"{src_word}\t{pred_word}\t{ref_word}\n")

print(f"Saved full predictions to: {output_dir}/test_predictions.txt")
wandb.save(os.path.join(output_dir, "test_predictions.txt"))
wandb.finish()

0,1
Test Accuracy (Multi-Ref),▁

0,1
Test Accuracy (Multi-Ref),0.70488


[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 89/89 [00:45<00:00,  1.97it/s]


Test Accuracy: 0.3278
Correct: 1867, Total: 5696
Saved full predictions to: predictions_vanilla/test_predictions.txt


0,1
Test Accuracy,▁

0,1
Test Accuracy,0.32777
