In [None]:
pip install torch wandb pandas tqdm

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random
import wandb
import editdistance
import numpy as np
import os
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
LANG = 'te'
data_path = f'/content/drive/MyDrive/dakshina_dataset_v1.0/{LANG}/lexicons/'

def read_data(filepath, max_len=40):
    pairs = []
    # Open the file with UTF-8 encoding to properly read Unicode characters
    with open(filepath, encoding='utf8') as f:
        for line in f:
            # Remove leading/trailing whitespace and split by tab
            parts = line.strip().split('\t')
            # Skip lines that don't contain both source and target text
            if len(parts) < 2:
                continue
            devanagiri, latin = parts[0], parts[1]

            # We are training a Latin → Devanagiri transliteration model,
            # so set Latin as the source and Devanagiri as the target
            source, target = latin, devanagiri

            # Only keep pairs where both source and target are within the allowed max length
            if len(source) <= max_len and len(target) <= max_len:
                pairs.append((source, target))

    # Return the list of filtered (source, target) pairs
    return pairs

def make_vocab(sequences):
    # Initialize the vocabulary with special tokens
    vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2}
    idx = 3  # Starting index for regular characters

    # Loop through all sequences to build the vocabulary
    for seq in sequences:
        for ch in seq:
            # Add each unique character to the vocabulary
            if ch not in vocab:
                vocab[ch] = idx
                idx += 1

    # Create reverse mapping from index to character
    idx2char = {i: c for c, i in vocab.items()}

    # Return both the character-to-index and index-to-character dictionaries
    return vocab, idx2char

def encode_word(word, vocab):
    # Convert a word into a list of indices using the vocabulary
    # Add <sos> token at the beginning and <eos> token at the end
    return [vocab['<sos>']] + [vocab[ch] for ch in word] + [vocab['<eos>']]

def pad_seq(seq, max_len, pad_idx=0):
    # Pad the sequence with <pad> tokens (default index 0) to reach max_len
    return seq + [pad_idx] * (max_len - len(seq))

class TransliterationDataset(Dataset):
    def __init__(self, pairs, source_vocab, target_vocab):
        # Save padding indices for both source and target vocabularies
        self.source_pad = source_vocab['<pad>']
        self.target_pad = target_vocab['<pad>']
        self.data = []

        # Convert each (source, target) word pair into sequences of token indices
        for source, target in pairs:
            source_t = encode_word(source, source_vocab)
            target_t = encode_word(target, target_vocab)
            self.data.append((source_t, target_t))

        # Determine the maximum lengths of source and target sequences
        self.source_max = max(len(x[0]) for x in self.data)
        self.target_max = max(len(x[1]) for x in self.data)

    def __len__(self):
        # Return total number of samples in the dataset
        return len(self.data)

    def __getitem__(self, idx):
        # Fetch a source-target pair and pad both to their respective max lengths
        source, target = self.data[idx]
        source = pad_seq(source, self.source_max, self.source_pad)
        target = pad_seq(target, self.target_max, self.target_pad)
        return torch.tensor(source), torch.tensor(target)

class Attention(nn.Module):
    def __init__(self, hid_dimensions):
        super().__init__()
        # Linear layer to compute attention scores from hidden and encoder outputs
        self.attn = nn.Linear(hid_dimensions * 2, hid_dimensions)

        # Learnable vector used to reduce the attention scores to a scalar
        self.v = nn.Parameter(torch.rand(hid_dimensions))

        # Initialize vector weights uniformly
        stdv = 1. / (hid_dimensions ** 0.5)
        self.v.data.uniform_(-stdv, stdv)

        self.hid_dimensions = hid_dimensions

    def forward(self, hidden, encoder_outputs):
        # hidden: decoder hidden state
        # encoder_outputs: all encoder outputs for the input sequence

        batch_size = encoder_outputs.size(0)
        src_len = encoder_outputs.size(1)

        # If hidden state has multiple layers, take the last one
        if hidden.dim() == 3:
            hidden = hidden[-1]
        elif hidden.dim() != 2:
            raise ValueError(f"Expected hidden to be 2D or 3D, got shape {hidden.shape}")

        # Repeat hidden state to match the number of encoder outputs
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

        # Concatenate hidden and encoder outputs, then pass through a non-linear layer
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))

        # Compute raw attention scores using the learnable vector `v`
        energy = energy @ self.v

        # Normalize scores into a probability distribution (attention weights)
        attn_weights = torch.softmax(energy, dim=1).unsqueeze(2)

        # Compute weighted sum of encoder outputs (context vector)
        context = torch.sum(attn_weights * encoder_outputs, dim=1)

        # Return both the context vector and the attention weights
        return context, attn_weights.squeeze(2)

class translit_Decoder(nn.Module):
    def __init__(self, output_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell='lstm'):
        super().__init__()

        # Embedding layer to convert token indices into dense vectors
        self.embedding = nn.Embedding(output_dimensions, emb_dimensions)

        # Attention module to focus on relevant parts of the encoder output
        self.attention = Attention(hid_dimensions)

        # Choose RNN type based on user-specified cell type
        rnn_cls = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}[cell.lower()]

        # RNN layer to process embedded inputs and context
        self.rnn = rnn_cls(
            emb_dimensions, hid_dimensions, num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )

        # Final fully connected layer to map combined context + RNN output to vocabulary logits
        self.fc_out = nn.Linear(hid_dimensions * 2, output_dimensions)

        # Store the type of RNN cell
        self.cell = cell.lower()

        # Apply dropout to the embeddings
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell, encoder_outputs):
        # Add time dimension to input (batch_size → batch_size x 1)
        input = input.unsqueeze(1)

        # Convert input token index to embedding and apply dropout
        embedded = self.dropout(self.embedding(input))

        # Pass through the RNN (handle LSTM and others differently)
        if self.cell == 'lstm':
            output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        else:
            output, hidden = self.rnn(embedded, hidden)
            cell = None  # Non-LSTM cells don't return a separate cell state

        # Use attention mechanism to compute context vector from encoder outputs
        context, attn_weights = self.attention(hidden, encoder_outputs)

        # Remove time dimension from RNN output
        rnn_output = output.squeeze(1)

        # Combine RNN output and context for final prediction
        combined = torch.cat((rnn_output, context), dim=1)

        # Compute the predicted output token scores
        prediction = self.fc_out(combined)

        # Return prediction, updated hidden/cell states, and attention weights
        return prediction, hidden, cell, attn_weights


class translit_Encoder(nn.Module):
    def __init__(self, input_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell='lstm'):
        super().__init__()

        # Embedding layer to convert input indices into dense vectors
        self.embedding = nn.Embedding(input_dimensions, emb_dimensions)

        # Choose RNN type based on cell argument
        rnn_cls = {'rnn': nn.RNN, 'gru': nn.GRU, 'lstm': nn.LSTM}[cell.lower()]

        # RNN layer to process the embedded input sequence
        self.rnn = rnn_cls(
            emb_dimensions, hid_dimensions, num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )

        # Store attention module and cell type
        self.attention = Attention(hid_dimensions)
        self.cell = cell.lower()

        # Dropout layer for regularization
        self.dropout = nn.Dropout(dropout)

    def forward(self, source):
        # Convert input token indices into embeddings and apply dropout
        embedded = self.dropout(self.embedding(source))

        # Pass embedded input through RNN
        if self.cell == 'lstm':
            outputs, (hidden, cell) = self.rnn(embedded)
        else:
            outputs, hidden = self.rnn(embedded)
            cell = None

        # Compute context using attention (optional, can be ignored in basic encoder usage)
        context = self.attention(hidden, outputs)

        # Return the full sequence of encoder outputs, last hidden state, and cell state (if any)
        return outputs, hidden, cell

class translit_Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder        # Encoder processes the input sequence
        self.decoder = decoder        # Decoder generates the output sequence
        self.device = device          # Device on which computation is performed (CPU/GPU)

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        batch_size = source.size(0)
        target_len = target.size(1)
        output_dimensions = self.decoder.fc_out.out_features

        # Initialize tensor to store decoder predictions for each time step
        outputs = torch.zeros(batch_size, target_len, output_dimensions).to(self.device)

        # Initialize tensor to keep track of attention weights over time
        attn_weights_all = torch.zeros(batch_size, target_len, source.size(1)).to(self.device)

        # Run the encoder on the source sequence to get hidden states
        encoder_outputs, hidden, cell = self.encoder(source)

        # Set initial decoder input to the <sos> token
        input = target[:, 0]

        # Loop over each time step in the target sequence
        for t in range(1, target_len):
            # Get decoder output and updated hidden states
            output, hidden, cell, attn_weights = self.decoder(input, hidden, cell, encoder_outputs)

            # Store the current output prediction
            outputs[:, t] = output

            # Save attention weights for this time step
            attn_weights_all[:, t] = attn_weights

            # Decide whether to use ground truth or model prediction for next input
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = target[:, t] if teacher_force else top1

        # Return the full sequence of predictions and attention weights
        return outputs, attn_weights_all


def strip_after_eos(seq, eos_idx):
    # Convert tensor to list if needed
    if isinstance(seq, torch.Tensor):
        seq = seq.cpu().numpy().tolist()
    # Trim the sequence at the first <eos> token
    if eos_idx in seq:
        return seq[:seq.index(eos_idx)]
    return seq

def calculate_word_accuracy(preds, targets, pad_idx=0, eos_idx=None):
    correct = 0
    for pred, target in zip(preds, targets):
        # Remove padding and stop at <eos> for fair comparison
        pred = strip_after_eos(pred, eos_idx) if eos_idx else pred
        target = strip_after_eos(target, eos_idx) if eos_idx else target
        pred = [p for p in pred if p != pad_idx]
        target = [t for t in target if t != pad_idx]
        # Count if full predicted word matches target
        correct += int(pred == target)
    return correct / max(len(preds), 1)


def calculate_cer(preds, targets, pad_idx=0, eos_idx=None):
    cer = 0
    total = 0
    for pred, target in zip(preds, targets):
        # Clean sequences by removing padding and trimming after <eos>
        pred = strip_after_eos(pred, eos_idx) if eos_idx else pred
        target = strip_after_eos(target, eos_idx) if eos_idx else target
        pred = [p for p in pred if p != pad_idx]
        target = [t for t in target if t != pad_idx]
        # Accumulate edit distance and total characters
        cer += editdistance.eval(pred, target)
        total += max(len(target), 1)
    return cer / total if total > 0 else float('inf')


def calculate_accuracy(preds, targets, pad_idx=0, eos_idx=None):
    correct = 0
    total = 0
    for pred, target in zip(preds, targets):
        # Convert tensors to lists if necessary
        if isinstance(pred, torch.Tensor):
            pred = pred.cpu().tolist()
        if isinstance(target, torch.Tensor):
            target = target.cpu().tolist()
        # Strip <eos> tokens if specified
        if eos_idx is not None:
            pred = strip_after_eos(pred, eos_idx)
            target = strip_after_eos(target, eos_idx)
        # Compare tokens one by one, ignoring padding
        for p_token, t_token in zip(pred, target):
            if t_token == pad_idx:
                continue
            if p_token == t_token:
                correct += 1
            total += 1
    return correct / total if total > 0 else 0.0

In [None]:
wandb.login()

In [None]:
!pip install Pillow

# Verify font file
!ls /content

# Font path
font_path = '/content/NotoSansTelugu-VariableFont.ttf'
import os
if os.path.exists(font_path):
    print(f"Font file found at {font_path}")
else:
    print(f"Font file not found at {font_path}. Please upload it.")
    from google.colab import files
    uploaded = files.upload()

In [None]:
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import wandb
import os

def create_heatmap_image(src_tokens, pred_tokens, attn_weights, idx, idx2char_src, idx2char_tgt):
    """Create a single heatmap image for WandB table using PIL"""
    # Define image dimensions
    cell_size = 50
    label_width = 150
    margin = 50
    title_height = 50
    xlabel_height = 50
    ylabel_width = 100

    # Filter tokens and get labels
    src_labels = [idx2char_src.get(idx, '?') for idx in src_tokens if idx in idx2char_src and idx not in [source_vocab.get('<pad>', -1), source_vocab.get('<sos>', -1), source_vocab.get('<eos>', -1)]]
    pred_labels = [idx2char_tgt.get(idx, '?') for idx in pred_tokens if idx in idx2char_tgt and idx not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1), target_vocab.get('<eos>', -1)]]

    # Debug labels with Unicode code points
    print(f"Example {idx+1} - Source tokens: {src_tokens}")
    print(f"Example {idx+1} - Source labels: {src_labels}")
    print(f"Example {idx+1} - Pred tokens: {pred_tokens}")
    print(f"Example {idx+1} - Pred labels: {pred_labels}")
    print(f"Example {idx+1} - Pred labels (Unicode): {[f'U+{ord(c):04X}' for c in pred_labels if c != '?']}")
    print(f"Example {idx+1} - Attention weights shape: {attn_weights.shape}")

    # Filter out invalid characters
    valid_pred_labels = []
    for char in pred_labels:
        if char == '?' or not (0x0C00 <= ord(char) <= 0x0C7F):  # Telugu Unicode range
            valid_pred_labels.append('?')  # Replace invalid chars with '?'
        else:
            valid_pred_labels.append(char)
    pred_labels = valid_pred_labels
    print(f"Example {idx+1} - Filtered pred_labels: {pred_labels}")

    # Truncate attention weights to match label lengths
    attn_weights = attn_weights[:min(len(pred_labels), attn_weights.shape[0]),
                                :min(len(src_labels), attn_weights.shape[1])]
    print(f"Example {idx+1} - Truncated attention weights shape: {attn_weights.shape}")

    # Calculate image size
    heatmap_width = len(src_labels) * cell_size
    heatmap_height = len(pred_labels) * cell_size
    img_width = heatmap_width + label_width + ylabel_width + 2 * margin
    img_height = heatmap_height + label_width + title_height + xlabel_height + 2 * margin

    # Create a new image with white background
    image = Image.new('RGB', (img_width, img_height), 'white')
    draw = ImageDraw.Draw(image)

    # Load Telugu font
    font_path = '/content/LohitTeluguRegular.ttf'
    if not os.path.exists(font_path):
        raise FileNotFoundError(f"Telugu font not found at {font_path}.")
    try:
        telugu_font = ImageFont.truetype(font_path, size=20)
    except Exception as e:
        raise Exception(f"Failed to load font {font_path}: {e}")

    # Default font for Latin text (use FreeSans if available, or fallback to default)
    try:
        latin_font = ImageFont.truetype('/usr/share/fonts/truetype/freefont/FreeSans.ttf', size=20)
    except:
        latin_font = ImageFont.load_default()

    # Draw title
    title = f'Example {idx+1}'
    draw.text((margin + ylabel_width, margin), title, font=latin_font, fill='black')

    # Draw x-axis labels (Source Tokens - Latin)
    for i, label in enumerate(src_labels):
        x = margin + ylabel_width + i * cell_size + cell_size // 2
        y = margin + title_height + heatmap_height + 10
        draw.text((x, y), label, font=latin_font, fill='black', anchor='mm')

    # Draw y-axis labels (Target Tokens - Telugu)
    for i, label in enumerate(pred_labels):
        x = margin + ylabel_width - 10
        y = margin + title_height + i * cell_size + cell_size // 2
        draw.text((x, y), label, font=telugu_font, fill='black', anchor='rm')

    # Draw x-axis title
    draw.text((margin + ylabel_width + heatmap_width // 2, margin + title_height + heatmap_height + xlabel_height - 10),
              'Source Tokens (Latin)', font=latin_font, fill='black', anchor='mm')

    # Draw y-axis title
    draw.text((margin + ylabel_width // 2, margin + title_height + heatmap_height // 2),
              'Target Tokens (Telugu)', font=telugu_font, fill='black', angle=90, anchor='mm')

    # Draw heatmap
    for i in range(len(pred_labels)):
        for j in range(len(src_labels)):
            # Normalize attention weights to [0, 1] for color mapping
            weight = attn_weights[i, j]
            # Map to a color (viridis-like: 0=blue, 1=yellow)
            r = int(255 * weight)
            g = int(255 * (1 - weight))
            b = 0
            color = (r, g, b)
            x0 = margin + ylabel_width + j * cell_size
            y0 = margin + title_height + i * cell_size
            draw.rectangle([x0, y0, x0 + cell_size, y0 + cell_size], fill=color)

    # Save image for verification
    if idx == 0:
        image.save(f"/content/heatmap_example_{idx+1}.png")
        print(f"Saved sample heatmap to /content/heatmap_example_{idx+1}.png")

    # Convert to WandB image
    wandb_image = wandb.Image(image, caption=f"Attention Heatmap Example {idx+1}")
    return wandb_image

In [None]:
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import wandb
import os

run = wandb.init(project="dakshina-seq2seq", entity="sai-sakunthala-indian-institute-of-technology-madras", name="evaluate_test")
artifact = run.use_artifact('sai-sakunthala-indian-institute-of-technology-madras/dakshina-seq2seq-3/best_model:v52', type='model')
artifact_dir = artifact.download()

# Read data and create vocabularies
test_pairs = read_data(data_path + f"{LANG}.translit.sampled.test.tsv", max_len=30)
train_pairs = read_data(data_path + f"{LANG}.translit.sampled.train.tsv", max_len=30)
source_vocab, idx2char_src = make_vocab([x[0] for x in train_pairs])
target_vocab, idx2char_tgt = make_vocab([x[1] for x in train_pairs])

# Model parameters (must match training)
input_dimensions = len(source_vocab)
output_dimensions = len(target_vocab)
emb_dimensions = 128
hid_dimensions = 128 * 2
num_layers = 2
dropout = 0.2
cell = 'lstm'
batch_size = 64
max_len = 30

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model
encoder = translit_Encoder(input_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell).to(device)
decoder = translit_Decoder(output_dimensions, emb_dimensions, hid_dimensions, num_layers, dropout, cell).to(device)
model = translit_Seq2Seq(encoder, decoder, device).to(device)

# Load model weights
state_dict = torch.load(f"{artifact_dir}/best_model.pt", map_location=device)
model.load_state_dict(state_dict)
model.eval()

# Create test dataset and loader
test_translit = TransliterationDataset(test_pairs, source_vocab, target_vocab)
test_loader = DataLoader(test_translit, batch_size=batch_size, shuffle=False, drop_last=True)

all_src, all_preds, all_tgts, all_attn_weights = [], [], [], []
correct = 0
total = 0
selected_examples = []

def predict(model, src, max_len=30):
    """Greedy decoding implementation with attention weights"""
    encoder_outputs, encoder_hidden, encoder_cell = model.encoder(src)
    input = torch.tensor([target_vocab['<sos>']] * src.size(0)).to(device)
    outputs = []
    attn_weights_list = []

    for t in range(max_len):
        output, encoder_hidden, encoder_cell, attn_weights = model.decoder(input, encoder_hidden, encoder_cell, encoder_outputs)
        input = output.argmax(1)
        outputs.append(input)
        attn_weights_list.append(attn_weights)

        if (input == target_vocab.get('<eos>', -1)).all():
            break

    outputs = torch.stack(outputs, dim=1)  # (batch_size, max_len)
    attn_weights_all = torch.stack(attn_weights_list, dim=1)  # (batch_size, max_len, src_len)
    return outputs, attn_weights_all

def log_attention_heatmaps_individually(sources, predictions, attn_weights_list, idx2char_src, idx2char_tgt, num_plots=10):
    for i in range(min(num_plots, len(sources))):
        heatmap_img = create_heatmap_image(
            sources[i], predictions[i], attn_weights_list[i], i, idx2char_src, idx2char_tgt
        )
        wandb.log({f"Attention Heatmap {i+1}": wandb.Image(heatmap_img)})

with torch.no_grad():
    for src, tgt in tqdm(test_loader):
        src, tgt = src.to(device), tgt.to(device)
        preds, attn_weights_batch = predict(model, src)

        # Convert to numpy arrays for processing
        src_np = src.cpu().numpy()
        preds_np = preds.cpu().numpy()
        tgt_np = tgt.cpu().numpy()
        attn_weights_np = attn_weights_batch.cpu().numpy()  # (batch_size, max_len, src_len)

        for i in range(len(src_np)):
            # Get source, prediction, target, and attention weights
            s = src_np[i]
            p = preds_np[i]
            t = tgt_np[i]
            attn = attn_weights_np[i]

            # Store sequences and attention weights for all examples
            all_src.append(s)
            all_preds.append(p)
            all_tgts.append(t)
            all_attn_weights.append(attn)

            # Collect up to 12 examples for heatmaps
            if len(selected_examples) < 12:
                selected_examples.append((s, p, t, attn))

            # Process prediction: remove padding and everything after EOS
            p_processed = []
            for token in p:
                if token == target_vocab.get('<eos>', -1):
                    break
                if token not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1)]:
                    p_processed.append(token)

            # Process target: remove padding and everything after EOS
            t_processed = []
            for token in t:
                if token == target_vocab.get('<eos>', -1):
                    break
                if token not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1)]:
                    t_processed.append(token)

            # Compare the processed sequences
            if p_processed == t_processed:
                correct += 1
            total += 1

# Plot and log attention heatmaps for 12 examples
if len(selected_examples) >= 12:
    sources, predictions, _, attn_weights_list = zip(*selected_examples[:12])
    log_attention_heatmaps_individually(sources, predictions, attn_weights_list, idx2char_src, idx2char_tgt, num_plots=12)

# Calculate accuracy
accuracy = correct / total if total > 0 else 0
print(f"Test Accuracy: {accuracy:.4f}")
print(f"Correct: {correct}, Total: {total}")
wandb.log({"Test Accuracy": accuracy})


def log_table_wandb(sources, preds, targets, idx2char_src, idx2char_tgt, num_samples=10, min_correct=7):

    # Create W&B table with relevant column headers
    table = wandb.Table(columns=["Source", "Prediction", "Reference", "Status"])

    total_samples = len(sources)
    indices = random.sample(range(total_samples), min(num_samples, total_samples))

    for i in indices:
        # Convert index sequences to readable strings, skipping special tokens
        src_word = ''.join([idx2char_src.get(idx, '?') for idx in sources[i]
                            if idx not in {source_vocab.get('<pad>'), source_vocab.get('<sos>'), source_vocab.get('<eos>')}])

        pred_word = ''.join([idx2char_tgt.get(idx, '?') for idx in preds[i]
                             if idx not in {target_vocab.get('<pad>'), target_vocab.get('<sos>'), target_vocab.get('<eos>')}])

        ref_word = ''.join([idx2char_tgt.get(idx, '?') for idx in targets[i]
                            if idx not in {target_vocab.get('<pad>'), target_vocab.get('<sos>'), target_vocab.get('<eos>')}])

        # Determine if prediction matches the reference
        is_correct = pred_word == ref_word
        status = "🟩 **Correct**" if is_correct else "🟥 **Incorrect**"

        # Add row to table
        table.add_data(src_word, pred_word, ref_word, status)

    # Log the table to Weights & Biases
    wandb.log({"Test Sample Predictions (Random)": table})
    print(f"Logged {len(indices)} random predictions to W&B")

#log table
log_table_wandb(all_src, all_preds, all_tgts, idx2char_src, idx2char_tgt, num_samples=10, min_correct=7)

# Save predictions to file
output_dir = "predictions_attention"
os.makedirs(output_dir, exist_ok=True)

with open(os.path.join(output_dir, "test_predictions.txt"), "w", encoding="utf-8") as f:
    for s, p, t in zip(all_src, all_preds, all_tgts):
        src_word = ''.join([idx2char_src.get(idx, '?') for idx in s if idx not in [source_vocab.get('<pad>', -1), source_vocab.get('<sos>', -1), source_vocab.get('<eos>', -1)]])
        pred_word = ''.join([idx2char_tgt.get(idx, '?') for idx in p if idx not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1), target_vocab.get('<eos>', -1)]])
        ref_word = ''.join([idx2char_tgt.get(idx, '?') for idx in t if idx not in [target_vocab.get('<pad>', -1), target_vocab.get('<sos>', -1), target_vocab.get('<eos>', -1)]])
        f.write(f"{src_word}\t{pred_word}\t{ref_word}\n")

print(f"Saved full predictions to: {output_dir}/test_predictions.txt")
wandb.save(os.path.join(output_dir, "test_predictions.txt"))
wandb.finish()