In [None]:
!nvidia-smi

In [None]:
!pip install -q evaluate sacrebleu

# B1. Imports & Configuration

In [None]:
import os
import re
import gc
import pandas as pd
import torch
import evaluate
from datasets import Dataset
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM, 
    DataCollatorForSeq2Seq, 
    Seq2SeqTrainer, 
    Seq2SeqTrainingArguments,
    set_seed,
 )

# --- Configuration ---
MODEL_PATH = "/kaggle/input/models-for-dpc/pretrained_models/t5-base"
DATA_DIR = "/kaggle/input/deep-past-initiative-machine-translation"
OUTPUT_DIR = "/kaggle/working/t5-base-fine-tuned"

MAX_LENGTH = 266 
PREFIX = "translate Akkadian to English: "

set_seed(42)

# B2. Data Loading & Alignment

In [None]:
SUBSCRIPT_TRANS = str.maketrans({"₀": "0", "₁": "1", "₂": "2", "₃": "3", "₄": "4", "₅": "5", "₆": "6", "₇": "7", "₈": "8", "₉": "9", "ₓ": "x"})

def normalize_subscripts(text: str) -> str:
    return text.translate(SUBSCRIPT_TRANS)

def clean_translit(text):
    """Normalize transliteration by stripping scribal marks and gaps."""
    if not isinstance(text, str):
        return ""
    text = normalize_subscripts(text)
    text = text.replace("…", " <big_gap> ")
    text = re.sub(r"\\.\\.\\.+", " <big_gap> ", text)
    text = re.sub(r"\[[^\]]*\]", " ", text)
    text = re.sub(r"<<[^>]*>>", " ", text)
    text = re.sub(r"[˹˺]", " ", text)
    text = re.sub(r"\([^)]*\)", " ", text)
    text = re.sub(r"\{([^}]*)\}", r"\1", text)
    text = re.sub(r"<([^>]*)>", r"\1", text)
    text = re.sub(r"[!?/:·]", " ", text)
    text = re.sub(r"\s+", " ", text)
    return text.strip()

def clean_translation(text):
    if not isinstance(text, str):
        return ""
    text = text.replace("…", " ")
    text = re.sub(r"\s+", " ", text)
    return text.strip()

def filter_quality(df):
    df["src_len"] = df["src"].str.split().str.len()
    df["tgt_len"] = df["tgt"].str.split().str.len()
    df = df[(df["src_len"] >= 3) & (df["tgt_len"] >= 3)]
    ratio = (df["src_len"] / df["tgt_len"]).clip(upper=6)
    df = df[(ratio >= 0.2) & (ratio <= 5)]
    df = df.drop_duplicates(subset=["src", "tgt"])
    return df.drop(columns=["src_len", "tgt_len"])

def load_and_align_data(filepath):
    """
    Aligns Akkadian transliterations to English translations.
    Tries to split by line/sentence; falls back to document-level if counts mismatch.
    Filters low-quality pairs.
    """
    df = pd.read_csv(filepath)
    aligned_rows = []

    print(f"Raw documents: {len(df)}")

    for _, row in df.iterrows():
        src = clean_translit(row.get("transliteration", ""))
        tgt = clean_translation(row.get("translation", ""))

        src_lines = [s.strip() for s in src.split("\n") if len(s.strip()) > 1]
        tgt_sents = [t.strip() for t in re.split(r'(?<=[.!?])\s+', tgt) if len(t.strip()) > 1]

        if len(src_lines) == len(tgt_sents) and len(src_lines) > 1:
            for s, t in zip(src_lines, tgt_sents):
                aligned_rows.append({"src": s, "tgt": t})
        else:
            merged_src = src.replace("\n", " ")
            if len(merged_src) > 3 and len(tgt) > 3:
                aligned_rows.append({"src": merged_src, "tgt": tgt})

    print(f"Aligned training examples (pre-filter): {len(aligned_rows)}")
    out_df = filter_quality(pd.DataFrame(aligned_rows))
    print(f"Aligned training examples (post-filter): {len(out_df)}")
    return out_df

df = load_and_align_data(f"{DATA_DIR}/train.csv")
dataset = Dataset.from_pandas(df)
dataset = dataset.train_test_split(test_size=0.05, seed=42)
print("Data loaded successfully.")

# B3. Tokenization

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

def preprocess_function(examples):
    # Add prefix to inputs
    inputs = [PREFIX + doc for doc in examples["src"]]
    targets = examples["tgt"]

    model_inputs = tokenizer(
        inputs, 
        max_length=MAX_LENGTH, 
        truncation=True, 
        padding="max_length"
    )

    labels = tokenizer(
        targets, 
        max_length=MAX_LENGTH, 
        truncation=True, 
        padding="max_length"
    )

    # Replace padding token id with -100 to ignore in loss calculation
    labels["input_ids"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label]
        for label in labels["input_ids"]
    ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Apply processing
tokenized_train = dataset["train"].map(preprocess_function, batched=True)
tokenized_val = dataset["test"].map(preprocess_function, batched=True)

# B4. Model Setup

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer, 
    model=model,
    label_pad_token_id=-100
)

# B5 . Training Configuration

In [None]:
# --- CORRECTED B5. Training Configuration (Disk Space Safe) ---
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    
    # --- DISK SPACE FIXES ---
    save_strategy="no",           # Do NOT save checkpoints during training
    eval_strategy="epoch",        # Evaluate every epoch
    load_best_model_at_end=False, # Must be False if we aren't saving checkpoints
    # ------------------------
    
    learning_rate=2e-4, 
    
    per_device_train_batch_size=4, 
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    
    num_train_epochs=7,
    weight_decay=0.01,
    predict_with_generate=True,
    
    fp16=True, # T5 is safe with fp16
    report_to="none",
    
    # Accuracy-focused tweaks
    label_smoothing_factor=0.1,
    lr_scheduler_type="cosine",
    warmup_ratio=0.04,
    generation_max_length=280,
    generation_num_beams=5
)

# B6. Execution

In [None]:
# Clear cache before training
torch.cuda.empty_cache()
gc.collect()

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

print("Starting T5-Base Training...")
trainer.train()

In [None]:
# Quick validation BLEU/chrF on held-out split
metric_bleu = evaluate.load("sacrebleu")
metric_chrf = evaluate.load("chrf")

def dedup_repeats(text: str) -> str:
    toks = text.split()
    out = []
    for t in toks:
        if len(out) >= 2 and t == out[-1] == out[-2]:
            continue
        out.append(t)
    return " ".join(out)

def postprocess_text(preds):
    out = []
    for p in preds:
        p = p.strip()
        p = re.sub(r"\s+([.,!?;:])", r"\1", p)
        p = re.sub(r"([.,!?;:])([A-Za-z])", r"\1 \2", p)
        p = dedup_repeats(p)
        if p and p[0].islower():
            p = p[0].upper() + p[1:]
        if p and p[-1] not in ".!?":
            p += "."
        p = re.sub(r"([.!?]){2,}", ".", p)
        out.append(p.strip())
    return out

val_texts = dataset["test"]["src"]
val_refs = [[t] for t in dataset["test"]["tgt"]]

def generate_batch(texts):
    batch_inputs = [PREFIX + doc for doc in texts]
    enc = tokenizer(batch_inputs, max_length=MAX_LENGTH, truncation=True, padding=True, return_tensors="pt").to(model.device)
    gen = model.generate(
        **enc,
        max_length=MAX_LENGTH,
        min_length=6,
        num_beams=6,
        no_repeat_ngram_size=3,
        length_penalty=1.05,
        early_stopping=True,
    )
    return tokenizer.batch_decode(gen, skip_special_tokens=True)

preds = []
for i in range(0, len(val_texts), 16):
    preds.extend(generate_batch(val_texts[i:i+16]))

preds = postprocess_text(preds)
bleu = metric_bleu.compute(predictions=preds, references=val_refs)
chrf = metric_chrf.compute(predictions=preds, references=val_refs)
print({"bleu": bleu["score"], "chrf": chrf["score"]})

# B7. Save Model

In [None]:
print(f"Saving model to {OUTPUT_DIR}...")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("Notebook B (T5) Complete.")