# Kriol → English NMT (NLLB, single-GPU)

This notebook trains a Kriol → English translator using NLLB (facebook/nllb-200-distilled-600M) with the Hugging Face Trainer.

- Cleans and preprocesses pairs, caching a cleaned CSV to speed reruns
- Trains a single-GPU baseline and saves a `final/` checkpoint with HF artifacts and `.pth`
- Optional: back-translation plan and custom tokenizer scaffolding (placeholders)

References:
- NLLB model card: https://huggingface.co/facebook/nllb-200-distilled-600M
- Transformers Seq2Seq docs: https://huggingface.co/docs/transformers/en/tasks/translation


### Step 1 — Environment & imports

In [1]:
import os
import re
from typing import List

import torch
from torch.utils.data import Dataset

import pandas as pd
from sklearn.model_selection import train_test_split

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer as Trainer,
    Seq2SeqTrainingArguments as TrainingArguments,
)

print(torch.__version__)
print(torch.cuda.is_available())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


  from .autonotebook import tqdm as notebook_tqdm


2.8.0+cu129
True


### Step 2 — Config

In [2]:

class CFG:
    # Model & paths
    MODEL_NAME = "facebook/nllb-200-distilled-600M"
    OUTPUT_DIR = "../model/"
    CFG_JSON = "CFG.json"

    # Data
    DATA_FILE = "../data/train_data.xlsx"
    CLEAN_DATA_FILE = "../data/train_data_cleaned.csv"
    SRC_COL = "kriol"
    TGT_COL = "english"
    VAL_SIZE = 0.1
    SEED = 42

    # NLLB language tags
    SRC_LANG = None  # e.g., "eng_Latn" if source were English; set properly when defined
    TGT_LANG = "eng_Latn"

    # Preprocessing
    APPLY_ENGLISH_LID = True
    MAX_TOKENS = 128
    LEN_RATIO = 3.0
    STRIP_PUNCT_SRC = True
    STRIP_PUNCT_TGT = False

    # Cleaning control
    SKIP_CLEAN_IF_EXISTS = True

    # Cross-validation
    USE_CV = False
    K_FOLDS = 5
    CV_FOLD = 0

    # Training
    NUM_EPOCHS = 12
    BATCH_SIZE = 8
    LR = 5e-5
    MAX_LEN = 128
    EVAL_EVERY_STEPS = 2000  # step-based eval printing frequency

    # Decoding
    BEAM_SIZE = 6
    LENGTH_PENALTY = 1.0
    EARLY_STOPPING = True

    GRADIENT_CHECKPOINTING = True

    # Back-translation
    ENABLE_BT = False
    EN2KR_MODEL = "Helsinki-NLP/opus-mt-en-mul"
    EN2KR_DIR = "../model/en2kriol"
    SYNTH_CSV = "../data/synthetic/en_to_kriol_v1.csv"
    SYNTH_CSV_SAMPLE = "../data/synthetic/en_to_kriol_v1_sample.csv"
    BT_FAST = False
    BT_BEAM_SIZE = 1  # will be adjusted below
    BT_LENGTH_PENALTY = 1.0
    BT_EARLY_STOPPING = True
    BT_BATCH = 64

    # Synthetic integration
    INTEGRATE_SYNTH = False
    SYNTH_MAX_RATIO = 1.0

    # Tokenizer (custom placeholder)
    USE_SPM = False
    SPM_DIR = "../outputs/tokenizers/spm_kriol_en_v1"

    # Decoding/generation extras
    GEN_MAX_NEW_TOKENS = 64

    # COMET
    COMET_MODEL = "Unbabel/wmt22-comet-da"
    COMET_BATCH = 16

    # Trainer args
    WARMUP_STEPS = 500
    GRAD_ACCUM_STEPS = 2
    LABEL_SMOOTHING = 0.1
    LOGGING_STEPS = 50
    SAVE_STEPS = 1000
    SAVE_TOTAL_LIMIT = 3
    FP16 = True
    REPORT_TO = ["tensorboard"]

    # Augmented training
    RETRAIN_WITH_SYNTH = False
    AUG_OUTPUT_DIR = "../model/final_aug"

# dynamic ties
CFG.BT_BEAM_SIZE = 1 if CFG.BT_FAST else 4

os.makedirs(CFG.OUTPUT_DIR, exist_ok=True)

# Save CFG to JSON for reproducibility
try:
    import json
    cfg_path = os.path.join(CFG.OUTPUT_DIR, "CFG.json")
    cfg_dict = {k: getattr(CFG, k) for k in dir(CFG) if k.isupper()}
    with open(cfg_path, "w", encoding="utf-8") as f:
        json.dump(cfg_dict, f, indent=2)
except Exception as _e:
    print("CFG save failed:", _e)

torch.manual_seed(CFG.SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CFG.SEED)


### Step 3 — Load data

In [3]:
# Prefer cleaned CSV if present; otherwise load raw and proceed to Step 4
if os.path.exists(CFG.CLEAN_DATA_FILE):
    df = pd.read_csv(CFG.CLEAN_DATA_FILE)
else:
    if CFG.DATA_FILE.endswith(".csv"):
        df = pd.read_csv(CFG.DATA_FILE)
    elif CFG.DATA_FILE.endswith((".xlsx", ".xls")):
        df = pd.read_excel(CFG.DATA_FILE)
    else:
        raise ValueError("Unsupported data file format: use .csv or .xlsx")

assert CFG.SRC_COL in df.columns and CFG.TGT_COL in df.columns, f"Columns not found: {CFG.SRC_COL}, {CFG.TGT_COL}"

# Keep minimal shape only; Step 4 handles full cleaning
df = df[[CFG.SRC_COL, CFG.TGT_COL]].dropna()
df = df[(df[CFG.SRC_COL].astype(str).str.strip() != "") & (df[CFG.TGT_COL].astype(str).str.strip() != "")]



### Step 4 — Clean and persist dataset (merged cleaner + execution)

In [4]:
import re
import html
import unicodedata

try:
    import langid as _langid
except Exception:
    _langid = None


def _to_str(x):
    return "" if x is None else str(x)


def _normalize_unicode(s: str) -> str:
    if not isinstance(s, str):
        s = _to_str(s)
    s = unicodedata.normalize("NFC", s)
    s = re.sub(r"[\u200B-\u200F\u202A-\u202E\u2066-\u2069\uFEFF]", "", s)
    s = "".join(ch for ch in s if (ch in "\t\n\r" or unicodedata.category(ch)[0] != "C"))
    return s


def _fix_mojibake(s: str) -> str:
    replacements = {
        "â€™": "'", "â€˜": "'", "â€œ": '"', "â€�": '"',
        "â€“": "-", "â€”": "-", "Â ": " ", "Â ": " ",
    }
    for k, v in replacements.items():
        s = s.replace(k, v)
    return html.unescape(s)


def _normalize_quotes_and_spacing_no_punct_space(s: str) -> str:
    # Standardize quotes/dashes and collapse spaces; do not add spaces around punctuation here
    s = s.replace("“", '"').replace("”", '"').replace("‘", "'").replace("’", "'")
    s = s.replace("–", "-").replace("—", "-")
    s = s.replace("\r", " ").replace("\n", " ")
    s = re.sub(r"\s+", " ", s).strip()
    return s


def _strip_punct(s: str) -> str:
    return re.sub(r"[^\w\s]", "", s)


def _simple_tok_count(s: str) -> int:
    return 0 if not s else len(s.split())


def _kriol_token_score(s: str) -> int:
    kriol_markers = {"bin","langa","blanga","det","im","imbin","garra","olabat","nomo","wal","mob","deya",
    "garram","gin","seim"}
    tokens = set(_to_str(s).lower().split())
    return sum(1 for t in tokens if t in kriol_markers)


def _maybe_swap(src: str, tgt: str, use_langid: bool = True):
    kriol_src = _kriol_token_score(src)
    kriol_tgt = _kriol_token_score(tgt)
    en_src = en_tgt = 0
    if use_langid and _langid is not None:
        try:
            en_src = 1 if _langid.classify(_to_str(src))[0] == "en" else 0
            en_tgt = 1 if _langid.classify(_to_str(tgt))[0] == "en" else 0
        except Exception:
            pass
    should_swap = (en_src > en_tgt and kriol_tgt > kriol_src and (kriol_tgt - kriol_src) >= 1)
    return (tgt, src) if should_swap else (src, tgt)


def clean_parallel_dataframe(
    df,
    src_col: str,
    tgt_col: str,
    lowercase_src: bool = True,
    lowercase_tgt: bool = False,
    strip_punct_src: bool = True,
    strip_punct_tgt: bool = False,
    max_tokens: int = 128,
    len_ratio: float = 3.0,
    apply_english_lid_on_tgt: bool = False,
    try_swap_misplaced_rows: bool = True,
    drop_identical_pairs: bool = True,
):
    work = df[[src_col, tgt_col]].copy()

    for c in (src_col, tgt_col):
        work[c] = (
            work[c]
            .astype(str)
            .map(_normalize_unicode)
            .map(_fix_mojibake)
            .map(_normalize_quotes_and_spacing_no_punct_space)
        )

    if try_swap_misplaced_rows:
        work[[src_col, tgt_col]] = work.apply(
            lambda r: _maybe_swap(r[src_col], r[tgt_col], use_langid=True), axis=1, result_type="expand"
        )

    if lowercase_src:
        work[src_col] = work[src_col].str.lower()
    if lowercase_tgt:
        work[tgt_col] = work[tgt_col].str.lower()
    if strip_punct_src:
        work[src_col] = work[src_col].map(_strip_punct)
    if strip_punct_tgt:
        work[tgt_col] = work[tgt_col].map(_strip_punct)

    work = work[(work[src_col].str.strip() != "") & (work[tgt_col].str.strip() != "")]
    if drop_identical_pairs:
        work = work[work[src_col] != work[tgt_col]]

    work = work.drop_duplicates(subset=[src_col, tgt_col])

    def _keep_len(row) -> bool:
        s_len = _simple_tok_count(row[src_col])
        t_len = _simple_tok_count(row[tgt_col])
        if s_len == 0 or t_len == 0:
            return False
        if s_len > max_tokens or t_len > max_tokens:
            return False
        ratio = max(s_len / max(1, t_len), t_len / max(1, s_len))
        return ratio <= len_ratio

    work = work[work.apply(_keep_len, axis=1)]

    if apply_english_lid_on_tgt and _langid is not None:
        try:
            work = work[work[tgt_col].map(lambda s: _langid.classify(_to_str(s))[0] == "en")]
        except Exception:
            pass

    return work.reset_index(drop=True)


# Execute cleaning and persist (skippable when cleaned exists)
if CFG.SKIP_CLEAN_IF_EXISTS and os.path.exists(CFG.CLEAN_DATA_FILE):
    print(f"Using existing cleaned dataset: {CFG.CLEAN_DATA_FILE}")
    df = pd.read_csv(CFG.CLEAN_DATA_FILE)
else:
    cleaned = clean_parallel_dataframe(
        df,
        src_col=CFG.SRC_COL,
        tgt_col=CFG.TGT_COL,
        lowercase_src=True,
        lowercase_tgt=False,
        strip_punct_src=CFG.STRIP_PUNCT_SRC,
        strip_punct_tgt=CFG.STRIP_PUNCT_TGT,
        max_tokens=CFG.MAX_TOKENS,
        len_ratio=CFG.LEN_RATIO,
        apply_english_lid_on_tgt=CFG.APPLY_ENGLISH_LID,
        try_swap_misplaced_rows=True,
    )

    os.makedirs(os.path.dirname(CFG.CLEAN_DATA_FILE), exist_ok=True)
    cleaned.to_csv(CFG.CLEAN_DATA_FILE, index=False)
    print(f"Saved cleaned dataset: {CFG.CLEAN_DATA_FILE} rows={len(cleaned)}")

    # Use cleaned data downstream
    df = cleaned.copy()


Using existing cleaned dataset: ../data/train_data_cleaned.csv


### Step 5 — Preprocess & normalize

This step prepares pairs for Kriol→English only:
- Lowercase both sides; normalize whitespace
- Deduplicate pairs to avoid train/val leakage
- Length filter: max 128 tokens on each side
- Length ratio filter: src/tgt and tgt/src ≤ 3.0
- Optional: language ID filter for English targets (off by default)

Note: These filters run before the split to ensure clean train/val sets.


In [5]:
try:
    import langid
except Exception:
    langid = None


def _normalize_common(s: str) -> str:
    s = str(s)
    s = s.replace("“", '"').replace("”", '"').replace("‘", "'").replace("’", "'")
    s = re.sub(r"\s+", " ", s).strip()
    s = re.sub(r"\s*([,;:?!\.])\s*", r" \1 ", s)
    return re.sub(r"\s+", " ", s)

def _strip_punct(s: str) -> str:
    # remove common punctuation; keep alphanumerics and spaces
    return re.sub(r"[^\w\s]", "", s)

def normalize_src_text(s: str) -> str:
    s = _normalize_common(s)
    if CFG.STRIP_PUNCT_SRC:
        s = _strip_punct(s)
    return s.lower()

def normalize_tgt_text(s: str) -> str:
    s = _normalize_common(s)
    if CFG.STRIP_PUNCT_TGT:
        s = _strip_punct(s)
    return s


def simple_token_count(text: str) -> int:
    return len(str(text).split())


def passes_filters(row) -> bool:
    src = normalize_src_text(row[CFG.SRC_COL])
    tgt = normalize_tgt_text(row[CFG.TGT_COL])
    if src == "" or tgt == "":
        return False
    # length tokens
    s_len = simple_token_count(src)
    t_len = simple_token_count(tgt)
    if s_len > CFG.MAX_TOKENS or t_len > CFG.MAX_TOKENS:
        return False
    # ratio
    if s_len > 0 and t_len > 0:
        if s_len / t_len > CFG.LEN_RATIO or t_len / s_len > CFG.LEN_RATIO:
            return False
    # optional English LID on target
    if CFG.APPLY_ENGLISH_LID and langid is not None:
        lid, _ = langid.classify(tgt)
        if lid != "en":
            return False
    return True

# Apply normalization + filters before split
df_filtered = df[[CFG.SRC_COL, CFG.TGT_COL]].dropna().copy()
df_filtered[CFG.SRC_COL] = df_filtered[CFG.SRC_COL].apply(normalize_src_text)
df_filtered[CFG.TGT_COL] = df_filtered[CFG.TGT_COL].apply(normalize_tgt_text)
df_filtered = df_filtered.drop_duplicates(subset=[CFG.SRC_COL, CFG.TGT_COL])
df_filtered = df_filtered[df_filtered.apply(passes_filters, axis=1)]

# Split: K-Fold when enabled, else single random split
if CFG.USE_CV:
    assert CFG.K_FOLDS >= 2, "K_FOLDS must be >= 2 for cross-validation"
    kf = KFold(n_splits=CFG.K_FOLDS, shuffle=True, random_state=CFG.SEED)
    folds = list(kf.split(df_filtered))
    fold_idx = CFG.CV_FOLD % CFG.K_FOLDS
    train_index, val_index = folds[fold_idx]
    train_df = df_filtered.iloc[train_index].reset_index(drop=True)
    val_df = df_filtered.iloc[val_index].reset_index(drop=True)
    print(f"KFold split: fold {fold_idx+1}/{CFG.K_FOLDS} -> train {len(train_df)} val {len(val_df)}")
else:
    train_df, val_df = train_test_split(df_filtered, test_size=CFG.VAL_SIZE, random_state=CFG.SEED)
    print("After filters:", len(train_df), len(val_df))



After filters: 20424 2270


### Step 6 — Back-translation plan (design, no-op)

Goal: augment Kriol→English training with synthetic pairs generated by an English→Kriol reverse model.

Plan:
- Train/load reverse MarianMT (en→kriol) with same normalization rules.
- Generate synthetic Kriol for English monolingual (start with our training English).
- Build synthetic pairs (kriol_syn, english_orig), dedup + filter, and merge with real pairs using sample weighting.
- Re-train forward model and evaluate COMET.

Artifacts:
- Reverse model at `model/en2kriol/`
- Synthetic CSV at `data/synthetic/en_to_kriol_v1.csv`

Note: This cell does not execute generation; code scaffold follows below and is disabled by default.


In [6]:
# Back-translation scaffold (enable with CFG.ENABLE_BT)

def maybe_load_en2kriol():
    tok = AutoTokenizer.from_pretrained(CFG.EN2KR_MODEL)
    mdl = AutoModelForSeq2SeqLM.from_pretrained(CFG.EN2KR_MODEL).to(device)
    return tok, mdl

@torch.no_grad()
def en_to_kriol_generate(tok, mdl, texts: List[str]) -> List[str]:
    mdl.eval()
    inputs = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=CFG.MAX_LEN)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    out = mdl.generate(
        **inputs,
        max_new_tokens=CFG.GEN_MAX_NEW_TOKENS,
        num_beams=CFG.BT_BEAM_SIZE,
        length_penalty=CFG.BT_LENGTH_PENALTY,
        early_stopping=CFG.BT_EARLY_STOPPING,
    )
    return tok.batch_decode(out, skip_special_tokens=True)

# Generate and save synthetic data when enabled
if CFG.ENABLE_BT:
    from math import ceil
    try:
        from tqdm import tqdm
    except Exception:
        tqdm = None

    tok_r, mdl_r = maybe_load_en2kriol()
    eng_all = train_df[CFG.TGT_COL].tolist()
    BATCH_BT = CFG.BT_BATCH
    kriol_syn_all: List[str] = []

    total = len(eng_all)
    it = range(0, total, BATCH_BT)
    if tqdm is not None:
        it = tqdm(it, total=ceil(total / BATCH_BT), desc="Back-translation")

    for i in it:
        batch = eng_all[i:i+BATCH_BT]
        kriol_syn_all.extend(en_to_kriol_generate(tok_r, mdl_r, batch))

    os.makedirs(os.path.dirname(CFG.SYNTH_CSV), exist_ok=True)
    synth = pd.DataFrame({CFG.SRC_COL: kriol_syn_all, CFG.TGT_COL: eng_all})
    synth.to_csv(CFG.SYNTH_CSV, index=False)

    # Save a small preview sample for manual QA (e.g., 100 rows)
    try:
        n_preview = min(100, len(synth))
        synth.head(n_preview).to_csv(CFG.SYNTH_CSV_SAMPLE, index=False)
    except Exception as _e:
        print("Could not save synthetic preview:", _e)

    print(f"Saved synthetic pairs: {CFG.SYNTH_CSV} rows: {len(synth)} (preview: {CFG.SYNTH_CSV_SAMPLE})")



⚠️ DO NOT RUN NOW (design only)

### Step 7 — Tokenizer training plan (no-op)

- Objective: Prepare a shared SentencePiece tokenizer plan for Kriol↔English without changing current training.
- Corpus: All training sentences from both Kriol and English sides combined.
- Preprocessing: lowercase, normalize whitespace/punctuation, keep dialectal spellings, drop empties, deduplicate.
- Hyperparameters: vocab_size=8000 (up to 12000 if OOV>1.5%), character_coverage=0.9995, model_type=unigram; special tokens: <pad>, <s>, </s>, <unk>.
- Training flags (indicative):
  - --model_type=unigram --vocab_size=8000 --character_coverage=0.9995
  - --shuffle_input_sentence=true --max_sentence_length=2048 --num_threads=[CPU cores]
- Artifacts: spm_kriol_en_v1.model, spm_kriol_en_v1.vocab + a JSON with training config.
- Adoption criteria: Switch only if corpus grows >30% or OOV >1.5% and A/B shows COMET improvement.

Evaluation protocol (A/B): Train/score with Marian default vs SentencePiece on same split; report COMET delta and per-segment examples.



In [7]:
# Tokenizer scaffolding (no-op; does not alter training unless enabled)
USE_SPM = CFG.USE_SPM
SPM_DIR = CFG.SPM_DIR  # where spm.model/spm.vocab would live


def prepare_tokenizer_corpus(df, src_col: str, tgt_col: str):
    """Return a list of cleaned lines for SPM training (lowercase, strip, dedup)."""
    src = df[src_col].astype(str).str.lower().str.strip()
    tgt = df[tgt_col].astype(str).str.lower().str.strip()
    lines = pd.concat([src, tgt], ignore_index=True)
    lines = lines[lines != ""].drop_duplicates()
    return lines.tolist()


def train_sentencepiece_corpus(lines, model_prefix: str, vocab_size: int = 8000):
    """Sketch only: real training will be added later. No side effects here."""
    # import sentencepiece as spm
    # spm.SentencePieceTrainer.Train(
    #     input=data_path,
    #     model_prefix=model_prefix,
    #     vocab_size=vocab_size,
    #     character_coverage=0.9995,
    #     model_type="unigram",
    # )
    pass


def load_tokenizer(marian_model_name: str, use_spm: bool = USE_SPM):
    """Return tokenizer, preferring SPM dir if enabled, else Marian default."""
    if use_spm and os.path.isdir(SPM_DIR):
        return AutoTokenizer.from_pretrained(SPM_DIR)
    return AutoTokenizer.from_pretrained(marian_model_name)

# Note: current notebook flow continues to use Marian tokenizer by default.



### Step 8 — Generate synthetic data (optional)
Creates English→Kriol synthetic pairs using a reverse model when enabled in Step 2.


### Step 9 — Integrate synthetic data (optional)
Applies the same normalization & filtering rules as real data, dedups, and preserves real-only validation (train-only merge).


### Step 10 — Custom tokenizer training (not finished)
Trains a shared SentencePiece tokenizer on combined Kriol+English corpus when enabled in Step 2. This is different from the default Marian tokenizer. It’s only adopted if metrics improve.


### Step 10.1 — Custom tokenizer (placeholder)

(Empty for now — we will implement a custom tokenizer later.)


In [8]:

SYNTH_CSV_PATH = CFG.SYNTH_CSV

try:
    if CFG.INTEGRATE_SYNTH and os.path.exists(SYNTH_CSV_PATH):
        syn_raw = pd.read_csv(SYNTH_CSV_PATH)
        if not {CFG.SRC_COL, CFG.TGT_COL}.issubset(set(syn_raw.columns)):
            raise ValueError(f"Synthetic CSV missing required columns: {CFG.SRC_COL}, {CFG.TGT_COL}")

        # Normalize + filter synthetic with same functions
        syn_df = syn_raw[[CFG.SRC_COL, CFG.TGT_COL]].dropna().copy()
        syn_df[CFG.SRC_COL] = syn_df[CFG.SRC_COL].apply(normalize_src_text)
        syn_df[CFG.TGT_COL] = syn_df[CFG.TGT_COL].apply(normalize_tgt_text)
        syn_df = syn_df.drop_duplicates(subset=[CFG.SRC_COL, CFG.TGT_COL])
        syn_df = syn_df[syn_df.apply(passes_filters, axis=1)]

        # Merge synthetic into training only; keep validation purely real
        before_train = len(train_df)
        max_synth = int(CFG.SYNTH_MAX_RATIO * before_train)
        if len(syn_df) > max_synth:
            syn_df = syn_df.sample(n=max_synth, random_state=CFG.SEED)

        train_df = pd.concat([train_df, syn_df], ignore_index=True)
        train_df = train_df.drop_duplicates(subset=[CFG.SRC_COL, CFG.TGT_COL])
        after_train = len(train_df)

        print(
            f"Synthetic integration complete: +{after_train - before_train} pairs (train: {after_train}, val: {len(val_df)})"
        )
    else:
        print(f"Synthetic integration OFF or CSV not found at {SYNTH_CSV_PATH}. Skipping integration.")
except Exception as e:
    print("Synthetic integration error:", e)



Synthetic integration OFF or CSV not found at ../data/synthetic/en_to_kriol_v1.csv. Skipping integration.


### Step 11 — Tokenizer & Model

In [9]:

# NLLB requires language codes for tokenizer; set with safe fallback
src_lang = CFG.SRC_LANG if CFG.SRC_LANG else "eng_Latn"
tgt_lang = CFG.TGT_LANG if CFG.TGT_LANG else "eng_Latn"
if CFG.SRC_LANG is None:
    print("[warn] CFG.SRC_LANG is unset. Using fallback src_lang=eng_Latn. Set CFG.SRC_LANG explicitly when decided.")

# Load tokenizer/model (8-bit quantization removed)
tokenizer = AutoTokenizer.from_pretrained(CFG.MODEL_NAME, src_lang=src_lang, tgt_lang=tgt_lang)
model = AutoModelForSeq2SeqLM.from_pretrained(CFG.MODEL_NAME)
# Disable cache during training to avoid decoder arg conflicts
if hasattr(model, "config"):
    model.config.use_cache = False

# Ensure decoder start and BOS/EOS tokens are set for NLLB/M2M100
if hasattr(model, "config") and hasattr(tokenizer, "lang_code_to_id"):
    forced_bos = tokenizer.convert_tokens_to_ids(tgt_lang)
    if forced_bos is not None and forced_bos != tokenizer.unk_token_id:
        model.config.forced_bos_token_id = forced_bos
    if getattr(model.config, "decoder_start_token_id", None) is None:
        model.config.decoder_start_token_id = model.config.forced_bos_token_id

# Gradient checkpointing for VRAM (disable to avoid rare HF decoder arg issues)
if False and CFG.GRADIENT_CHECKPOINTING and hasattr(model, "gradient_checkpointing_enable"):
    model.gradient_checkpointing_enable()

assert torch.cuda.is_available(), "CUDA is not available. Please check your GPU drivers and PyTorch install."
device = torch.device("cuda")
model.to(device)
print(device, torch.cuda.get_device_name(0))


[warn] CFG.SRC_LANG is unset. Using fallback src_lang=eng_Latn. Set CFG.SRC_LANG explicitly when decided.
cuda NVIDIA GeForce RTX 5060 Laptop GPU


### Step 12 — Dataset

In [10]:

class PairedTextDataset(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer: AutoTokenizer, max_len: int):
        self.src = df[CFG.SRC_COL].tolist()
        self.tgt = df[CFG.TGT_COL].tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx: int):
        src_text = str(self.src[idx])
        tgt_text = str(self.tgt[idx])
        model_inputs = self.tokenizer(
            src_text,
            max_length=self.max_len,
            truncation=True,
            padding=False,
            return_tensors="pt",
        )
        labels = self.tokenizer(
            text_target=tgt_text,
            max_length=self.max_len,
            truncation=True,
            padding=False,
            return_tensors="pt",
        )
        item = {k: v.squeeze(0) for k, v in model_inputs.items()}
        item["labels"] = labels["input_ids"].squeeze(0)
        return item

train_ds = PairedTextDataset(train_df, tokenizer, CFG.MAX_LEN)
val_ds = PairedTextDataset(val_df, tokenizer, CFG.MAX_LEN)
len(train_ds), len(val_ds)


(20424, 2270)

### Step 13 — Trainer setup (DDP-ready)

In [11]:
# Utility: Trainer that drops unintended *_embeds keys to avoid HF arg conflicts
from transformers import Seq2SeqTrainer

class CleanSeq2SeqTrainer(Seq2SeqTrainer):
    def _prepare_inputs(self, inputs):
        # Sanitize at input-prep stage too
        inputs.pop("decoder_inputs_embeds", None)
        inputs.pop("inputs_embeds", None)
        inputs.pop("decoder_input_ids", None)
        allowed = {"input_ids", "attention_mask", "labels", "decoder_attention_mask"}
        filtered = {k: v for k, v in inputs.items() if k in allowed}
        return super()._prepare_inputs(filtered)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # Drop any embed keys and strictly whitelist safe args for seq2seq training
        inputs.pop("decoder_inputs_embeds", None)
        inputs.pop("inputs_embeds", None)
        inputs.pop("decoder_input_ids", None)
        allowed = {"input_ids", "attention_mask", "labels", "decoder_attention_mask"}
        filtered = {k: v for k, v in inputs.items() if k in allowed}
        # Call model explicitly with safe kwargs to avoid decoder ids/embeds conflicts
        outputs = model(
            decoder_input_ids=None,
            decoder_inputs_embeds=None,
            use_cache=False,
            **filtered,
        )
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs.loss
        return (loss, outputs) if return_outputs else loss



In [12]:
label_pad_token_id = -100
collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=label_pad_token_id, padding=True)


args = TrainingArguments(
    output_dir=CFG.OUTPUT_DIR,
    num_train_epochs=CFG.NUM_EPOCHS,
    per_device_train_batch_size=CFG.BATCH_SIZE,
    per_device_eval_batch_size=CFG.BATCH_SIZE,
    learning_rate=CFG.LR,
    warmup_steps=CFG.WARMUP_STEPS,
    gradient_accumulation_steps=CFG.GRAD_ACCUM_STEPS,
    label_smoothing_factor=CFG.LABEL_SMOOTHING,
    optim="adamw_torch",
    logging_steps=CFG.LOGGING_STEPS,
    save_steps=CFG.SAVE_STEPS,
    save_total_limit=CFG.SAVE_TOTAL_LIMIT,
    fp16=CFG.FP16,
    report_to=CFG.REPORT_TO,
    eval_accumulation_steps=1,
    remove_unused_columns=False,
    label_names=["labels"],
)

trainer = CleanSeq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    data_collator=collator,
)



  trainer = CleanSeq2SeqTrainer(


### Step 14 — Train

In [13]:

trainer.train()

# Evaluate once to log metrics
metrics = trainer.evaluate()
print("eval_loss:", metrics.get("eval_loss"))

# Launch TensorBoard from notebook
%load_ext tensorboard
%tensorboard --logdir "../model"


  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


Step,Training Loss


KeyboardInterrupt: 

### Save HF artifacts and a .pth checkpoint

In [None]:

final_dir = os.path.join(CFG.OUTPUT_DIR, "final")
os.makedirs(final_dir, exist_ok=True)
trainer.save_model(final_dir)
model_path = os.path.join(final_dir, "model_state.pth")
torch.save(model.state_dict(), model_path)
print(f"Saved .pth to: {model_path}")


### Step 15 — Inference helper (final only)

In [None]:
final_dir = os.path.join(CFG.OUTPUT_DIR, "final")

final_tok = AutoTokenizer.from_pretrained(final_dir, src_lang=(CFG.SRC_LANG or "eng_Latn"), tgt_lang=(CFG.TGT_LANG or "eng_Latn"))
final_model = AutoModelForSeq2SeqLM.from_pretrained(final_dir).to(device)

@torch.no_grad()
def generate_with(model, tok, texts):
    inputs = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=CFG.MAX_LEN)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    forced_bos = tok.convert_tokens_to_ids(CFG.TGT_LANG or "eng_Latn")
    out = model.generate(
        **inputs,
        max_new_tokens=CFG.GEN_MAX_NEW_TOKENS,
        num_beams=CFG.BEAM_SIZE,
        length_penalty=CFG.LENGTH_PENALTY,
        early_stopping=CFG.EARLY_STOPPING,
        forced_bos_token_id=forced_bos,
    )
    return tok.batch_decode(out, skip_special_tokens=True)

# Pick 3 random samples from training data and show Kriol / predicted / original
rows = val_df.sample(n=3)
kriols = rows[CFG.SRC_COL].astype(str).tolist()
eng_refs = rows[CFG.TGT_COL].astype(str).tolist()
eng_preds = generate_with(final_model, final_tok, kriols)

for i, (kriol, pred, ref) in enumerate(zip(kriols, eng_preds, eng_refs), start=1):
    print(f"Sample {i}")
    print("Kriol:", kriol)
    print("English predicted:", pred)
    print("English original:", ref)
    print("-")



### Step 16 — Retrain with synthetic data (optional)
If Step 8 integrated synthetic pairs and CFG.RETRAIN_WITH_SYNTH=True, reinitialize datasets and run a second training to produce `final_aug/`.


In [None]:
# Optional retraining on merged dataset
if CFG.RETRAIN_WITH_SYNTH and CFG.INTEGRATE_SYNTH:
    # Rebuild datasets from potentially expanded train_df
    train_ds_aug = PairedTextDataset(train_df, tokenizer, CFG.MAX_LEN)
    collator_aug = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=-100, padding=True)

    args_aug = TrainingArguments(
        output_dir=CFG.AUG_OUTPUT_DIR,
        num_train_epochs=CFG.NUM_EPOCHS,
        per_device_train_batch_size=CFG.BATCH_SIZE,
        per_device_eval_batch_size=CFG.BATCH_SIZE,
        learning_rate=CFG.LR,
        warmup_steps=CFG.WARMUP_STEPS,
        gradient_accumulation_steps=CFG.GRAD_ACCUM_STEPS,
        label_smoothing_factor=CFG.LABEL_SMOOTHING,
        optim="adamw_torch",
        logging_steps=CFG.LOGGING_STEPS,
        save_steps=CFG.SAVE_STEPS,
        save_total_limit=CFG.SAVE_TOTAL_LIMIT,
        fp16=CFG.FP16,
        report_to=CFG.REPORT_TO,
        remove_unused_columns=False,
        label_names=["labels"],
    )

    trainer_aug = Trainer(
        model=model,
        args=args_aug,
        train_dataset=train_ds_aug,
        eval_dataset=val_df,  # keep same real-only val
        tokenizer=tokenizer,
        data_collator=collator_aug,
    )

    trainer_aug.train()

    # Save augmented final
    os.makedirs(CFG.AUG_OUTPUT_DIR, exist_ok=True)
    trainer_aug.save_model(CFG.AUG_OUTPUT_DIR)
    torch.save(model.state_dict(), os.path.join(CFG.AUG_OUTPUT_DIR, "model_state.pth"))



### Step 17 — COMET evaluation (final only)
Scores validation translations with Unbabel COMET if available; otherwise prints a note (Python 3.13 may lack wheels).


In [None]:
# COMET: evaluate final only
try:
    from comet import download_model, load_from_checkpoint

    BATCH = CFG.COMET_BATCH
    refs = val_df[CFG.TGT_COL].tolist()
    srcs = val_df[CFG.SRC_COL].tolist()

    def batched_hyps_final():
        hyps = []
        for i in range(0, len(val_df), BATCH):
            hyps.extend(generate_with(final_model, final_tok, srcs[i:i+BATCH]))
        return hyps

    hyps_final = batched_hyps_final()
    data_final = [{"src": s, "mt": h, "ref": r} for s, h, r in zip(srcs, hyps_final, refs)]

    model_path = download_model(CFG.COMET_MODEL)
    comet_model = load_from_checkpoint(model_path)

    def get_score(output):
        if isinstance(output, dict):
            return output.get("system_score") or output.get("score") or output.get("mean_score")
        try:
            _, s = output
            return s
        except Exception:
            return output

    def get_segments(output):
        if isinstance(output, dict):
            segs = output.get("segments") or output.get("scores") or output.get("segment_scores")
            if isinstance(segs, list):
                return segs
        return None

    out_final = comet_model.predict(data_final, batch_size=BATCH, gpus=1 if torch.cuda.is_available() else 0)

    sf = get_score(out_final)
    try:
        print("COMET (final):", f"{float(sf):.4f}")
    except Exception:
        print("COMET (final, raw):", sf)

    # Save COMET outputs to disk
    final_dir = os.path.join(CFG.OUTPUT_DIR, "final")
    os.makedirs(final_dir, exist_ok=True)

    def safe_float(x):
        try:
            return float(x)
        except Exception:
            return None

    # Write system score
    sf_f = safe_float(sf)
    try:
        with open(os.path.join(final_dir, "system_score.txt"), "w", encoding="utf-8") as f:
            f.write(f"{sf_f if sf_f is not None else sf}\n")
    except Exception as _e:
        print("Could not save final system score:", _e)

    # Write per-segment CSV (src, mt, ref, score)
    seg_final = get_segments(out_final)

    try:
        if isinstance(seg_final, list) and len(seg_final) == len(hyps_final):
            df_final = pd.DataFrame({
                "src": srcs,
                "mt": hyps_final,
                "ref": refs,
                "comet_score": seg_final,
            })
            df_final.to_csv(os.path.join(final_dir, "comet_segments.csv"), index=False)
    except Exception as _e:
        print("Could not save final segments CSV:", _e)

except Exception as e:
    print("COMET evaluation unavailable:", e)

