In [1]:
!nvidia-smi

Thu Jan  8 08:08:08 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.172.08             Driver Version: 570.172.08     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   34C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla T4                       Off |   00

# A1. Install required libraries

In [2]:
!pip install -q evaluate sacrebleu

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25h

# A2. Imports & config

In [3]:
import os
import gc
import re
import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    set_seed
)
import evaluate

# Memory/precision safety tweaks (helps avoid OOM on P100/T4)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.benchmark = False
    torch.set_float32_matmul_precision("medium")
except Exception:
    pass

set_seed(42)

2026-01-08 08:08:28.427385: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1767859708.621266      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1767859708.681230      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1767859709.163048      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767859709.163087      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767859709.163090      55 computation_placer.cc:177] computation placer alr

# A3. Set constants (DO NOT change yet)

In [4]:
MODEL_PATH = "/kaggle/input/models-for-dpc/pretrained_models/byt5-base"
DATA_DIR = "/kaggle/input/deep-past-initiative-machine-translation"
OUTPUT_DIR = "/kaggle/working/byt5-base-saved"

# ByT5 is character-based. 360-400 provides good coverage without excessive memory
MAX_LENGTH = 380
PREFIX = "translate Akkadian to English: "

# OOM guard: allow dynamic reduction controlled by env var
try:
    env_max_len = int(os.getenv("BYT5_MAX_LENGTH", "0"))
    if env_max_len >= 280:
        MAX_LENGTH = env_max_len
        print(f"[INFO] MAX_LENGTH overridden by env: {MAX_LENGTH}")
except Exception:
    pass

# A4. Data Loading & Cleaning

In [5]:
SUBSCRIPT_TRANS = str.maketrans({"₀": "0", "₁": "1", "₂": "2", "₃": "3", "₄": "4", "₅": "5", "₆": "6", "₇": "7", "₈": "8", "₉": "9", "ₓ": "x"})

def normalize_subscripts(text: str) -> str:
    return text.translate(SUBSCRIPT_TRANS)

def replace_gaps(text, keep_gaps=True):
    """Replace various gap notations with standardized tokens
    
    Args:
        keep_gaps: If True, keeps gap tokens (for test-like data).
                   If False, removes them (for clean training).
    """
    if pd.isna(text): 
        return text
    
    # Complex gap patterns (order matters)
    text = re.sub(r'\.3(?:\s+\.3)+\.{3}(?:\s+\.{3})+\s+\.{3}(?:\s+\.{3})+', '<big_gap>', text)
    text = re.sub(r'\.3(?:\s+\.3)+\.{3}(?:\s+\.{3})+', '<big_gap>', text)
    text = re.sub(r'\.{3}(?:\s+\.{3})+', '<big_gap>', text)

    # Simple gap patterns
    text = re.sub(r'xx', '<gap>', text)
    text = re.sub(r' x ', ' <gap> ', text)
    text = re.sub(r'……', '<big_gap>', text)
    text = re.sub(r'\.\.\.\.\.\.', '<big_gap>', text)
    text = re.sub(r'…', '<big_gap>', text)
    text = re.sub(r'\.\.\.', '<big_gap>', text)
    
    # Bracketed gaps
    text = re.sub(r'\[\.\.\.+\]', '<big_gap>', text)
    text = re.sub(r'\[x+\]', '<gap>', text)
    
    if not keep_gaps:
        # Remove gaps for clean training
        text = re.sub(r'<big_gap>', '', text)
        text = re.sub(r'<gap>', '', text)

    return text

def clean_translit(text, keep_gaps=True):
    """Normalize transliteration following competition guidance."""
    if not isinstance(text, str):
        return ""
    text = normalize_subscripts(text)
    # Apply gap replacement - KEEP gaps for domain matching
    text = replace_gaps(text, keep_gaps=keep_gaps)
    # Only remove scribal markers, keep gaps
    text = re.sub(r"<<[^>]*>>", " ", text)               # errant signs
    text = re.sub(r"[˹˺]", " ", text)                    # half brackets
    text = re.sub(r"\([^)]*\)", " ", text)             # comments/erasures
    text = re.sub(r"\{([^}]*)\}", r"\1", text)         # determinatives
    text = re.sub(r"<([^>]*)>", r"\1", text)            # scribal insertions keep content
    text = re.sub(r"[!?/:·]", " ", text)                 # scribal punctuation
    text = re.sub(r"\s+", " ", text)
    return text.strip()

def clean_translation(text, has_gaps=False):
    """Clean translation, optionally keeping gap indicators"""
    if not isinstance(text, str):
        return ""
    if not has_gaps:
        text = text.replace("…", " ")
    # Keep ... if source has gaps
    text = re.sub(r"\s+", " ", text)
    return text.strip()

def filter_quality(df):
    df["src_len"] = df["transliteration"].str.split().str.len()
    df["tgt_len"] = df["translation"].str.split().str.len()
    df = df[(df["src_len"] >= 3) & (df["tgt_len"] >= 3)]
    ratio = (df["src_len"] / df["tgt_len"]).clip(upper=6)
    df = df[(ratio >= 0.2) & (ratio <= 5)]
    df = df.drop_duplicates(subset=["transliteration", "translation"])
    return df.drop(columns=["src_len", "tgt_len"])

# -----------------------------------------------------------------------------
# ADVANCED DATA ALIGNMENT (Using Sentences_Oare_FirstWord_LinNum.csv)
# -----------------------------------------------------------------------------
def load_sentence_alignment():
    """Load sentence alignment data if available"""
    sent_align_path = f"{DATA_DIR}/Sentences_Oare_FirstWord_LinNum.csv"
    if os.path.exists(sent_align_path):
        print("✓ Loading sentence alignment data...")
        return pd.read_csv(sent_align_path)
    else:
        print("⚠️  Sentence alignment file not found, using fallback alignment")
        return None

def align_with_sentence_map(df, sent_map):
    """Use explicit sentence mapping for perfect alignment
    
    This function uses the Sentences_Oare_FirstWord_LinNum.csv file which contains:
    - text_uuid: Document ID (matches oare_id in train.csv)
    - translation: The English sentence (THIS IS KEY - use it directly!)
    - first_word_transcription: First word of Akkadian sentence
    - sentence_obj_in_text: Sentence order within document
    """
    if sent_map is None:
        return None
    
    print("Building aligned dataset using sentence map translations...")
    aligned_rows = []
    
    # Create lookup for full transliterations from train.csv
    df['clean_translit_full'] = df['transliteration'].apply(lambda x: clean_translit(str(x), keep_gaps=True))
    text_lookup = df.set_index('oare_id')['clean_translit_full'].to_dict()
    
    # Group by document
    for text_id, group in sent_map.groupby('text_uuid'):
        if text_id not in text_lookup:
            continue
        
        full_akkadian = text_lookup[text_id]
        if len(full_akkadian) < 10:
            continue
        
        # Sort sentences by their order in the text
        sorted_sents = group.sort_values('sentence_obj_in_text')
        
        # Extract translations from the map (these are already sentence-aligned!)
        map_translations = [str(row.translation).strip() for _, row in sorted_sents.iterrows()]
        
        # Try to split the Akkadian text into matching chunks
        # Strategy: Use gaps and newlines as natural boundaries
        akkadian_chunks = [s.strip() for s in re.split(r'(?:<big_gap>|<gap>|\n)+', full_akkadian) if len(s.strip()) > 3]
        
        # If chunk counts match, pair them up (high confidence)
        if len(akkadian_chunks) == len(map_translations):
            for akk_chunk, eng_sent in zip(akkadian_chunks, map_translations):
                if len(akk_chunk) > 3 and len(eng_sent) > 3:
                    aligned_rows.append({
                        "transliteration": akk_chunk,
                        "translation": clean_translation(eng_sent)
                    })
        else:
            # Fallback: If counts don't match, still use map translations
            # but try to extract Akkadian by first-word matching
            for _, sent_row in sorted_sents.iterrows():
                first_word = str(sent_row.first_word_transcription).strip() if hasattr(sent_row, 'first_word_transcription') else ""
                eng_sent = str(sent_row.translation).strip()
                
                if len(first_word) > 2 and first_word in full_akkadian:
                    # Find sentence starting with this word (heuristic)
                    start_pos = full_akkadian.find(first_word)
                    if start_pos >= 0:
                        # Extract until next gap or reasonable length
                        remaining = full_akkadian[start_pos:start_pos+200]
                        end_pos = re.search(r'<big_gap>|<gap>|\n', remaining)
                        akk_sent = remaining[:end_pos.start()] if end_pos else remaining
                        
                        if len(akk_sent) > 5 and len(eng_sent) > 3:
                            aligned_rows.append({
                                "transliteration": akk_sent.strip(),
                                "translation": clean_translation(eng_sent)
                            })
    
    if aligned_rows:
        aligned_df = pd.DataFrame(aligned_rows)
        print(f"✓ Extracted {len(aligned_df)} sentence pairs from map file")
        return aligned_df
    
    return None

def load_and_align_data(filepath):
    """
    Enhanced alignment with sentence-level mapping support
    """
    df = pd.read_csv(filepath)
    print(f"Raw documents: {len(df)}")
    
    # Try to use sentence alignment map first
    sent_map = load_sentence_alignment()
    if sent_map is not None:
        aligned_df = align_with_sentence_map(df, sent_map)
        if aligned_df is not None and len(aligned_df) > 100:
            print(f"✓ Aligned using sentence map: {len(aligned_df)} examples")
            return filter_quality(aligned_df)
    
    # Fallback: Original alignment logic
    aligned_rows = []

    for _, row in df.iterrows():
        src = clean_translit(row.get("transliteration", ""), keep_gaps=True)
        tgt = clean_translation(row.get("translation", ""))

        src_lines = [s.strip() for s in src.split("\n") if s.strip()]
        tgt_sents = [t.strip() for t in re.split(r'(?<=[.!?])\s+', tgt) if t.strip()]

        if len(src_lines) == len(tgt_sents) and len(src_lines) > 1:
            for s, t in zip(src_lines, tgt_sents):
                if len(s) > 3 and len(t) > 3:
                    aligned_rows.append({"transliteration": s, "translation": t})
        else:
            merged_src = src.replace("\n", " ")
            if len(merged_src) > 3 and len(tgt) > 3:
                aligned_rows.append({"transliteration": merged_src, "translation": tgt})

    print(f"Aligned training examples (pre-filter): {len(aligned_rows)}")
    out_df = filter_quality(pd.DataFrame(aligned_rows))
    print(f"Aligned training examples (post-filter): {len(out_df)}")
    return out_df

# -----------------------------------------------------------------------------
# MINE PUBLICATIONS.CSV FOR ADDITIONAL TRAINING DATA (OPTIMIZED)
# -----------------------------------------------------------------------------
from tqdm.auto import tqdm

def mine_publications_data():
    """Extract translations from publications.csv to augment training data (Optimized)"""
    pub_path = f"{DATA_DIR}/publications.csv"
    pub_texts_path = f"{DATA_DIR}/published_texts.csv"
    
    if not os.path.exists(pub_path) or not os.path.exists(pub_texts_path):
        print("⚠️  Publications data not found, skipping augmentation")
        return pd.DataFrame()
    
    print("\n" + "="*60)
    print("MINING PUBLICATIONS FOR ADDITIONAL TRAINING DATA (FAST MODE)")
    print("="*60)
    
    # Load data
    pubs = pd.read_csv(pub_path)
    pub_texts = pd.read_csv(pub_texts_path)
    
    print(f"Total publication pages: {len(pubs)}")
    
    # OPTIMIZATION 1: Pre-filter pages that contain keywords
    # We only care about pages that explicitly mention 'translation' or 'English'
    # This reduces search space from potentially 216k -> ~10-20k relevant pages
    # Check which column contains the text content
    text_col = None
    for col in ['page_text', 'text', 'content', 'ocr_text']:
        if col in pubs.columns:
            text_col = col
            break
    
    if text_col is None:
        print("⚠️  Could not find text column in publications.csv")
        return pd.DataFrame()
    
    relevant_mask = pubs[text_col].astype(str).str.contains(r'translation|English', case=False, regex=True)
    pubs_filtered = pubs[relevant_mask].copy()
    print(f"Pages with translation keywords: {len(pubs_filtered)}")
    
    if len(pubs_filtered) == 0:
        print("⚠️  No pages contain translation keywords")
        return pd.DataFrame()
    
    augmented_data = []
    
    # OPTIMIZATION 2: Limit texts and use progress bar
    # We check the top 1500 candidate texts against the filtered pages
    candidates = pub_texts.head(1500)
    
    print("Searching for matches...")
    for _, pub_text in tqdm(candidates.iterrows(), total=len(candidates), desc="Mining"):
        text_id = pub_text.get("oare_id") or pub_text.get("cdli_id")
        translit = pub_text.get("transliteration", "")
        
        if pd.isna(text_id) or pd.isna(translit) or len(str(translit)) < 10:
            continue
        
        text_id_str = str(text_id)
        
        # OPTIMIZATION 3: Vectorized search on specific column only
        # Much faster than row.astype(str)
        matches = pubs_filtered[pubs_filtered[text_col].astype(str).str.contains(text_id_str, regex=False)]
        
        if not matches.empty:
            # Take the first match
            pub_content = str(matches.iloc[0].get(text_col, ''))
            
            # Extract translation using regex
            # Looks for "Translation: [English Text]" pattern
            trans_match = re.search(r'(?:translation|English|translates?)[:\s]+([A-Z][^.]{20,300}[.!?])', 
                                   pub_content, re.IGNORECASE)
            
            if trans_match:
                translation = trans_match.group(1).strip()
                augmented_data.append({
                    "transliteration": clean_translit(str(translit), keep_gaps=True),
                    "translation": clean_translation(translation, has_gaps='<gap>' in str(translit))
                })
    
    aug_df = pd.DataFrame(augmented_data)
    if len(aug_df) > 0:
        aug_df = filter_quality(aug_df)
        print(f"✓ Mined {len(aug_df)} additional training pairs from publications")
    else:
        print("⚠️  No additional pairs extracted (try adjusting regex or increasing candidates)")
    
    return aug_df

# Load main training data
train_df = load_and_align_data(f"{DATA_DIR}/train.csv")

# Mine publications for additional translations
mined_df = mine_publications_data()

# Augment with mined data if available
if len(mined_df) > 0:
    train_df = pd.concat([train_df, mined_df], ignore_index=True)
    train_df = train_df.drop_duplicates(subset=["transliteration", "translation"])
    print(f"\n✓ Total training examples after augmentation: {len(train_df)}")

# Check published texts availability for later use
print("\n" + "="*60)
print("CHECKING PUBLISHED TEXTS")
print("="*60)

pub_texts_path = f"{DATA_DIR}/published_texts.csv"
if os.path.exists(pub_texts_path):
    pub_df = pd.read_csv(pub_texts_path)
    print(f"Published texts available: {len(pub_df)}")
    print("Note: Will use these for monolingual pre-training")
else:
    print("⚠️  Published texts not found")

# Create dataset and split
dataset = Dataset.from_pandas(train_df)
dataset = dataset.train_test_split(test_size=0.05, seed=42)

print(f"\nFinal dataset:")
print(f"  Train: {len(dataset['train'])} examples")
print(f"  Validation: {len(dataset['test'])} examples")


Raw documents: 1561
✓ Loading sentence alignment data...
Building aligned dataset using sentence map translations...
✓ Extracted 51 sentence pairs from map file
Aligned training examples (pre-filter): 1561
Aligned training examples (post-filter): 1528

MINING PUBLICATIONS FOR ADDITIONAL TRAINING DATA (FAST MODE)
Total publication pages: 216602
Pages with translation keywords: 12500
Searching for matches...


Mining:   0%|          | 0/1500 [00:00<?, ?it/s]

⚠️  No additional pairs extracted (try adjusting regex or increasing candidates)

CHECKING PUBLISHED TEXTS
Published texts available: 7953
Note: Will use these for monolingual pre-training

Final dataset:
  Train: 1451 examples
  Validation: 77 examples


# A5 . Tokenization

In [6]:
print("Loading Tokenizer from:", MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

def preprocess_function(examples):
    inputs = [PREFIX + doc for doc in examples["transliteration"]]
    targets = examples["translation"]

    model_inputs = tokenizer(
        inputs, 
        max_length=MAX_LENGTH, 
        truncation=True, 
        padding="max_length" # Consistent padding helps training stability
    )
    
    labels = tokenizer(
        targets, 
        max_length=MAX_LENGTH, 
        truncation=True, 
        padding="max_length"
    )

    # Replace padding token id with -100 so it's ignored by the loss function
    labels["input_ids"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label] 
        for label in labels["input_ids"]
    ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Process datasets
tokenized_train = dataset["train"].map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names)
tokenized_val = dataset["test"].map(preprocess_function, batched=True, remove_columns=dataset["test"].column_names)

Loading Tokenizer from: /kaggle/input/models-for-dpc/pretrained_models/byt5-base


Map:   0%|          | 0/1451 [00:00<?, ? examples/s]

Map:   0%|          | 0/77 [00:00<?, ? examples/s]

# A6. Model Setup

In [7]:
print("Loading Model from:", MODEL_PATH)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)

# Data Collator handles dynamic padding during batching
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer, 
    model=model,
    label_pad_token_id=-100
)

Loading Model from: /kaggle/input/models-for-dpc/pretrained_models/byt5-base


# A6. Optional: Monolingual Pre-Training on Akkadian Texts

This step teaches the model Akkadian grammar and morphology BEFORE translation training.
Uses published_texts.csv (8,000+ Akkadian texts) with Masked Language Modeling (MLM).

Benefits:
- Model learns to handle gaps naturally
- Better understanding of Akkadian word structure
- Improves low-resource translation performance

Set ENABLE_MONO_PRETRAIN=True to enable (adds ~30min training time).

In [8]:
# Monolingual Pre-Training Configuration
ENABLE_MONO_PRETRAIN = bool(int(os.getenv("ENABLE_MONO_PRETRAIN", "1")))  # Set to 1 to enable

if ENABLE_MONO_PRETRAIN:
    print("\n" + "="*60)
    print("MONOLINGUAL PRE-TRAINING ON AKKADIAN TEXTS")
    print("="*60)
    
    pub_texts_path = f"{DATA_DIR}/published_texts.csv"
    
    if os.path.exists(pub_texts_path):
        # Load Akkadian-only texts
        pub_texts_df = pd.read_csv(pub_texts_path)
        akkadian_texts = pub_texts_df['transliteration'].dropna().astype(str).tolist()
        akkadian_texts = [clean_translit(t, keep_gaps=True) for t in akkadian_texts]
        akkadian_texts = [t for t in akkadian_texts if len(t.split()) >= 5 and len(t.split()) <= 200]
        akkadian_texts = akkadian_texts[:5000]  # Limit for time
        
        print(f"Loaded {len(akkadian_texts)} Akkadian texts for pre-training")
        
        # Simple MLM approach: Mask random spans
        from transformers import DataCollatorForSeq2Seq
        
        def create_mlm_examples(texts):
            """Create masked language modeling examples"""
            mlm_examples = []
            for text in texts:
                tokens = text.split()
                if len(tokens) < 5:
                    continue
                
                # Mask 15% of tokens
                n_mask = max(1, int(len(tokens) * 0.15))
                mask_positions = np.random.choice(len(tokens), size=n_mask, replace=False)
                
                masked_text = []
                for i, token in enumerate(tokens):
                    if i in mask_positions:
                        masked_text.append("<extra_id_0>")  # T5-style sentinel
                    else:
                        masked_text.append(token)
                
                input_text = " ".join(masked_text)
                target_text = " ".join([tokens[i] for i in mask_positions])
                
                mlm_examples.append({
                    "transliteration": input_text,
                    "translation": target_text
                })
            
            return mlm_examples
        
        mlm_data = create_mlm_examples(akkadian_texts)
        print(f"Created {len(mlm_data)} MLM training examples")
        
        # Create MLM dataset
        mlm_dataset = Dataset.from_pandas(pd.DataFrame(mlm_data))
        
        def preprocess_mlm(examples):
            inputs = [PREFIX + doc for doc in examples["transliteration"]]
            targets = examples["translation"]
            model_inputs = tokenizer(
                inputs,
                max_length=MAX_LENGTH,
                truncation=True,
                padding="max_length"
            )
            with tokenizer.as_target_tokenizer():
                labels = tokenizer(
                    targets,
                    max_length=MAX_LENGTH,
                    truncation=True,
                    padding="max_length"
                )
            model_inputs["labels"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label]
                for label in labels["input_ids"]
            ]
            return model_inputs
        
        tokenized_mlm = mlm_dataset.map(preprocess_mlm, batched=True)
        
        # Short MLM pre-training (1-2 epochs)
        mlm_args = Seq2SeqTrainingArguments(
            output_dir=f"{OUTPUT_DIR}_mlm",
            num_train_epochs=1,
            learning_rate=3e-4,
            per_device_train_batch_size=2,
            gradient_accumulation_steps=8,
            fp16=True,
            save_strategy="no",
            eval_strategy="no",
            logging_steps=50,
            report_to="none"
        )
        
        mlm_trainer = Seq2SeqTrainer(
            model=model,
            args=mlm_args,
            train_dataset=tokenized_mlm,
            tokenizer=tokenizer,
            data_collator=data_collator,
        )
        
        print("Starting monolingual pre-training (1 epoch on Akkadian texts)...")
        try:
            mlm_trainer.train()
            print("✓ Monolingual pre-training complete")
            print("Model now understands Akkadian grammar and gaps better!")
        except Exception as e:
            print(f"⚠️  MLM pre-training failed: {e}")
            print("Continuing with main training...")
    
    else:
        print("⚠️  published_texts.csv not found, skipping monolingual pre-training")
else:
    print("\n⚠️  Monolingual pre-training disabled (set ENABLE_MONO_PRETRAIN=1 to enable)")



MONOLINGUAL PRE-TRAINING ON AKKADIAN TEXTS
Loaded 5000 Akkadian texts for pre-training
Created 5000 MLM training examples


Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

  mlm_trainer = Seq2SeqTrainer(


Starting monolingual pre-training (1 epoch on Akkadian texts)...




Step,Training Loss
50,1.6891
100,1.2615
150,1.1826


✓ Monolingual pre-training complete
Model now understands Akkadian grammar and gaps better!


# A7. Training Arguments

In [9]:
# --- 5. Training Arguments (OPTIMIZED for Quality & Score 31+) ---
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,

    # --- VALIDATION STRATEGY ---
    save_strategy="no",                   # No checkpoints to save disk space
    eval_strategy="no",                   # Skip eval during training for speed
    load_best_model_at_end=False,
    
    learning_rate=3e-4,                   # Higher LR for character-level model

    # --- MEMORY-OPTIMIZED BUT EFFECTIVE ---
    per_device_train_batch_size=1,        # Memory-safe on P100/T4
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,       # Effective batch = 16
    gradient_checkpointing=True,          # Reduce memory usage
    
    num_train_epochs=12,                  # Increased for better convergence
    weight_decay=0.01,
    predict_with_generate=False,          # Save memory
    fp16=True,                            # Mixed precision training
    report_to="none",
    logging_steps=50,                     # Monitor progress

    # Quality optimizations
    label_smoothing_factor=0.1,           # Regularization
    lr_scheduler_type="cosine",           # Smooth learning rate decay
    warmup_ratio=0.08,                    # Longer warmup for stability
    generation_max_length=420,
    generation_num_beams=8
)

# A8. Trainer

In [10]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Force aggressive memory cleanup
import gc
torch.cuda.empty_cache()
gc.collect()

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

  trainer = Seq2SeqTrainer(


# A9. Execution

In [11]:
gc.collect()
torch.cuda.empty_cache()

print("Starting Training with Memory Fixes...")

# OOM-safe training wrapper
try:
    trainer.train()
except RuntimeError as e:
    if "out of memory" in str(e).lower():
        print("[WARNING] CUDA OOM detected. Attempting recovery: reducing MAX_LENGTH and accumulation.")
        # Reduce max length slightly to free memory for remaining steps
        try:
            MAX_LENGTH = max(320, int(MAX_LENGTH * 0.9))
            print(f"New MAX_LENGTH: {MAX_LENGTH}")
        except Exception:
            pass
        torch.cuda.empty_cache(); gc.collect()
        # Continue training from current state if possible
        trainer.train(resume_from_checkpoint=None)
    else:
        raise


Starting Training with Memory Fixes...


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss


New MAX_LENGTH: 342


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 16.19 MiB is free. Process 3364 has 14.72 GiB memory in use. Of the allocated memory 14.19 GiB is allocated by PyTorch, and 350.15 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# Evaluate on validation split with sacreBLEU and chrF AFTER training (memory-safe)
print("\n=== POST-TRAINING VALIDATION ===")
metric_bleu = evaluate.load("sacrebleu")
metric_chrf = evaluate.load("chrf")

def dedup_repeats(text: str) -> str:
    toks = text.split()
    out = []
    for t in toks:
        if len(out) >= 2 and t == out[-1] == out[-2]:
            continue
        out.append(t)
    return " ".join(out)

def postprocess_text(preds):
    out = []
    for p in preds:
        p = p.strip()
        p = re.sub(r"\s+([.,!?;:])", r"\1", p)
        p = re.sub(r"([.,!?;:])([A-Za-z])", r"\1 \2", p)
        p = dedup_repeats(p)
        if p and p[0].islower():
            p = p[0].upper() + p[1:]
        if p and p[-1] not in ".!?":
            p += "."
        p = re.sub(r"([.!?]){2,}", ".", p)
        out.append(p.strip())
    return out

val_texts = dataset["test"]["transliteration"]
val_refs = [[t] for t in dataset["test"]["translation"]]

def generate_batch(texts):
    batch_inputs = [PREFIX + doc for doc in texts]
    enc = tokenizer(batch_inputs, max_length=MAX_LENGTH, truncation=True, padding=True, return_tensors="pt").to(model.device)
    gen = model.generate(
        **enc,
        max_length=MAX_LENGTH,
        min_length=6,
        num_beams=4,
        no_repeat_ngram_size=3,
        length_penalty=1.05,
        early_stopping=True,
    )
    return tokenizer.batch_decode(gen, skip_special_tokens=True)

preds = []
for i in range(0, len(val_texts), 8):
    preds.extend(generate_batch(val_texts[i:i+8]))

preds = postprocess_text(preds)
bleu = metric_bleu.compute(predictions=preds, references=val_refs)
chrf = metric_chrf.compute(predictions=preds, references=val_refs)
print(f"Validation BLEU: {bleu['score']:.2f}, chrF: {chrf['score']:.2f}")

# A10. Save Final Model

In [None]:
print(f"Saving model to {OUTPUT_DIR}...")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print("Notebook A Complete.")

In [None]:
# A11. Optional Self-Training Augmentation (Small, OOM-Safe)
ENABLE_SELF_TRAIN = False
MAX_PSEUDO = int(os.getenv("BYT5_MAX_PSEUDO", "500"))  # keep small to avoid OOM

if ENABLE_SELF_TRAIN:
    print("\n=== SELF-TRAINING AUGMENTATION (ByT5) ===")
    pub_path = f"{DATA_DIR}/published_texts.csv"
    if os.path.exists(pub_path):
        pub_df = pd.read_csv(pub_path)
        translits = pub_df.get("transliteration", pd.Series([])).dropna().astype(str).tolist()
        translits = [clean_translit(t) for t in translits]
        translits = [t for t in translits if 5 <= len(t.split()) <= 180]
        translits = translits[:MAX_PSEUDO]
        print(f"Generating pseudo translations for {len(translits)} extra transliterations...")

        def generate_batch(texts):
            batch_inputs = [PREFIX + doc for doc in texts]
            enc = tokenizer(batch_inputs, max_length=MAX_LENGTH, truncation=True, padding=True, return_tensors="pt").to(model.device)
            gen = model.generate(
                **enc,
                max_length=min(MAX_LENGTH, 400),
                min_length=6,
                num_beams=6,
                no_repeat_ngram_size=3,
                length_penalty=1.05,
                early_stopping=True,
            )
            return tokenizer.batch_decode(gen, skip_special_tokens=True)

        pseudo_trans = []
        for i in range(0, len(translits), 8):  # small batch to avoid OOM
            try:
                batch_preds = generate_batch(translits[i:i+8])
                pseudo_trans.extend(batch_preds)
            except RuntimeError as e:
                if "out of memory" in str(e).lower():
                    print("[WARNING] OOM during pseudo generation; skipping remaining.")
                    break
                else:
                    raise

        # Postprocess & filter
        def dedup_repeats(text: str) -> str:
            toks = text.split()
            out = []
            for t in toks:
                if len(out) >= 2 and t == out[-1] == out[-2]:
                    continue
                out.append(t)
            return " ".join(out)
        def postprocess_text(preds):
            out = []
            for p in preds:
                p = p.strip()
                p = re.sub(r"\s+([.,!?;:])", r"\1", p)
                p = re.sub(r"([.,!?;:])([A-Za-z])", r"\1 \2", p)
                p = dedup_repeats(p)
                if p and p[0].islower():
                    p = p[0].upper() + p[1:]
                if p and p[-1] not in ".!?":
                    p += "."
                p = re.sub(r"([.!?]){2,}", ".", p)
                out.append(p.strip())
            return out

        pseudo_trans = postprocess_text(pseudo_trans)
        aug_df = pd.DataFrame({"transliteration": translits[:len(pseudo_trans)], "translation": pseudo_trans})
        aug_df["src_len"] = aug_df["transliteration"].str.split().str.len()
        aug_df["tgt_len"] = aug_df["translation"].str.split().str.len()
        ratio = (aug_df["tgt_len"] / aug_df["src_len"]).clip(upper=6)
        aug_df = aug_df[(aug_df["tgt_len"] >= 4) & (ratio >= 0.5) & (ratio <= 6)]
        aug_df = aug_df.drop(columns=["src_len", "tgt_len"])
        print(f"Pseudo pairs retained after filtering: {len(aug_df)}")

        base_train = pd.read_csv(f"{DATA_DIR}/train.csv")
        base_train = base_train.dropna(subset=["transliteration", "translation"]).astype(str)
        base_train["transliteration"] = base_train["transliteration"].map(clean_translit)
        base_train["translation"] = base_train["translation"].map(clean_translation)
        combined = pd.concat([
            base_train[["transliteration", "translation"]],
            aug_df[["transliteration", "translation"]]
        ], axis=0).drop_duplicates().reset_index(drop=True)
        print(f"Total combined training pairs: {len(combined)}")

        ds_combined = Dataset.from_pandas(combined)
        def preprocess_function_aug(examples):
            inputs = [PREFIX + ex for ex in examples["transliteration"]]
            targets = examples["translation"]
            model_inputs = tokenizer(
                inputs,
                max_length=MAX_LENGTH,
                truncation=True,
                padding="max_length"
            )
            with tokenizer.as_target_tokenizer():
                labels = tokenizer(
                    targets,
                    max_length=MAX_LENGTH,
                    truncation=True,
                    padding="max_length"
                )
            model_inputs["labels"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label]
                for label in labels["input_ids"]
            ]
            return model_inputs
        tokenized_combined = ds_combined.map(preprocess_function_aug, batched=True)

        training_args_aug = Seq2SeqTrainingArguments(
            output_dir=OUTPUT_DIR,
            save_strategy="no",
            eval_strategy="no",
            load_best_model_at_end=False,
            learning_rate=2.5e-4,
            per_device_train_batch_size=1,
            gradient_accumulation_steps=16,
            num_train_epochs=1,  # keep short to avoid OOM/time
            fp16=True,
            report_to="none"
        )
        trainer_aug = Seq2SeqTrainer(
            model=model,
            args=training_args_aug,
            train_dataset=tokenized_combined,
            tokenizer=tokenizer,
            data_collator=data_collator,
        )
        print("Starting second-stage training (ByT5) with augmented data...")
        try:
            trainer_aug.train()
        except RuntimeError as e:
            print(f"[WARNING] Augmentation training skipped due to error: {e}")
        print("Augmentation stage complete.")

        print(f"Saving augmented model to {OUTPUT_DIR}...")
        trainer_aug.save_model(OUTPUT_DIR)
        tokenizer.save_pretrained(OUTPUT_DIR)
    else:
        print("published_texts.csv not found; skipping self-training.")