In [None]:
pip install torch wandb pandas tqdm

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random
import wandb
import editdistance
import numpy as np
import os
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
LANG = 'te'
data_path = f'/content/drive/MyDrive/dakshina_dataset_v1.0/{LANG}/lexicons/'

def read_data(filepath, max_len=40):
    pairs = []
    with open(filepath, encoding='utf8') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) < 2:
                continue
            devanagiri, latin = parts[0], parts[1]  # Current order
            # Swap to make Latin → Devanagiri
            source, target = latin, devanagiri  # Now Latin is source, Devanagiri is target
            if len(source) <= max_len and len(target) <= max_len:
                pairs.append((source, target))
    return pairs

def make_vocab(sequences):
    vocab = {'<pad>':0, '<sos>':1, '<eos>':2}
    idx = 3
    for seq in sequences:
        for ch in seq:
            if ch not in vocab:
                vocab[ch] = idx
                idx += 1
    idx2char = {i:c for c,i in vocab.items()}
    return vocab, idx2char

def encode_word(word, vocab):
    return [vocab['<sos>']] + [vocab[ch] for ch in word] + [vocab['<eos>']]

def pad_seq(seq, max_len, pad_idx=0):
    return seq + [pad_idx] * (max_len - len(seq))

class TransliterationDataset(Dataset):
    def __init__(self, pairs, source_vocab, target_vocab):
        self.source_pad = source_vocab['<pad>']
        self.target_pad = target_vocab['<pad>']
        self.data = []
        for source, target in pairs:
            source_t = encode_word(source, source_vocab)
            target_t = encode_word(target, target_vocab)
            self.data.append((source_t, target_t))
        self.source_max = max(len(x[0]) for x in self.data)
        self.target_max = max(len(x[1]) for x in self.data)

    def __len__(self): return len(self.data)

    def __getitem__(self, idx):
        source, target = self.data[idx]
        source = pad_seq(source, self.source_max, self.source_pad)
        target = pad_seq(target, self.target_max, self.target_pad)
        return torch.tensor(source), torch.tensor(target)

class Attention(nn.Module):
    def __init__(self, hid_dimensions):
        super().__init__()
        self.attn = nn.Linear(hid_dimensions * 2, hid_dimensions)
        self.v = nn.Parameter(torch.rand(hid_dimensions))
        stdv = 1. / (hid_dimensions ** 0.5)
        self.v.data.uniform_(-stdv, stdv)
        self.hid_dimensions = hid_dimensions

    def forward(self, hidden, encoder_outputs):
        # hidden: (batch_size, hid_dimensions) or (num_layers, batch_size, hid_dimensions)
        # encoder_outputs: (batch_size, src_len, hid_dimensions)
        batch_size = encoder_outputs.size(0)
        src_len = encoder_outputs.size(1)

        # Ensure hidden is 2D (batch_size, hid_dimensions)
        if hidden.dim() == 3:  # (num_layers, batch_size, hid_dimensions)
            hidden = hidden[-1]  # Take last layer: (batch_size, hid_dimensions)
        elif hidden.dim() != 2:
            raise ValueError(f"Expected hidden to be 2D or 3D, got shape {hidden.shape}")

        # Repeat hidden state src_len times
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)  # (batch_size, src_len, hid_dimensions)

        # Compute energy
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        energy = energy @ self.v  # (batch_size, src_len)

        # Compute attention weights
        attn_weights = torch.softmax(energy, dim=1).unsqueeze(2)  # (batch_size, src_len, 1)

        # Compute context vector
        context = torch.sum(attn_weights * encoder_outputs, dim=1)  # (batch_size, hid_dimensions)

        return context, attn_weights.squeeze(2)  # Return (batch_size, hid_dimensions), (batch_size, src_len)

# Updated translit_Decoder to return attention weights
class translit_Decoder(nn.Module):
    def __init__(self, output_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell='lstm'):
        super().__init__()
        self.embedding = nn.Embedding(output_dimensions, emb_dimensions)
        self.attention = Attention(hid_dimensions)
        rnn_cls = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}[cell.lower()]
        self.rnn = rnn_cls(emb_dimensions, hid_dimensions, num_layers, dropout=dropout if num_layers > 1 else 0, batch_first=True)
        self.fc_out = nn.Linear(hid_dimensions * 2, output_dimensions)
        self.cell = cell.lower()
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell, encoder_outputs):

        input = input.unsqueeze(1)  # (batch_size, 1)
        embedded = self.dropout(self.embedding(input))  # (batch_size, 1, emb_dimensions)

        if self.cell == 'lstm':
            output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        else:
            output, hidden = self.rnn(embedded, hidden)
            cell = None
        context, attn_weights = self.attention(hidden, encoder_outputs)  # context: (batch_size, hid_dimensions), attn_weights: (batch_size, src_len)

        rnn_output = output.squeeze(1)  # (batch_size, hid_dimensions)
        combined = torch.cat((rnn_output, context), dim=1)  # (batch_size, hid_dimensions * 2)

        prediction = self.fc_out(combined)  # (batch_size, output_dimensions)

        return prediction, hidden, cell, attn_weights  # Added attn_weights

class translit_Encoder(nn.Module):
    def __init__(self, input_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell='lstm'):
        super().__init__()
        self.embedding = nn.Embedding(input_dimensions, emb_dimensions)
        rnn_cls = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}[cell.lower()]
        self.rnn = rnn_cls(emb_dimensions, hid_dimensions, num_layers, dropout=dropout if num_layers > 1 else 0, batch_first=True)
        self.attention = Attention(hid_dimensions)
        self.cell = cell.lower()
        self.dropout = nn.Dropout(dropout)

    def forward(self, source):
        embedded = self.dropout(self.embedding(source))
        if self.cell == 'lstm':
            outputs, (hidden, cell) = self.rnn(embedded)
        else:
            outputs, hidden = self.rnn(embedded)
            cell = None
        context = self.attention(hidden, outputs)
        return outputs, hidden, cell

class translit_Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        batch_size = source.size(0)
        target_len = target.size(1)
        output_dimensions = self.decoder.fc_out.out_features

        outputs = torch.zeros(batch_size, target_len, output_dimensions).to(self.device)
        attn_weights_all = torch.zeros(batch_size, target_len, source.size(1)).to(self.device)  # To store attention weights

        # Encoder
        encoder_outputs, hidden, cell = self.encoder(source)

        # First input
        input = target[:, 0]

        # Decoder loop
        for t in range(1, target_len):
            output, hidden, cell, attn_weights = self.decoder(input, hidden, cell, encoder_outputs)
            outputs[:, t] = output
            attn_weights_all[:, t] = attn_weights  # Store attention weights
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = target[:, t] if teacher_force else top1

        return outputs, attn_weights_all  # Return attention weights

def strip_after_eos(seq, eos_idx):
    if isinstance(seq, torch.Tensor):  # Handle tensors
        seq = seq.cpu().numpy().tolist()
    if eos_idx in seq:
        return seq[:seq.index(eos_idx)]  # Exclude EOS for fair comparison
    return seq

def calculate_word_accuracy(preds, targets, pad_idx=0, eos_idx=None):
    correct = 0
    for pred, target in zip(preds, targets):
        pred = strip_after_eos(pred, eos_idx) if eos_idx else pred
        target = strip_after_eos(target, eos_idx) if eos_idx else target
        pred = [p for p in pred if p != pad_idx]
        target = [t for t in target if t != pad_idx]
        correct += int(pred == target)
    return correct / max(len(preds), 1)

def calculate_cer(preds, targets, pad_idx=0, eos_idx=None):
    cer, total = 0, 0
    for pred, target in zip(preds, targets):
        pred = strip_after_eos(pred, eos_idx) if eos_idx else pred
        target = strip_after_eos(target, eos_idx) if eos_idx else target
        pred = [p for p in pred if p != pad_idx]
        target = [t for t in target if t != pad_idx]
        cer += editdistance.eval(pred, target)
        total += max(len(target), 1)
    return cer / total if total > 0 else float('inf')

def calculate_accuracy(preds, targets, pad_idx=0, eos_idx=None):
    correct = 0
    total = 0
    for pred, target in zip(preds, targets):
        if isinstance(pred, torch.Tensor):
            pred = pred.cpu().tolist()
        if isinstance(target, torch.Tensor):
            target = target.cpu().tolist()
        if eos_idx is not None:
            pred = strip_after_eos(pred, eos_idx)
            target = strip_after_eos(target, eos_idx)
        for p_token, t_token in zip(pred, target):
            if t_token == pad_idx:
                continue
            if p_token == t_token:
                correct += 1
            total += 1
    return correct / total if total > 0 else 0.0

cpu


In [None]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msai-sakunthala[0m ([33msai-sakunthala-indian-institute-of-technology-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
!pip install Pillow

# Verify font file (assuming you uploaded NotoSansTelugu-Regular.ttf)
!ls /content

# If the font is not in /content, specify the correct path where you uploaded it
font_path = '/content/NotoSansTelugu-VariableFont.ttf'  # Adjust if uploaded to a different directory
import os
if os.path.exists(font_path):
    print(f"Font file found at {font_path}")
else:
    print(f"Font file not found at {font_path}. Please upload it.")
    from google.colab import files
    uploaded = files.upload()

artifacts  heatmap_example_1.png   NotoSansTelugu-VariableFont.ttf  sample_data
drive	   LohitTeluguRegular.ttf  predictions_vanilla		    wandb
Font file found at /content/NotoSansTelugu-VariableFont.ttf


In [None]:
# Minimal test for create_heatmap_image
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import wandb
import os

def create_heatmap_image(src_tokens, pred_tokens, attn_weights, idx, idx2char_src, idx2char_tgt):
    """Create a single heatmap image for WandB table using PIL"""
    # Define image dimensions
    cell_size = 50  # Size of each heatmap cell in pixels
    label_width = 150  # Width for labels (source and target)
    margin = 50  # Margin around the heatmap
    title_height = 50  # Space for title
    xlabel_height = 50  # Space for x-axis label
    ylabel_width = 100  # Space for y-axis label

    # Filter tokens and get labels
    src_labels = [idx2char_src.get(idx, '?') for idx in src_tokens if idx in idx2char_src and idx not in [source_vocab.get('<pad>', -1), source_vocab.get('<sos>', -1), source_vocab.get('<eos>', -1)]]
    pred_labels = [idx2char_tgt.get(idx, '?') for idx in pred_tokens if idx in idx2char_tgt and idx not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1), target_vocab.get('<eos>', -1)]]

    # Debug labels with Unicode code points
    print(f"Example {idx+1} - Source tokens: {src_tokens}")
    print(f"Example {idx+1} - Source labels: {src_labels}")
    print(f"Example {idx+1} - Pred tokens: {pred_tokens}")
    print(f"Example {idx+1} - Pred labels: {pred_labels}")
    print(f"Example {idx+1} - Pred labels (Unicode): {[f'U+{ord(c):04X}' for c in pred_labels if c != '?']}")
    print(f"Example {idx+1} - Attention weights shape: {attn_weights.shape}")

    # Filter out invalid characters (non-Telugu or '?')
    valid_pred_labels = []
    for char in pred_labels:
        if char == '?' or not (0x0C00 <= ord(char) <= 0x0C7F):  # Telugu Unicode range
            valid_pred_labels.append('?')  # Replace invalid chars with '?'
        else:
            valid_pred_labels.append(char)
    pred_labels = valid_pred_labels
    print(f"Example {idx+1} - Filtered pred_labels: {pred_labels}")

    # Truncate attention weights to match label lengths
    attn_weights = attn_weights[:min(len(pred_labels), attn_weights.shape[0]),
                                :min(len(src_labels), attn_weights.shape[1])]
    print(f"Example {idx+1} - Truncated attention weights shape: {attn_weights.shape}")

    # Calculate image size
    heatmap_width = len(src_labels) * cell_size
    heatmap_height = len(pred_labels) * cell_size
    img_width = heatmap_width + label_width + ylabel_width + 2 * margin
    img_height = heatmap_height + label_width + title_height + xlabel_height + 2 * margin

    # Create a new image with white background
    image = Image.new('RGB', (img_width, img_height), 'white')
    draw = ImageDraw.Draw(image)

    # Load Telugu font
    font_path = '/content/LohitTeluguRegular.ttf'
    if not os.path.exists(font_path):
        raise FileNotFoundError(f"Telugu font not found at {font_path}.")
    try:
        telugu_font = ImageFont.truetype(font_path, size=20)
    except Exception as e:
        raise Exception(f"Failed to load font {font_path}: {e}")

    # Default font for Latin text (use FreeSans if available, or fallback to default)
    try:
        latin_font = ImageFont.truetype('/usr/share/fonts/truetype/freefont/FreeSans.ttf', size=20)
    except:
        latin_font = ImageFont.load_default()

    # Draw title
    title = f'Example {idx+1}'
    draw.text((margin + ylabel_width, margin), title, font=latin_font, fill='black')

    # Draw x-axis labels (Source Tokens - Latin)
    for i, label in enumerate(src_labels):
        x = margin + ylabel_width + i * cell_size + cell_size // 2
        y = margin + title_height + heatmap_height + 10
        draw.text((x, y), label, font=latin_font, fill='black', anchor='mm')

    # Draw y-axis labels (Target Tokens - Telugu)
    for i, label in enumerate(pred_labels):
        x = margin + ylabel_width - 10
        y = margin + title_height + i * cell_size + cell_size // 2
        draw.text((x, y), label, font=telugu_font, fill='black', anchor='rm')

    # Draw x-axis title
    draw.text((margin + ylabel_width + heatmap_width // 2, margin + title_height + heatmap_height + xlabel_height - 10),
              'Source Tokens (Latin)', font=latin_font, fill='black', anchor='mm')

    # Draw y-axis title
    draw.text((margin + ylabel_width // 2, margin + title_height + heatmap_height // 2),
              'Target Tokens (Telugu)', font=telugu_font, fill='black', angle=90, anchor='mm')

    # Draw heatmap
    for i in range(len(pred_labels)):
        for j in range(len(src_labels)):
            # Normalize attention weights to [0, 1] for color mapping
            weight = attn_weights[i, j]
            # Map to a color (viridis-like: 0=blue, 1=yellow)
            r = int(255 * weight)
            g = int(255 * (1 - weight))
            b = 0
            color = (r, g, b)
            x0 = margin + ylabel_width + j * cell_size
            y0 = margin + title_height + i * cell_size
            draw.rectangle([x0, y0, x0 + cell_size, y0 + cell_size], fill=color)

    # Save image for verification
    if idx == 0:
        image.save(f"/content/heatmap_example_{idx+1}.png")
        print(f"Saved sample heatmap to /content/heatmap_example_{idx+1}.png")

    # Convert to WandB image
    wandb_image = wandb.Image(image, caption=f"Attention Heatmap Example {idx+1}")
    return wandb_image

In [None]:
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import wandb
import os

run = wandb.init(project="dakshina-seq2seq", entity="sai-sakunthala-indian-institute-of-technology-madras", name="evaluate_test")
artifact = run.use_artifact('sai-sakunthala-indian-institute-of-technology-madras/dakshina-seq2seq-3/best_model:v52', type='model')
artifact_dir = artifact.download()

# Read data and create vocabularies
test_pairs = read_data(data_path + f"{LANG}.translit.sampled.test.tsv", max_len=30)
train_pairs = read_data(data_path + f"{LANG}.translit.sampled.train.tsv", max_len=30)
source_vocab, idx2char_src = make_vocab([x[0] for x in train_pairs])
target_vocab, idx2char_tgt = make_vocab([x[1] for x in train_pairs])

# Model parameters (must match training)
input_dimensions = len(source_vocab)
output_dimensions = len(target_vocab)
emb_dimensions = 128
hid_dimensions = 128 * 2
num_layers = 2
dropout = 0.2
cell = 'lstm'
batch_size = 64
max_len = 30

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model
encoder = translit_Encoder(input_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell).to(device)
decoder = translit_Decoder(output_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell).to(device)
model = translit_Seq2Seq(encoder, decoder, device).to(device)

# Load model weights
state_dict = torch.load(f"{artifact_dir}/best_model.pt", map_location=device)
model.load_state_dict(state_dict)
model.eval()

# Create test dataset and loader
test_translit = TransliterationDataset(test_pairs, source_vocab, target_vocab)
test_loader = DataLoader(test_translit, batch_size=batch_size, shuffle=False, drop_last=True)

all_src, all_preds, all_tgts, all_attn_weights = [], [], [], []
correct = 0
total = 0
selected_examples = []

def predict(model, src, max_len=30):
    """Greedy decoding implementation with attention weights"""
    encoder_outputs, encoder_hidden, encoder_cell = model.encoder(src)
    input = torch.tensor([target_vocab['<sos>']] * src.size(0)).to(device)
    outputs = []
    attn_weights_list = []

    for t in range(max_len):
        output, encoder_hidden, encoder_cell, attn_weights = model.decoder(input, encoder_hidden, encoder_cell, encoder_outputs)
        input = output.argmax(1)  # Greedy decoding
        outputs.append(input)
        attn_weights_list.append(attn_weights)  # Collect attention weights

        # Stop if all sequences predicted EOS
        if (input == target_vocab.get('<eos>', -1)).all():
            break

    outputs = torch.stack(outputs, dim=1)  # (batch_size, max_len)
    attn_weights_all = torch.stack(attn_weights_list, dim=1)  # (batch_size, max_len, src_len)
    return outputs, attn_weights_all

def plot_attention_heatmap_grid(sources, predictions, attn_weights_list, idx2char_src, idx2char_tgt, num_plots=12):
    """Create a 4x3 grid of attention heatmaps in a WandB table"""
    table = wandb.Table(columns=['Heatmap 1', 'Heatmap 2', 'Heatmap 3'])

    # Generate heatmaps for each example
    heatmap_images = []
    for i in range(min(num_plots, len(sources))):
        heatmap_images.append(create_heatmap_image(
            sources[i], predictions[i], attn_weights_list[i], i, idx2char_src, idx2char_tgt
        ))

    # Fill table rows (4 rows, 3 columns each)
    for row_idx in range(4):
        row_data = [heatmap_images[row_idx * 3 + col_idx] if row_idx * 3 + col_idx < len(heatmap_images) else None for col_idx in range(3)]
        table.add_data(*row_data)

    wandb.log({"Attention Heatmap Grid (4x3)": table})

def log_attention_heatmaps_individually(sources, predictions, attn_weights_list, idx2char_src, idx2char_tgt, num_plots=10):
    for i in range(min(num_plots, len(sources))):
        heatmap_img = create_heatmap_image(
            sources[i], predictions[i], attn_weights_list[i], i, idx2char_src, idx2char_tgt
        )
        wandb.log({f"Attention Heatmap {i+1}": wandb.Image(heatmap_img)})

with torch.no_grad():
    for src, tgt in tqdm(test_loader):
        src, tgt = src.to(device), tgt.to(device)
        preds, attn_weights_batch = predict(model, src)

        # Convert to numpy arrays for processing
        src_np = src.cpu().numpy()
        preds_np = preds.cpu().numpy()
        tgt_np = tgt.cpu().numpy()
        attn_weights_np = attn_weights_batch.cpu().numpy()  # (batch_size, max_len, src_len)

        for i in range(len(src_np)):
            # Get source, prediction, target, and attention weights
            s = src_np[i]
            p = preds_np[i]
            t = tgt_np[i]
            attn = attn_weights_np[i]

            # Store sequences and attention weights for all examples
            all_src.append(s)
            all_preds.append(p)
            all_tgts.append(t)
            all_attn_weights.append(attn)

            # Collect up to 12 examples for heatmaps
            if len(selected_examples) < 12:
                selected_examples.append((s, p, t, attn))

            # Process prediction: remove padding and everything after EOS
            p_processed = []
            for token in p:
                if token == target_vocab.get('<eos>', -1):
                    break
                if token not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1)]:
                    p_processed.append(token)

            # Process target: remove padding and everything after EOS
            t_processed = []
            for token in t:
                if token == target_vocab.get('<eos>', -1):
                    break
                if token not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1)]:
                    t_processed.append(token)

            # Compare the processed sequences
            if p_processed == t_processed:
                correct += 1
            total += 1

# Plot and log attention heatmaps for 12 examples
if len(selected_examples) >= 10:
    sources, predictions, _, attn_weights_list = zip(*selected_examples[:12])
    log_attention_heatmaps_individually(sources, predictions, attn_weights_list, idx2char_src, idx2char_tgt, num_plots=12)

# Calculate accuracy
accuracy = correct / total if total > 0 else 0
print(f"Test Accuracy: {accuracy:.4f}")
print(f"Correct: {correct}, Total: {total}")
wandb.log({"Test Accuracy": accuracy})

# Log sample predictions table with at least 7 correct labels
def log_table_wandb(sources, preds, targets, idx2char_src, idx2char_tgt, num_samples=10, min_correct=7):
    table = wandb.Table(columns=["Source", "Prediction", "Reference", "Status"])

    # Collect correct and incorrect indices
    correct_indices = []
    incorrect_indices = []
    for i in range(len(sources)):
        src_word = ''.join([idx2char_src.get(idx, '?') for idx in sources[i] if idx not in [source_vocab.get('<pad>', -1), source_vocab.get('<sos>', -1), source_vocab.get('<eos>', -1)]])
        pred_word = ''.join([idx2char_tgt.get(idx, '?') for idx in preds[i] if idx not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1), target_vocab.get('<eos>', -1)]])
        ref_word = ''.join([idx2char_tgt.get(idx, '?') for idx in targets[i] if idx not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1), target_vocab.get('<eos>', -1)]])
        is_correct = (pred_word == ref_word)
        if is_correct:
            correct_indices.append((i, src_word, pred_word, ref_word))
        else:
            incorrect_indices.append((i, src_word, pred_word, ref_word))

    # Select at least min_correct correct samples, or all if fewer are available
    num_correct = min(len(correct_indices), min_correct)
    selected_correct = random.sample(correct_indices, num_correct) if correct_indices else []

    # Fill remaining slots with incorrect samples, up to num_samples
    remaining_slots = num_samples - len(selected_correct)
    selected_incorrect = random.sample(incorrect_indices, min(remaining_slots, len(incorrect_indices))) if incorrect_indices and remaining_slots > 0 else []

    # Combine and shuffle selected samples
    selected_samples = selected_correct + selected_incorrect
    random.shuffle(selected_samples)

    # Add to table
    for i, src_word, pred_word, ref_word in selected_samples:
        is_correct = (pred_word == ref_word)
        status = "🟩 **Correct**" if is_correct else "🟥 **Incorrect**"
        table.add_data(src_word, pred_word, ref_word, status)

    wandb.log({"Test Sample Predictions (Color-Coded)": table})
    print(f"Logging table: correct={len(correct_indices)}, incorrect={len(incorrect_indices)}")

log_table_wandb(all_src, all_preds, all_tgts, idx2char_src, idx2char_tgt, num_samples=10, min_correct=7)

# Save predictions to file
output_dir = "predictions_vanilla"
os.makedirs(output_dir, exist_ok=True)

with open(os.path.join(output_dir, "test_predictions.txt"), "w", encoding="utf-8") as f:
    for s, p, t in zip(all_src, all_preds, all_tgts):
        src_word = ''.join([idx2char_src.get(idx, '?') for idx in s if idx not in [source_vocab.get('<pad>', -1), source_vocab.get('<sos>', -1), source_vocab.get('<eos>', -1)]])
        pred_word = ''.join([idx2char_tgt.get(idx, '?') for idx in p if idx not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1), target_vocab.get('<eos>', -1)]])
        ref_word = ''.join([idx2char_tgt.get(idx, '?') for idx in t if idx not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1), target_vocab.get('<eos>', -1)]])
        f.write(f"{src_word}\t{pred_word}\t{ref_word}\n")

print(f"Saved full predictions to: {output_dir}/test_predictions.txt")
wandb.save(os.path.join(output_dir, "test_predictions.txt"))
wandb.finish()

[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 89/89 [00:50<00:00,  1.77it/s]


Example 1 - Source tokens: [ 1  3  4  5  3  4 18 21  2  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0]
Example 1 - Source labels: ['a', 'm', 'k', 'a', 'm', 'l', 'o']
Example 1 - Pred tokens: [ 3  4  5  4 21 29  2  2  2  2  2  2  2  2 29  2  2  2 29  2  2  2 29  2
  2  2 29  2  2  2]
Example 1 - Pred labels: ['అ', 'ం', 'క', 'ం', 'ల', 'ో', 'ో', 'ో', 'ో', 'ో']
Example 1 - Pred labels (Unicode): ['U+0C05', 'U+0C02', 'U+0C15', 'U+0C02', 'U+0C32', 'U+0C4B', 'U+0C4B', 'U+0C4B', 'U+0C4B', 'U+0C4B']
Example 1 - Attention weights shape: (30, 25)
Example 1 - Filtered pred_labels: ['అ', 'ం', 'క', 'ం', 'ల', 'ో', 'ో', 'ో', 'ో', 'ో']
Example 1 - Truncated attention weights shape: (10, 7)
Saved sample heatmap to /content/heatmap_example_1.png
Example 2 - Source tokens: [ 1  3  8  5  3  4 18 21  2  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0]
Example 2 - Source labels: ['a', 'n', 'k', 'a', 'm', 'l', 'o']
Example 2 - Pred tokens: [ 3  4  5  4 21 29  2  2  2  2  2  2  2  2 29  2  2  2 29  2 29  2  

0,1
Test Accuracy,▁

0,1
Test Accuracy,0.58585
