# Urdu→Roman pipeline — WordPiece tokenization

This notebook swaps your previous BPE setup to **WordPiece** and keeps your training/validation flow.
It is pre-wired to your dataset files and columns as used in your earlier notebook.

## 1) Initialization

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:


import os, sys, math, random, time, json, pathlib, re
from pathlib import Path

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Reproducibility
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED);
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

# ---- Paths (pre-wired to your dataset) ----
DATA_DIR = Path("/content/drive/MyDrive/NLP")
TRAIN_CSV = DATA_DIR / "train.csv"
VALID_CSV = DATA_DIR / "valid.csv"
TEST_CSV  = DATA_DIR / "test.csv"

# Columns discovered in your previous notebook
SRC_COL = "urdu"
TGT_COL = "roman"

EXP_DIR = DATA_DIR / "experiments"
EXP_DIR.mkdir(parents=True, exist_ok=True)

# WordPiece tokenizer outputs
TOK_DIR = DATA_DIR / "tokenizers"
TOK_DIR.mkdir(parents=True, exist_ok=True)
WP_VOCAB_TXT = TOK_DIR / "roman_wp-vocab.txt"
WP_JSON = TOK_DIR / "roman_wp-tokenizer.json"

BEST_CKPT = EXP_DIR / "bilstm_wp_best.pt"

for p in [TRAIN_CSV, VALID_CSV]:
    print("Exists:", p, p.exists())

# pip installs (idempotent)
try:
    import tokenizers  # huggingface tokenizers (fast)
except ImportError:
    !pip -q install tokenizers > /dev/null
    import tokenizers

print("tokenizers version:", tokenizers.__version__)


Device: cuda
Exists: /content/drive/MyDrive/NLP/train.csv True
Exists: /content/drive/MyDrive/NLP/valid.csv True
tokenizers version: 0.22.0


## 2) Vocabularies: Urdu char vocab (source) + Roman WordPiece (target, with [EOS])

In [3]:


from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.normalizers import NFD, Lowercase, Strip, Sequence as NormSequence
from tokenizers.processors import TemplateProcessing

# Load training CSV
df_train = pd.read_csv(TRAIN_CSV)
assert SRC_COL in df_train.columns and TGT_COL in df_train.columns, f"Need '{SRC_COL}' and '{TGT_COL}'"

# ---------- Target (Roman) WordPiece with [EOS] ----------
roman_texts = df_train[TGT_COL].dropna().astype(str).tolist()
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
tokenizer.normalizer = NormSequence([NFD(), Lowercase(), Strip()])
tokenizer.pre_tokenizer = Whitespace()

special_tokens = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]","[EOS]"]
trainer = WordPieceTrainer(
    vocab_size=5000,
    special_tokens=special_tokens,
    min_frequency=2,
    continuing_subword_prefix="-"
)
tokenizer.train_from_iterator(roman_texts, trainer=trainer)
tokenizer.post_processor = TemplateProcessing(
    single="$A",
    pair="$A [SEP] $B:1",
    special_tokens=[("[SEP]", tokenizer.token_to_id("[SEP]"))] if tokenizer.token_to_id("[SEP]") is not None else []
)

# Save tokenizer artifacts
tokenizer.save(str(WP_JSON))
with open(WP_VOCAB_TXT, "w", encoding="utf-8") as f:
    for i in range(tokenizer.get_vocab_size()):
        f.write(tokenizer.id_to_token(i) + "\n")

pad_id_tgt = tokenizer.token_to_id("[PAD]")
eos_id_tgt = tokenizer.token_to_id("[EOS]")
unk_id_tgt = tokenizer.token_to_id("[UNK]")
assert eos_id_tgt is not None and pad_id_tgt is not None, "Target tokenizer must include [EOS]/[PAD]"

print("Saved target tokenizer:", WP_JSON, "| vocab size:", tokenizer.get_vocab_size())

# ---------- Source (Urdu) char-level vocab with specials ----------
SRC_PAD = "<pad>"
SRC_UNK = "<unk>"
SRC_SPECIALS = [SRC_PAD, SRC_UNK]

def _char_set_from_series(series):
    seen = []
    for s in series.astype(str):
        for ch in s:
            if ch not in seen:
                seen.append(ch)
    return seen

src_chars = _char_set_from_series(df_train[SRC_COL])
SRC_TOKENS = SRC_SPECIALS + src_chars
SRC_token2id = {t: i for i, t in enumerate(SRC_TOKENS)}
SRC_id2token = {i: t for t, i in SRC_token2id.items()}
pad_id_src = SRC_token2id[SRC_PAD]
unk_id_src = SRC_token2id[SRC_UNK]

print("Source (Urdu) vocab size:", len(SRC_TOKENS))


Saved target tokenizer: /content/drive/MyDrive/NLP/tokenizers/roman_wp-tokenizer.json | vocab size: 5000
Source (Urdu) vocab size: 54


## 3) Dataset & DataLoader (LM over WordPiece tokens)

In [4]:
"""## 3) Dataset & DataLoaders (Urdu→Roman seq2seq with BOS/EOS) — FIXED LENGTHS"""

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# Re-load tokenizer from disk to be safe
from tokenizers import Tokenizer as _Tok
tokenizer = _Tok.from_file(str(WP_JSON))
pad_id_tgt = tokenizer.token_to_id("[PAD]")
eos_id_tgt = tokenizer.token_to_id("[EOS]")
unk_id_tgt = tokenizer.token_to_id("[UNK]")

# Use [CLS] as BOS (present in your WordPiece specials). Fallbacks stay as before.
bos_id_tgt = tokenizer.token_to_id("[CLS]") or tokenizer.token_to_id("[SEP]") or eos_id_tgt

MAX_SRC_LEN = 256
MAX_TGT_LEN = 256

def src_tokenize_chars(s: str):
    return list(str(s))

def encode_src_urdu(s: str):
    toks = src_tokenize_chars(s)[:MAX_SRC_LEN]
    ids = [SRC_token2id.get(t, unk_id_src) for t in toks]
    return torch.tensor(ids, dtype=torch.long)

def encode_tgt_roman(s: str):
    # WordPiece for target; add EOS; cap length so +EOS fits
    enc = tokenizer.encode(str(s))
    ids = enc.ids[:max(0, MAX_TGT_LEN - 1)] + [eos_id_tgt]
    return torch.tensor(ids, dtype=torch.long)

class PairDataset(Dataset):
    def __init__(self, csv_path):
        df = pd.read_csv(csv_path)
        assert SRC_COL in df.columns and TGT_COL in df.columns
        df = df[[SRC_COL, TGT_COL]].dropna()
        df = df[df[SRC_COL].astype(str).str.strip() != ""]
        df = df[df[TGT_COL].astype(str).str.strip() != ""]
        self.src = df[SRC_COL].astype(str).tolist()
        self.tgt = df[TGT_COL].astype(str).tolist()
    def __len__(self): return len(self.src)
    def __getitem__(self, i):
        return self.src[i], self.tgt[i]

def collate_batch(batch):
    src_seqs, tgt_seqs = [], []
    for s, t in batch:
        src_seqs.append(encode_src_urdu(s))
        tgt_seqs.append(encode_tgt_roman(t))           # each ends with EOS

    # ✅ FIX: decoder input length must equal target length
    # dec_inp = [BOS] + tgt[:-1]
    dec_inp = [torch.tensor([bos_id_tgt] + seq[:-1].tolist(), dtype=torch.long) for seq in tgt_seqs]

    src_pad = pad_sequence(src_seqs, batch_first=True, padding_value=pad_id_src)
    dec_pad = pad_sequence(dec_inp,  batch_first=True, padding_value=pad_id_tgt)
    tgt_pad = pad_sequence(tgt_seqs, batch_first=True, padding_value=pad_id_tgt)

    src_len = torch.tensor([len(s) for s in src_seqs], dtype=torch.long)
    tgt_len = torch.tensor([len(s) for s in tgt_seqs], dtype=torch.long)

    return {
        "src": src_pad.to(device),
        "src_len": src_len.to(device),
        "dec_inp": dec_pad.to(device),
        "tgt": tgt_pad.to(device),
        "tgt_len": tgt_len.to(device),
    }

BATCH_SIZE = 64
train_loader = DataLoader(PairDataset(TRAIN_CSV), batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate_batch)
valid_loader = DataLoader(PairDataset(VALID_CSV), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

print("Batches — train:", len(train_loader), "| valid:", len(valid_loader))


Batches — train: 163 | valid: 82


## 4) Model — BiLSTM Language Model

In [5]:
"""## 4) Model — BiLSTM Encoder + Attention LSTM Decoder (with state-depth bridge)"""

import torch.nn.functional as F

SRC_V = len(SRC_TOKENS)
TGT_V = tokenizer.get_vocab_size()
EMB_DIM_SRC = 256
EMB_DIM_TGT = 256

# You can set these above; kept here for clarity
HID_DIM = 512
ENC_LAYERS = 2      # e.g., 2-layer BiLSTM encoder
DEC_LAYERS = 3      # <- you set decoder to 4
DROPOUT = 0.3       # <- your new dropout

class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, n_layers, dropout, pad_idx):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.emb_drop = nn.Dropout(dropout)
        self.lstm = nn.LSTM(
            emb_dim, hid_dim // 2, num_layers=n_layers,
            dropout=dropout if n_layers > 1 else 0.0,
            bidirectional=True, batch_first=True
        )
    def forward(self, src, src_len):
        # src: [B, S]
        e = self.emb_drop(self.emb(src))  # [B, S, E]
        packed = nn.utils.rnn.pack_padded_sequence(e, src_len.cpu(), batch_first=True, enforce_sorted=False)
        h, (hn, cn) = self.lstm(packed)
        h, _ = nn.utils.rnn.pad_packed_sequence(h, batch_first=True)  # [B, S, H]

        # Merge directions for each layer
        def _merge_bidirectional(x):
            # x: [2*n_layers, B, H/2]
            layers = []
            for l in range(0, x.size(0), 2):
                f = x[l]     # [B, H/2]
                b = x[l+1]   # [B, H/2]
                layers.append(torch.cat([f, b], dim=-1))  # [B, H]
            return torch.stack(layers, dim=0)             # [n_layers, B, H]
        hn = _merge_bidirectional(hn)  # [L_enc, B, H]
        cn = _merge_bidirectional(cn)  # [L_enc, B, H]
        return h, (hn, cn)             # enc_out [B, S, H], states [L_enc, B, H]

class LuongAttention(nn.Module):
    def __init__(self, hid_dim):
        super().__init__()
        self.Wa = nn.Linear(hid_dim, hid_dim, bias=False)
    def forward(self, dec_h, enc_out, src_mask):
        # dec_h: [B, H], enc_out: [B, S, H], src_mask: [B, S] (bool)
        q = self.Wa(dec_h).unsqueeze(1)                      # [B, 1, H]
        scores = torch.bmm(q, enc_out.transpose(1, 2)).squeeze(1)  # [B, S]
        scores = scores.masked_fill(~src_mask, -1e9)
        attn = F.softmax(scores, dim=-1)                     # [B, S]
        ctx = torch.bmm(attn.unsqueeze(1), enc_out).squeeze(1)      # [B, H]
        return ctx, attn

class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, n_layers, dropout, pad_idx):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.emb_drop = nn.Dropout(dropout)
        self.lstm = nn.LSTM(
            emb_dim + hid_dim, hid_dim, num_layers=n_layers,
            dropout=dropout if n_layers > 1 else 0.0, batch_first=True
        )
        self.attn = LuongAttention(hid_dim)
        self.out_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(hid_dim, vocab_size)

    def forward(self, dec_inp, enc_out, src_mask, hidden):
        """
        dec_inp: [B, T] (BOS + tgt[:-1])
        enc_out: [B, S, H]
        src_mask: [B, S] (True for real tokens)
        hidden: (h0, c0) each [L_dec, B, H] after bridging
        returns logits [B, T, V]
        """
        B, T = dec_inp.shape
        e = self.emb_drop(self.emb(dec_inp))        # [B, T, E]
        outputs = []
        h, c = hidden
        for t in range(T):
            dec_h_top = h[-1]                       # [B, H]
            ctx, _ = self.attn(dec_h_top, enc_out, src_mask)  # [B, H]
            x_t = torch.cat([e[:, t, :], ctx], dim=-1).unsqueeze(1)  # [B, 1, E+H]
            o, (h, c) = self.lstm(x_t, (h, c))     # o: [B, 1, H]
            logits_t = self.proj(self.out_drop(o.squeeze(1)))        # [B, V]
            outputs.append(logits_t)
        logits = torch.stack(outputs, dim=1)        # [B, T, V]
        return logits, (h, c)

class Seq2Seq(nn.Module):
    def __init__(self, enc, dec):
        super().__init__()
        self.enc = enc
        self.dec = dec

    def _bridge_states(self, hn_enc, cn_enc):
        """
        Adapt encoder states [L_enc,B,H] to decoder depth [L_dec,B,H].
        - If decoder deeper: pad extra layers with zeros.
        - If decoder shallower: take the last L_dec layers.
        """
        L_enc = hn_enc.size(0)
        L_dec = self.dec.lstm.num_layers
        if L_enc == L_dec:
            return hn_enc, cn_enc
        if L_enc < L_dec:
            pad_h = hn_enc.new_zeros(L_dec - L_enc, hn_enc.size(1), hn_enc.size(2))
            pad_c = cn_enc.new_zeros(L_dec - L_enc, cn_enc.size(1), cn_enc.size(2))
            hn = torch.cat([hn_enc, pad_h], dim=0)
            cn = torch.cat([cn_enc, pad_c], dim=0)
            return hn, cn
        # L_enc > L_dec → keep the top-most layers
        return hn_enc[-L_dec:], cn_enc[-L_dec:]

    def forward(self, src, src_len, dec_inp):
        enc_out, (hn, cn) = self.enc(src, src_len)          # enc_out: [B, S, H], hn/cn: [L_enc,B,H]
        src_mask = (src != pad_id_src)                      # [B, S] bool
        hn, cn = self._bridge_states(hn, cn)                # now [L_dec, B, H]
        logits, _ = self.dec(dec_inp, enc_out, src_mask, (hn, cn))
        return logits

# Instantiate
encoder = Encoder(
    vocab_size=SRC_V, emb_dim=EMB_DIM_SRC, hid_dim=HID_DIM,
    n_layers=ENC_LAYERS, dropout=DROPOUT, pad_idx=pad_id_src
).to(device)

decoder = Decoder(
    vocab_size=TGT_V, emb_dim=EMB_DIM_TGT, hid_dim=HID_DIM,
    n_layers=DEC_LAYERS, dropout=DROPOUT, pad_idx=pad_id_tgt
).to(device)

model = Seq2Seq(encoder, decoder).to(device)
print(model)


Seq2Seq(
  (enc): Encoder(
    (emb): Embedding(54, 256, padding_idx=0)
    (emb_drop): Dropout(p=0.3, inplace=False)
    (lstm): LSTM(256, 256, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  )
  (dec): Decoder(
    (emb): Embedding(5000, 256, padding_idx=0)
    (emb_drop): Dropout(p=0.3, inplace=False)
    (lstm): LSTM(768, 512, num_layers=3, batch_first=True, dropout=0.3)
    (attn): LuongAttention(
      (Wa): Linear(in_features=512, out_features=512, bias=False)
    )
    (out_drop): Dropout(p=0.3, inplace=False)
    (proj): Linear(in_features=512, out_features=5000, bias=True)
  )
)


## 5) Training

In [6]:
"""## 3) Dataset & DataLoaders (Urdu→Roman seq2seq with BOS/EOS) — FIXED LENGTHS"""

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# Re-load tokenizer from disk to be safe
from tokenizers import Tokenizer as _Tok
tokenizer = _Tok.from_file(str(WP_JSON))
pad_id_tgt = tokenizer.token_to_id("[PAD]")
eos_id_tgt = tokenizer.token_to_id("[EOS]")
unk_id_tgt = tokenizer.token_to_id("[UNK]")

# Use [CLS] as BOS (present in your WordPiece specials). Fallbacks stay as before.
bos_id_tgt = tokenizer.token_to_id("[CLS]") or tokenizer.token_to_id("[SEP]") or eos_id_tgt

MAX_SRC_LEN = 256
MAX_TGT_LEN = 256

def src_tokenize_chars(s: str):
    return list(str(s))

def encode_src_urdu(s: str):
    toks = src_tokenize_chars(s)[:MAX_SRC_LEN]
    ids = [SRC_token2id.get(t, unk_id_src) for t in toks]
    return torch.tensor(ids, dtype=torch.long)

def encode_tgt_roman(s: str):
    # WordPiece for target; add EOS; cap length so +EOS fits
    enc = tokenizer.encode(str(s))
    ids = enc.ids[:max(0, MAX_TGT_LEN - 1)] + [eos_id_tgt]
    return torch.tensor(ids, dtype=torch.long)

class PairDataset(Dataset):
    def __init__(self, csv_path):
        df = pd.read_csv(csv_path)
        assert SRC_COL in df.columns and TGT_COL in df.columns
        df = df[[SRC_COL, TGT_COL]].dropna()
        df = df[df[SRC_COL].astype(str).str.strip() != ""]
        df = df[df[TGT_COL].astype(str).str.strip() != ""]
        self.src = df[SRC_COL].astype(str).tolist()
        self.tgt = df[TGT_COL].astype(str).tolist()
    def __len__(self): return len(self.src)
    def __getitem__(self, i):
        return self.src[i], self.tgt[i]

def collate_batch(batch):
    src_seqs, tgt_seqs = [], []
    for s, t in batch:
        src_seqs.append(encode_src_urdu(s))
        tgt_seqs.append(encode_tgt_roman(t))           # each ends with EOS

    # ✅ FIX: decoder input length must equal target length
    # dec_inp = [BOS] + tgt[:-1]
    dec_inp = [torch.tensor([bos_id_tgt] + seq[:-1].tolist(), dtype=torch.long) for seq in tgt_seqs]

    src_pad = pad_sequence(src_seqs, batch_first=True, padding_value=pad_id_src)
    dec_pad = pad_sequence(dec_inp,  batch_first=True, padding_value=pad_id_tgt)
    tgt_pad = pad_sequence(tgt_seqs, batch_first=True, padding_value=pad_id_tgt)

    src_len = torch.tensor([len(s) for s in src_seqs], dtype=torch.long)
    tgt_len = torch.tensor([len(s) for s in tgt_seqs], dtype=torch.long)

    return {
        "src": src_pad.to(device),
        "src_len": src_len.to(device),
        "dec_inp": dec_pad.to(device),
        "tgt": tgt_pad.to(device),
        "tgt_len": tgt_len.to(device),
    }

BATCH_SIZE = 32
train_loader = DataLoader(PairDataset(TRAIN_CSV), batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate_batch)
valid_loader = DataLoader(PairDataset(VALID_CSV), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

print("Batches — train:", len(train_loader), "| valid:", len(valid_loader))


Batches — train: 326 | valid: 163


In [7]:
# === CELL: Train Now — epochs, early stopping, checkpoint (no 'verbose' arg) ===
import math, time, os, torch, torch.nn as nn
from tqdm.auto import tqdm

# ---- Safety: paths / globals ----
try:
    BEST_CKPT
except NameError:
    BEST_CKPT = "bilstm_seq2seq_best.pt"

assert 'train_loader' in globals() and 'valid_loader' in globals(), "train_loader/valid_loader not defined. Run the DataLoaders cell first."
assert len(train_loader) > 0 and len(valid_loader) > 0, "Your loaders are empty. Check your CSV paths/splits/filters."
assert 'bos_id_tgt' in globals() and 'eos_id_tgt' in globals(), "BOS/EOS ids not found. Run the tokenizer/vocab + DataLoaders cells."

# ---- Loss / Optimizer / Schedule ----
criterion = nn.CrossEntropyLoss(ignore_index=pad_id_tgt)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)

# NOTE: your torch version doesn't support 'verbose' here
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.7, patience=2  # <- no verbose
)

# ---- Config ----
EPOCHS   = 100        # change if needed
CLIP     = 1.0
PATIENCE = 2         # early stopping patience on valid loss

def _align_logits_targets(logits, tgt):
    """Ensure time dims match to avoid shape errors."""
    if logits.size(1) != tgt.size(1):
        T = min(logits.size(1), tgt.size(1))
        logits = logits[:, :T, :]
        tgt    = tgt[:, :T]
    return logits, tgt

def _count_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

print(f"Model trainable params: {_count_params(model):,}")

def train_epoch():
    model.train()
    tot_loss, tot_tok = 0.0, 0
    pbar = tqdm(train_loader, leave=False, desc="Train")
    for batch in pbar:
        src     = batch["src"]
        src_len = batch["src_len"]
        dec_inp = batch["dec_inp"]   # BOS + tgt[:-1]
        tgt     = batch["tgt"]       # ends with EOS

        optimizer.zero_grad(set_to_none=True)
        logits = model(src, src_len, dec_inp)
        logits, tgt = _align_logits_targets(logits, tgt)

        B, T, V = logits.shape
        loss = criterion(logits.reshape(B*T, V), tgt.reshape(B*T))
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), CLIP)
        optimizer.step()

        with torch.no_grad():
            n = (tgt != pad_id_tgt).sum().item()
        tot_loss += loss.item() * max(1, n)
        tot_tok  += max(1, n)
        pbar.set_postfix(loss=f"{tot_loss/max(1,tot_tok):.4f}")
    return tot_loss / max(1, tot_tok)

@torch.no_grad()
def valid_epoch():
    model.eval()
    tot_loss, tot_tok = 0.0, 0
    pbar = tqdm(valid_loader, leave=False, desc="Valid")
    for batch in pbar:
        src     = batch["src"]
        src_len = batch["src_len"]
        dec_inp = batch["dec_inp"]
        tgt     = batch["tgt"]

        logits = model(src, src_len, dec_inp)
        logits, tgt = _align_logits_targets(logits, tgt)

        B, T, V = logits.shape
        loss = criterion(logits.reshape(B*T, V), tgt.reshape(B*T))
        n = (tgt != pad_id_tgt).sum().item()
        tot_loss += loss.item() * max(1, n)
        tot_tok  += max(1, n)
        pbar.set_postfix(loss=f"{tot_loss/max(1,tot_tok):.4f}")
    return tot_loss / max(1, tot_tok)

# ---- Run epochs now ----
best_val = float("inf")
epochs_no_improve = 0

for epoch in range(1, EPOCHS + 1):
    print(f"\nEpoch {epoch}/{EPOCHS}")
    tr = train_epoch()
    va = valid_epoch()

    # report LR change (since scheduler has no 'verbose' in this torch version)
    prev_lr = optimizer.param_groups[0]['lr']
    scheduler.step(va)
    new_lr = optimizer.param_groups[0]['lr']
    lr_msg = f"{new_lr:.2e}" + ("  (↓ lr)" if new_lr < prev_lr else "")

    print(f"➡️  train NLL: {tr:.4f} | valid NLL: {va:.4f} | lr={lr_msg}")

    if va < best_val - 1e-6:
        best_val = va
        epochs_no_improve = 0
        torch.save(
            {"model": model.state_dict(),
             "config": {
                 "SRC_V": len(SRC_TOKENS), "TGT_V": tokenizer.get_vocab_size(),
                 "EMB_DIM_SRC": EMB_DIM_SRC, "EMB_DIM_TGT": EMB_DIM_TGT,
                 "HID_DIM": HID_DIM, "ENC_LAYERS": ENC_LAYERS, "DEC_LAYERS": DEC_LAYERS,
                 "DROPOUT": DROPOUT,
                 "pad_id_src": pad_id_src, "pad_id_tgt": pad_id_tgt,
                 "bos_id_tgt": bos_id_tgt, "eos_id_tgt": eos_id_tgt,
             }},
            BEST_CKPT
        )
        print(f"✅ Saved best checkpoint: {BEST_CKPT}")
    else:
        epochs_no_improve += 1
        print(f"⏸️  no improvement ({epochs_no_improve}/{PATIENCE})")

    if epochs_no_improve >= PATIENCE:
        print("⏹️ Early stopping triggered.")
        break

print("\nTraining finished. Best valid NLL:", f"{best_val:.4f}")


Model trainable params: 13,578,632

Epoch 1/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 6.2434 | valid NLL: 5.9281 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 2/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 5.7425 | valid NLL: 5.6302 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 3/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 5.5133 | valid NLL: 5.5200 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 4/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 5.3884 | valid NLL: 5.3923 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 5/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 5.2814 | valid NLL: 5.2757 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 6/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 5.1357 | valid NLL: 5.0920 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 7/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 4.9921 | valid NLL: 5.0186 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 8/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 4.8653 | valid NLL: 4.8753 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 9/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 4.7462 | valid NLL: 4.7407 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 10/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 4.6277 | valid NLL: 4.6092 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 11/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 4.5053 | valid NLL: 4.5133 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 12/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 4.4012 | valid NLL: 4.4377 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 13/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 4.2956 | valid NLL: 4.3258 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 14/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 4.1903 | valid NLL: 4.2351 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 15/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 4.0876 | valid NLL: 4.1533 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 16/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 4.0029 | valid NLL: 4.0447 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 17/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 3.9124 | valid NLL: 3.9889 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 18/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 3.8361 | valid NLL: 4.0390 | lr=1.00e-04
⏸️  no improvement (1/2)

Epoch 19/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 3.7566 | valid NLL: 3.8497 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 20/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 3.6763 | valid NLL: 3.7893 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 21/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 3.5914 | valid NLL: 3.6805 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 22/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 3.5085 | valid NLL: 3.6188 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 23/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 3.4285 | valid NLL: 3.5364 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 24/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 3.3642 | valid NLL: 3.4848 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 25/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 3.2771 | valid NLL: 3.4443 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 26/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 3.2079 | valid NLL: 3.4068 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 27/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 3.1384 | valid NLL: 3.2944 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 28/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 3.0739 | valid NLL: 3.2703 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 29/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.9951 | valid NLL: 3.1863 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 30/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.9409 | valid NLL: 3.1154 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 31/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.8685 | valid NLL: 3.0914 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 32/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.8024 | valid NLL: 3.0232 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 33/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.7421 | valid NLL: 2.9638 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 34/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.6821 | valid NLL: 2.9234 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 35/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.6163 | valid NLL: 2.8802 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 36/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.5600 | valid NLL: 2.8210 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 37/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.5025 | valid NLL: 2.7696 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 38/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.4384 | valid NLL: 2.7101 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 39/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.3912 | valid NLL: 2.6853 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 40/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.3312 | valid NLL: 2.7115 | lr=1.00e-04
⏸️  no improvement (1/2)

Epoch 41/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.2759 | valid NLL: 2.6177 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 42/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.2222 | valid NLL: 2.5471 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 43/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.1681 | valid NLL: 2.5000 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 44/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.1216 | valid NLL: 2.4747 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 45/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.0734 | valid NLL: 2.4388 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 46/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 2.0179 | valid NLL: 2.3987 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 47/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.9727 | valid NLL: 2.3706 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 48/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.9185 | valid NLL: 2.3454 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 49/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.8762 | valid NLL: 2.2790 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 50/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.8364 | valid NLL: 2.2482 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 51/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.7910 | valid NLL: 2.2350 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 52/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.7610 | valid NLL: 2.1899 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 53/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.7138 | valid NLL: 2.2504 | lr=1.00e-04
⏸️  no improvement (1/2)

Epoch 54/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.6647 | valid NLL: 2.1645 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 55/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.6211 | valid NLL: 2.1468 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 56/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.5941 | valid NLL: 2.0842 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 57/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.5550 | valid NLL: 2.0724 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 58/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.5121 | valid NLL: 2.0229 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 59/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.4931 | valid NLL: 2.0156 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 60/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.4466 | valid NLL: 1.9950 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 61/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.4066 | valid NLL: 1.9529 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 62/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.3757 | valid NLL: 1.9588 | lr=1.00e-04
⏸️  no improvement (1/2)

Epoch 63/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.3504 | valid NLL: 1.9318 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 64/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.3176 | valid NLL: 1.8878 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 65/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.2874 | valid NLL: 1.8690 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 66/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.2579 | valid NLL: 1.8704 | lr=1.00e-04
⏸️  no improvement (1/2)

Epoch 67/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.2316 | valid NLL: 1.8322 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 68/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.2027 | valid NLL: 1.8172 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 69/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.1712 | valid NLL: 1.7906 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 70/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.1419 | valid NLL: 1.7954 | lr=1.00e-04
⏸️  no improvement (1/2)

Epoch 71/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.1208 | valid NLL: 1.7621 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 72/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.0889 | valid NLL: 1.7575 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 73/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.0645 | valid NLL: 1.7549 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 74/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.0394 | valid NLL: 1.7379 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 75/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 1.0212 | valid NLL: 1.7364 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 76/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 0.9955 | valid NLL: 1.6962 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 77/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 0.9710 | valid NLL: 1.6780 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 78/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 0.9524 | valid NLL: 1.6671 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 79/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 0.9330 | valid NLL: 1.6808 | lr=1.00e-04
⏸️  no improvement (1/2)

Epoch 80/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 0.9012 | valid NLL: 1.6587 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 81/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 0.8908 | valid NLL: 1.6418 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 82/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 0.8602 | valid NLL: 1.6153 | lr=1.00e-04
✅ Saved best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt

Epoch 83/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 0.8412 | valid NLL: 1.6295 | lr=1.00e-04
⏸️  no improvement (1/2)

Epoch 84/100


Train:   0%|          | 0/326 [00:00<?, ?it/s]

Valid:   0%|          | 0/163 [00:00<?, ?it/s]

➡️  train NLL: 0.8209 | valid NLL: 1.6239 | lr=1.00e-04
⏸️  no improvement (2/2)
⏹️ Early stopping triggered.

Training finished. Best valid NLL: 1.6153


## 6) Validation — qualitative samples

In [8]:
"""## 6) Validation — Urdu→Roman greedy decode (with state-depth bridge)"""

import torch

@torch.no_grad()
def greedy_decode(src_text: str, max_len: int = 256):
    """
    Greedy decode that:
      1) runs the encoder,
      2) BRIDGES encoder states to decoder depth, and
      3) steps the decoder until [EOS] or max_len.
    Works for any (ENC_LAYERS, DEC_LAYERS).
    """
    # ---- encode source (Urdu) ----
    src_ids = encode_src_urdu(src_text).unsqueeze(0).to(device)  # [1, S]
    src_len = torch.tensor([src_ids.shape[1]], dtype=torch.long, device=device)

    # ---- run encoder ----
    enc_out, (hn, cn) = encoder(src_ids, src_len)               # hn/cn: [L_enc, 1, H]
    src_mask = (src_ids != pad_id_src)                          # [1, S]

    # ---- bridge encoder states to decoder depth ----
    L_enc = hn.size(0)
    L_dec = decoder.lstm.num_layers
    if L_enc < L_dec:
        pad_h = hn.new_zeros(L_dec - L_enc, hn.size(1), hn.size(2))
        pad_c = cn.new_zeros(L_dec - L_enc, cn.size(1), cn.size(2))
        h, c = torch.cat([hn, pad_h], dim=0), torch.cat([cn, pad_c], dim=0)
    elif L_enc > L_dec:
        h, c = hn[-L_dec:], cn[-L_dec:]
    else:
        h, c = hn, cn

    # ---- step decoder greedily ----
    dec_id = torch.tensor([[bos_id_tgt]], dtype=torch.long, device=device)  # [1, 1]
    out_ids = []

    for _ in range(max_len):
        # We feed a single step each time; decoder returns logits for this step
        logits, (h, c) = decoder(dec_id, enc_out, src_mask, (h, c))  # logits: [1, 1, V]
        next_id = int(logits[:, -1, :].argmax(dim=-1).item())
        if next_id == eos_id_tgt:
            break
        out_ids.append(next_id)
        dec_id = torch.tensor([[next_id]], dtype=torch.long, device=device)

    # ---- detokenize target ----
    return tokenizer.decode(out_ids, skip_special_tokens=True).strip()

# Show a few SRC / PRED / GOLD lines from VALID_CSV
import pandas as pd
df_v = pd.read_csv(VALID_CSV)
samples = list(zip(df_v[SRC_COL].astype(str).tolist(),
                   df_v[TGT_COL].astype(str).tolist()))[:10]

for i, (src_line, gold_line) in enumerate(samples, 1):
    pred = greedy_decode(src_line, max_len=MAX_TGT_LEN)
    print(f"[{i}] SRC : {src_line[:160]}")
    print(f"    PRED: {pred[:160]}")
    print(f"    GOLD: {gold_line[:160]}\n")


[1] SRC : یہ مسیحائی اسے بھول گئی ہے محسنؔ
    PRED: ye masiha . i use bhuul ga . i hai ' mohsin '
    GOLD: ye masiha.i use bhuul ga.i hai 'mohsin'

[2] SRC : شیر مردوں سے ہوا بیشۂ تحقیق تہی
    PRED: bazm - e - gardun se hua mahram - e - zanjir vahi
    GOLD: sher mardon se hua besha-e-tehqiq tahi

[3] SRC : بلائے جاں ہے ادا تیری اک جہاں کے لیے
    PRED: bala - e - jan hai ada teri ik jahan ke liye
    GOLD: bala-e-jan hai ada teri ik jahan ke liye

[4] SRC : وقت بدلا تو تری رائے بدل جائے گی
    PRED: vaqt - e - bala to tiri ada . e badal ja . egi
    GOLD: vaqt badla to tiri raa.e badal ja.egi

[5] SRC : صحرا میں اے خدا کوئی دیوار بھی نہیں
    PRED: sahra men ai khuda koi divar bhi nahin
    GOLD: sahra men ai khuda koi divar bhi nahin

[6] SRC : لو آج ہم نے توڑ دیا رشتۂ امید
    PRED: lo aaj ham ne tod diya khurshid - e - khayal
    GOLD: lo aaj ham ne tod diya rishta-e-umid

[7] SRC : قاصدا ہم فقیر لوگوں کا
    PRED: sahra ham - zaban - e - malamat ka
    GOLD: qasida ham faqir lo

## 7) Evaluation — Perplexity, BLEU-4 (char), CER, Exact Match


In [12]:
"""## 7) Evaluation — Perplexity, BLEU, CER (with fixed eval collate + postprocess)"""

import math, pandas as pd, torch, torch.nn as nn, re
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm.auto import tqdm

# ---- Assumes these exist from previous cells ----
# device, model (Seq2Seq), encoder, decoder
# tokenizer, pad_id_src, pad_id_tgt, bos_id_tgt, eos_id_tgt
# SRC_COL, TGT_COL, TRAIN_CSV, VALID_CSV, TEST_CSV
# encode_src_urdu(text) -> 1D LongTensor (Urdu chars to ids)
# encode_tgt_roman(text) -> 1D LongTensor (WordPiece ids + EOS at end)
# greedy_decode(src_text, max_len=256) -> str
assert 'device' in globals() and 'model' in globals(), "Make sure model/device are defined."

# ---------- Normalization / post-processing ----------
def roman_postprocess(s: str) -> str:
    s = (s or "").strip()
    # collapse spaced hyphens/dots like " - e - " or " . "
    s = re.sub(r"\s*-\s*", "-", s)
    s = re.sub(r"\s*\.\s*", ".", s)
    # collapse repeats
    s = re.sub(r"-{2,}", "-", s)
    s = re.sub(r"\.{2,}", ".", s)
    # tidy spaces
    s = re.sub(r"\s+", " ", s).strip()
    return s

APPLY_NORM_TO_GOLD = True  # set False to score raw references

# ---------- 7.1 Test DataLoader with CORRECT decoder inputs ----------
class PairDatasetEval(torch.utils.data.Dataset):
    def __init__(self, csv_path):
        df = pd.read_csv(csv_path)
        df = df[[SRC_COL, TGT_COL]].dropna()
        df = df[df[SRC_COL].astype(str).str.strip() != ""]
        df = df[df[TGT_COL].astype(str).str.strip() != ""]
        self.src = df[SRC_COL].astype(str).tolist()
        self.tgt = df[TGT_COL].astype(str).tolist()
    def __len__(self): return len(self.src)
    def __getitem__(self, i): return self.src[i], self.tgt[i]

def collate_batch_eval(batch):
    src_seqs, tgt_seqs = [], []
    for s, t in batch:
        src_seqs.append(encode_src_urdu(s))
        tgt_ids = encode_tgt_roman(t)   # must already end with EOS
        tgt_seqs.append(tgt_ids)

    # dec_inp must be same length as tgt: dec_inp = [BOS] + tgt[:-1]
    dec_inp = [torch.tensor([bos_id_tgt] + seq[:-1].tolist(), dtype=torch.long) for seq in tgt_seqs]

    src_pad = pad_sequence(src_seqs, batch_first=True, padding_value=pad_id_src)
    dec_pad = pad_sequence(dec_inp,  batch_first=True, padding_value=pad_id_tgt)
    tgt_pad = pad_sequence(tgt_seqs, batch_first=True, padding_value=pad_id_tgt)

    src_len = torch.tensor([len(s) for s in src_seqs], dtype=torch.long)

    return {
        "src": src_pad.to(device),
        "src_len": src_len.to(device),
        "dec_inp": dec_pad.to(device),
        "tgt": tgt_pad.to(device),
    }

BATCH_SIZE_EVAL = 64
test_loader = DataLoader(PairDatasetEval(TEST_CSV), batch_size=BATCH_SIZE_EVAL,
                         shuffle=False, collate_fn=collate_batch_eval)
print("Test batches:", len(test_loader))

# ---------- 7.2 Load best checkpoint if present ----------
try:
    _ckpt = torch.load(BEST_CKPT, map_location=device)
    model.load_state_dict(_ckpt["model"])
    print(f"Loaded checkpoint: {BEST_CKPT}")
except Exception as e:
    print("Note: could not load BEST_CKPT (using current weights):", e)

# ---------- 7.3 Perplexity (teacher forcing, PAD-masked CE) ----------
criterion_eval = nn.CrossEntropyLoss(ignore_index=pad_id_tgt)

@torch.no_grad()
def evaluate_perplexity():
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    for batch in tqdm(test_loader, leave=False, desc="Eval PPL"):
        src, src_len = batch["src"], batch["src_len"]
        dec_inp, tgt = batch["dec_inp"], batch["tgt"]
        logits = model(src, src_len, dec_inp)  # [B, T, V]
        # Ensure time dims match (safety)
        if logits.size(1) != tgt.size(1):
            T = min(logits.size(1), tgt.size(1))
            logits, tgt = logits[:, :T, :], tgt[:, :T]
        B, T, V = logits.shape
        loss = criterion_eval(logits.reshape(B*T, V), tgt.reshape(B*T))
        n = (tgt != pad_id_tgt).sum().item()
        total_loss += loss.item() * max(1, n)
        total_tokens += max(1, n)
    nll = total_loss / max(1, total_tokens)
    ppl = math.exp(nll)
    return nll, ppl

# ---------- 7.4 BLEU (simple corpus BLEU over whitespace tokens) ----------
def _brevity_penalty(ref_len, hyp_len):
    if hyp_len > ref_len: return 1.0
    if hyp_len == 0: return 0.0
    return math.exp(1 - ref_len / hyp_len)

def _ngram_counts(tokens, n):
    return {tuple(tokens[i:i+n]): 1 for i in range(len(tokens)-n+1)}

def _modified_precision(hyp, refs, n):
    hyp_ngrams = _ngram_counts(hyp, n)
    if not hyp_ngrams: return 0.0, 0
    max_ref_counts = {}
    for r in refs:
        for g in _ngram_counts(r, n):
            max_ref_counts[g] = max(max_ref_counts.get(g, 0), 1)
    overlap = 0
    total = 0
    for i in range(len(hyp)-n+1):
        g = tuple(hyp[i:i+n])
        total += 1
        overlap += 1 if max_ref_counts.get(g, 0) > 0 else 0
    if total == 0: return 0.0, 0
    return overlap / total, total

def corpus_bleu(hyps, refs, max_order=4, smooth_eps=1e-9):
    weights = [1.0 / max_order] * max_order
    log_p_sum = 0.0
    hyp_len_total = 0
    ref_len_total = 0
    for h, r in zip(hyps, refs):
        ht = h.strip().split()
        rt = r.strip().split()
        hyp_len_total += len(ht)
        ref_len_total += len(rt)
        precisions = []
        for n in range(1, max_order+1):
            p, _ = _modified_precision(ht, [rt], n)
            precisions.append(p if p > 0 else smooth_eps)
        log_p_sum += sum(w * math.log(p) for w, p in zip(weights, precisions))
    geo_mean = math.exp(log_p_sum / max(1, len(hyps)))
    bp = _brevity_penalty(ref_len_total, hyp_len_total)
    return bp * geo_mean

# ---------- 7.5 CER (character error rate) ----------
def edit_distance(a: str, b: str) -> int:
    m, n = len(a), len(b)
    if m < n: a, b, m, n = b, a, n, m
    prev = list(range(n+1))
    for i in range(1, m+1):
        curr = [i] + [0]*n
        ca = a[i-1]
        for j in range(1, n+1):
            cb = b[j-1]
            cost = 0 if ca == cb else 1
            curr[j] = min(prev[j] + 1, curr[j-1] + 1, prev[j-1] + cost)
        prev = curr
    return prev[n]

def char_error_rate(hyps, refs):
    total_ed = 0
    total_chars = 0
    for h, r in zip(hyps, refs):
        ed = edit_distance(h, r)
        total_ed += ed
        total_chars += max(1, len(r))
    return total_ed / max(1, total_chars), total_ed / max(1, len(hyps))

# ---------- 7.6 Run full evaluation ----------
@torch.no_grad()
def run_full_eval(show_samples=10, use_beam=False, max_len=256):
    # 1) Perplexity (teacher forcing)
    nll, ppl = evaluate_perplexity()

    # 2) Decode entire test set for BLEU/CER
    df_t = pd.read_csv(TEST_CSV)
    src_list = df_t[SRC_COL].astype(str).tolist()
    ref_list_raw = df_t[TGT_COL].astype(str).tolist()
    ref_list = [roman_postprocess(r) for r in ref_list_raw] if APPLY_NORM_TO_GOLD else ref_list_raw

    hyp_list = []
    dec_fn = (lambda s: beam_decode(s, max_len=max_len, beam_size=5)) if 'beam_decode' in globals() and use_beam else (lambda s: greedy_decode(s, max_len=max_len))
    for s in tqdm(src_list, leave=False, desc=("Beam decode (test)" if use_beam else "Greedy decode (test)")):
        hyp_list.append(roman_postprocess(dec_fn(s)))

    bleu = corpus_bleu(hyp_list, ref_list)
    cer, avg_edit = char_error_rate(hyp_list, ref_list)

    print("\n=== Test Metrics ===")
    print(f"Per-token NLL : {nll:.4f}")
    print(f"Perplexity    : {ppl:.3f}")
    print(f"Corpus BLEU   : {bleu*100:.2f}")
    print(f"CER           : {cer*100:.2f}%")
    print(f"Avg Edit Dist : {avg_edit:.3f} chars/sample")

    # 3) Qualitative samples
    print("\n=== Qualitative Samples ===")
    for i in range(min(show_samples, len(src_list))):
        print(f"[{i+1}] SRC : {src_list[i][:160]}")
        print(f"     PRED: {hyp_list[i][:160]}")
        print(f"     GOLD: {ref_list[i][:160]}\n")

    # 4) Worst-by-CER quick view
    rows = []
    for s, h, r in zip(src_list, hyp_list, ref_list):
        ed = edit_distance(h, r)
        cer_i = ed / max(1, len(r))
        rows.append((cer_i, ed, len(r), s, h, r))
    rows.sort(reverse=True, key=lambda x: x[0])
    worst_n = min(15, len(rows))
    print(f"=== Worst {worst_n} by CER ===")
    for k in range(worst_n):
        cer_i, ed, rlen, s, h, r = rows[k]
        print(f"#{k+1:02d} CER={cer_i*100:.1f}%  ED={ed}  | SRC: {s[:60]}")
        print(f"    PRED: {h[:120]}")
        print(f"    GOLD: {r[:120]}\n")

# Go!
run_full_eval(show_samples=10, use_beam=True, max_len=256)


Test batches: 82
Loaded checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt


Eval PPL:   0%|          | 0/82 [00:00<?, ?it/s]

Beam decode (test):   0%|          | 0/5214 [00:00<?, ?it/s]


=== Test Metrics ===
Per-token NLL : 1.5954
Perplexity    : 4.930
Corpus BLEU   : 0.85
CER           : 20.08%
Avg Edit Dist : 8.097 chars/sample

=== Qualitative Samples ===
[1] SRC : کرے گر رند درد دل سے ہاو ہوئے مستانا
     PRED: kare gar rang-e-dard dil se hua hue begana
     GOLD: kare gar rind dard-e-dil se hav-hu-e-mastana

[2] SRC : مگر نوشتۂ قسمت کسی کو کیا معلوم
     PRED: magar rishta-e-qismat kisi ko kya ma.alum
     GOLD: magar navishta-e-qismat kisi ko kya ma.alum

[3] SRC : دہر کے غم سے ہوا ربط تو ہم بھول گئے
     PRED: dahr ke gham se hua rabt to ham bhuul ga.e
     GOLD: dahr ke gham se hua rabt to ham bhuul ga.e

[4] SRC : زمیں کی کیسی وکالت ہو پھر نہیں چلتی
     PRED: zamin ki kaisi karvan ho phir nahin dhuan
     GOLD: zamin ki kaisi vakalat ho phir nahin chalti

[5] SRC : فیضؔ زندہ رہیں وہ ہیں تو سہی
     PRED: ' faiz ' ullah tumhin vo hain to sahi
     GOLD: 'faiz' zinda rahen vo hain to sahi

[6] SRC : عشق کو رہنما کیا تو نے
     PRED: ishq ko hangama kiya tu ne


## 8) Artifacts summary

In [10]:

print("Tokenizer JSON :", WP_JSON, WP_JSON.exists())
print("Tokenizer vocab:", WP_VOCAB_TXT, WP_VOCAB_TXT.exists())
print("Best checkpoint:", BEST_CKPT, BEST_CKPT.exists())


Tokenizer JSON : /content/drive/MyDrive/NLP/tokenizers/roman_wp-tokenizer.json True
Tokenizer vocab: /content/drive/MyDrive/NLP/tokenizers/roman_wp-vocab.txt True
Best checkpoint: /content/drive/MyDrive/NLP/experiments/bilstm_wp_best.pt True
