In [None]:
import os, gc, math, time, json, random
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

torch.manual_seed(42)
random.seed(42)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("device:", device)

# ETL: Loading the dataset

For this assignment, the language I chose is Nepali. And the dataset is taken from HuggingFace: https://huggingface.co/datasets/Helsinki-NLP/opus-100/viewer/en-ne

In [None]:
from datasets import load_dataset

SRC_LANG = "en"
TARG_LANG = "ne"

# Change this if your dataset name/config differs in your environment
# This assumes you already used this dataset earlier in your notebook.
# If you used a local dataset object named `dataset`, this will be overwritten.
dataset = load_dataset("Helsinki-NLP/opus-100", f"{SRC_LANG}-{TARG_LANG}")

print(dataset)
print("sizes:", {k: len(dataset[k]) for k in dataset.keys()})

In [None]:
# Dataset sizes (safe slicing) 
N_TRAIN = 20000
N_VALID = 2000
N_TEST  = 2000

train_ds = dataset["train"].select(range(min(N_TRAIN, len(dataset["train"]))))
valid_ds = dataset["validation"].select(range(min(N_VALID, len(dataset["validation"]))))
test_ds  = dataset["test"].select(range(min(N_TEST, len(dataset["test"]))))

print("train/valid/test sizes:", len(train_ds), len(valid_ds), len(test_ds))

# Tokenizing

In [None]:
import spacy
from nepalitokenizers import WordPiece

spacy_en = spacy.blank("en")  # fast tokenizer
wp_ne = WordPiece()           # Nepali WordPiece tokenizer

def tok_en(text: str):
    return [t.text for t in spacy_en(text.lower().strip())]

def tok_ne(text: str):
    enc = wp_ne.encode(text.strip())
    return enc.tokens

# Quick check
print(tok_en("The clipboard could not be signed.")[:10])
print(tok_ne("क्लिपबोर्ड साइन गर्न सकिएन ।")[:10])

# Text numericalisation and batch collation with padding for DataLoader

In [None]:
from collections import Counter
from torchtext.vocab import Vocab

special_symbols = ["<unk>", "<pad>", "<bos>", "<eos>"]

def build_vocab_from_tokenizer(ds, tokenizer_fn):
    counter = Counter()
    for ex in ds:
        counter.update(tokenizer_fn(ex))
    v = Vocab(counter, specials=special_symbols)
    v.unk_index = v.stoi["<unk>"]
    return v

# Build vocab from TRAIN ONLY
vocab_transform = {}
vocab_transform[SRC_LANG]  = build_vocab_from_tokenizer(train_ds[SRC_LANG], tok_en)
vocab_transform[TARG_LANG] = build_vocab_from_tokenizer(train_ds[TARG_LANG], tok_ne)

UNK_IDX = vocab_transform[TARG_LANG].stoi["<unk>"]
PAD_IDX = vocab_transform[TARG_LANG].stoi["<pad>"]
BOS_IDX = vocab_transform[TARG_LANG].stoi["<bos>"]
EOS_IDX = vocab_transform[TARG_LANG].stoi["<eos>"]

SRC_PAD_IDX = vocab_transform[SRC_LANG].stoi["<pad>"]
TRG_PAD_IDX = PAD_IDX

print("vocab sizes:", len(vocab_transform[SRC_LANG]), len(vocab_transform[TARG_LANG]))
print("UNK, PAD, BOS, EOS:", UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX)

In [None]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

MAX_LEN = 96
BATCH_SIZE = 4

def numericalize(vocab, tokens):
    return [vocab.stoi.get(t, vocab.stoi["<unk>"]) for t in tokens]

def text_to_ids(text, lang):
    if lang == SRC_LANG:
        toks = tok_en(text)
        vocab = vocab_transform[SRC_LANG]
        pad = SRC_PAD_IDX
    else:
        toks = tok_ne(text)
        vocab = vocab_transform[TARG_LANG]
        pad = TRG_PAD_IDX

    ids = numericalize(vocab, toks)
    ids = ids[: MAX_LEN - 2]  
    if lang == SRC_LANG:
        return [vocab_transform[SRC_LANG].stoi["<bos>"]] + ids + [vocab_transform[SRC_LANG].stoi["<eos>"]]
    else:
        return [BOS_IDX] + ids + [EOS_IDX]

def collate_batch(batch):
    src_batch, src_len_batch, trg_batch = [], [], []
    for item in batch:
        src_ids = text_to_ids(item[SRC_LANG], SRC_LANG)
        trg_ids = text_to_ids(item[TARG_LANG], TARG_LANG)
        src_batch.append(torch.tensor(src_ids, dtype=torch.long))
        trg_batch.append(torch.tensor(trg_ids, dtype=torch.long))
        src_len_batch.append(len(src_ids))

    src_batch = pad_sequence(src_batch, padding_value=SRC_PAD_IDX, batch_first=True)
    trg_batch = pad_sequence(trg_batch, padding_value=TRG_PAD_IDX, batch_first=True)
    return src_batch, torch.tensor(src_len_batch, dtype=torch.int64), trg_batch

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

train_loader_length = len(train_loader)
val_loader_length   = len(valid_loader)
test_loader_length  = len(test_loader)

# Sanity UNK rates
src, _, trg = next(iter(train_loader))
print("SRC UNK %:", (src == vocab_transform[SRC_LANG].stoi["<unk>"]).float().mean().item())
print("TRG UNK %:", (trg == UNK_IDX).float().mean().item())
print("batch shapes:", src.shape, trg.shape)

In [None]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.dropout(torch.relu(self.fc_1(x)))
        x = self.fc_2(x)
        return x

class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, atten_type, device):
        super().__init__()
        assert hid_dim % n_heads == 0

        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        self.atten_type = atten_type
        self.device = device

        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        self.fc_o = nn.Linear(hid_dim, hid_dim)

       # general
        self.W = nn.Linear(self.head_dim, self.head_dim, bias=False)

        # additive
        self.Wq = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.Wk = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.v  = nn.Linear(self.head_dim, 1, bias=False)

        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float, device=device))

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]

        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)

        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)

        if self.atten_type == "general":
            energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale

        elif self.atten_type == "additive":
            # memory-safe additive: compute scores without expanding head_dim dimension
            # e_ij = v^T tanh(Wq q_i + Wk k_j)
            # We still need [B,H,Lq,Lk, D] for tanh, but MAX_LEN=96 and BATCH=4 keeps it safe.
            Qe = self.Wq(Q).unsqueeze(3)   # [B,H,Lq,1,D]
            Ke = self.Wk(K).unsqueeze(2)   # [B,H,1,Lk,D]
            e  = torch.tanh(Qe + Ke)       # [B,H,Lq,Lk,D]
            energy = self.v(e).squeeze(-1) # [B,H,Lq,Lk]

        else:
            raise ValueError(f"Unknown atten_type: {self.atten_type}")

        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)

        attention = torch.softmax(energy, dim=-1)
        x = torch.matmul(self.dropout(attention), V)  # [B,H,Lq,D]

        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, -1, self.hid_dim)
        x = self.fc_o(x)

        return x, attention

# Encoder and Decoder

### Seq2Seq

Our `trg_sub_mask` will look something like this (for a target with 5 tokens):

$$\begin{matrix}
1 & 0 & 0 & 0 & 0\\
1 & 1 & 0 & 0 & 0\\
1 & 1 & 1 & 0 & 0\\
1 & 1 & 1 & 1 & 0\\
1 & 1 & 1 & 1 & 1\\
\end{matrix}$$

The "subsequent" mask is then logically anded with the padding mask, this combines the two masks ensuring both the subsequent tokens and the padding tokens cannot be attended to. For example if the last two tokens were `<pad>` tokens the mask would look like:

$$\begin{matrix}
1 & 0 & 0 & 0 & 0\\
1 & 1 & 0 & 0 & 0\\
1 & 1 & 1 & 0 & 0\\
1 & 1 & 1 & 0 & 0\\
1 & 1 & 1 & 0 & 0\\
\end{matrix}$$

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, atten_type, device):
        super().__init__()
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm        = nn.LayerNorm(hid_dim)
        self.self_attention       = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, atten_type, device)
        self.feedforward          = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout              = nn.Dropout(dropout)

    def forward(self, src, src_mask):
        _src, _ = self.self_attention(src, src, src, src_mask)
        src = self.self_attn_layer_norm(src + self.dropout(_src))

        _src = self.feedforward(src)
        src = self.ff_layer_norm(src + self.dropout(_src))

        return src

class Encoder(nn.Module):
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, atten_type, device, max_length=MAX_LEN):
        super().__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.layers = nn.ModuleList([
            EncoderLayer(hid_dim, n_heads, pf_dim, dropout, atten_type, device)
            for _ in range(n_layers)
        ])
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)

    def forward(self, src, src_mask):
        batch_size, src_len = src.shape
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)

        src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))

        for layer in self.layers:
            src = layer(src, src_mask)

        return src

class DecoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, atten_type, device):
        super().__init__()
        self.self_attn_layer_norm  = nn.LayerNorm(hid_dim)
        self.enc_attn_layer_norm   = nn.LayerNorm(hid_dim)
        self.ff_layer_norm         = nn.LayerNorm(hid_dim)

        self.self_attention        = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, atten_type, device)
        self.encoder_attention     = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, atten_type, device)
        self.feedforward           = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout               = nn.Dropout(dropout)

    def forward(self, trg, enc_src, trg_mask, src_mask):
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
        trg = self.self_attn_layer_norm(trg + self.dropout(_trg))

        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))

        _trg = self.feedforward(trg)
        trg = self.ff_layer_norm(trg + self.dropout(_trg))

        return trg, attention

class Decoder(nn.Module):
    def __init__(self, output_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, atten_type, device, max_length=MAX_LEN):
        super().__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.layers = nn.ModuleList([
            DecoderLayer(hid_dim, n_heads, pf_dim, dropout, atten_type, device)
            for _ in range(n_layers)
        ])
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)

    def forward(self, trg, enc_src, trg_mask, src_mask):
        batch_size, trg_len = trg.shape
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)

        trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))

        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)

        output = self.fc_out(trg)
        return output, attention

class Seq2SeqTransformer(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        return (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

    def make_trg_mask(self, trg):
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        trg_len = trg.shape[1]
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=self.device)).bool()
        trg_sub_mask = trg_sub_mask.unsqueeze(0).unsqueeze(1)
        return trg_pad_mask & trg_sub_mask

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)

        enc_src = self.encoder(src, src_mask)
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        return output, attention

# Training

In [None]:
def train(model, loader, optimizer, criterion, clip, loader_length):
    model.train()
    epoch_loss = 0

    for src, src_len, trg in loader:
        src = src.to(device)
        trg = trg.to(device)

        optimizer.zero_grad()

        output, _ = model(src, trg[:, :-1])  # teacher forcing

        output_dim = output.shape[-1]
        output = output.contiguous().view(-1, output_dim)
        trg_out = trg[:, 1:].contiguous().view(-1)

        loss = criterion(output, trg_out)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / loader_length

def evaluate(model, loader, criterion, loader_length):
    model.eval()
    epoch_loss = 0

    with torch.no_grad():
        for src, src_len, trg in loader:
            src = src.to(device)
            trg = trg.to(device)

            output, _ = model(src, trg[:, :-1])

            output_dim = output.shape[-1]
            output = output.contiguous().view(-1, output_dim)
            trg_out = trg[:, 1:].contiguous().view(-1)

            loss = criterion(output, trg_out)
            epoch_loss += loss.item()

    return epoch_loss / loader_length

In [None]:
input_dim  = len(vocab_transform[SRC_LANG])
output_dim = len(vocab_transform[TARG_LANG])

# Smaller model for stability
HID_DIM = 128
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 256
DEC_PF_DIM = 256
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

LR = 8e-4  
EPOCHS = 5
clip = 1.0

ATTEN_TYPES = ["general", "multiplicative", "additive"]

def initialize_weights(m):
    if hasattr(m, "weight") and m.weight is not None and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)

models = {}
optimizers = {}
histories = {k: {"train": [], "valid": []} for k in ATTEN_TYPES}
best_valid = {k: float("inf") for k in ATTEN_TYPES}

for atten_type in ATTEN_TYPES:
    print("\n===== BUILD:", atten_type, "=====")

    enc = Encoder(input_dim, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT, atten_type, device, max_length=MAX_LEN)
    dec = Decoder(output_dim, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT, atten_type, device, max_length=MAX_LEN)
    model = Seq2SeqTransformer(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)
    model.apply(initialize_weights)

    optimizer = optim.Adam(model.parameters(), lr=LR)

    models[atten_type] = model
    optimizers[atten_type] = optimizer

    print("model ready")

for atten_type in ATTEN_TYPES:
    model = models[atten_type]
    optimizer = optimizers[atten_type]

    print("\n===== TRAIN:", atten_type, "=====")
    for epoch in range(EPOCHS):
        t0 = time.time()
        tr_loss = train(model, train_loader, optimizer, criterion, clip, train_loader_length)
        va_loss = evaluate(model, valid_loader, criterion, val_loader_length)
        t1 = time.time()

        histories[atten_type]["train"].append(tr_loss)
        histories[atten_type]["valid"].append(va_loss)

        if va_loss < best_valid[atten_type]:
            best_valid[atten_type] = va_loss
            torch.save(
                {"atten_type": atten_type,
                 "state_dict": model.state_dict(),
                 "input_dim": input_dim,
                 "output_dim": output_dim},
                f"model_{atten_type}.pt"
            )

        print(f"{atten_type} | epoch {epoch+1:02d} | train {tr_loss:.3f} ppl {math.exp(tr_loss):.2f} | valid {va_loss:.3f} ppl {math.exp(va_loss):.2f} | {t1-t0:.1f}s")

    # free MPS cache between trainings (keep model in dict for later evaluation/plots)
    if torch.backends.mps.is_available():
        torch.mps.empty_cache()
    gc.collect()

print("saved checkpoints:", [f"model_{k}.pt" for k in ATTEN_TYPES])

In [None]:
# Required table
print("Attention | Train Loss | Train PPL | Valid Loss | Valid PPL")
rows = []
for k in ATTEN_TYPES:
    tr = histories[k]["train"][-1]
    va = histories[k]["valid"][-1]
    rows.append((k, tr, math.exp(tr), va, math.exp(va)))
    print(f"{k:>12} | {tr:9.3f} | {math.exp(tr):9.2f} | {va:9.3f} | {math.exp(va):9.2f}")

# Graphs 
for k in ATTEN_TYPES:
    plt.figure()
    plt.plot(histories[k]["train"], label="train")
    plt.plot(histories[k]["valid"], label="valid")
    plt.title(f"Loss curves: {k}")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.legend()
    plt.show()

| Attentions | Training Loss | Traning PPL | Validation Loss | Validation PPL |
|----------|----------|----------|----------|----------|
| General Attention    | 2.141     | 8.51     | 3.158     | 23.53     |
| Additive Attention    | 2.112     | 8.27     | 3.149    | 23.32    |

### Explanation of the Results Table (Loss and Perplexity)

This table compares the performance of the **General Attention** and **Additive Attention** models using **training loss**, **training perplexity (PPL)**, **validation loss**, and **validation perplexity**. Both models show very similar behaviour, but **Additive Attention performs slightly better overall**.

On the training set, **Additive Attention** achieves a lower loss (**2.112**) and lower PPL (**8.27**) compared to **General Attention** (loss **2.141**, PPL **8.51**). This suggests that the Additive Attention model fits the training data marginally better and learns stronger token level patterns during optimisation.

On the validation set, Additive Attention again produces slightly lower loss (**3.149**) and lower PPL (**23.32**) than General Attention (loss **3.158**, PPL **23.53**). Since lower validation loss and PPL indicate better prediction on unseen data, this shows that Additive Attention generalises a little better.

A noticeable difference between training and validation values exists for both models, indicating mild overfitting. However, the consistent improvement across both metrics suggests that **Additive Attention provides a more flexible alignment mechanism** for source to target mapping in this English to Nepali translation task.


# Evaluation and Verification

In [None]:
def decode_ids_to_text(ids):
    itos = vocab_transform[TARG_LANG].itos
    out_tokens = []
    for i in ids:
        i = int(i)
        if i < 0 or i >= len(itos):
            continue
        tok = itos[i]
        if tok in {"<bos>", "<eos>", "<pad>", "<unk>"}:
            continue
        out_tokens.append(tok)
    out = " ".join(out_tokens).replace(" ##", "").strip()
    if out == "":
        return "<no_output>"
    return out

def translate(model, en_text, max_len=40, min_len=3):
    model.eval()

    src_ids = text_to_ids(en_text, SRC_LANG)
    src_tensor = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
    src_mask = (src_tensor != SRC_PAD_IDX).unsqueeze(1).unsqueeze(2)

    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)

    trg_ids = [BOS_IDX]

    for step in range(max_len):
        trg_tensor = torch.tensor(trg_ids, dtype=torch.long).unsqueeze(0).to(device)
        L = trg_tensor.size(1)
        trg_mask = torch.tril(torch.ones((L, L), device=device)).bool().unsqueeze(0).unsqueeze(1)

        with torch.no_grad():
            out, attn = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)

        logits = out[:, -1, :].squeeze(0)

        # block UNK always
        logits[UNK_IDX] = -1e10

        # prevent EOS too early
        if len(trg_ids) < min_len:
            logits[EOS_IDX] = -1e10

        next_id = int(torch.argmax(logits).item())
        trg_ids.append(next_id)

        if next_id == EOS_IDX:
            break

    return decode_ids_to_text(trg_ids), attn, src_ids, trg_ids

# Show a few samples
samples = [train_ds[i][SRC_LANG] for i in [0, 1, 2, 3, 4]]
for s in samples:
    print("EN:", s)
    for k in ATTEN_TYPES:
        ne, _, _, _ = translate(models[k], s)
        print(k, "NE:", ne)
    print("-"*60)

### Attention Maps

In [None]:
def attention_heatmap(attn, src_ids, trg_ids, title):
    if attn is None:
        print("no attention")
        return
    # [1, heads, trg_len, src_len] -> average heads
    A = attn.squeeze(0).mean(0).detach().cpu().numpy()

    src_tokens = [vocab_transform[SRC_LANG].itos[i] for i in src_ids]
    trg_tokens = [vocab_transform[TARG_LANG].itos[i] for i in trg_ids]

    # limit for readability
    max_src = min(len(src_tokens), 20)
    max_trg = min(len(trg_tokens), 20)
    A = A[:max_trg, :max_src]

    plt.figure(figsize=(10, 6))
    plt.imshow(A, aspect="auto")
    plt.colorbar()
    plt.xticks(range(max_src), src_tokens[:max_src], rotation=45, ha="right")
    plt.yticks(range(max_trg), trg_tokens[:max_trg])
    plt.title(title)
    plt.tight_layout()
    plt.show()

example = train_ds[0][SRC_LANG]
print("Example EN:", example)

for k in ATTEN_TYPES:
    ne, attn, src_ids, trg_ids = translate(models[k], example)
    print(k, "NE:", ne)
    attention_heatmap(attn, src_ids, trg_ids, f"Attention map: {k}")

### Analysis of Results

Based on the discussion in the previous sections, **Additive Attention** produced the best overall performance among the three attention mechanisms. It achieved the **lowest loss and perplexity** during evaluation, indicating more accurate next token prediction compared to the General and Multiplicative variants. Therefore, the Additive Attention model was selected for deployment in the web application.

Despite this relative improvement, the translation quality across all three models was still weak. This may be due to **insufficient training epochs**, which can limit learning, or **overfitting**, where the model performs well on the training data but does not generalise effectively to unseen inputs.


### User Interface and Model Integration

For this assignment, the user interface was developed using **Dash**. The entire interface, along with the required model integration, is implemented within the `app.py` file. The UI is intentionally kept simple and consists of a text input field for the user query, a **Translate** button, basic input validation, and an output section to display the translation result. Screenshots demonstrating the interface and its functionality are included in the `README.md` file inside the A3 folder.

The trained model is integrated into the interface through a clear and structured pipeline. First, the saved vocabulary is loaded, and the trained model is initialised using its stored parameters. Among the three attention mechanisms implemented, the **Additive Attention** model was selected for deployment as it showed the best overall performance during evaluation. Once the user provides an input sentence, the text is tokenised and numericalised using the loaded vocabulary, after which tensors are created and passed to the model. The model then predicts the output by selecting the token with the highest probability at each decoding step, and the final result is displayed to the user.

### User Interaction Flow

The interaction flow of the application is as follows:

- The user enters an English sentence into the input field  
- The user clicks the **Translate** button  
- The translated Nepali output is displayed on the screen  
