In [2]:
# ============================================================
# Colab Cell #1 — Setup, Mount, Paths, Installs, Quick Scan
# ============================================================
import sys, os, re, random, json, math, unicodedata, shutil, glob
from pathlib import Path
from datetime import datetime

# --------------------------
# 0) Basic environment info
# --------------------------
print("Python:", sys.version)
try:
    import torch
    print("PyTorch:", torch.__version__)
    print("CUDA available:", torch.cuda.is_available())
    if torch.cuda.is_available():
        print("CUDA device:", torch.cuda.get_device_name(0))
except Exception as e:
    print("Torch not yet available:", e)

# --------------------------
# 1) Mount Google Drive
# --------------------------
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    ROOT = Path("/content")
    GDRIVE = Path("/content/drive/MyDrive")
else:
    # Fallback for local dev (optional)
    ROOT = Path.cwd()
    GDRIVE = ROOT

PROJECT_DIR = ROOT / "nmt_urdu_roman"
for p in ["data","artifacts","models","logs","src","runs"]:
    (PROJECT_DIR / p).mkdir(parents=True, exist_ok=True)
print("Project root:", PROJECT_DIR)

# --------------------------
# 2) Pip installs
# --------------------------
# Keep installs minimal in first cell; add others later when needed
!pip -q install --upgrade pip
!pip -q install sacrebleu==2.4.2 jiwer==3.0.4 python-Levenshtein==0.25.1 sentencepiece==0.2.0 pyarrow==17.0.0 pandas==2.2.2 pandarallel==1.6.5

import pandas as pd
import sacrebleu, sentencepiece as spm
from jiwer import cer
import Levenshtein

# --------------------------
# 3) Reproducibility
# --------------------------
import numpy as np
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
try:
    import torch
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
except:
    pass
print("Seed fixed to", SEED)

# --------------------------
# 4) Dataset wiring
# --------------------------
# >>>> IMPORTANT: Set this to your Drive path where poet folders live <<<<
# Example: /content/drive/MyDrive/marwah/dataset
CANDIDATE_PATHS = [
    GDRIVE / "marwah" / "dataset",
    GDRIVE / "dataset" / "urdu_ghazals_rekhta",
    PROJECT_DIR / "data" / "urdu_ghazals_rekhta"
]

DATASET_DIR = None
for p in CANDIDATE_PATHS:
    if p.exists() and any((p / d).exists() for d in [
        "ahmad-faraz","akbar-allahabadi","allama-iqbal","ameer-khusrau",
        "mirza-ghalib","parveen-shakir","faiz-ahmad-faiz"
    ]):
        DATASET_DIR = p
        break

if DATASET_DIR is None:
    print("⚠️ Could not find your local poet folders in Drive.")
    print("Attempting to clone GitHub dataset into project data ...")
    !rm -rf /content/urdu_ghazals_rekhta
    !git clone -q https://github.com/amir9ume/urdu_ghazals_rekhta.git /content/urdu_ghazals_rekhta
    src = Path("/content/urdu_ghazals_rekhta")
    if src.exists():
        # try to locate poet folders
        poet_dirs = [d for d in src.iterdir() if d.is_dir()]
        if poet_dirs:
            target = PROJECT_DIR / "data" / "urdu_ghazals_rekhta"
            target.mkdir(parents=True, exist_ok=True)
            # Copy structure lightly to our data dir
            for d in poet_dirs:
                shutil.copytree(d, target / d.name, dirs_exist_ok=True)
            DATASET_DIR = target
        else:
            print("❌ Clone succeeded but poet directories not found. Please set DATASET_DIR manually.")
else:
    print("✅ Found dataset at:", DATASET_DIR)

# Persist a config file
cfg = {
    "created_at": datetime.utcnow().isoformat() + "Z",
    "seed": SEED,
    "dataset_dir": str(DATASET_DIR) if DATASET_DIR else None,
    "device": "cuda" if ("torch" in sys.modules and torch.cuda.is_available()) else "cpu",
    "project_dir": str(PROJECT_DIR)
}
with open(PROJECT_DIR / "project_config.json", "w", encoding="utf-8") as f:
    json.dump(cfg, f, indent=2, ensure_ascii=False)
print("Saved config:", PROJECT_DIR / "project_config.json")

# --------------------------
# 5) Quick scan of poets
# --------------------------
def list_poets(base: Path, limit=30):
    if base is None or not base.exists():
        print("Dataset path not set. Please update CANDIDATE_PATHS.")
        return []
    poets = sorted([d.name for d in base.iterdir() if d.is_dir()])
    print(f"Found {len(poets)} poet folders (showing up to {limit}):")
    for name in poets[:limit]:
        print(" -", name)
    return poets

poets = list_poets(DATASET_DIR)

# Probe a few files inside first 2 poets to understand file patterns
def sample_files(base: Path, poets_list, n_per_poet=3):
    samples = []
    for poet in poets_list[:2]:  # keep light
        folder = base / poet
        if not folder.exists():
            continue
        files = sorted([*folder.rglob("*")])
        text_like = [p for p in files if p.suffix.lower() in (".txt", ".csv", ".json", ".md")]
        chosen = text_like[:n_per_poet] if text_like else files[:n_per_poet]
        for f in chosen:
            size_kb = os.path.getsize(f) / 1024.0
            samples.append({"poet": poet, "path": str(f), "size_kb": round(size_kb, 2)})
    df = pd.DataFrame(samples)
    if not df.empty:
        from caas_jupyter_tools import display_dataframe_to_user
        display_dataframe_to_user("Dataset quick scan", df)
    else:
        print("No text-like files found in first two poets; structure may be nested differently.")
    return samples

_ = sample_files(DATASET_DIR, poets)

print("\n✅ Setup complete. If the scan table opened, review the paths & sizes.")
print("If your dataset lives elsewhere in Drive, update CANDIDATE_PATHS above and re-run this cell.")


Python: 3.12.11 (main, Jun  4 2025, 08:56:18) [GCC 11.4.0]
PyTorch: 2.8.0+cu126
CUDA available: True
CUDA device: Tesla T4
Mounted at /content/drive
Project root: /content/nmt_urdu_roman
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m60.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[33m  DEPRECATION: Building 'pandarallel' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-isolation`), or adding a `pyproject.toml` file to the source tree of 'pandarallel'. Discussion can be found at https://github.com/pypa/pip/issues/6334[0m[33m
[0m  Building wheel for pandarallel (setup.py) ... [?25l[?25hdone
Seed fixed to 42
⚠️ Could not find your local poet folders in Drive.
Attempting to clone

  "created_at": datetime.utcnow().isoformat() + "Z",


ModuleNotFoundError: No module named 'caas_jupyter_tools'

In [8]:
!pip -q install --force-reinstall "sentencepiece==0.1.99"


  Preparing metadata (setup.py) ... [?25l[?25hdone
[33m  DEPRECATION: Building 'sentencepiece' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-isolation`), or adding a `pyproject.toml` file to the source tree of 'sentencepiece'. Discussion can be found at https://github.com/pypa/pip/issues/6334[0m[33m
[0m  Building wheel for sentencepiece (setup.py) ... [?25l[?25hdone


In [3]:
# ============================================================
# Colab Cell #1B — Point to /MyDrive/dataset + safe quick scan
# ============================================================
import json, os
from pathlib import Path
from datetime import datetime

# If you changed your folder name, update this path:
DATASET_DIR = Path("/content/drive/MyDrive/dataset")

assert DATASET_DIR.exists(), f"Dataset path not found: {DATASET_DIR}"

# Basic sanity: look for a few poet folders
EXPECTED = {
    "ahmad-faraz","akbar-allahabadi","allama-iqbal","altaf-hussain-hali",
    "ameer-khusrau","bahadur-shah-zafar","dagh-dehlvi","fahmida-riaz",
    "faiz-ahmad-faiz","firaq-gorakhpuri","gulzar","habib-jalib",
    "jaan-nisar-akhtar","jaun-eliya","javed-akhtar","meer-taqi-meer",
    "mirza-ghalib","parveen-shakir"
}
present = {d.name for d in DATASET_DIR.iterdir() if d.is_dir()}
print(f"✅ DATASET_DIR set to: {DATASET_DIR}")
print(f"Found {len(present)} top-level folders.")

missing = sorted(list(EXPECTED - present))
if missing:
    print("Note: some expected poet folders not found (ok if your dump differs). Example missing:", missing[:8])

# Show a tiny table-like listing (no special display tools needed)
def sample_files(base: Path, n_per_poet=3, poets_limit=5):
    rows = []
    poets = sorted([d for d in base.iterdir() if d.is_dir()])[:poets_limit]
    for poet_dir in poets:
        files = sorted([*poet_dir.rglob("*")])
        text_like = [p for p in files if p.suffix.lower() in (".txt", ".csv", ".json", ".md")]
        chosen = text_like[:n_per_poet] if text_like else files[:n_per_poet]
        for f in chosen:
            try:
                size_kb = os.path.getsize(f) / 1024.0
            except Exception:
                size_kb = -1
            rows.append((poet_dir.name, str(f.relative_to(base)), round(size_kb, 2)))
    print("\nSample of discovered files:")
    print("poet\t\trelative_path\t\tsize_kb")
    for poet, rel, sz in rows[:30]:
        print(f"{poet}\t{rel}\t{sz}")

sample_files(DATASET_DIR)

# Persist the config with the corrected path
PROJECT_DIR = Path("/content/nmt_urdu_roman")
cfg_path = PROJECT_DIR / "project_config.json"
try:
    with open(cfg_path, "r", encoding="utf-8") as f:
        cfg = json.load(f)
except Exception:
    cfg = {}

cfg.update({
    "dataset_dir": str(DATASET_DIR),
    "updated_at": datetime.utcnow().isoformat() + "Z",
})
with open(cfg_path, "w", encoding="utf-8") as f:
    json.dump(cfg, f, indent=2, ensure_ascii=False)

print("\n🔧 Config updated at:", cfg_path)


✅ DATASET_DIR set to: /content/drive/MyDrive/dataset
Found 30 top-level folders.

Sample of discovered files:
poet		relative_path		size_kb
ahmad-faraz	ahmad-faraz/.DS_Store	6.0
ahmad-faraz	ahmad-faraz/en	4.0
ahmad-faraz	ahmad-faraz/en/aankh-se-duur-na-ho-dil-se-utar-jaaegaa-ahmad-faraz-ghazals	0.46
akbar-allahabadi	akbar-allahabadi/en	4.0
akbar-allahabadi	akbar-allahabadi/en/aah-jo-dil-se-nikaalii-jaaegii-akbar-allahabadi-ghazals	0.38
akbar-allahabadi	akbar-allahabadi/en/aaj-aaraaish-e-gesuu-e-dotaa-hotii-hai-akbar-allahabadi-ghazals	1.22
allama-iqbal	allama-iqbal/.DS_Store	6.0
allama-iqbal	allama-iqbal/en	4.0
allama-iqbal	allama-iqbal/en/aalam-e-aab-o-khaak-o-baad-sirr-e-ayaan-hai-tuu-ki-main-allama-iqbal-ghazals	0.46
altaf-hussain-hali	altaf-hussain-hali/en	4.0
altaf-hussain-hali	altaf-hussain-hali/en/aage-badhe-na-qissa-e-ishq-e-butaan-se-ham-altaf-hussain-hali-ghazals-3	0.48
altaf-hussain-hali	altaf-hussain-hali/en/ab-vo-aglaa-saa-iltifaat-nahiin-altaf-hussain-hali-ghazals	0.49
ame

  "updated_at": datetime.utcnow().isoformat() + "Z",


In [4]:
# ============================================================
# Colab Cell #2 — Parse Rekhta dump → Urdu/Roman pairs (raw)
# ============================================================
from pathlib import Path
import os, io, re, json, unicodedata, itertools
import pandas as pd
from tqdm import tqdm

# Load config written earlier
import json
cfg_path = Path("/content/nmt_urdu_roman/project_config.json")
with open(cfg_path, "r", encoding="utf-8") as f:
    CFG = json.load(f)
PROJECT_DIR = Path(CFG["project_dir"])
DATASET_DIR = Path(CFG["dataset_dir"])
OUT_DIR = PROJECT_DIR / "data"
OUT_DIR.mkdir(parents=True, exist_ok=True)

print("Using DATASET_DIR =", DATASET_DIR)

# --------------------------
# 1) Normalization helpers
# --------------------------
ZW_CHARS = "".join([
    "\u200b", "\u200c", "\u200d", "\ufeff"  # ZW space, joiners, BOM
])

URDU_MAP = {
    # Common confusables → canonical
    "ي": "ی",   # Arabic yeh → Farsi yeh
    "ك": "ک",   # Arabic kaf → Farsi kaf
    "ہ": "ہ",   # already Urdu heh-doachashmee
    "ۀ": "ہ",   # heh with hamza → heh
    "ھ": "ہ",   # heh goal → heh (approx)
    "ۃ": "ہ",
    "أ": "ا", "إ": "ا", "آ": "آ", "ٱ": "ا",
    "ؤ": "و", "ئ": "ی",
    "ٔ": "", "ٰ": "", "ٌ": "", "ً": "", "ٍ": "", "ْ": "", "ّ": "",  # remove tashkeel
    "ـ": "",  # tatweel
}

PUNCT_KEEP = set(list("،؛؟!.,:;!?…—–-()[]{}\"'`’“”«»"))

def normalize_urdu(s: str) -> str:
    if not isinstance(s, str): return ""
    s = unicodedata.normalize("NFC", s)
    # Remove zero-width & BOM
    s = re.sub(f"[{re.escape(ZW_CHARS)}]", "", s)
    # Map confusables
    s = "".join(URDU_MAP.get(ch, ch) for ch in s)
    # Collapse spaces
    s = re.sub(r"\s+", " ", s).strip()
    return s

def normalize_roman(s: str) -> str:
    if not isinstance(s, str): return ""
    s = unicodedata.normalize("NFC", s)
    s = s.replace("’", "'").replace("“", '"').replace("”", '"')
    s = re.sub(f"[{re.escape(ZW_CHARS)}]", "", s)
    # unify long vowels a bit (gentle)
    s = re.sub(r"\baa\b", "aa", s, flags=re.I)
    s = re.sub(r"\bee\b", "ee", s, flags=re.I)
    s = re.sub(r"\boo\b", "oo", s, flags=re.I)
    # collapse spaces
    s = re.sub(r"\s+", " ", s).strip()
    return s

def looks_like_title(line: str) -> bool:
    # Heuristic: titles are often long slugs or very short with no Urdu letters
    urdu_chars = re.findall(r"[\u0600-\u06FF]", line)
    if not line: return True
    if len(urdu_chars) == 0 and len(line.split()) <= 3:
        return True
    return False

# --------------------------
# 2) File readers
# --------------------------
def read_text_file(p: Path) -> list[str]:
    """Read a text file with unknown extension; try utf-8 then latin-1 as fallback."""
    try_enc = ["utf-8", "utf-8-sig", "cp1256", "latin-1"]
    for enc in try_enc:
        try:
            with open(p, "r", encoding=enc, errors="ignore") as f:
                raw = f.read()
            break
        except Exception:
            continue
    else:
        return []

    lines = [ln.strip() for ln in raw.splitlines()]
    # remove empty & very short decoration lines
    lines = [ln for ln in lines if ln.strip() != ""]
    return lines

def collect_language_folder(poet_dir: Path, lang_code: str) -> dict[str, list[str]]:
    """
    Returns a dict: { ghazal_key -> list_of_lines } for given language subfolder.
    lang_code in {"ur","en","hi"} (en ~ English transliteration ~ Roman Urdu).
    Handles files WITH or WITHOUT extensions.
    """
    base = poet_dir / lang_code
    if not base.exists():
        return {}
    files = sorted([p for p in base.rglob("*") if p.is_file()])
    out = {}
    for f in files:
        key = f.name  # keep full filename (with/without ext) to maximize pairing
        # Also allow pairing by stem if we need later
        lines = read_text_file(f)
        # remove leading title-ish junk line if it looks like a header
        if lines and looks_like_title(lines[0]) and len(lines) > 1:
            lines = lines[1:]
        out[key] = lines
    return out

# --------------------------
# 3) Pairing Urdu ↔ Roman per ghazal file
# --------------------------
def align_lines(ur_lines: list[str], rom_lines: list[str]) -> list[tuple[str, str]]:
    """
    Align lines by index. If lengths differ, cut to min length.
    Apply normalization per line.
    Drop pairs that become empty after normalization.
    """
    n = min(len(ur_lines), len(rom_lines))
    pairs = []
    for i in range(n):
        ur = normalize_urdu(ur_lines[i])
        rom = normalize_roman(rom_lines[i])
        # filter trivial separators
        if ur and rom and not re.fullmatch(r"[-–—…\.\*]+", ur) and not re.fullmatch(r"[-–—…\.\*]+", rom):
            pairs.append((ur, rom))
    return pairs

def best_key_match(key: str, candidates: set[str]) -> str | None:
    """
    Try exact key, then stem match, then relaxed slug match.
    """
    if key in candidates:
        return key
    stem = Path(key).stem
    by_stem = {Path(c).stem: c for c in candidates}
    if stem in by_stem:
        return by_stem[stem]
    # relaxed: remove trailing numerals like -1/-2, etc.
    stem_relaxed = re.sub(r"[-_]{0,1}\d+$", "", stem)
    for s, full in by_stem.items():
        if re.sub(r"[-_]{0,1}\d+$", "", s) == stem_relaxed:
            return full
    return None

# --------------------------
# 4) Walk poets and build a dataframe of pairs
# --------------------------
rows = []
poets = sorted([d for d in DATASET_DIR.iterdir() if d.is_dir()])

for poet_dir in tqdm(poets, desc="Poets"):
    poet = poet_dir.name
    ur_map = collect_language_folder(poet_dir, "ur")
    en_map = collect_language_folder(poet_dir, "en")  # Roman Urdu (English transliteration)
    if not en_map:
        # Some dumps may use 'roman' instead of 'en'; try that too
        en_map = collect_language_folder(poet_dir, "roman")
    if not ur_map or not en_map:
        # Keep note but continue
        print(f"Note: skipping '{poet}' (ur={len(ur_map)}, en/roman={len(en_map)})")
        continue

    en_keys = set(en_map.keys())
    for key_ur, ur_lines in ur_map.items():
        key_en = best_key_match(key_ur, en_keys)
        if key_en is None:
            # attempt match by stem across all en files quickly
            continue
        rom_lines = en_map.get(key_en, [])
        pairs = align_lines(ur_lines, rom_lines)
        ghazal_id = Path(key_ur).stem
        for idx, (src_ur, tgt_rom) in enumerate(pairs):
            rows.append({
                "poet": poet,
                "ghazal_id": ghazal_id,
                "line_idx": idx,
                "src_ur": src_ur,
                "tgt_rom_gold": tgt_rom,  # gold transliteration when present
                "tgt_rom_rule": None,     # placeholder (we’ll fill later if needed)
            })

df = pd.DataFrame(rows)
print("\nCollected pairs:", len(df))
print("Unique poets parsed:", df["poet"].nunique() if not df.empty else 0)

# Basic sanity stats
if not df.empty:
    lens_src = df["src_ur"].str.len()
    lens_tgt = df["tgt_rom_gold"].str.len()
    print("Avg src len:", round(lens_src.mean(), 1), "| 95p:", int(lens_src.quantile(0.95)))
    print("Avg tgt len:", round(lens_tgt.mean(), 1), "| 95p:", int(lens_tgt.quantile(0.95)))

# Save raw (gold-present) pairs
RAW_PATH = OUT_DIR / "train_raw_gold.parquet"
if not df.empty:
    df.to_parquet(RAW_PATH, index=False)
    print("✅ Saved:", RAW_PATH)
else:
    print("⚠️ No aligned pairs found yet. We may need to handle alternative layouts.")

# Peek a few rows
print("\nSample rows:")
print(df.sample(min(10, len(df)), random_state=42) if not df.empty else "No data")


Using DATASET_DIR = /content/drive/MyDrive/dataset


Poets: 100%|██████████| 30/30 [27:45<00:00, 55.53s/it]



Collected pairs: 20983
Unique poets parsed: 30
Avg src len: 33.2 | 95p: 43
Avg tgt len: 40.3 | 95p: 52
✅ Saved: /content/nmt_urdu_roman/data/train_raw_gold.parquet

Sample rows:
                     poet                                          ghazal_id  \
10172        javed-akhtar  mujh-ko-yaqiin-hai-sach-kahtii-thiin-jo-bhii-a...   
14968        mirza-ghalib  maze-jahaan-ke-apnii-nazar-men-khaak-nahiin-mi...   
13406        mirza-ghalib  dost-gam-khvaarii-men-merii-saii-farmaavenge-k...   
3211   bahadur-shah-zafar  khvaah-kar-insaaf-zaalim-khvaah-kar-bedaad-tuu...   
6914     firaq-gorakhpuri  samajhtaa-huun-ki-tuu-mujh-se-judaa-hai-firaq-...   
18471          nida-fazli  safar-men-dhuup-to-hogii-jo-chal-sako-to-chalo...   
11114    jigar-moradabadi  kuchh-is-adaa-se-aaj-vo-pahluu-nashiin-rahe-ji...   
3725   bahadur-shah-zafar  zulf-jo-rukh-par-tire-ai-mehr-e-talat-khul-gai...   
19854  wali-mohammad-wali  aaj-distaa-hai-haal-kuchh-kaa-kuchh-wali-moham...   
10594    jigar-morada

In [5]:
# ============================================================
# Colab Cell #3 — Rule translit, unified corpus, 50/25/25 split,
#                  char vocabs saved to artifacts/
# ============================================================
from pathlib import Path
import json, re, math, random
import pandas as pd
from collections import Counter, defaultdict
from sklearn.model_selection import train_test_split

# Load config and paths
cfg = json.load(open("/content/nmt_urdu_roman/project_config.json", "r", encoding="utf-8"))
PROJECT_DIR = Path(cfg["project_dir"])
DATA_DIR = PROJECT_DIR / "data"
ARTI = PROJECT_DIR / "artifacts"
ARTI.mkdir(parents=True, exist_ok=True)

RAW_GOLD = DATA_DIR / "train_raw_gold.parquet"
assert RAW_GOLD.exists(), f"Missing: {RAW_GOLD}"
df = pd.read_parquet(RAW_GOLD)
print("Loaded gold pairs:", len(df))

# --------------------------
# 1) Rule-based Urdu→Roman
# --------------------------
# This is a light transliteration (not phonetic-perfect), good as a fallback.
UR2ROM = {
    "ا":"a", "آ":"aa", "أ":"a", "إ":"i", "ٱ":"a", "ء":"'", "ؤ":"o", "ئ":"i",
    "ب":"b", "پ":"p", "ت":"t", "ٹ":"t", "ث":"s", "ج":"j", "چ":"ch", "ح":"h",
    "خ":"kh", "د":"d", "ڈ":"d", "ذ":"z", "ر":"r", "ڑ":"r", "ز":"z", "ژ":"zh",
    "س":"s", "ش":"sh", "ص":"s", "ض":"z", "ط":"t", "ظ":"z", "ع":"'", "غ":"gh",
    "ف":"f", "ق":"q", "ک":"k", "گ":"g", "ل":"l", "م":"m", "ن":"n", "ں":"n",
    "و":"v",  # will handle vowels via context rules below
    "ہ":"h", "ھ":"h", "ء":"'", "ٔ":"", "ٰ":"", "ى":"a", "ی":"y", "ے":"e",
    "،":",", "؛":";", "؟":"?", "۔":".", "—":"-", "–":"-", "ْ":"", "ّ":""
}
# Digits and spaces
for d_ar, d_en in zip("٠١٢٣٤٥٦٧٨٩", "0123456789"):
    UR2ROM[d_ar] = d_en

# Simple vowel context rules
def urdu_to_roman_rule(text: str) -> str:
    if not isinstance(text, str): return ""
    out = []
    for i, ch in enumerate(text):
        rom = UR2ROM.get(ch)
        if rom is None:
            # Urdu whitespace / Latin punctuation passthrough
            if re.match(r"[\u0600-\u06FF]", ch):
                rom = ""  # unknown Urdu char → empty
            else:
                rom = ch
        out.append(rom)
    s = "".join(out)

    # Heuristics:
    #   - handle long vowels around y/w (ی/و became y/v)
    #   - basic combos
    s = re.sub(r"\bkh", "kh", s)  # keep
    s = re.sub(r"gh", "gh", s)

    # Normalize multiple apostrophes/spaces
    s = re.sub(r"'+", "'", s)
    s = re.sub(r"\s+", " ", s).strip()

    # Light tidy for v/w and y/i/u heuristics (very rough):
    s = re.sub(r"\bv", "w", s)  # often و is 'w' in roman urdu
    s = re.sub(r"([aeiou])y\b", r"\1i", s)
    return s

# Fill rule transliteration only where gold missing (future-proof)
needs_rule = df["tgt_rom_gold"].isna() | df["tgt_rom_gold"].eq("")
if needs_rule.any():
    df.loc[needs_rule, "tgt_rom_rule"] = df.loc[needs_rule, "src_ur"].apply(urdu_to_roman_rule)

# Choose final target: prefer gold, else rule
df["tgt_rom"] = df["tgt_rom_gold"].fillna("").replace("", pd.NA)
df["tgt_rom"] = df["tgt_rom"].fillna(df["tgt_rom_rule"].fillna(""))

# Drop empty/degenerate rows
before = len(df)
df = df[(df["src_ur"].str.len() >= 2) & (df["tgt_rom"].str.len() >= 2)]
df = df[~df["src_ur"].str.fullmatch(r"[-–—…\.\*]+")]
df = df[~df["tgt_rom"].str.fullmatch(r"[-–—…\.\*]+")]
after = len(df)
print(f"Filtered: {before} → {after}")

# --------------------------
# 2) Ghazal-level split (50/25/25)
#    to avoid same ghazal leaking across sets
# --------------------------
key = df["poet"].astype(str) + "§" + df["ghazal_id"].astype(str)
unique_gh = key.drop_duplicates().tolist()
random.Random(42).shuffle(unique_gh)

n = len(unique_gh)
n_train = int(0.50 * n)
n_val   = int(0.25 * n)
train_gh = set(unique_gh[:n_train])
val_gh   = set(unique_gh[n_train:n_train+n_val])
test_gh  = set(unique_gh[n_train+n_val:])

def assign_split(row):
    k = f"{row.poet}§{row.ghazal_id}"
    if k in train_gh: return "train"
    if k in val_gh:   return "val"
    return "test"

df["split"] = df.apply(assign_split, axis=1)

print(df["split"].value_counts())

# Save split files
for sp in ["train", "val", "test"]:
    outp = DATA_DIR / f"pairs_{sp}.parquet"
    df[df["split"] == sp][["poet","ghazal_id","line_idx","src_ur","tgt_rom"]].to_parquet(outp, index=False)
    print("✅ Saved:", outp)

# --------------------------
# 3) Build character-level vocabs
# --------------------------
SPECIAL = ["<pad>", "<s>", "</s>", "<unk>"]

def build_char_vocab(series: pd.Series, extra_chars=None):
    counter = Counter()
    for s in series.astype(str).tolist():
        for ch in s:
            counter[ch] += 1
    chars = sorted(counter.keys())
    if extra_chars:
        for ch in extra_chars:
            if ch not in chars: chars.append(ch)
    # index mapping
    itos = SPECIAL + chars
    stoi = {ch:i for i,ch in enumerate(itos)}
    meta = {"size": len(itos), "num_special": len(SPECIAL)}
    return {"itos": itos, "stoi": stoi, "meta": meta}

train_df = df[df["split"]=="train"]

src_vocab = build_char_vocab(train_df["src_ur"])
tgt_vocab = build_char_vocab(train_df["tgt_rom"])

with open(ARTI / "vocab_src_char.json", "w", encoding="utf-8") as f:
    json.dump(src_vocab, f, ensure_ascii=False, indent=2)
with open(ARTI / "vocab_tgt_char.json", "w", encoding="utf-8") as f:
    json.dump(tgt_vocab, f, ensure_ascii=False, indent=2)

print("\nVocab sizes — src:", src_vocab["meta"]["size"], "tgt:", tgt_vocab["meta"]["size"])
print("Special tokens:", SPECIAL)

# Quick peek
print("\nSamples:")
print(df[df["split"]=="train"][["src_ur","tgt_rom"]].head(5))


Loaded gold pairs: 20983
Filtered: 20983 → 20983
split
train    10493
test      5255
val       5235
Name: count, dtype: int64
✅ Saved: /content/nmt_urdu_roman/data/pairs_train.parquet
✅ Saved: /content/nmt_urdu_roman/data/pairs_val.parquet
✅ Saved: /content/nmt_urdu_roman/data/pairs_test.parquet

Vocab sizes — src: 50 tgt: 46
Special tokens: ['<pad>', '<s>', '</s>', '<unk>']

Samples:
                                    src_ur  \
24      اب اور کیا کسی سے مراسم بڑہاییں ہم   
25  یہ بہی بہت ہے تجہ کو اگر بہول جاییں ہم   
26      صحرایے زندگی میں کویی دوسرا نہ تہا   
27       سنتے رہے ہیں آپ ہی اپنی صداییں ہم   
28        اس زندگی میں اتنی فراغت کسے نصیب   

                                           tgt_rom  
24         ab aur kyā kisī se marāsim baḌhā.eñ ham  
25  ye bhī bahut hai tujh ko agar bhuul jaa.eñ ham  
26            sahrā-e-zindagī meñ koī dūsrā na thā  
27         sunte rahe haiñ aap hī apnī sadā.eñ ham  
28         is zindagī meñ itnī farāġhat kise nasīb  


In [10]:
# ============================================================
# Colab Cell #4 — Train SentencePiece (patched) + Dataset/Loaders
#   - Trains separate SPM models for src (Urdu) & tgt (Roman)
#   - Uses normalization_rule_name="identity" to avoid nmt_nfkc error
#   - Builds PyTorch Dataset (char/spm), DataLoaders, and saves exp config
# ============================================================
from pathlib import Path
import json, os, io, random
import pandas as pd
import numpy as np
import sentencepiece as spm
import torch
from torch.utils.data import Dataset, DataLoader

# ----- Paths & config
cfg = json.load(open("/content/nmt_urdu_roman/project_config.json","r",encoding="utf-8"))
PROJECT_DIR = Path(cfg["project_dir"])
DATA_DIR    = PROJECT_DIR / "data"
ARTI        = PROJECT_DIR / "artifacts"
ARTI.mkdir(parents=True, exist_ok=True)

TRAIN_PQ = DATA_DIR / "pairs_train.parquet"
VAL_PQ   = DATA_DIR / "pairs_val.parquet"
TEST_PQ  = DATA_DIR / "pairs_test.parquet"

df_train = pd.read_parquet(TRAIN_PQ)
df_val   = pd.read_parquet(VAL_PQ)
df_test  = pd.read_parquet(TEST_PQ)

# ============================================================
# (A) Train SentencePiece tokenizers (patched)
#     Separate models for src/tgt. Reserve: pad=0, bos=1, eos=2, unk=3
# ============================================================
SRC_TXT = ARTI / "spm_src_train.txt"
TGT_TXT = ARTI / "spm_tgt_train.txt"
SRC_MODEL = ARTI / "spm_src.model"
TGT_MODEL = ARTI / "spm_tgt.model"

if not SRC_TXT.exists():
    SRC_TXT.write_text("\n".join(df_train["src_ur"].astype(str).tolist()), encoding="utf-8")
if not TGT_TXT.exists():
    TGT_TXT.write_text("\n".join(df_train["tgt_rom"].astype(str).tolist()), encoding="utf-8")

SPM_SRC_VOCAB = 2000
SPM_TGT_VOCAB = 2000

def _clean_partial_models(model_prefix_str: str):
    # delete tiny/corrupt artifacts to allow clean retrain
    for ext in (".model", ".vocab"):
        f = Path(model_prefix_str + ext)
        if f.exists() and f.stat().st_size < 1024:
            f.unlink(missing_ok=True)

def train_spm(input_path, model_path, vocab_size, character_coverage=0.9995, model_type="bpe"):
    model_prefix = str(Path(model_path).with_suffix(""))
    _clean_partial_models(model_prefix)
    if Path(model_prefix + ".model").exists():
        return  # already trained

    # Key fix: normalization_rule_name="identity" (prevents 'nmt_nfkc' missing error)
    spm.SentencePieceTrainer.Train(
        input=str(input_path),
        model_prefix=model_prefix,
        vocab_size=int(vocab_size),
        model_type=model_type,
        character_coverage=float(character_coverage),
        pad_id=0, bos_id=1, eos_id=2, unk_id=3,
        input_sentence_size=1_000_000,
        shuffle_input_sentence=True,
        normalization_rule_name="identity",
    )

# Train patched SPM models
train_spm(SRC_TXT, SRC_MODEL, SPM_SRC_VOCAB, character_coverage=0.9995, model_type="bpe")
train_spm(TGT_TXT, TGT_MODEL, SPM_TGT_VOCAB, character_coverage=0.9995, model_type="bpe")

# Load processors
sp_src = spm.SentencePieceProcessor()
sp_tgt = spm.SentencePieceProcessor()
sp_src.load(str(SRC_MODEL))
sp_tgt.load(str(TGT_MODEL))

print("SPM src size:", sp_src.get_piece_size(), "| tgt size:", sp_tgt.get_piece_size())

# ============================================================
# (B) Tokenization utilities
#     - CHAR mode uses vocab jsons from Cell #3
#     - SPM mode uses the SentencePiece models above
# ============================================================
SPECIAL = ["<pad>", "<s>", "</s>", "<unk>"]
v_src_char = json.load(open(ARTI / "vocab_src_char.json", "r", encoding="utf-8"))
v_tgt_char = json.load(open(ARTI / "vocab_tgt_char.json", "r", encoding="utf-8"))
stoi_src_char = v_src_char["stoi"]; itos_src_char = v_src_char["itos"]
stoi_tgt_char = v_tgt_char["stoi"]; itos_tgt_char = v_tgt_char["itos"]

PAD_ID = 0; BOS_ID = 1; EOS_ID = 2; UNK_ID = 3

def encode_char_src(s: str):
    ids = [BOS_ID]
    for ch in s:
        ids.append(stoi_src_char.get(ch, UNK_ID))
    ids.append(EOS_ID)
    return ids

def encode_char_tgt(s: str):
    ids = [BOS_ID]
    for ch in s:
        ids.append(stoi_tgt_char.get(ch, UNK_ID))
    ids.append(EOS_ID)
    return ids

def encode_spm_src(s: str):
    return [BOS_ID] + sp_src.encode(s, out_type=int, enable_sampling=False) + [EOS_ID]

def encode_spm_tgt(s: str):
    return [BOS_ID] + sp_tgt.encode(s, out_type=int, enable_sampling=False) + [EOS_ID]

# ============================================================
# (C) PyTorch Dataset + Collate (supports 'char' or 'spm')
# ============================================================
class NMTDataset(Dataset):
    def __init__(self, df: pd.DataFrame, mode="char"):
        self.df = df.reset_index(drop=True)
        self.mode = mode
        assert mode in ("char", "spm")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        src = str(row["src_ur"])
        tgt = str(row["tgt_rom"])
        if self.mode == "char":
            src_ids = encode_char_src(src)
            tgt_ids = encode_char_tgt(tgt)
        else:
            src_ids = encode_spm_src(src)
            tgt_ids = encode_spm_tgt(tgt)
        return {
            "src_ids": torch.tensor(src_ids, dtype=torch.long),
            "tgt_ids": torch.tensor(tgt_ids, dtype=torch.long),
            "len_src": len(src_ids),
            "len_tgt": len(tgt_ids),
        }

def collate_fn(batch):
    # Sort by src length (desc) for minor efficiency
    batch = sorted(batch, key=lambda x: x["len_src"], reverse=True)
    src_lens = [b["len_src"] for b in batch]
    tgt_lens = [b["len_tgt"] for b in batch]
    max_src = max(src_lens)
    max_tgt = max(tgt_lens)

    def pad_seq(seq, max_len):
        out = torch.full((max_len,), PAD_ID, dtype=torch.long)
        out[:len(seq)] = seq
        return out

    src_pad = torch.stack([pad_seq(b["src_ids"], max_src) for b in batch], dim=0)
    tgt_pad = torch.stack([pad_seq(b["tgt_ids"], max_tgt) for b in batch], dim=0)

    return {
        "src_ids": src_pad,    # [B, Tsrc]
        "tgt_ids": tgt_pad,    # [B, Ttgt]
        "src_lens": torch.tensor(src_lens, dtype=torch.long),
        "tgt_lens": torch.tensor(tgt_lens, dtype=torch.long),
    }

# ============================================================
# (D) Build DataLoaders (both modes so you can A/B later)
# ============================================================
def make_loaders(mode="char", batch_size=64, num_workers=2, shuffle_train=True):
    ds_train = NMTDataset(df_train, mode=mode)
    ds_val   = NMTDataset(df_val,   mode=mode)
    ds_test  = NMTDataset(df_test,  mode=mode)

    # pin_memory only helps on CUDA; harmless on CPU but can warn
    dl_train = DataLoader(ds_train, batch_size=batch_size, shuffle=shuffle_train,
                          num_workers=num_workers, collate_fn=collate_fn, pin_memory=torch.cuda.is_available())
    dl_val   = DataLoader(ds_val,   batch_size=batch_size, shuffle=False,
                          num_workers=num_workers, collate_fn=collate_fn, pin_memory=torch.cuda.is_available())
    dl_test  = DataLoader(ds_test,  batch_size=batch_size, shuffle=False,
                          num_workers=num_workers, collate_fn=collate_fn, pin_memory=torch.cuda.is_available())
    return ds_train, ds_val, ds_test, dl_train, dl_val, dl_test

# Quick smoke test (char mode)
_, _, _, dl_tr_char, dl_v_char, _ = make_loaders(mode="char", batch_size=32)
batch_char = next(iter(dl_tr_char))
print("CHAR batch shapes:",
      batch_char["src_ids"].shape, batch_char["tgt_ids"].shape)

# Quick smoke test (spm mode)
_, _, _, dl_tr_spm, dl_v_spm, _ = make_loaders(mode="spm", batch_size=32)
batch_spm = next(iter(dl_tr_spm))
print("SPM  batch shapes:",
      batch_spm["src_ids"].shape, batch_spm["tgt_ids"].shape)

# ============================================================
# (E) Save a default experiment config
# ============================================================
exp_cfg = {
    "tokenization": "char",   # or "spm"
    "embedding_dim": 256,
    "hidden_size": 256,
    "enc_layers": 2,
    "dec_layers": 4,
    "dropout": 0.3,
    "learning_rate": 5e-4,
    "batch_size": 64,
    "teacher_forcing_start": 1.0,
    "teacher_forcing_end": 0.5,
    "epochs": 20,
    "grad_clip": 1.0,
    "beam_size": 5
}
with open(ARTI / "exp_default.json","w",encoding="utf-8") as f:
    json.dump(exp_cfg, f, indent=2)
print("Saved default exp config:", ARTI / "exp_default.json")


SPM src size: 2000 | tgt size: 2000
CHAR batch shapes: torch.Size([32, 41]) torch.Size([32, 49])
SPM  batch shapes: torch.Size([32, 17]) torch.Size([32, 25])
Saved default exp config: /content/nmt_urdu_roman/artifacts/exp_default.json


In [11]:
# ============================================================
# Colab Cell #5 — BiLSTM Encoder + 4-Layer LSTM Decoder (Luong)
#   • Restart-safe: redefines everything needed for training
#   • Works with tokenization = "char" or "spm"
#   • Trains, evaluates on val (loss, ppl, BLEU, CER), saves best
# ============================================================
import os, json, math, time, numpy as np
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import sentencepiece as spm
import sacrebleu
from jiwer import cer as jiwer_cer

# ---------- Paths / config ----------
cfg = json.load(open("/content/nmt_urdu_roman/project_config.json","r",encoding="utf-8"))
PROJECT_DIR = Path(cfg["project_dir"])
ARTI        = PROJECT_DIR / "artifacts"
MODELS_DIR  = PROJECT_DIR / "models"
MODELS_DIR.mkdir(parents=True, exist_ok=True)

# Load experiment base (you can override when calling train_model)
exp = json.load(open(ARTI / "exp_default.json","r",encoding="utf-8"))

# Special tokens
PAD_ID = 0; BOS_ID = 1; EOS_ID = 2; UNK_ID = 3

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# ---------- Bring in loaders from Cell #4 (and validate) ----------
try:
    make_loaders
except NameError:
    raise RuntimeError("❌ make_loaders is undefined. Please re-run Cell #4 first.")

def get_loaders(mode, batch):
    return make_loaders(mode=mode, batch_size=batch, num_workers=2, shuffle_train=True)

# ---------- Tokenizers / decoders ----------
def _load_tokenizers(tokenization):
    if tokenization == "char":
        v_src = json.load(open(ARTI / "vocab_src_char.json","r",encoding="utf-8"))
        v_tgt = json.load(open(ARTI / "vocab_tgt_char.json","r",encoding="utf-8"))
        SRC_VSIZE = len(v_src["itos"]); TGT_VSIZE = len(v_tgt["itos"])
        itos_tgt  = v_tgt["itos"]
        def decode_char(ids):
            out=[]
            for i in ids:
                if i in (PAD_ID, BOS_ID, EOS_ID): continue
                out.append(itos_tgt[i] if 0 <= i < len(itos_tgt) else "")
            return "".join(out).strip()
        return {"mode":"char","SRC_V":SRC_VSIZE,"TGT_V":TGT_VSIZE,
                "decode":decode_char, "sp_src":None, "sp_tgt":None, "v_tgt":v_tgt}
    else:
        sp_src = spm.SentencePieceProcessor(); sp_src.load(str(ARTI / "spm_src.model"))
        sp_tgt = spm.SentencePieceProcessor(); sp_tgt.load(str(ARTI / "spm_tgt.model"))
        SRC_VSIZE = sp_src.get_piece_size(); TGT_VSIZE = sp_tgt.get_piece_size()
        def decode_spm(ids):
            ids = [i for i in ids if i not in (PAD_ID, BOS_ID, EOS_ID)]
            try: return sp_tgt.decode(ids).strip()
            except: return sp_tgt.decode_pieces([sp_tgt.id_to_piece(i) for i in ids]).strip()
        return {"mode":"spm","SRC_V":SRC_VSIZE,"TGT_V":TGT_VSIZE,
                "decode":decode_spm, "sp_src":sp_src, "sp_tgt":sp_tgt, "v_tgt":None}

# ---------- Model ----------
class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_size, num_layers=2, dropout=0.3, bidirectional=True):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD_ID)
        self.bilstm = nn.LSTM(
            emb_dim, hid_size, num_layers=num_layers, dropout=dropout,
            bidirectional=bidirectional, batch_first=True
        )
        self.dropout = nn.Dropout(dropout)
        self.bidirectional = bidirectional
        self.hid_size = hid_size
        self.num_layers = num_layers

    def forward(self, src_ids):
        emb = self.dropout(self.embedding(src_ids))          # [B, T, E]
        outputs, (hn, cn) = self.bilstm(emb)                 # outputs: [B, T, 2H] if bi
        return outputs, (hn, cn)

class LuongAttention(nn.Module):
    # general: s_t^T * W * h_i
    def __init__(self, dec_hid, enc_hid_bi):
        super().__init__()
        self.W = nn.Linear(dec_hid, enc_hid_bi, bias=False)
    def forward(self, dec_h_t, enc_outputs, mask=None):
        # dec_h_t: [B,H], enc_outputs: [B,Tsrc,Henc]
        score = torch.bmm(self.W(dec_h_t).unsqueeze(1), enc_outputs.transpose(1,2))  # [B,1,Tsrc]
        if mask is not None:
            score = score.masked_fill(mask.unsqueeze(1), -1e9)
        attn = torch.softmax(score, dim=-1)       # [B,1,Tsrc]
        ctx  = torch.bmm(attn, enc_outputs).squeeze(1)  # [B,Henc]
        return ctx, attn.squeeze(1)

class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_size, enc_hid_bi, num_layers=4, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD_ID)
        self.input_size = emb_dim + enc_hid_bi
        self.lstm = nn.LSTM(self.input_size, hid_size, num_layers=num_layers,
                            dropout=dropout, batch_first=True)
        self.attn = LuongAttention(hid_size, enc_hid_bi)
        self.fc_out = nn.Linear(hid_size + enc_hid_bi, vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.num_layers = num_layers

    def forward_step(self, y_prev, h_c, enc_outputs, enc_pad_mask, ctx_prev):
        # y_prev: [B]
        emb = self.dropout(self.embedding(y_prev))           # [B,E]
        lstm_in = torch.cat([emb, ctx_prev], dim=-1).unsqueeze(1)  # [B,1,E+Henc]
        out, h_c = self.lstm(lstm_in, h_c)                   # out: [B,1,Hdec]
        h_t = out.squeeze(1)                                 # [B,Hdec]
        ctx_t, attn = self.attn(h_t, enc_outputs, enc_pad_mask)    # [B,Henc]
        logits = self.fc_out(torch.cat([h_t, ctx_t], dim=-1))      # [B,V]
        return logits, h_c, ctx_t, attn

    # explicit forward so Module is happy
    def forward(self, y_prev, h_c, enc_outputs, enc_pad_mask, ctx_prev):
        return self.forward_step(y_prev, h_c, enc_outputs, enc_pad_mask, ctx_prev)

class Bridge(nn.Module):
    def __init__(self, enc_hid, dec_hid, dec_layers, bidirectional=True):
        super().__init__()
        mul = 2 if bidirectional else 1
        self.h_proj = nn.Linear(enc_hid*mul, dec_hid)
        self.c_proj = nn.Linear(enc_hid*mul, dec_hid)
        self.dec_layers = dec_layers
    def forward(self, enc_hn, enc_cn):
        # enc_hn: [L*mul, B, H]
        top_h = torch.cat([enc_hn[-2], enc_hn[-1]], dim=-1)  # [B, 2H]
        top_c = torch.cat([enc_cn[-2], enc_cn[-1]], dim=-1)  # [B, 2H]
        h0 = torch.tanh(self.h_proj(top_h))                  # [B,Hd]
        c0 = torch.tanh(self.c_proj(top_c))                  # [B,Hd]
        h0 = h0.unsqueeze(0).repeat(self.dec_layers, 1, 1)   # [Ldec,B,Hd]
        c0 = c0.unsqueeze(0).repeat(self.dec_layers, 1, 1)
        return (h0, c0)

class Seq2Seq(nn.Module):
    def __init__(self, src_vsize, tgt_vsize, emb_dim, hid_size, enc_layers, dec_layers, dropout=0.3):
        super().__init__()
        self.encoder = Encoder(src_vsize, emb_dim, hid_size, enc_layers, dropout, bidirectional=True)
        enc_hid_bi = hid_size * 2
        self.decoder = Decoder(tgt_vsize, emb_dim, hid_size, enc_hid_bi, dec_layers, dropout)
        self.bridge  = Bridge(hid_size, hid_size, dec_layers, bidirectional=True)

    def forward(self, src_ids, tgt_ids, teacher_forcing=1.0):
        # src_ids: [B,Tsrc], tgt_ids: [B,Ttgt]
        B, Tt = tgt_ids.size()
        enc_outputs, (hn, cn) = self.encoder(src_ids)            # enc_outputs: [B,Tsrc,2H]
        enc_pad_mask = (src_ids == PAD_ID)

        dec_hc = self.bridge(hn, cn)
        ctx = torch.zeros(src_ids.size(0), enc_outputs.size(-1), device=src_ids.device)
        logits_all = []
        y_prev = tgt_ids[:,0]  # BOS

        for t in range(1, Tt):
            use_tf = (np.random.rand() < teacher_forcing)
            logits, dec_hc, ctx, _ = self.decoder(y_prev, dec_hc, enc_outputs, enc_pad_mask, ctx)
            logits_all.append(logits.unsqueeze(1))
            y_prev = tgt_ids[:,t] if use_tf else torch.argmax(logits, dim=-1)

        return torch.cat(logits_all, dim=1)  # [B, Tt-1, V]

# ---------- Training / Evaluation ----------
def sequence_nll(logits, tgt_ids):
    # logits: [B,T-1,V], tgt_ids: [B,T]
    B, Tm1, V = logits.shape
    gold = tgt_ids[:,1:1+Tm1]                      # [B,T-1]
    loss = F.cross_entropy(
        logits.reshape(B*Tm1, V),
        gold.reshape(B*Tm1),
        ignore_index=PAD_ID
    )
    return loss

def greedy_decode(model, batch, max_len=150):
    model.eval()
    with torch.no_grad():
        src = batch["src_ids"].to(DEVICE)
        enc_outputs, (hn, cn) = model.encoder(src)
        enc_mask = (src == PAD_ID)
        dec_hc = model.bridge(hn, cn)
        ctx = torch.zeros(src.size(0), enc_outputs.size(-1), device=src.device)
        y_prev = torch.full((src.size(0),), BOS_ID, dtype=torch.long, device=src.device)
        outs = []
        for _ in range(max_len):
            logits, dec_hc, ctx, _ = model.decoder(y_prev, dec_hc, enc_outputs, enc_mask, ctx)
            y_prev = torch.argmax(logits, dim=-1)
            outs.append(y_prev.unsqueeze(1))
        return torch.cat(outs, dim=1)  # [B, L]

def _ids_to_strs_tgt(batch_ids, decode_fn):
    # batch_ids: Tensor [B,L] or list[list[int]]
    arr = batch_ids.cpu().tolist() if isinstance(batch_ids, torch.Tensor) else batch_ids
    return [decode_fn(ids) for ids in arr]

def eval_on_loader(model, dl, decode_fn, mode):
    model.eval()
    total_loss, total_tokens = 0.0, 0
    preds_text, refs_text = [], []
    with torch.no_grad():
        for batch in dl:
            src = batch["src_ids"].to(DEVICE)
            tgt = batch["tgt_ids"].to(DEVICE)

            # teacher-forced loss
            logits = model(src, tgt, teacher_forcing=1.0)
            loss = sequence_nll(logits, tgt)

            gold = tgt[:, 1:1+logits.size(1)]
            ntok = (gold != PAD_ID).sum().item()
            total_loss += loss.item() * ntok
            total_tokens += ntok

            # greedy decode for metrics
            gen = greedy_decode(model, batch, max_len=tgt.size(1)-1)
            preds_text.extend(_ids_to_strs_tgt(gen, decode_fn))
            refs_text.extend(_ids_to_strs_tgt(tgt, decode_fn))

    avg_nll = total_loss / max(total_tokens, 1)
    ppl = math.exp(avg_nll)
    bleu = sacrebleu.corpus_bleu(preds_text, [refs_text]).score
    cer_scores = [jiwer_cer(r, p) for p, r in zip(preds_text, refs_text)]
    cer_mean = float(np.mean(cer_scores)) if cer_scores else 1.0
    return avg_nll, ppl, bleu, cer_mean, preds_text[:5], refs_text[:5]

def train_model(exp_overrides=None):
    # Merge overrides with default exp
    cfg_local = dict(exp)
    if exp_overrides: cfg_local.update(exp_overrides)

    # Tokenizers & decode function
    tok = _load_tokenizers(cfg_local["tokenization"])
    SRC_VSIZE, TGT_VSIZE = tok["SRC_V"], tok["TGT_V"]
    decode_fn = tok["decode"]

    # Data
    _, _, _, dl_train, dl_val, _ = get_loaders(cfg_local["tokenization"], batch=cfg_local["batch_size"])

    # Model + opt
    model = Seq2Seq(
        src_vsize=SRC_VSIZE, tgt_vsize=TGT_VSIZE,
        emb_dim=cfg_local["embedding_dim"], hid_size=cfg_local["hidden_size"],
        enc_layers=cfg_local["enc_layers"], dec_layers=cfg_local["dec_layers"],
        dropout=cfg_local["dropout"]
    ).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=cfg_local["learning_rate"])

    # Schedules
    EPOCHS   = cfg_local["epochs"]
    TF_START = cfg_local["teacher_forcing_start"]
    TF_END   = cfg_local["teacher_forcing_end"]
    def tf_ratio(epoch_idx):
        if EPOCHS <= 1: return TF_END
        return TF_START + (TF_END - TF_START) * (epoch_idx / (EPOCHS-1))

    best_bleu, best_path = -1.0, None

    for ep in range(1, EPOCHS+1):
        model.train()
        t0 = time.time()
        tf = tf_ratio(ep-1)
        total_loss, total_tok = 0.0, 0

        for batch in dl_train:
            src = batch["src_ids"].to(DEVICE)
            tgt = batch["tgt_ids"].to(DEVICE)
            logits = model(src, tgt, teacher_forcing=tf)
            loss = sequence_nll(logits, tgt)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            clip_grad_norm_(model.parameters(), cfg_local["grad_clip"])
            opt.step()

            gold = tgt[:,1:1+logits.size(1)]
            ntok = (gold != PAD_ID).sum().item()
            total_loss += loss.item() * ntok
            total_tok += ntok

        train_nll = total_loss / max(total_tok,1)
        train_ppl = math.exp(train_nll)

        # Validate
        val_nll, val_ppl, val_bleu, val_cer, samp_pred, samp_ref = eval_on_loader(model, dl_val, decode_fn, tok["mode"])

        dt = time.time()-t0
        print(f"[Ep {ep:02d}] tf={tf:.2f} | train ppl={train_ppl:.2f} | "
              f"val ppl={val_ppl:.2f} | BLEU={val_bleu:.2f} | CER={val_cer:.3f} | {dt:.1f}s")

        # Save best by BLEU
        if val_bleu > best_bleu:
            best_bleu = val_bleu
            label = f"{tok['mode']}_E{cfg_local['embedding_dim']}_H{cfg_local['hidden_size']}_enc{cfg_local['enc_layers']}_dec{cfg_local['dec_layers']}_drop{cfg_local['dropout']}"
            best_path = MODELS_DIR / f"bilstm4lstm_{label}_best.pt"
            torch.save({
                "model_state": model.state_dict(),
                "exp": cfg_local,
                "bleu": float(best_bleu),
                "tokenization": tok["mode"]
            }, best_path)
            print("  ↳ Saved best:", best_path)

        # A couple of samples
        for i in range(min(2, len(samp_pred))):
            print(f"  pred: {samp_pred[i]}")
            print(f"  ref : {samp_ref[i]}")

    print("\nBest BLEU:", best_bleu, " | ckpt:", best_path)
    return best_path

print("✅ Cell #5 ready. Starting a quick run …")

# --- Kick a quick baseline (optional: reduce epochs for smoke test) ---
# exp["epochs"] = 3  # uncomment for a super-fast check
best_ckpt = train_model()   # uses exp_default.json settings (char, E=256,H=256,…)
print("Done.")


Device: cuda
✅ Cell #5 ready. Starting a quick run …
[Ep 01] tf=1.00 | train ppl=14.27 | val ppl=4.59 | BLEU=0.56 | CER=0.665 | 49.7s
  ↳ Saved best: /content/nmt_urdu_roman/models/bilstm4lstm_char_E256_H256_enc2_dec4_drop0.3_best.pt
  pred: taaaah raaā ke ke mo to mai mhhī bahā hoñ hoñ ho hai
  ref : taaza rifāqat ke mausam tak maiñ bhī jiyā huuñ vo bhī jiyā hai
  pred: ham na sss ke shhrr ko hhhhā āo no no nank lanā haiā haiā ha
  ref : ham ne us ke shahr ko chhoḌā aur āñkhoñ ko muuñd liyā hai
[Ep 02] tf=0.97 | train ppl=3.49 | val ppl=2.02 | BLEU=10.89 | CER=0.426 | 49.0s
  ↳ Saved best: /content/nmt_urdu_roman/models/bilstm4lstm_char_E256_H256_enc2_dec4_drop0.3_best.pt
  pred: taāza rafāqt ke muusm tak meñ bhī jhī jayñ joñ voā joā joā j
  ref : taaza rifāqat ke mausam tak maiñ bhī jiyā huuñ vo bhī jiyā hai
  pred: ham ne as ke shhhr ko chhvā āo aankhoñ ko maundd layā laidd la
  ref : ham ne us ke shahr ko chhoḌā aur āñkhoñ ko muuñd liyā hai
[Ep 03] tf=0.95 | train ppl=2.08 | val pp

In [14]:
# ==========================================
# Cell 5.1 — EOS-safe greedy decoding
#   • Stops per sequence on EOS
#   • Returns sequences trimmed to EOS
#   • Optional simple no-repeat-3-gram guard
# ==========================================
import torch

PAD_ID = 0; BOS_ID = 1; EOS_ID = 2

def _has_repeat_ngram(ids, n=3):
    if len(ids) < 2*n:
        return False
    last = tuple(ids[-n:])
    for i in range(len(ids)-n*2+1):
        if tuple(ids[i:i+n]) == last:
            return True
    return False

def greedy_decode(model, batch, max_len=200, no_repeat_ngram=0):
    """
    Returns a LongTensor [B, L] where each row is terminated by EOS (if produced)
    and padded with EOS to equal length L.
    """
    model.eval()
    with torch.no_grad():
        src = batch["src_ids"].to(next(model.parameters()).device)
        enc_outputs, (hn, cn) = model.encoder(src)
        enc_mask = (src == PAD_ID)
        dec_hc = model.bridge(hn, cn)
        ctx = torch.zeros(src.size(0), enc_outputs.size(-1), device=src.device)
        y_prev = torch.full((src.size(0),), BOS_ID, dtype=torch.long, device=src.device)

        B = src.size(0)
        finished = torch.zeros(B, dtype=torch.bool, device=src.device)
        outputs = [[] for _ in range(B)]

        for _ in range(max_len):
            logits, dec_hc, ctx, _ = model.decoder(y_prev, dec_hc, enc_outputs, enc_mask, ctx)

            # simple no-repeat-ngram guard (char/subword)
            if no_repeat_ngram and no_repeat_ngram > 1:
                for b in range(B):
                    if not finished[b] and len(outputs[b]) >= no_repeat_ngram-1:
                        # block last (n-1) history repeating as a new n-gram
                        # (cheap heuristic: downweight the top-1 if it causes repeat)
                        top1 = torch.argmax(logits[b])
                        if _has_repeat_ngram(outputs[b] + [int(top1)], n=no_repeat_ngram):
                            # pick 2nd best instead
                            top2 = torch.topk(logits[b], k=2).indices[1]
                            logits[b, top1] = -1e9
                            logits[b, top2] += 1e-3  # tiny nudge

            y_prev = torch.argmax(logits, dim=-1)

            for b in range(B):
                if finished[b]:
                    continue
                token = int(y_prev[b])
                outputs[b].append(token)
                if token == EOS_ID:
                    finished[b] = True

            if finished.all():
                break

        # pad with EOS to rectangular tensor
        max_out = max(len(x) for x in outputs) if outputs else 1
        out_tensor = torch.full((B, max_out), EOS_ID, dtype=torch.long, device=src.device)
        for b, seq in enumerate(outputs):
            if len(seq):
                out_tensor[b, :len(seq)] = torch.tensor(seq, dtype=torch.long, device=src.device)
        return out_tensor


In [15]:
# ============================================================
# Colab Cell #6 — Load BEST ckpt → Test metrics + translate()
#   (standalone; assumes Cells 3–5 are already run)
# ============================================================
import os, json, math, random
from pathlib import Path
import numpy as np
import pandas as pd
import torch, sentencepiece as spm
import sacrebleu
from jiwer import cer as jiwer_cer

# ---- paths / device
cfg = json.load(open("/content/nmt_urdu_roman/project_config.json","r",encoding="utf-8"))
PROJECT_DIR = Path(cfg["project_dir"])
ARTI        = PROJECT_DIR / "artifacts"
MODELS_DIR  = PROJECT_DIR / "models"
DATA_DIR    = PROJECT_DIR / "data"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def find_best_ckpt():
    cands = sorted(MODELS_DIR.glob("bilstm4lstm_*_best.pt"), key=os.path.getmtime, reverse=True)
    assert cands, "No *_best.pt found. Train first."
    return cands[0]

best_path = find_best_ckpt()
ckpt = torch.load(best_path, map_location=DEVICE)
exp_used = ckpt["exp"]
print("Loaded:", best_path.name, "| tokenization:", ckpt.get("tokenization"), "| BLEU(val):", ckpt.get("bleu"))

# ---- Seq2Seq & greedy_decode must exist from Cell #5
assert "Seq2Seq" in globals() and "greedy_decode" in globals(), "Please re-run Cell #5."

# ---- vocab sizes + decode
PAD_ID = 0; BOS_ID = 1; EOS_ID = 2; UNK_ID = 3

if exp_used["tokenization"] == "char":
    v_src = json.load(open(ARTI / "vocab_src_char.json","r",encoding="utf-8"))
    v_tgt = json.load(open(ARTI / "vocab_tgt_char.json","r",encoding="utf-8"))
    SRC_V = len(v_src["itos"]); TGT_V = len(v_tgt["itos"])
    itos_tgt = v_tgt["itos"]
    def decode_fn(ids):
        if isinstance(ids, torch.Tensor): ids = ids.tolist()
        return "".join(itos_tgt[i] for i in ids if i not in (PAD_ID,BOS_ID,EOS_ID)).strip()
else:
    sp_src = spm.SentencePieceProcessor(); sp_src.load(str(ARTI / "spm_src.model"))
    sp_tgt = spm.SentencePieceProcessor(); sp_tgt.load(str(ARTI / "spm_tgt.model"))
    SRC_V = sp_src.get_piece_size(); TGT_V = sp_tgt.get_piece_size()
    def decode_fn(ids):
        if isinstance(ids, torch.Tensor): ids = ids.tolist()
        ids = [i for i in ids if i not in (PAD_ID,BOS_ID,EOS_ID)]
        try: return sp_tgt.decode(ids).strip()
        except: return sp_tgt.decode_pieces([sp_tgt.id_to_piece(i) for i in ids]).strip()

# ---- rebuild model AFTER we know SRC_V/TGT_V
model = Seq2Seq(SRC_V, TGT_V,
                exp_used["embedding_dim"], exp_used["hidden_size"],
                exp_used["enc_layers"],   exp_used["dec_layers"],
                exp_used["dropout"]).to(DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()

# ---- test loader via make_loaders from Cell #4
assert "make_loaders" in globals(), "Please re-run Cell #4."
_, _, _, _, _, dl_test = make_loaders(mode=exp_used["tokenization"], batch_size=exp_used["batch_size"])

def eval_on_loader(model, dl):
    preds_text, refs_text = [], []
    total_loss, total_tokens = 0.0, 0
    with torch.no_grad():
        for batch in dl:
            src = batch["src_ids"].to(DEVICE)
            tgt = batch["tgt_ids"].to(DEVICE)

            # loss (teacher-forced)
            logits = model(src, tgt, teacher_forcing=1.0)
            B, Tm1, V = logits.shape
            gold = tgt[:,1:1+Tm1]
            loss = torch.nn.functional.cross_entropy(
                logits.reshape(B*Tm1, V),
                gold.reshape(B*Tm1),
                ignore_index=PAD_ID
            )
            ntok = (gold != PAD_ID).sum().item()
            total_loss += loss.item() * ntok
            total_tokens += ntok

            # greedy decode
            gen = greedy_decode(model, batch, max_len=tgt.size(1)-1)
            preds_text.extend([decode_fn(row) for row in gen])
            refs_text.extend([decode_fn(row) for row in tgt])

    avg_nll = total_loss / max(total_tokens,1)
    ppl = math.exp(avg_nll)
    bleu = sacrebleu.corpus_bleu(preds_text, [refs_text]).score
    cer_scores = [jiwer_cer(r, p) for p, r in zip(preds_text, refs_text)]
    cer_mean = float(np.mean(cer_scores)) if cer_scores else 1.0
    return ppl, bleu, cer_mean

test_ppl, test_bleu, test_cer = eval_on_loader(model, dl_test)
print(f"\n=== TEST ===\nPPL: {test_ppl:.3f} | BLEU: {test_bleu:.2f} | CER: {test_cer:.3f}")

# ---- simple translator
def translate(texts, max_len=200):
    batch = {"src_ids": [], "tgt_ids": []}
    if exp_used["tokenization"] == "char":
        stoi = json.load(open(ARTI / "vocab_src_char.json","r",encoding="utf-8"))["stoi"]
        for t in texts:
            ids = [BOS_ID] + [stoi.get(ch, UNK_ID) for ch in t] + [EOS_ID]
            batch["src_ids"].append(torch.tensor(ids, dtype=torch.long))
            batch["tgt_ids"].append(torch.tensor([BOS_ID, EOS_ID], dtype=torch.long))
    else:
        for t in texts:
            ids = [BOS_ID] + sp_src.encode(t, out_type=int) + [EOS_ID]
            batch["src_ids"].append(torch.tensor(ids, dtype=torch.long))
            batch["tgt_ids"].append(torch.tensor([BOS_ID, EOS_ID], dtype=torch.long))

    def pad(lst):
        m = max(len(x) for x in lst)
        out = torch.full((len(lst), m), PAD_ID, dtype=torch.long)
        for i,x in enumerate(lst): out[i,:len(x)] = x
        return out

    batch = {k: pad(v).to(DEVICE) for k,v in batch.items()}
    gen = greedy_decode(model, batch, max_len=max_len)
    return [decode_fn(row) for row in gen]

# demo on a few test lines
test_df = pd.read_parquet(DATA_DIR / "pairs_test.parquet")
samps = test_df.sample(5, random_state=7)[["src_ur","tgt_rom"]].to_dict("records")
preds = translate([s["src_ur"] for s in samps], max_len=120)  # shorter max_len


print("\n=== Qualitative (5 random from TEST) ===")
for i, s in enumerate(samps):
    print(f"{i+1:02d}. UR : {s['src_ur']}")
    print(f"    GT : {s['tgt_rom']}")
    print(f"    PR : {preds[i]}")


Loaded: bilstm4lstm_char_E256_H256_enc2_dec4_drop0.3_best.pt | tokenization: char | BLEU(val): 63.03771018975481

=== TEST ===
PPL: 1.137 | BLEU: 75.15 | CER: 0.049

=== Qualitative (5 random from TEST) ===
01. UR : ہو رہے گا کچہ نہ کچہ گہبراییں کیا
    GT : ho rahegā kuchh na kuchh ghabrā.eñ kyā
    PR : ho rahegā kuchh na kuchh ghabrā.eñ kyā
02. UR : لگ گیی چپ حالیؔ رنجور کو
    GT : lag ga.ī chup 'hālī'-e-ranjūr ko
    PR : lag ga.ī chup 'hālī' ranjūr ko
03. UR : مری طرح بہی کویی میرا غم گسار نہ ہو
    GT : mirī tarah bhī koī merā ġham-gusār na ho
    PR : mirī tarah bhī koī merā ġham-gusār na ho
04. UR : سورج دماغ لوگ بہی ابلاغ فکر میں
    GT : sūraj-dimāġh log bhī ablāġh-e-fikr meñ
    PR : sūraj-e-dimāġh log bhī iblāġh-e-fikr meñ
05. UR : مویے شیشہ دیدۂ ساغر کی مژگانی کرے
    GT : mū-e-shīsha dīda-e-sāġhar kī mizhgānī kare
    PR : mū-e-shīsha dīda-e-sāġhar kī mizhgānī kare


In [22]:
# ============================================================
# Mini experiment sweep (robust to eval_on_loader signature)
#   - Runs: char_H512, char_LR1e-3, spm_baseline (if SPM available)
#   - Saves CSV + BLEU bar plot
# ============================================================
import os, json, time, math, numpy as np, pandas as pd, matplotlib.pyplot as plt
from pathlib import Path
import torch, inspect

# --- Expect Cell #4 and #5 already run
assert 'train_model' in globals(), "train_model not found. Run Cell #5."
assert 'Seq2Seq' in globals(), "Seq2Seq not found. Run Cell #5."
assert 'DEVICE' in globals(), "DEVICE not found. Run Cell #5."
assert 'ARTI' in globals(), "ARTI not found. Run Cell #5."
assert 'get_loaders' in globals(), "get_loaders not found. Run Cell #5."

# ---- Shim eval_on_loader to support BOTH signatures
assert 'eval_on_loader' in globals(), "eval_on_loader not found. Run Cell #5."
_orig_eval_on_loader = eval_on_loader  # keep original

def eval_on_loader(*args, **kwargs):
    """
    Wrapper that supports:
      - old: eval_on_loader(model, dl)
      - new: eval_on_loader(model, dl, decode_fn, mode)
    """
    sig = inspect.signature(_orig_eval_on_loader)
    if len(sig.parameters) == 2:
        # old signature — just hand it model, dl
        model, dl = args[:2]
        return _orig_eval_on_loader(model, dl)
    # new signature — pass through
    return _orig_eval_on_loader(*args, **kwargs)

# ---- Paths / base config
cfg = json.load(open("/content/nmt_urdu_roman/project_config.json","r",encoding="utf-8"))
PROJECT_DIR = Path(cfg["project_dir"])
RUNS_DIR = PROJECT_DIR / "runs"; RUNS_DIR.mkdir(parents=True, exist_ok=True)
MODELS_DIR = PROJECT_DIR / "models"

base = json.load(open(ARTI / "exp_default.json","r",encoding="utf-8"))
BASE_EPOCHS = 6
BASE_BATCH  = base.get("batch_size", 64)

# SPM availability
SPM_OK = (ARTI/'spm_src.model').exists() and (ARTI/'spm_tgt.model').exists()
if not SPM_OK:
    print("⚠️ SPM model files not found. The SPM config will be skipped.")

# ---- Sweep definitions
sweep = [
    {"name": "char_H512",   "overrides": {"tokenization": "char", "hidden_size": 512, "epochs": BASE_EPOCHS, "batch_size": BASE_BATCH}},
    {"name": "char_LR1e-3", "overrides": {"tokenization": "char", "learning_rate": 1e-3, "epochs": BASE_EPOCHS, "batch_size": BASE_BATCH}},
]
if SPM_OK:
    sweep.append({"name": "spm_baseline", "overrides": {"tokenization": "spm", "epochs": BASE_EPOCHS, "batch_size": BASE_BATCH}})

# ---- Helper to get decode function only when needed
def _decode_fn_for(tokenization):
    if tokenization == "char":
        v_tgt = json.load(open(ARTI / "vocab_tgt_char.json","r",encoding="utf-8"))
        itos_tgt = v_tgt["itos"]
        PAD_ID, BOS_ID, EOS_ID = 0,1,2
        def decode_char(ids):
            return "".join(itos_tgt[i] for i in ids if i not in (PAD_ID,BOS_ID,EOS_ID) and 0<=i<len(itos_tgt)).strip()
        return decode_char
    else:
        import sentencepiece as spm
        PAD_ID, BOS_ID, EOS_ID = 0,1,2
        sp_tgt = spm.SentencePieceProcessor(); sp_tgt.load(str(ARTI / "spm_tgt.model"))
        def decode_spm(ids):
            ids = [i for i in ids if i not in (PAD_ID,BOS_ID,EOS_ID)]
            try: return sp_tgt.decode(ids).strip()
            except: return sp_tgt.decode_pieces([sp_tgt.id_to_piece(i) for i in ids]).strip()
        return decode_spm

results = []

for job in sweep:
    name = job["name"]
    overrides = dict(base); overrides.update(job["overrides"])
    print(f"\n=== Running: {name} ===")
    t0 = time.time()
    try:
        ckpt_path = train_model(exp_overrides=overrides)
    except Exception as e:
        print(f"❌ Training failed for {name}: {e}")
        continue
    dt = time.time() - t0

    if not ckpt_path or not Path(ckpt_path).exists():
        print(f"❌ No checkpoint produced for {name}")
        continue

    meta = torch.load(ckpt_path, map_location="cpu")
    best_bleu = float(meta.get("bleu", float('nan')))
    exp_used  = meta.get("exp", overrides)
    tokenization = exp_used["tokenization"]

    # Rebuild model to recompute val ppl/CER
    if tokenization == "char":
        v_src = json.load(open(ARTI / "vocab_src_char.json","r",encoding="utf-8"))
        v_tgt = json.load(open(ARTI / "vocab_tgt_char.json","r",encoding="utf-8"))
        src_v = len(v_src["itos"]); tgt_v = len(v_tgt["itos"])
    else:
        import sentencepiece as spm
        sp_src = spm.SentencePieceProcessor(); sp_src.load(str(ARTI / "spm_src.model"))
        sp_tgt = spm.SentencePieceProcessor(); sp_tgt.load(str(ARTI / "spm_tgt.model"))
        src_v = sp_src.get_piece_size(); tgt_v = sp_tgt.get_piece_size()

    model = Seq2Seq(
        src_vsize=src_v, tgt_vsize=tgt_v,
        emb_dim=exp_used["embedding_dim"], hid_size=exp_used["hidden_size"],
        enc_layers=exp_used["enc_layers"], dec_layers=exp_used["dec_layers"],
        dropout=exp_used["dropout"]
    ).to(DEVICE)
    state = torch.load(ckpt_path, map_location=DEVICE)["model_state"]
    model.load_state_dict(state)

    # loaders + eval on VAL (signature-agnostic)
    _, _, _, dl_train_tmp, dl_val_tmp, _ = get_loaders(tokenization, batch=exp_used["batch_size"])
    decode_fn = _decode_fn_for(tokenization)
    try:
        val_nll, val_ppl, val_bleu2, val_cer, _, _ = eval_on_loader(model, dl_val_tmp, decode_fn, tokenization)
    except TypeError:
        val_nll, val_ppl, val_bleu2, val_cer, _, _ = eval_on_loader(model, dl_val_tmp)

    results.append({
        "name": name,
        "tokenization": tokenization,
        "embedding_dim": exp_used["embedding_dim"],
        "hidden_size": exp_used["hidden_size"],
        "learning_rate": exp_used["learning_rate"],
        "epochs": exp_used["epochs"],
        "batch_size": exp_used["batch_size"],
        "best_bleu_val_ckpt": float(best_bleu),
        "val_ppl_reval": float(val_ppl),
        "val_cer_reval": float(val_cer),
        "train_time_s": round(dt, 1),
        "ckpt_path": str(ckpt_path),
    })

# ---- Save/show table safely
if results:
    df = pd.DataFrame(results).sort_values("best_bleu_val_ckpt", ascending=False)
    out_csv = RUNS_DIR / "exp_results_sweep.csv"
    df.to_csv(out_csv, index=False)
    print("\n✅ Saved sweep results:", out_csv)
    display(df)

    # BLEU bar chart
    plt.figure(figsize=(8,4))
    x = np.arange(len(df))
    plt.bar(x, df["best_bleu_val_ckpt"])
    plt.xticks(x, df["name"], rotation=15, ha='right')
    plt.ylabel("Best Val BLEU (ckpt meta)")
    plt.title("Mini Sweep — Best Validation BLEU")
    plt.tight_layout()
    plt.show()
else:
    print("\n⚠️ No successful runs to summarize (results list is empty).")



=== Running: char_H512 ===
❌ Training failed for char_H512: not enough values to unpack (expected 6, got 3)

=== Running: char_LR1e-3 ===
❌ Training failed for char_LR1e-3: not enough values to unpack (expected 6, got 3)

=== Running: spm_baseline ===
❌ Training failed for spm_baseline: list index out of range

⚠️ No successful runs to summarize (results list is empty).


In [16]:
# Save everything you need to /MyDrive
from pathlib import Path
import shutil, json, torch

ROOT = Path("/content")
PROJ = ROOT / "nmt_urdu_roman"
GDR  = Path("/content/drive/MyDrive/nmt_urdu_roman_export")
GDR.mkdir(parents=True, exist_ok=True)

for p in ["models","artifacts","exp_results.csv"]:
    src = (PROJ / p) if p.endswith(".csv") else (PROJ / "runs" / p if (PROJ/"runs"/p).exists() else PROJ / p)
    if src.exists():
        dst = GDR / p
        if src.is_dir():
            shutil.copytree(src, dst, dirs_exist_ok=True)
        else:
            shutil.copy2(src, dst)

# small manifest
ckpt = "/content/nmt_urdu_roman/models/bilstm4lstm_char_E256_H256_enc2_dec4_drop0.3_best.pt"
meta = torch.load(ckpt, map_location="cpu")
(Path(GDR/"MANIFEST.json")).write_text(json.dumps({
    "best_ckpt": ckpt,
    "tokenization": meta.get("tokenization"),
    "bleu_val": float(meta.get("bleu", -1)),
}, indent=2), encoding="utf-8")

print("✅ Exported to:", GDR)


✅ Exported to: /content/drive/MyDrive/nmt_urdu_roman_export


In [17]:
# ============================================================
# Colab Cell #7 — Reusable inference (greedy or beam), batch API
# ============================================================
import json, torch
from pathlib import Path
import sentencepiece as spm

# Paths
cfg = json.load(open("/content/nmt_urdu_roman/project_config.json","r",encoding="utf-8"))
PROJECT_DIR = Path(cfg["project_dir"])
ARTI        = PROJECT_DIR / "artifacts"
MODELS_DIR  = PROJECT_DIR / "models"

# Load best checkpoint (adjust if you saved multiple)
ckpts = sorted(MODELS_DIR.glob("bilstm4lstm_*_best.pt"))
assert ckpts, "No best checkpoints found in models/."
CKPT_PATH = ckpts[-1]
ckpt = torch.load(CKPT_PATH, map_location="cpu")
exp_used = ckpt["exp"]
tokenization = ckpt.get("tokenization", exp_used.get("tokenization", "char"))
print(f"Using ckpt: {CKPT_PATH.name} | tokenization={tokenization}")

# Bring model defs from Cell #5
assert "Seq2Seq" in globals(), "Please re-run Cell #5 first (model classes)."

# Vocab / tokenizers
PAD_ID=0; BOS_ID=1; EOS_ID=2; UNK_ID=3
if tokenization == "char":
    v_src = json.load(open(ARTI / "vocab_src_char.json","r",encoding="utf-8"))
    v_tgt = json.load(open(ARTI / "vocab_tgt_char.json","r",encoding="utf-8"))
    stoi_src, itos_tgt = v_src["stoi"], v_tgt["itos"]
    SRC_V, TGT_V = len(v_src["itos"]), len(v_tgt["itos"])
    def enc_src(s):
        return [BOS_ID] + [stoi_src.get(ch, UNK_ID) for ch in s] + [EOS_ID]
    def dec_tgt(ids):
        return "".join(itos_tgt[i] for i in ids if i not in (PAD_ID,BOS_ID,EOS_ID)).strip()
else:
    sp_src = spm.SentencePieceProcessor(); sp_src.load(str(ARTI/"spm_src.model"))
    sp_tgt = spm.SentencePieceProcessor(); sp_tgt.load(str(ARTI/"spm_tgt.model"))
    SRC_V, TGT_V = sp_src.get_piece_size(), sp_tgt.get_piece_size()
    def enc_src(s):
        return [BOS_ID] + sp_src.encode(s, out_type=int) + [EOS_ID]
    def dec_tgt(ids):
        ids = [i for i in ids if i not in (PAD_ID,BOS_ID,EOS_ID)]
        try: return sp_tgt.decode(ids).strip()
        except: return sp_tgt.decode_pieces([sp_tgt.id_to_piece(i) for i in ids]).strip()

# Rebuild model & load weights
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Seq2Seq(SRC_V, TGT_V,
                exp_used["embedding_dim"], exp_used["hidden_size"],
                exp_used["enc_layers"], exp_used["dec_layers"],
                exp_used["dropout"]).to(DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()

# --- Greedy (EOS-safe) from Cell 5.1 (assume it's already redefined) ---
assert "greedy_decode" in globals(), "Please run the EOS-safe greedy from earlier."

# --- Lightweight Beam Search (beam=5) ---
import heapq
def beam_search(model, src_ids, beam_size=5, max_len=200, length_penalty=0.7):
    """
    src_ids: 1D LongTensor on DEVICE including BOS/EOS around source.
    returns: list[int] best hypothesis token ids (no BOS, ends at EOS if emitted).
    """
    model.eval()
    with torch.no_grad():
        src = src_ids.unsqueeze(0)  # [1, Tsrc]
        enc_outputs, (hn, cn) = model.encoder(src)
        enc_mask = (src == PAD_ID)
        dec_hc = model.bridge(hn, cn)
        ctx = torch.zeros(1, enc_outputs.size(-1), device=src.device)
        # beams: list of tuples (-score, seq, dec_hc, ctx, last_token)
        start = (0.0, [], dec_hc, ctx, torch.tensor([BOS_ID], device=src.device))
        beams = [start]
        finished = []

        for _ in range(max_len):
            new_beams = []
            for score, seq, dec_hc, ctx, y_prev in beams:
                if len(seq) and seq[-1] == EOS_ID:
                    finished.append((score, seq))
                    continue
                logits, dec_hc_n, ctx_n, _ = model.decoder(y_prev, dec_hc, enc_outputs, enc_mask, ctx)
                logp = torch.log_softmax(logits, dim=-1).squeeze(0)  # [V]
                topk = torch.topk(logp, k=beam_size)
                for k in range(beam_size):
                    tok = int(topk.indices[k])
                    sc = score - float(topk.values[k])  # negative log-prob
                    new_beams.append((sc, seq+[tok], dec_hc_n, ctx_n, torch.tensor([tok], device=src.device)))
            # keep top-K
            beams = heapq.nsmallest(beam_size, new_beams, key=lambda x: x[0])
            # early stop if we already have K finished with EOS
            if len(finished) >= beam_size and all(seq and seq[-1]==EOS_ID for _,seq in finished[:beam_size]):
                break

        cand = finished if finished else beams
        # length penalty: favor slightly longer but reasonable sequences
        cand = [(sc / (len(seq) ** length_penalty if len(seq)>0 else 1.0), seq) for sc, seq in cand]
        best = min(cand, key=lambda x: x[0])[1]
        return best

# --- Public functions ---
def translate_one(text, decoder="greedy", max_len=None, beam_size=5):
    ids = enc_src(text)
    if max_len is None:
        max_len = max(30, len(ids) + 10)  # safe cap
    src_tensor = torch.tensor(ids, dtype=torch.long, device=DEVICE)

    if decoder == "beam":
        out = beam_search(model, src_tensor, beam_size=beam_size, max_len=max_len)
        return dec_tgt(out)
    else:
        batch = {
            "src_ids": src_tensor.unsqueeze(0),
            "tgt_ids": torch.tensor([[BOS_ID, EOS_ID]], dtype=torch.long, device=DEVICE)
        }
        gen = greedy_decode(model, batch, max_len=max_len)
        return dec_tgt(gen[0].tolist())

def translate_batch(texts, decoder="greedy", max_len=None, beam_size=5):
    outs = []
    for t in texts:
        outs.append(translate_one(t, decoder=decoder, max_len=max_len, beam_size=beam_size))
    return outs

print("✅ Inference ready. Example:")
print("UR  :", "مزے جہان کے اپنی نظر میں خاک نہیں")
print("ROM :", translate_one("مزے جہان کے اپنی نظر میں خاک نہیں", decoder="beam"))


Using ckpt: bilstm4lstm_char_E256_H256_enc2_dec4_drop0.3_best.pt | tokenization=char
✅ Inference ready. Example:
UR  : مزے جہان کے اپنی نظر میں خاک نہیں
ROM : maze jahān ke apnī nazar meñ ḳhaak nahīñ


In [19]:
# ============================================================
# Cell #8 (fixed) — Batch translate & export CSV
#   * Robust beam search tuples (score, seq, h_c, ctx)
#   * Works with tokenization="char" or "spm"
#   * Saves /content/nmt_urdu_roman/runs/preds_test.csv
# ============================================================
import os, json, math
from pathlib import Path
import torch
import torch.nn.functional as F
import pandas as pd

# --------- Common constants / device ---------
PAD_ID = 0; BOS_ID = 1; EOS_ID = 2; UNK_ID = 3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------- Paths and config ---------
cfg = json.load(open("/content/nmt_urdu_roman/project_config.json","r",encoding="utf-8"))
PROJECT_DIR = Path(cfg["project_dir"])
ARTI        = PROJECT_DIR / "artifacts"
DATA_DIR    = PROJECT_DIR / "data"
RUNS_DIR    = PROJECT_DIR / "runs"
RUNS_DIR.mkdir(parents=True, exist_ok=True)

# --------- Helper: load checkpoint / model / tokenizers ---------
def load_model_and_tok(ckpt_path=None):
    # Try to reuse globals if they exist
    global best_model, exp_used
    if ckpt_path is None:
        # Pick the best *.pt in models dir if not given
        models = sorted((PROJECT_DIR / "models").glob("*.pt"), key=os.path.getmtime, reverse=True)
        assert models, "No checkpoints found in /models. Train first."
        ckpt_path = models[0]

    ckpt = torch.load(ckpt_path, map_location=DEVICE)
    exp_local = ckpt["exp"]
    tokenization = ckpt.get("tokenization", exp_local.get("tokenization", "char"))

    # Build model skeleton (reuse class from Cell #5)
    assert "Seq2Seq" in globals(), "Seq2Seq not found. Re-run Cell #5."
    # Get vocab sizes
    if tokenization == "char":
        v_src = json.load(open(ARTI / "vocab_src_char.json","r",encoding="utf-8"))
        v_tgt = json.load(open(ARTI / "vocab_tgt_char.json","r",encoding="utf-8"))
        src_v = len(v_src["itos"]); tgt_v = len(v_tgt["itos"])
    else:
        import sentencepiece as spm
        sp_src = spm.SentencePieceProcessor(); sp_src.load(str(ARTI / "spm_src.model"))
        sp_tgt = spm.SentencePieceProcessor(); sp_tgt.load(str(ARTI / "spm_tgt.model"))
        src_v = sp_src.get_piece_size(); tgt_v = sp_tgt.get_piece_size()

    model = Seq2Seq(src_v, tgt_v,
                    exp_local["embedding_dim"], exp_local["hidden_size"],
                    exp_local["enc_layers"], exp_local["dec_layers"],
                    exp_local["dropout"]).to(DEVICE)
    model.load_state_dict(ckpt["model_state"])
    model.eval()

    # Tokenizers & decode from Cell #5 utility
    assert "_load_tokenizers" in globals(), "_load_tokenizers not found. Re-run Cell #5."
    tok = _load_tokenizers(tokenization)

    # For char-mode encoding we need src stoi
    if tokenization == "char":
        v_src = json.load(open(ARTI / "vocab_src_char.json","r",encoding="utf-8"))
        stoi_src = v_src["stoi"]
    else:
        stoi_src = None  # we’ll use tok["sp_src"]

    return model, exp_local, tok, stoi_src, ckpt_path

best_model, exp_used, tok, stoi_src_char, used_ckpt = load_model_and_tok()
print(f"Using ckpt: {Path(used_ckpt).name} | tokenization={exp_used['tokenization']}")

# --------- Encoding / decoding helpers ---------
def encode_src_text(s: str, tokenization: str):
    if tokenization == "char":
        ids = [BOS_ID] + [stoi_src_char.get(ch, UNK_ID) for ch in s] + [EOS_ID]
    else:
        ids = [BOS_ID] + tok["sp_src"].encode(s, out_type=int) + [EOS_ID]
    t = torch.tensor(ids, dtype=torch.long, device=DEVICE).unsqueeze(0)  # [1,T]
    return t

def decode_ids(ids: torch.Tensor):
    # ids: [L] on cpu or gpu
    arr = ids.tolist() if isinstance(ids, torch.Tensor) else ids
    return tok["decode"](arr)

# --------- Greedy (quick) ---------
def greedy_one(model, src_ids, max_len=200):
    model.eval()
    with torch.no_grad():
        enc_out, (hn, cn) = model.encoder(src_ids)
        enc_mask = (src_ids == PAD_ID)
        dec_hc = model.bridge(hn, cn)
        ctx = torch.zeros(src_ids.size(0), enc_out.size(-1), device=src_ids.device)
        y_prev = torch.full((src_ids.size(0),), BOS_ID, dtype=torch.long, device=src_ids.device)
        outs = []
        for _ in range(max_len):
            logits, dec_hc, ctx, _ = model.decoder(y_prev, dec_hc, enc_out, enc_mask, ctx)
            y_prev = torch.argmax(logits, dim=-1)
            outs.append(y_prev)
            if (y_prev == EOS_ID).all(): break
        return torch.stack(outs, dim=1).squeeze(0)  # [L]

# --------- Beam search (fixed tuples) ---------
def beam_one(model, src_ids, beam_size=5, max_len=200, length_penalty=0.8):
    """
    Returns best seq ids (Tensor [L]).
    Tuples in beams/finished are (neg_logprob, seq_list, (h,c), ctx)
    """
    model.eval()
    with torch.no_grad():
        # Encode once
        enc_out, (hn, cn) = model.encoder(src_ids)
        enc_mask = (src_ids == PAD_ID)
        dec_hc0 = model.bridge(hn, cn)
        ctx0 = torch.zeros(src_ids.size(0), enc_out.size(-1), device=src_ids.device)

        # Init beams
        beams = [(0.0, [BOS_ID], dec_hc0, ctx0)]  # negative log-prob score
        finished = []

        for _ in range(max_len):
            new_beams = []
            for score, seq, dec_hc, ctx in beams:
                last = seq[-1]
                if last == EOS_ID:
                    finished.append((score, seq, dec_hc, ctx))
                    continue
                y_prev = torch.tensor([last], dtype=torch.long, device=src_ids.device)
                logits, dec_hc_new, ctx_new, _ = model.decoder(y_prev, dec_hc, enc_out, enc_mask, ctx)
                logp = F.log_softmax(logits, dim=-1).squeeze(0)  # [V]
                topk = torch.topk(logp, beam_size)
                for k in range(beam_size):
                    tok_id = int(topk.indices[k].item())
                    sc = float(-topk.values[k].item())  # negative log prob
                    new_seq = seq + [tok_id]
                    new_beams.append((score + sc, new_seq, dec_hc_new, ctx_new))

            # prune
            new_beams.sort(key=lambda x: x[0])
            beams = new_beams[:beam_size]

            # early stop if all finished
            if len(finished) >= beam_size and all(b[-1] == EOS_ID for _, b, _, _ in finished[-beam_size:]):
                break

        # Candidate pool: prefer finished, else current beams
        cand = finished if finished else beams

        # Apply length penalty on a view with only (score, seq)
        normed = []
        for sc, seq, _, _ in cand:
            L = max(1, len(seq))
            normed.append((sc / (L ** length_penalty), seq))
        best_seq = min(normed, key=lambda x: x[0])[1]

        # strip leading BOS, truncate at EOS
        if best_seq and best_seq[0] == BOS_ID:
            best_seq = best_seq[1:]
        if EOS_ID in best_seq:
            best_seq = best_seq[:best_seq.index(EOS_ID)]
        return torch.tensor(best_seq, dtype=torch.long)

# --------- Public helpers ---------
def translate_one(urdu_text: str, decoder="beam", beam_size=5, max_len=200):
    src_ids = encode_src_text(urdu_text, exp_used["tokenization"])
    if decoder == "greedy":
        out = greedy_one(best_model, src_ids, max_len=max_len)
    else:
        out = beam_one(best_model, src_ids, beam_size=beam_size, max_len=max_len)
    return decode_ids(out)

def translate_batch(texts, decoder="beam", beam_size=5, max_len=200):
    outs = []
    for s in texts:
        outs.append(translate_one(str(s), decoder=decoder, beam_size=beam_size, max_len=max_len))
    return outs

# --------- Run on test split (or your own list) and export ---------
test_df = pd.read_parquet(DATA_DIR / "pairs_test.parquet")[["src_ur","tgt_rom"]].copy()

# Example for custom list instead:
# custom = ["مزے جہان کے اپنی نظر میں خاک نہیں", "اب کے ہم بچھڑے تو شاید کبھی خوابوں میں ملیں"]
# test_df = pd.DataFrame({"src_ur": custom, "tgt_rom": [""]*len(custom)})

preds = translate_batch(test_df["src_ur"].tolist(), decoder="beam", beam_size=5)
out = test_df.assign(pred=preds)

csv_path = RUNS_DIR / "preds_test.csv"
out.to_csv(csv_path, index=False)
print("✅ Saved:", csv_path)
out.head(10)


Using ckpt: bilstm4lstm_char_E256_H256_enc2_dec4_drop0.3_best.pt | tokenization=char
✅ Saved: /content/nmt_urdu_roman/runs/preds_test.csv


Unnamed: 0,src_ur,tgt_rom,pred
0,عاشقی میں میرؔ جیسے خواب مت دیکہا کرو,āshiqī meñ 'mīr' jaise ḳhvāb mat dekhā karo,āshiqī meñ 'mīr' jaise ḳhvāb mat dekhā karo
1,باولے ہو جاو گے مہتاب مت دیکہا کرو,bāvle ho jāoge mahtāb mat dekhā karo,bāvale ho jāoge mahtāb mat dekhā karo
2,جستہ جستہ پڑہ لیا کرنا مضامین وفا,jasta jasta paḌh liyā karnā mazāmīn-e-vafā,jasta jasta paḌh liyā karnā mazāmīn-e-vafā
3,پر کتاب عشق کا ہر باب مت دیکہا کرو,par kitāb-e-ishq kā har baab mat dekhā karo,par kitāb-e-ishq kā har baab mat dekhā karo
4,اس تماشے میں الٹ جاتی ہیں اکثر کشتیاں,is tamāshe meñ ulaT jaatī haiñ aksar kashtiyāñ,is tamāshe meñ ulaT jaatī haiñ aksar kushtiyāñ
5,ڈوبنے والوں کو زیر آب مت دیکہا کرو,Dūbne vāloñ ko zer-e-āb mat dekhā karo,Dūbne vāloñ ko zer-e-āb mat dekhā karo
6,مے کدے میں کیا تکلف مے کشی میں کیا حجاب,mai-kade meñ kyā takalluf mai-kashī meñ kyā hijāb,mai-kade meñ kyā takalluf-e-mai-kashī meñ kyā ...
7,بزم ساقی میں ادب آداب مت دیکہا کرو,bazm-e-sāqī meñ adab ādāb mat dekhā karo,bazm-e-sāqī meñ adab ādāb mat dekhā karo
8,ہم سے درویشوں کے گہر آو تو یاروں کی طرح,ham se durveshoñ ke ghar aao to yāroñ kī tarah,ham se darveshoñ ke ghar aao to yāroñ kī tarah
9,ہر جگہ خس خانہ و برفاب مت دیکہا کرو,har jagah ḳhas-ḳhāna o barfāb mat dekhā karo,har jagah ḳhas-ḳhāna-o-barfāb mat dekhā karo


In [20]:
# ============================================================
# Cell #9 — Quick error analysis + light post-processing
#   - Shows top confusions
#   - Normalizes hyphens/ezāfe/wa and a couple frequent typos
#   - Re-scores BLEU/CER after cleanup and saves new CSV
# ============================================================
import re, pandas as pd, numpy as np, sacrebleu
from jiwer import cer as jiwer_cer
from pathlib import Path

PROJECT_DIR = Path("/content/nmt_urdu_roman")
RUNS_DIR    = PROJECT_DIR / "runs"
DATA_DIR    = PROJECT_DIR / "data"

df = pd.read_csv(RUNS_DIR / "preds_test.csv", sep=",")
print("Rows:", len(df))

# ---- quick error peek
def token_diff_rows(df, n=10):
    rows = []
    for i, r in df.iterrows():
        gt = str(r["tgt_rom"]).split()
        pr = str(r["pred"]).split()
        if gt != pr:
            rows.append((i, r["src_ur"], r["tgt_rom"], r["pred"]))
        if len(rows) >= n: break
    return rows

print("\nExamples with diffs (first 10):")
for i, ur, gt, pr in token_diff_rows(df, n=10):
    print(f"- {i:04d} UR: {ur}\n   GT: {gt}\n   PR: {pr}\n")

# ---- light post-processing
EZAFE = r"(?:\s*-\s*e\s*-\s*|\s+e\s+|\s*e\s*-\s*|\s*-\s*e\s*)"
WA    = r"(?:\s*-\s*o\s*-\s*|\s+o\s+|\s*o\s*-\s*|\s*-\s*o\s*)"

def tidy_roman(s: str) -> str:
    if not isinstance(s, str): return ""
    t = s

    # unify ezāfe and wa to "-e-" / "-o-"
    t = re.sub(EZAFE, "-e-", t)
    t = re.sub(WA, "-o-", t)

    # collapse multiple hyphens/spaces around hyphens
    t = re.sub(r"\s*-\s*", "-", t)
    t = re.sub(r"\s{2,}", " ", t).strip()

    # common char fixes
    # kashti vs kushti (if your ground truth prefers 'kashtiyāñ')
    t = re.sub(r"\bkushti(yāñ|yān|yā|yān?)\b", r"kashti\1", t, flags=re.I)

    # normalize dotted 'ḳhaak' variants -> 'ḳhaak' / 'khāk' style harmonization is repo-specific; skip heavy changes
    # small punctuation spacing
    t = re.sub(r"\s+([,؛۔?!])", r"\1", t)

    return t

df["pred_clean"] = df["pred"].map(tidy_roman)

# ---- score before/after
preds_raw   = df["pred"].tolist()
preds_clean = df["pred_clean"].tolist()
refs        = df["tgt_rom"].tolist()

bleu_raw   = sacrebleu.corpus_bleu(preds_raw,   [refs]).score
bleu_clean = sacrebleu.corpus_bleu(preds_clean, [refs]).score
cer_raw    = float(np.mean([jiwer_cer(r, p) for p, r in zip(preds_raw, refs)]))
cer_clean  = float(np.mean([jiwer_cer(r, p) for p, r in zip(preds_clean, refs)]))

print(f"\nBLEU raw   : {bleu_raw:.2f} | CER raw   : {cer_raw:.3f}")
print(f"BLEU clean : {bleu_clean:.2f} | CER clean : {cer_clean:.3f}")

# ---- save cleaned file
out_path = RUNS_DIR / "preds_test_clean.csv"
df[["src_ur","tgt_rom","pred_clean"]].to_csv(out_path, index=False)
print("✅ Saved cleaned preds:", out_path)


Rows: 5255

Examples with diffs (first 10):
- 0001 UR: باولے ہو جاو گے مہتاب مت دیکہا کرو
   GT: bāvle ho jāoge mahtāb mat dekhā karo
   PR: bāvale ho jāoge mahtāb mat dekhā karo

- 0004 UR: اس تماشے میں الٹ جاتی ہیں اکثر کشتیاں
   GT: is tamāshe meñ ulaT jaatī haiñ aksar kashtiyāñ
   PR: is tamāshe meñ ulaT jaatī haiñ aksar kushtiyāñ

- 0006 UR: مے کدے میں کیا تکلف مے کشی میں کیا حجاب
   GT: mai-kade meñ kyā takalluf mai-kashī meñ kyā hijāb
   PR: mai-kade meñ kyā takalluf-e-mai-kashī meñ kyā hijāb

- 0008 UR: ہم سے درویشوں کے گہر آو تو یاروں کی طرح
   GT: ham se durveshoñ ke ghar aao to yāroñ kī tarah
   PR: ham se darveshoñ ke ghar aao to yāroñ kī tarah

- 0009 UR: ہر جگہ خس خانہ و برفاب مت دیکہا کرو
   GT: har jagah ḳhas-ḳhāna o barfāb mat dekhā karo
   PR: har jagah ḳhas-ḳhāna-o-barfāb mat dekhā karo

- 0010 UR: مانگے تانگے کی قباییں دیر تک رہتی نہیں
   GT: māñge-tāñge kī qabā.eñ der tak rahtī nahīñ
   PR: māñge tāñge kī qabā.eñ der tah rahtī nahīñ

- 0011 UR: یار لوگوں کے لقب الق

In [23]:
# If not already mounted:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Copy folders into Drive
!cp -r /content/nmt_urdu_roman /content/drive/MyDrive/
!cp -r /content/urdu_ghazals_rekhta /content/drive/MyDrive/

print("Saved to MyDrive/nmt_urdu_roman and MyDrive/urdu_ghazals_rekhta")


Mounted at /content/drive
Saved to MyDrive/nmt_urdu_roman and MyDrive/urdu_ghazals_rekhta
