In [1]:
import torch, torchtext
import torch.nn as nn
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt 
from datasets import load_dataset
from numpy.random import default_rng

import random, math, time

In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

Using device: mps


In [3]:
import os
for f in ["model_general.pt", "model_multiplicative.pt"]:
    print(f, "exists?", os.path.exists(f))


model_general.pt exists? True
model_multiplicative.pt exists? True


In [4]:
# 1. Force a total reset by reloading the dataset
from datasets import load_dataset
dataset = load_dataset("opus100", "en-ne")

# 2. Use a completely new function name to avoid cache matching
def repair_dataset_columns(batch):
    # Explicitly pull the strings from the nested dictionary
    return {
        'en': batch['translation']['en'],
        'ne': batch['translation']['ne']
    }

# 3. Map with CACHE DISABLED
dataset = dataset.map(
    repair_dataset_columns, 
    remove_columns=['translation'],
    load_from_cache_file=False  # <--- THIS IS THE KEY
)

# 4. VERIFY - If this is still "en", there is a deeper issue
print("--- FINAL VERIFICATION ---")
example = dataset['train'][0]
print(f"EN data: {example['en']}")
print(f"NE data: {example['ne']}")

if example['en'] == "en" or len(example['en']) < 3:
    print("CRITICAL ERROR: Data is still corrupted. Restart your Notebook Kernel.")
else:
    print("SUCCESS: Data is repaired.")

Map: 100%|##########| 2000/2000 [00:00<?, ? examples/s]

Map: 100%|##########| 406381/406381 [00:00<?, ? examples/s]

Map: 100%|##########| 2000/2000 [00:00<?, ? examples/s]

--- FINAL VERIFICATION ---
EN data: _Inv
NE data: Inv
SUCCESS: Data is repaired.


In [5]:
# ===== Dataset sizes (safe slicing) =====
N_TRAIN = 50000
N_VALID = 5000
N_TEST  = 5000

train_ds = dataset["train"].select(range(min(N_TRAIN, len(dataset["train"]))))
valid_ds = dataset["validation"].select(range(min(N_VALID, len(dataset["validation"]))))
test_ds  = dataset["test"].select(range(min(N_TEST, len(dataset["test"]))))

print("train/valid/test sizes:", len(train_ds), len(valid_ds), len(test_ds))


train/valid/test sizes: 50000 2000 2000


In [6]:
# Place-holders
token_transform = {}
vocab_transform = {}

SRC_LANG= 'en'
TARG_LANG = 'ne'

In [7]:
from torchtext.data.utils import get_tokenizer
from nepalitokenizers import WordPiece

In [8]:
token_transform["en"] = get_tokenizer('spacy', language='en_core_web_sm')
token_transform["ne"] = WordPiece()

In [9]:

def get_data_token(batch, lang):
    text = batch[lang]
    tok = token_transform[lang]

    if lang == "en":
        # spacy tokenizer returns tokens (strings) or token objects
        raw = tok(text.lower().strip())
        tokens = [getattr(t, "text", t) for t in raw]
    else:
        enc = tok.encode(text.strip())
        tokens = enc.tokens if hasattr(enc, "tokens") else enc

    # remove obvious control tokens if present
    remove = {"[CLS]", "[SEP]", "en", "ne", "<s>", "</s>"}
    tokens = [t for t in tokens if t not in remove]

    return {lang: tokens}


In [10]:

# Tokenize the limited datasets
tokenized_train = train_ds.map(get_data_token, fn_kwargs={"lang": SRC_LANG})
tokenized_train = tokenized_train.map(get_data_token, fn_kwargs={"lang": TARG_LANG})

tokenized_valid = valid_ds.map(get_data_token, fn_kwargs={"lang": SRC_LANG})
tokenized_valid = tokenized_valid.map(get_data_token, fn_kwargs={"lang": TARG_LANG})

tokenized_test  = test_ds.map(get_data_token, fn_kwargs={"lang": SRC_LANG})
tokenized_test  = tokenized_test.map(get_data_token, fn_kwargs={"lang": TARG_LANG})

# quick peek
print("sample EN tokens:", tokenized_train[0][SRC_LANG][:10])
print("sample NE tokens:", tokenized_train[0][TARG_LANG][:10])


Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

sample EN tokens: ['_', 'inv']
sample NE tokens: ['inv']


In [11]:

from torchtext.vocab import Vocab
from collections import Counter

special_symbols = ["<unk>", "<pad>", "<bos>", "<eos>"]

vocab_transform = {}

for ln in [SRC_LANG, TARG_LANG]:
    counter = Counter()
    for tokens in tokenized_train[ln]:
        counter.update(tokens)
    v = Vocab(counter, specials=special_symbols)
    v.unk_index = v.stoi["<unk>"]
    vocab_transform[ln] = v

UNK_IDX = vocab_transform[TARG_LANG].stoi["<unk>"]
PAD_IDX = vocab_transform[TARG_LANG].stoi["<pad>"]
BOS_IDX = vocab_transform[TARG_LANG].stoi["<bos>"]
EOS_IDX = vocab_transform[TARG_LANG].stoi["<eos>"]

SRC_PAD_IDX = vocab_transform[SRC_LANG].stoi["<pad>"]
TRG_PAD_IDX = PAD_IDX

print("vocab sizes:", len(vocab_transform[SRC_LANG]), len(vocab_transform[TARG_LANG]))
print("UNK, PAD, BOS, EOS:", UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX)


vocab sizes: 15889 9209
UNK, PAD, BOS, EOS: 0 1 2 3


In [12]:
import torch 
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

def numericalize(vocab, tokens):
    return [vocab.stoi.get(t, vocab.stoi["<unk>"]) for t in tokens]

def text_to_ids(text, lang):
    # uses the already-defined tokenizers for consistency
    if lang == "en":
        raw = token_transform["en"](text.lower().strip())
        toks = [getattr(t, "text", t) for t in raw]
    else:
        enc = token_transform["ne"].encode(text.strip())
        toks = enc.tokens if hasattr(enc, "tokens") else enc

    remove = {"[CLS]", "[SEP]", "en", "ne", "<s>", "</s>"}
    toks = [t for t in toks if t not in remove]

    ids = numericalize(vocab_transform[lang], toks)
    return [BOS_IDX] + ids + [EOS_IDX]

def collate_batch(batch):
    src_batch, src_len_batch, trg_batch = [], [], []
    for item in batch:
        src_text = item[SRC_LANG]
        trg_text = item[TARG_LANG]

        src_ids = text_to_ids(src_text, SRC_LANG)
        trg_ids = text_to_ids(trg_text, TARG_LANG)

        src_batch.append(torch.tensor(src_ids, dtype=torch.long))
        trg_batch.append(torch.tensor(trg_ids, dtype=torch.long))
        src_len_batch.append(len(src_ids))

    src_batch = pad_sequence(src_batch, padding_value=SRC_PAD_IDX, batch_first=True)
    trg_batch = pad_sequence(trg_batch, padding_value=TRG_PAD_IDX, batch_first=True)
    return src_batch, torch.tensor(src_len_batch, dtype=torch.int64), trg_batch

batch_size = 8
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, collate_fn=collate_batch)

train_loader_length = len(train_loader)
val_loader_length   = len(valid_loader)
test_loader_length  = len(test_loader)

# sanity: unk rates
src, _, trg = next(iter(train_loader))
print("SRC UNK %:", (src == vocab_transform[SRC_LANG].stoi["<unk>"]).float().mean().item())
print("TRG UNK %:", (trg == UNK_IDX).float().mean().item())


SRC UNK %: 0.0
TRG UNK %: 0.0


In [13]:
ATTEN_TYPES = ["additive"]
print("ATTEN_TYPES now:", ATTEN_TYPES)


ATTEN_TYPES now: ['additive']


In [15]:
import torch
import torch.nn as nn

class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, atten_type, device):
        super().__init__()
        assert hid_dim % n_heads == 0

        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        self.atten_type = atten_type
        self.device = device

        # projections
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        self.fc_o = nn.Linear(hid_dim, hid_dim)

        # multiplicative attention parameter
        self.W = nn.Linear(self.head_dim, self.head_dim, bias=False)

        # additive attention parameters
        self.Wq = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.Wk = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.v  = nn.Linear(self.head_dim, 1, bias=False)

        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float, device=device))

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]

        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)

        # [B, H, L, D]
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)

        if self.atten_type == "general":
            energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale

        elif self.atten_type == "multiplicative":
            Qw = self.W(Q)  # [B, H, Lq, D]
            energy = torch.matmul(Qw, K.permute(0, 1, 3, 2)) / self.scale

        elif self.atten_type == "additive":
            # expand for pairwise (Lq x Lk)
            Qe = self.Wq(Q).unsqueeze(3)  # [B,H,Lq,1,D]
            Ke = self.Wk(K).unsqueeze(2)  # [B,H,1,Lk,D]
            e = torch.tanh(Qe + Ke)       # [B,H,Lq,Lk,D]
            energy = self.v(e).squeeze(-1)  # [B,H,Lq,Lk]

        else:
            raise ValueError(f"Unknown atten_type: {self.atten_type}")

        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)

        attention = torch.softmax(energy, dim=-1)

        x = torch.matmul(self.dropout(attention), V)  # [B,H,Lq,D]
        x = x.permute(0, 2, 1, 3).contiguous()         # [B,Lq,H,D]
        x = x.view(batch_size, -1, self.hid_dim)       # [B,Lq,hid]
        x = self.fc_o(x)

        return x, attention


In [16]:
import torch.nn as nn

class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: [batch, seq_len, hid_dim]
        x = self.dropout(torch.relu(self.fc_1(x)))
        x = self.fc_2(x)
        return x


In [17]:
class EncoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, atten_type, device):
        super().__init__()
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm        = nn.LayerNorm(hid_dim)
        self.self_attention       = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, atten_type, device)
        self.feedforward          = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout              = nn.Dropout(dropout)

    def forward(self, src, src_mask):
        #src = [batch size, src len, hid dim]
        #src_mask = [batch size, 1, 1, src len]   #if the token is padding, it will be 1, otherwise 0
        _src, _ = self.self_attention(src, src, src, src_mask)
        src     = self.self_attn_layer_norm(src + self.dropout(_src))
        #src: [batch_size, src len, hid dim]

        _src    = self.feedforward(src)
        src     = self.ff_layer_norm(src + self.dropout(_src))
        #src: [batch_size, src len, hid dim]

        return src

In [18]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, atten_type, device, max_length = 512):
        super().__init__()
        self.device = device
        self.atten_type = atten_type
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.layers        = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, atten_type,device)
                                           for _ in range(n_layers)])
        self.dropout       = nn.Dropout(dropout)
        self.scale         = torch.sqrt(torch.FloatTensor([hid_dim])).to(self.device)
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src len]
        #src_mask = [batch size, 1, 1, src len]
        
        batch_size = src.shape[0]
        src_len    = src.shape[1]
        
        pos        = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        #pos: [batch_size, src_len]
        
        src        = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))
        #src: [batch_size, src_len, hid_dim]
        
        for layer in self.layers:
            src = layer(src, src_mask)
        #src: [batch_size, src_len, hid_dim]
        
        return src

In [19]:
class DecoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, atten_type, device):
        super().__init__()
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.enc_attn_layer_norm  = nn.LayerNorm(hid_dim)
        self.ff_layer_norm        = nn.LayerNorm(hid_dim)
        self.self_attention       = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, atten_type, device)
        self.encoder_attention    = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, atten_type, device)
        self.feedforward          = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout              = nn.Dropout(dropout)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        #trg = [batch size, trg len, hid dim]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
        trg     = self.self_attn_layer_norm(trg + self.dropout(_trg))
        #trg = [batch_size, trg len, hid dim]
        
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        trg             = self.enc_attn_layer_norm(trg + self.dropout(_trg))
        #trg = [batch_size, trg len, hid dim]
        #attention = [batch_size, n heads, trg len, src len]
        
        _trg = self.feedforward(trg)
        trg  = self.ff_layer_norm(trg + self.dropout(_trg))
        #trg = [batch_size, trg len, hid dim]
        
        return trg, attention

In [20]:
class Decoder(nn.Module):
    def __init__(self, output_dim, hid_dim, n_layers, n_heads, 
                 pf_dim, dropout, atten_type, device,max_length = 512):
        super().__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.layers        = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, atten_type, device)
                                            for _ in range(n_layers)])
        self.fc_out        = nn.Linear(hid_dim, output_dim)
        self.dropout       = nn.Dropout(dropout)
        self.scale         = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        #trg = [batch size, trg len]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        
        batch_size = trg.shape[0]
        trg_len    = trg.shape[1]
        
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        #pos: [batch_size, trg len]
        
        trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
        #trg: [batch_size, trg len, hid dim]
        
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
            
        #trg: [batch_size, trg len, hid dim]
        #attention: [batch_size, n heads, trg len, src len]
        
        output = self.fc_out(trg)
        #output = [batch_size, trg len, output_dim]
        
        return output, attention

In [21]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx, device):
        super().__init__()
        self.params = {'encoder': encoder, 'decoder': decoder,
                       'src_pad_idx': src_pad_idx, 'trg_pad_idx': trg_pad_idx}
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        
    def make_src_mask(self, src):
        
        #src = [batch size, src len]
        
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        #src_mask = [batch size, 1, 1, src len]

        return src_mask
    
    def make_trg_mask(self, trg):
        
        #trg = [batch size, trg len]
        
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        #trg_pad_mask = [batch size, 1, 1, trg len]
        
        trg_len = trg.shape[1]
        
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()
        #trg_sub_mask = [trg len, trg len]
            
        trg_mask = trg_pad_mask & trg_sub_mask
        #trg_mask = [batch size, 1, trg len, trg len]
        
        return trg_mask

    def forward(self, src, trg):
        
        #src = [batch size, src len]
        #trg = [batch size, trg len]
                
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        
        #src_mask = [batch size, 1, 1, src len]
        #trg_mask = [batch size, 1, trg len, trg len]
        
        enc_src = self.encoder(src, src_mask)
        #enc_src = [batch size, src len, hid dim]
                
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        
        #output = [batch size, trg len, output dim]
        #attention = [batch size, n heads, trg len, src len]
        
        return output, attention

In [23]:

import torch.optim as optim
import copy

input_dim  = len(vocab_transform[SRC_LANG])
output_dim = len(vocab_transform[TARG_LANG])

HID_DIM = 128
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

def initialize_weights(m):
    if hasattr(m, "weight") and m.weight is not None and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

ATTEN_TYPES = ["additive"]

models = {}
optimizers = {}
histories = {}  # {atten_type: {"train":[], "valid":[]}}

for atten_type in ATTEN_TYPES:
    enc = Encoder(input_dim, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT, atten_type, device)
    dec = Decoder(output_dim, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT, atten_type, device)

    model = Seq2SeqTransformer(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)
    model.apply(initialize_weights)

    optimizer = optim.Adam(model.parameters(), lr=5e-4)

    models[atten_type] = model
    optimizers[atten_type] = optimizer
    histories[atten_type] = {"train": [], "valid": []}

criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)
clip = 1.0

print("Models ready:", list(models.keys()))


Models ready: ['additive']


In [24]:
print("models:", list(models.keys()))


models: ['additive']


In [25]:
def train(model, loader, optimizer, criterion, clip, loader_length):
    model.train()
    epoch_loss = 0
    
    for src, src_len, trg in loader:
        src = src.to(device)
        trg = trg.to(device)
        
        optimizer.zero_grad()

        # SHIFTING LOGIC: Feed the target without the last token (<eos>)
        output, _ = model(src, trg[:, :-1])
        
        # RESHAPE: Compare output against target shifted by one (removes <bos>)
        # This forces the model to predict the NEXT word in the sequence
        output_dim = output.shape[-1]
        output = output.contiguous().view(-1, output_dim)

        trg_output = trg[:, 1:].contiguous().view(-1)
        loss = criterion(output, trg_output)

        loss.backward()
        
        # Prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / loader_length

In [26]:
def evaluate(model, loader, criterion, loader_length):
    model.eval()
    epoch_loss = 0
    
    with torch.no_grad():
        for src, src_len, trg in loader:
            src = src.to(device)
            trg = trg.to(device)

            # Shifting logic: Feed target without the last token
            output, _ = model(src, trg[:, :-1])
            
            output_dim = output.shape[-1]
            
            # Reshape for loss calculation
            output = output.contiguous().view(-1, output_dim)
            trg_output = trg[:, 1:].contiguous().view(-1)

            loss = criterion(output, trg_output)
            epoch_loss += loss.item()
        
    return epoch_loss / loader_length

In [38]:

import time, math

EPOCHS = 6  # keep fast for 1000 samples

best_valid = {k: float("inf") for k in models.keys()}

for atten_type in ATTEN_TYPES:
    print("\n===== Training:", atten_type, "=====")
    model = models[atten_type]
    optimizer = optimizers[atten_type]

    for epoch in range(EPOCHS):
        t0 = time.time()
        train_loss = train(model, train_loader, optimizer, criterion, clip, train_loader_length)
        valid_loss = evaluate(model, valid_loader, criterion, val_loader_length)
        t1 = time.time()

        histories[atten_type]["train"].append(train_loss)
        histories[atten_type]["valid"].append(valid_loss)

        if valid_loss < best_valid[atten_type]:
            best_valid[atten_type] = valid_loss
            torch.save({"atten_type": atten_type,
                        "state_dict": model.state_dict(),
                        "input_dim": input_dim,
                        "output_dim": output_dim}, f"model_{atten_type}.pt")

        print(f"{atten_type} | epoch {epoch+1:02d} | train {train_loss:.3f} ppl {math.exp(train_loss):.2f} | valid {valid_loss:.3f} ppl {math.exp(valid_loss):.2f} | {t1-t0:.1f}s")



===== Training: general =====


KeyboardInterrupt: 

In [41]:

import matplotlib.pyplot as plt
import numpy as np

def translate_with_attention(model, en_text, max_len=30):
    model.eval()

    # source ids + tokens
    src_ids = text_to_ids(en_text, SRC_LANG)
    src_tensor = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
    src_mask = (src_tensor != SRC_PAD_IDX).unsqueeze(1).unsqueeze(2)

    src_tokens = [vocab_transform[SRC_LANG].itos[i] for i in src_ids]

    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)

    trg_ids = [BOS_IDX]
    attentions = None

    for _ in range(max_len):
        trg_tensor = torch.tensor(trg_ids, dtype=torch.long).unsqueeze(0).to(device)
        L = trg_tensor.size(1)
        trg_mask = torch.tril(torch.ones((L, L), device=device)).bool().unsqueeze(0).unsqueeze(1)

        with torch.no_grad():
            out, attn = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)

        probs = out[:, -1, :].squeeze(0)
        probs[UNK_IDX] = -1e10
        next_id = int(probs.argmax().item())
        trg_ids.append(next_id)

        attentions = attn  # last layer attention [1, heads, trg_len, src_len]

        if next_id == EOS_IDX:
            break

    trg_tokens = [vocab_transform[TARG_LANG].itos[i] for i in trg_ids]
    # clean tokens
    skip = {"<bos>", "<eos>", "<pad>", "<unk>"}
    out_tokens = [t for t in trg_tokens if t not in skip]
    out_text = " ".join(out_tokens).replace(" ##", "").strip()

    return out_text, attentions, src_tokens, trg_tokens

def plot_loss_curves(histories):
    for atten_type, h in histories.items():
        plt.figure()
        plt.plot(h["train"], label="train")
        plt.plot(h["valid"], label="valid")
        plt.title(f"Loss curves: {atten_type}")
        plt.xlabel("epoch")
        plt.ylabel("loss")
        plt.legend()
        plt.show()

def attention_heatmap(attn, src_tokens, trg_tokens, title="Attention"):
    # attn: [1, heads, trg_len, src_len]
    if attn is None:
        print("No attention returned")
        return
    attn = attn.squeeze(0).mean(0).cpu().numpy()  # [trg_len, src_len]

    # limit sizes for readability
    max_src = min(len(src_tokens), 20)
    max_trg = min(len(trg_tokens), 20)

    attn = attn[:max_trg, :max_src]
    src = src_tokens[:max_src]
    trg = trg_tokens[:max_trg]

    plt.figure(figsize=(10, 6))
    plt.imshow(attn, aspect="auto")
    plt.colorbar()
    plt.xticks(range(len(src)), src, rotation=45, ha="right")
    plt.yticks(range(len(trg)), trg)
    plt.title(title)
    plt.tight_layout()
    plt.show()

# ===== Performance table =====
print("Attentions | Train Loss | Train PPL | Valid Loss | Valid PPL")
for atten_type in ATTEN_TYPES:
    tr = histories[atten_type]["train"][-1]
    va = histories[atten_type]["valid"][-1]
    print(f"{atten_type:>13} | {tr:9.3f} | {math.exp(tr):9.2f} | {va:9.3f} | {math.exp(va):9.2f}")

# ===== Loss plots =====
plot_loss_curves(histories)

# ===== Attention maps (one example per model) =====
example = train_ds[0]["en"]
print("\nExample EN:", example)

for atten_type in ATTEN_TYPES:
    model = models[atten_type]
    out, attn, src_toks, trg_toks = translate_with_attention(model, example, max_len=30)
    print("\n", atten_type, "NE:", out)
    attention_heatmap(attn, src_toks, trg_toks, title=f"Attention map ({atten_type})")


Attentions | Train Loss | Train PPL | Valid Loss | Valid PPL
      general |     1.819 |      6.17 |     3.389 |     29.64
multiplicative |     1.988 |      7.30 |     3.410 |     30.27


IndexError: list index out of range