In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
import wandb
from torch.utils.data import DataLoader, Dataset


In [1]:
#This cell contains necessary code for dataset preprocessing and at the I print few examples for looking how the dataset looks like
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

trainPth = "/mnt/e_disk/DA6401_Assignment3/dataset/dakshina_dataset_v1.0/ta/lexicons/ta.translit.sampled.train.tsv"
devPth   = "/mnt/e_disk/DA6401_Assignment3/dataset/dakshina_dataset_v1.0/ta/lexicons/ta.translit.sampled.dev.tsv"
testPth = "/mnt/e_disk/DA6401_Assignment3/dataset/dakshina_dataset_v1.0/ta/lexicons/ta.translit.sampled.test.tsv"
def get_vocab(paths):
    chars = set()
    for path in paths:
        with open(path, encoding="utf-8") as f:
            for line in f:
                native, roman, _ = line.strip().split("\t")
                chars.update(native)
                chars.update(roman)
    return chars

def get_char2idx(char_set):
    chars = ["<pad>", "<sos>", "<eos>", "<unk>"] + sorted(char_set)
    return {ch: i for i, ch in enumerate(chars)}, chars



char_set = get_vocab([trainPth, devPth])
roman2idx, idx2roman = get_char2idx(set(c for c in char_set if c.isascii()))
dev2idx, idx2dev = get_char2idx(set(c for c in char_set if not c.isascii()))

class TranslitDataset(Dataset):
    def __init__(self, path, src_c2i, tgt_c2i, max_len=32):
        self.data = []
        with open(path, encoding="utf-8") as f:
            for line in f:
                native, roman, _ = line.strip().split("\t")
                self.data.append((roman, native))
        self.src_c2i = src_c2i
        self.tgt_c2i = tgt_c2i
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        roman, native = self.data[i]
        src = [self.src_c2i.get(c, self.src_c2i["<unk>"]) for c in roman[:self.max_len]]
        tgt = [self.tgt_c2i["<sos>"]] + \
              [self.tgt_c2i.get(c, self.tgt_c2i["<unk>"]) for c in native[:self.max_len - 1]] + \
              [self.tgt_c2i["<eos>"]]
        return torch.tensor(src), torch.tensor(tgt)

def pad_batch(batch):
    src, tgt = zip(*batch)
    src = pad_sequence(src, batch_first=True, padding_value=roman2idx["<pad>"])
    tgt = pad_sequence(tgt, batch_first=True, padding_value=dev2idx["<pad>"])
    return src, tgt

train_ds = TranslitDataset(trainPth, roman2idx, dev2idx, max_len=32)
dev_ds   = TranslitDataset(devPth, roman2idx, dev2idx, max_len=32)
test_ds   = TranslitDataset(testPth, roman2idx, dev2idx, max_len=32)

print("Train set")
for i in range(5):
    src, tgt = train_ds[i]
    roman = ''.join([idx2roman[idx] for idx in src])
    native = ''.join([idx2dev[idx] for idx in tgt[1:-1]])  # skip <sos> and <eos>
    print(f"{i+1}. Roman: {roman:20s}  →  Native: {native}")
print("Dev set")
for i in range(5):
    src, tgt = dev_ds[i]
    roman = ''.join([idx2roman[idx] for idx in src])
    native = ''.join([idx2dev[idx] for idx in tgt[1:-1]])
    print(f"{i+1}. Roman: {roman:20s}  →  Native: {native}")


Train set
1. Roman: fiat                  →  Native: ஃபியட்
2. Roman: phiyat                →  Native: ஃபியட்
3. Roman: piyat                 →  Native: ஃபியட்
4. Roman: firaans               →  Native: ஃபிரான்ஸ்
5. Roman: france                →  Native: ஃபிரான்ஸ்
Dev set
1. Roman: fire                  →  Native: ஃபயர்
2. Roman: phayar                →  Native: ஃபயர்
3. Roman: baar                  →  Native: ஃபார்
4. Roman: bar                   →  Native: ஃபார்
5. Roman: far                   →  Native: ஃபார்


In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
import wandb
from torch.utils.data import DataLoader, Dataset

# Attention Module
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_dim * 2, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        src_len = encoder_outputs.size(1)
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = self.v(energy).squeeze(2)
        return torch.softmax(attention, dim=1)

# Decoder
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hidden_dim, num_layers, cell_type, dropout):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.embedding_dropout = nn.Dropout(dropout)
        self.rnn = getattr(nn, cell_type.upper())(emb_dim + hidden_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
        self.fc_out = nn.Linear(hidden_dim * 2, output_dim)
        self.attention = Attention(hidden_dim)
        self.output_dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, encoder_outputs):
        if input.dim() == 1:
            input = input.unsqueeze(1)
        embedded = self.embedding_dropout(self.embedding(input))

        if isinstance(hidden, tuple):
            last_hidden = hidden[0][-1]
        else:
            last_hidden = hidden[-1]

        attn_weights = self.attention(last_hidden, encoder_outputs)
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)

        rnn_input = torch.cat((embedded, context), dim=2)
        output, hidden = self.rnn(rnn_input, hidden)
        output = self.output_dropout(output)

        prediction = self.fc_out(torch.cat((output.squeeze(1), context.squeeze(1)), dim=1))
        return prediction, hidden, attn_weights

# Encoder
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, num_layers, cell_type, dropout):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.embedding_dropout = nn.Dropout(dropout)
        self.rnn = getattr(nn, cell_type.upper())(emb_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)

    def forward(self, src):
        embedded = self.embedding_dropout(self.embedding(src))
        outputs, hidden = self.rnn(embedded)
        return outputs, hidden

# Seq2Seq Wrapper
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size, trg_len = trg.shape
        output_dim = self.decoder.fc_out.out_features
        outputs = torch.zeros(batch_size, trg_len, output_dim).to(self.device)

        encoder_outputs, hidden = self.encoder(src)
        input = trg[:, 0]

        for t in range(1, trg_len):
            output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
            outputs[:, t] = output
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            input = trg[:, t] if teacher_force else output.argmax(1)

        return outputs

def trim_eos(seq, eos_idx):
    seq = seq.tolist()
    if eos_idx in seq:
        return seq[:seq.index(eos_idx)]
    return seq

# Modified evaluate_with_attention
def evaluate_with_attention(model, input_seq, idx2roman, idx2dev, eos_idx, device):
    model.eval()
    with torch.no_grad():
        encoder_outputs, hidden = model.encoder(input_seq.unsqueeze(0))  # [1, src_len]
        input_token = input_seq[0].view(1)
        decoded = []
        attentions = []

        for _ in range(32):  # max_len
            output, hidden, attn = model.decoder(input_token, hidden, encoder_outputs)
            top1 = output.argmax(1)
            if top1.item() == eos_idx:
                break
            decoded.append(top1.item())
            attentions.append(attn.squeeze(0).cpu())
            input_token = top1.unsqueeze(0)

    input_tokens = [idx2roman[idx.item()] for idx in input_seq]
    output_tokens = [idx2dev[idx] for idx in decoded]

    if attentions:
        attention_tensor = torch.stack(attentions)  # [output_len, input_len]
    else:
        attention_tensor = torch.empty(0)

    return output_tokens, input_tokens, attention_tensor


def calc_word_accuracy(model, loader, eos_idx, device):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for src, tgt in loader:
            src, tgt = src.to(device), tgt.to(device)
            for i in range(src.size(0)):
                pred_seq, _, _ = evaluate_with_attention(model, src[i], idx2roman, idx2dev, eos_idx, device)
                true_seq = trim_eos(tgt[i][1:], eos_idx)
                pred_idx = [dev2idx[c] for c in pred_seq if c in dev2idx]
                if pred_idx == true_seq:
                    correct += 1
                total += 1
    return correct / total if total > 0 else 0

def train_model(config=None):
    with wandb.init(config=config):
        config = wandb.config
        wandb.run.name = f"cell_{config.cell_type}/hid_{config.hidden_dim}/emb_{config.emb_dim}/lay_{config.num_layers}/lr_{config.lr}"

        device = 'cuda' if torch.cuda.is_available() else 'cpu'

        train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True, collate_fn=pad_batch)
        dev_loader   = DataLoader(dev_ds, batch_size=config.batch_size, shuffle=False, collate_fn=pad_batch)

        vocab_size_input = len(roman2idx)
        vocab_size_output = len(dev2idx)

        encoder = Encoder(vocab_size_input, config.emb_dim, config.hidden_dim, config.num_layers, config.cell_type, config.dropout)
        decoder = Decoder(vocab_size_output, config.emb_dim, config.hidden_dim, config.num_layers, config.cell_type, config.dropout)
        model = Seq2Seq(encoder, decoder, device).to(device)

        optimizer = opt.Adam(model.parameters(), lr=config.lr)
        criterion = nn.CrossEntropyLoss(ignore_index=dev2idx["<pad>"])

        best_val_loss = float('inf')
        save_path = os.path.join(wandb.run.dir, 'best_model.pth')
        eos_idx = dev2idx["<eos>"]

        for epoch in range(config.epochs):
            model.train()
            train_loss = 0
            token_correct = 0
            token_total = 0

            for src, tgt in train_loader:
                src, tgt = src.to(device), tgt.to(device)
                optimizer.zero_grad()
                output = model(src, tgt, teacher_forcing_ratio=0.5)

                output_flat = output[:, 1:].reshape(-1, vocab_size_output)
                tgt_flat = tgt[:, 1:].reshape(-1)
                loss = criterion(output_flat, tgt_flat)
                loss.backward()
                optimizer.step()

                train_loss += loss.item() * src.size(0)

                preds = output_flat.argmax(1)
                mask = tgt_flat != dev2idx["<pad>"]
                token_correct += ((preds == tgt_flat) & mask).sum().item()
                token_total += mask.sum().item()

            train_loss /= len(train_ds)
            train_token_acc = token_correct / token_total if token_total > 0 else 0

            val_loss = 0
            val_token_correct = 0
            val_token_total = 0

            model.eval()
            with torch.no_grad():
                for src, tgt in dev_loader:
                    src, tgt = src.to(device), tgt.to(device)
                    output = model(src, tgt, teacher_forcing_ratio=0.0)

                    output_flat = output[:, 1:].reshape(-1, vocab_size_output)
                    tgt_flat = tgt[:, 1:].reshape(-1)
                    loss = criterion(output_flat, tgt_flat)
                    val_loss += loss.item() * src.size(0)

                    preds = output_flat.argmax(1)
                    mask = tgt_flat != dev2idx["<pad>"]
                    val_token_correct += ((preds == tgt_flat) & mask).sum().item()
                    val_token_total += mask.sum().item()

            val_loss /= len(dev_ds)
            val_token_acc = val_token_correct / val_token_total if val_token_total > 0 else 0
            val_word_acc = calc_word_accuracy(model, dev_loader, eos_idx, device)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), save_path)
                artifact = wandb.Artifact('best-model', type='model')
                artifact.add_file(save_path)
                wandb.log_artifact(artifact)

            wandb.log({
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_token_accuracy': train_token_acc,
                'val_token_accuracy': val_token_acc,
                'val_word_accuracy': val_word_acc
            })

# Sweep config
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'val_loss', 'goal': 'minimize'},
    'parameters': {
        'epochs': {'values': [10, 15]},
        'emb_dim': {'values': [64, 128]},
        'hidden_dim': {'values': [128, 256]},
        'num_layers': {'values': [1, 2]},
        'cell_type': {'values': ['RNN', 'GRU', 'LSTM']},
        'lr': {'values': [1e-3, 1e-4]},
        'batch_size': {'values': [32, 64]},
        'dropout': {'values': [0.2, 0.3]},
        'beam_size': {'values': [1, 3, 5]}
    }
}


In [3]:

# wandb sweep
sweep_id = wandb.sweep(sweep_config, project='Assignment3_attention')
wandb.agent(sweep_id, function=train_model, count=5)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Create sweep with ID: ex5t9y6l
Sweep URL: https://wandb.ai/navaneeth001/Assignment3_attention/sweeps/ex5t9y6l


[34m[1mwandb[0m: Agent Starting Run: 0ouaa0ll with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	beam_size: 5
[34m[1mwandb[0m: 	cell_type: GRU
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	emb_dim: 64
[34m[1mwandb[0m: 	epochs: 15
[34m[1mwandb[0m: 	hidden_dim: 128
[34m[1mwandb[0m: 	lr: 0.0001
[34m[1mwandb[0m: 	num_layers: 2
[34m[1mwandb[0m: Currently logged in as: [33mnavaneeth001[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train_loss,█▄▃▂▂▂▂▁▁▁▁▁▁▁▁
train_token_accuracy,▁▅▆▇▇▇█████████
val_loss,█▄▃▂▂▂▂▂▁▁▁▁▁▁▁
val_token_accuracy,▁▅▇▇▇▇█████████
val_word_accuracy,▁▄▅▆▆▇▇▇▇██████

0,1
epoch,15.0
train_loss,0.44512
train_token_accuracy,0.88396
val_loss,0.69121
val_token_accuracy,0.84968
val_word_accuracy,0.50696


[34m[1mwandb[0m: Agent Starting Run: cwozmg1s with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	beam_size: 3
[34m[1mwandb[0m: 	cell_type: LSTM
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	emb_dim: 64
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	hidden_dim: 128
[34m[1mwandb[0m: 	lr: 0.0001
[34m[1mwandb[0m: 	num_layers: 1


0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_loss,█▆▄▃▂▂▁▁▁▁
train_token_accuracy,▁▃▅▆▇▇████
val_loss,█▆▄▃▂▂▁▁▁▁
val_token_accuracy,▁▃▅▆▇▇████
val_word_accuracy,▁▁▃▅▅▆▇▇██

0,1
epoch,10.0
train_loss,0.60541
train_token_accuracy,0.84388
val_loss,0.81119
val_token_accuracy,0.81658
val_word_accuracy,0.40032


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: j5dpo5ji with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	beam_size: 1
[34m[1mwandb[0m: 	cell_type: RNN
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	emb_dim: 128
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	hidden_dim: 128
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	num_layers: 1


0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_loss,█▃▂▂▂▁▁▁▁▁
train_token_accuracy,▁▆▇▇▇█████
val_loss,█▄▄▄▁▁▃▂▃▃
val_token_accuracy,▁▄▆▆██▇▇█▇
val_word_accuracy,▁▄▇▅▆██▆▆▆

0,1
epoch,10.0
train_loss,0.4025
train_token_accuracy,0.89496
val_loss,0.75927
val_token_accuracy,0.8386
val_word_accuracy,0.41277


[34m[1mwandb[0m: Agent Starting Run: 687e9gyx with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	beam_size: 1
[34m[1mwandb[0m: 	cell_type: RNN
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	emb_dim: 64
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	hidden_dim: 128
[34m[1mwandb[0m: 	lr: 0.0001
[34m[1mwandb[0m: 	num_layers: 1


0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_loss,█▅▃▂▂▁▁▁▁▁
train_token_accuracy,▁▄▆▇▇█████
val_loss,█▅▄▃▂▂▂▁▁▁
val_token_accuracy,▁▄▅▆▇█▇███
val_word_accuracy,▁▂▃▅▆▇▇▇██

0,1
epoch,10.0
train_loss,0.64609
train_token_accuracy,0.83026
val_loss,0.81307
val_token_accuracy,0.81546
val_word_accuracy,0.41643


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: y2slk2ty with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	beam_size: 5
[34m[1mwandb[0m: 	cell_type: GRU
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	emb_dim: 64
[34m[1mwandb[0m: 	epochs: 15
[34m[1mwandb[0m: 	hidden_dim: 256
[34m[1mwandb[0m: 	lr: 0.0001
[34m[1mwandb[0m: 	num_layers: 2


0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train_loss,█▃▃▂▂▂▂▁▁▁▁▁▁▁▁
train_token_accuracy,▁▆▇▇▇▇▇████████
val_loss,█▄▄▂▂▂▂▁▂▁▁▁▁▁▁
val_token_accuracy,▁▅▅▇▇▇██▇██████
val_word_accuracy,▁▄▄▆▆▇▇▇▇▇█████

0,1
epoch,15.0
train_loss,0.31577
train_token_accuracy,0.9176
val_loss,0.6495
val_token_accuracy,0.86141
val_word_accuracy,0.57068


In [9]:
import os
import csv
import torch
import wandb
from wandb import Api
from torch.utils.data import DataLoader
from tqdm import tqdm

# === Load best model config from sweep ===
ENTITY     = 'navaneeth001'
PROJECT    = 'Assignment3_attention'
SWEEP_ID   = 'ex5t9y6l'
ARTIFACT_REF = 'navaneeth001/Assignment3_attention/best-model:v67'
OUTPUT_DIR = '/mnt/e_disk/DA6401_Assignment3/predictions_attention'
CSV_PATH   = os.path.join(OUTPUT_DIR, 'test_predictions.csv')

os.makedirs(OUTPUT_DIR, exist_ok=True)
pred_rows = [["Input", "Target", "Prediction"]]

api      = Api()
sweep    = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
runs     = sweep.runs
best_run = max(runs, key=lambda r: r.summary.get('val_word_accuracy', float('inf')))
cfg      = best_run.config

eval_run = wandb.init(
    project=PROJECT,
    entity=ENTITY,
    job_type='evaluation'
)
artifact     = eval_run.use_artifact(ARTIFACT_REF, type='model')
download_dir = artifact.download()

model_path   = '/mnt/e_disk/DA6401_Assignment3/artifacts/best-model:v67/best_model.pth'
print(f"Loaded model artifact to: {model_path}")

# === Load model ===
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def evaluate_with_attention(model, input_seq, idx2roman, idx2dev, eos_idx, device):
    model.eval()
    with torch.no_grad():
        encoder_outputs, hidden = model.encoder(input_seq.unsqueeze(0))
        input_token = input_seq[0].view(1)
        decoded = []
        attentions = []

        for _ in range(32):  # max_len
            output, hidden, attn = model.decoder(input_token, hidden, encoder_outputs)
            top1 = output.argmax(1)
            if top1.item() == eos_idx:
                break
            decoded.append(top1.item())
            attentions.append(attn.squeeze(0).cpu())
            input_token = top1.unsqueeze(0)

    input_tokens = [idx2roman[idx.item()] for idx in input_seq]
    output_tokens = [idx2dev[idx] for idx in decoded]

    if attentions:
        attention_tensor = torch.stack(attentions)
    else:
        attention_tensor = torch.empty(0)

    return output_tokens, input_tokens, attention_tensor

def test_model(model_path, cfg, device, pred_rows):
    test_loader = DataLoader(test_ds, batch_size=64, shuffle=False, collate_fn=pad_batch)

    vocab_size_input = len(roman2idx)
    vocab_size_output = len(dev2idx)

    encoder = Encoder(
        vocab_size_input,
        64,
        256,
        2,
        "GRU",
        0.3
    )
    decoder = Decoder(
        vocab_size_output,
        64,
        256,
        2,
        "GRU",
        0.3
    )

    model = Seq2Seq(encoder, decoder, device).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    eos_idx = dev2idx["<eos>"]
    total_token_correct = 0
    total_token_count = 0
    total_word_correct = 0
    total_word_count = 0

    with torch.no_grad():
        for src, tgt in test_loader:
            src, tgt = src.to(device), tgt.to(device)

            for i in range(src.size(0)):
                input_seq = src[i]
                true_seq = trim_eos(tgt[i][1:], eos_idx)

                pred_tokens, input_tokens, attn_weights = evaluate_with_attention(
                    model, input_seq, idx2roman, idx2dev, eos_idx, device
                )
                pred_idx = [dev2idx[token] for token in pred_tokens if token in dev2idx]

                # Token Accuracy
                match_length = min(len(pred_idx), len(true_seq))
                for j in range(match_length):
                    if pred_idx[j] == true_seq[j]:
                        total_token_correct += 1
                total_token_count += len(true_seq)

                # Word Accuracy
                if pred_idx == true_seq:
                    total_word_correct += 1
                total_word_count += 1

                # Save row to CSV
                pred_rows.append([
                    "".join(input_tokens),
                    "".join([idx2dev[int(i)] for i in true_seq]),
                    "".join(pred_tokens)
                ])

    token_accuracy = total_token_correct / total_token_count if total_token_count > 0 else 0
    word_accuracy = total_word_correct / total_word_count if total_word_count > 0 else 0

    print(f"\nTest Token Accuracy: {token_accuracy:.4f}")
    print(f"Test Word Accuracy:  {word_accuracy:.4f}")

    wandb.log({
        'test_token_accuracy': token_accuracy,
        'test_word_accuracy': word_accuracy
    })

# === Run test and collect predictions ===
test_model(model_path, cfg, device, pred_rows)

# === Save predictions to CSV ===
with open(CSV_PATH, 'w', newline='', encoding='utf-8') as f:
    writer = csv.writer(f)
    writer.writerows(pred_rows)

print(f"Saved predictions to {CSV_PATH}")

eval_run.finish()


[34m[1mwandb[0m:   1 of 1 files downloaded.  


Loaded model artifact to: /mnt/e_disk/DA6401_Assignment3/artifacts/best-model:v67/best_model.pth

Test Token Accuracy: 0.8431
Test Word Accuracy:  0.5479
Saved predictions to /mnt/e_disk/DA6401_Assignment3/predictions_attention/test_predictions.csv


0,1
test_token_accuracy,▁
test_word_accuracy,▁

0,1
test_token_accuracy,0.84314
test_word_accuracy,0.54793


# Question 6 (20 Marks)

This a challenge question and most of you will find it hard. 

I like the visualisation in the figure captioned "Connectivity" in this [article](https://distill.pub/2019/memorization-in-rnns/#appendix-autocomplete). Make a similar visualisation for your model. Please look at this [blog](https://medium.com/data-science/visualising-lstm-activations-in-keras-b50206da96ff) for some starter code. The goal is to figure out the following: When the model is decoding the $i$-th character in the output which is the input character that it is looking at?

Have fun!

In [None]:
import torch
import ipywidgets as widgets
import plotly.graph_objects as go
from IPython.display import display
import plotly.io as pio
import wandb

# Set Plotly renderer suitable for notebooks
pio.renderers.default = "iframe"

# Initialize Weights & Biases (use your project name & reinit=True to allow multiple runs)
wandb.init(project="attention-visualization", name="attention-heatmaps", reinit=True)

# ---------- Attention heatmap plotting ----------
def plot_attention_heatmap_plotly(attn_weights, input_chars, output_chars, log_to_wandb=False, step=None):
    fig = go.Figure(data=go.Heatmap(
        z=attn_weights,
        x=input_chars,
        y=output_chars,
        colorscale='Viridis',
        hoverongaps=False,
        hovertemplate='<b>Output:</b> %{y}<br><b>Input:</b> %{x}<br><b>Attention:</b> %{z:.3f}<extra></extra>',
        colorbar=dict(title='Attention Weight')
    ))

    fig.update_layout(
        title=f'Attention Heatmap (Step {step})' if step is not None else 'Attention Heatmap',
        xaxis_title='Input (Romanized)',
        yaxis_title='Output (Native)',
        font=dict(family="Noto Sans", size=14),
        autosize=True,
        margin=dict(l=50, r=50, t=50, b=50)
    )

    if log_to_wandb:
        wandb.log({f"attention_heatmap_{step or 0}": wandb.Plotly(fig)})

    fig.show()
    return fig


# ---------- Interactive attention visualizer ----------
def visualize_attention(model, dataset, idx2src, idx2tgt, eos_idx, device):
    options = [f"{i}: {''.join([idx2src[idx.item()] for idx in dataset[i][0]])}" for i in range(len(dataset))]
    dropdown = widgets.Dropdown(options=options[:100], description="Example:", layout=widgets.Layout(width='80%'))

    def on_change(change):
        if change['type'] == 'change' and change['name'] == 'value':
            i = int(change['new'].split(":")[0])
            src, tgt = dataset[i]
            src = src.to(device)
            tgt = tgt.to(device)

            pred_chars, input_chars, attn_tensor = evaluate_with_attention(
                model, src, idx2src, idx2tgt, eos_idx, device
            )

            if attn_tensor.nelement() == 0:
                print(f"⚠️ Empty attention weights for example {i}. Possibly predicted <eos> immediately.")
                print(f"Input:  {''.join(input_chars)}")
                print(f"Target: {''.join([idx2tgt[idx.item()] for idx in tgt[1:-1]])}")
                print(f"Pred:   {''.join(pred_chars)}")
                return

            attn_matrix = attn_tensor.cpu().numpy()  # To NumPy for plotly
            plot_attention_heatmap_plotly(attn_matrix, input_chars, pred_chars, log_to_wandb=True, step=i)

            print("\n📝 Sequence Info:")
            print(f"Input:  {''.join(input_chars)}")
            print(f"Target: {''.join([idx2tgt[idx.item()] for idx in tgt[1:-1]])}")
            print(f"Pred:   {''.join(pred_chars)}")

    dropdown.observe(on_change)
    display(dropdown)


# -------------- Usage example -----------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Make sure these are consistent with your training config:
encoder = Encoder(len(roman2idx), 128, 256, 1, "LSTM", 0.2)
decoder = Decoder(len(dev2idx), 128, 256, 1, "LSTM", 0.2)
model = Seq2Seq(encoder, decoder, device).to(device)

# Load your trained weights (replace path with your model's path)
model.load_state_dict(torch.load("path_to_best_model.pth", map_location=device))

eos_idx = dev2idx["<eos>"]

# Call the interactive visualizer on your test dataset
visualize_attention(model, test_ds, idx2roman, idx2dev, eos_idx, device)


Dropdown(description='Example:', layout=Layout(width='80%'), options=('0: tensor([ 9,  4,  4, 21, 16])', '1: t…

Best model configs obtained from sweeps

- batch_size - 32
- beam_size - 5
- cell_type - "LSTM"
- dropout - 0.2
- emb_dim - 128
- epochs - 10
- hidden_dim - 256
- lr - 0.0001
- num_layers - 1

In [10]:
vocab_size_input = len(roman2idx)
vocab_size_output = len(dev2idx)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder = Encoder(vocab_size_input, 128, 256, 1, "LSTM", 0.2)
decoder = Decoder(vocab_size_output, 128, 256, 1, "LSTM", 0.2)
model = Seq2Seq(encoder, decoder, device).to(device)


In [11]:

eos_idx = dev2idx["<eos>"]
visualize_attention(model, test_ds, idx2roman, idx2dev, eos_idx, device)

Dropdown(description='Example:', layout=Layout(width='80%'), options=('0: tensor([ 9,  4,  4, 21, 16])', '1: t…