In [1]:
import urllib.request
import zipfile
import os

url = "https://object.pouta.csc.fi/OPUS-TED2020/v1/moses/en-sw.txt.zip"
output_path = "en-sw.txt.zip"

# Download the zip file
urllib.request.urlretrieve(url, output_path)

# Extract it
with zipfile.ZipFile(output_path, 'r') as zip_ref:
    zip_ref.extractall("ted2020_en_sw")

print("✅ Downloaded and extracted to 'ted2020_en_sw'")


✅ Downloaded and extracted to 'ted2020_en_sw'


In [2]:
en_file_path = "ted2020_en_sw/TED2020.en-sw.en"
sw_file_path = "ted2020_en_sw/TED2020.en-sw.sw"

# Count lines (each line = 1 sentence)
with open(en_file_path, "r", encoding="utf-8") as en_file:
    en_lines = en_file.readlines()

with open(sw_file_path, "r", encoding="utf-8") as sw_file:
    sw_lines = sw_file.readlines()

# Safety check (in case of mismatch)
assert len(en_lines) == len(sw_lines), "Mismatch between EN and SW sentence count"

print(f"✅ Total sentence pairs: {len(en_lines)}")


✅ Total sentence pairs: 9745


In [5]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import collections
from datasets import Dataset, DatasetDict, load_dataset

# ✅ Language map
languages = {
    "yo": "Yoruba",
    "ar": "Arabic",
    "zh": "Chinese",
    "ru": "Russian",
    "hi": "Hindi",
    "ja": "Japanese",
    "sw": "Swahili",
    "bn": "Bengali",
    "tr": "Turkish"
}

# ✅ Data quality filter function
def is_valid_translation_pair(en_text, target_text):
    """Filter out low-quality translation pairs"""
    if not en_text or not target_text:
        return False
    
    # Remove whitespace for comparison
    en_clean = en_text.strip()
    target_clean = target_text.strip()
    
    # Filter out identical pairs (likely technical strings)
    if en_clean == target_clean:
        return False
    
    # Filter out very short pairs (less than 3 words)
    if len(en_clean.split()) < 3 or len(target_clean.split()) < 2:
        return False
    
    # Filter out technical/UI strings
    technical_keywords = [
        'html', 'xml', 'mozilla', 'utf-8', 'iso-8859', 'ascii',
        'preferences', 'workspace', 'desktop file', 'invalid',
        'error', 'warning', 'debug', 'log', 'config'
    ]
    
    if any(keyword in en_clean.lower() for keyword in technical_keywords):
        return False
    
    # Filter out very long sentences (likely corrupted)
    if len(en_clean) > 200 or len(target_clean) > 200:
        return False
    
    # Filter out sentences that are mostly punctuation or numbers
    if len(''.join(c for c in en_clean if c.isalpha())) < 10:
        return False
    
    return True

loaded_datasets = {}  # lang_code: Dataset
min_count = float("inf")

# ✅ Load Swahili manually
def load_ted2020_en_sw(path="ted2020_en_sw"):
    """Load TED2020 English-Swahili dataset"""
    try:
        with open(os.path.join(path, "TED2020.en-sw.en"), encoding="utf-8") as f_en, \
             open(os.path.join(path, "TED2020.en-sw.sw"), encoding="utf-8") as f_sw:
            en_lines = f_en.readlines()
            sw_lines = f_sw.readlines()
        
        # Filter valid pairs during loading
        valid_pairs = []
        for en, sw in zip(en_lines, sw_lines):
            en_clean = en.strip()
            sw_clean = sw.strip()
            if is_valid_translation_pair(en_clean, sw_clean):
                valid_pairs.append({"en": en_clean, "sw": sw_clean})
        
        data = {"translation": valid_pairs}
        print(f"📁 Loaded {len(valid_pairs)} valid Swahili pairs (filtered from {len(en_lines)} total)")
        return Dataset.from_dict(data)
    except FileNotFoundError:
        print("⚠️  Swahili TED2020 files not found, skipping Swahili")
        return None

# ✅ Load datasets and find minimum size for balancing
print("📦 Loading datasets...")
for code in languages:
    print(f"Loading {languages[code]} ({code})...")
    
    if code == "sw":
        ds = load_ted2020_en_sw()
        if ds is None:
            continue
    else:
        try:
            # Try both directions for OPUS100
            try:
                ds = load_dataset("opus100", f"en-{code}")["train"]
            except:
                ds = load_dataset("opus100", f"{code}-en")["train"]
            
            # Filter the dataset for quality
            def filter_dataset(example):
                en_text = example["translation"]["en"]
                target_text = example["translation"][code]
                return is_valid_translation_pair(en_text, target_text)
            
            print(f"  Original size: {len(ds)}")
            ds = ds.filter(filter_dataset)
            print(f"  After filtering: {len(ds)}")
            
        except Exception as e:
            print(f"  ❌ Failed to load {code}: {e}")
            continue
    
    loaded_datasets[code] = ds
    min_count = min(min_count, len(ds))
    print(f"  ✅ Loaded {len(ds)} valid pairs")

print(f"\n📊 Balancing all datasets to {min_count} sentence pairs")

# ✅ Combine and clean datasets
combined_data = {"translation": [], "language": []}
total_added = 0

for code, ds in loaded_datasets.items():
    print(f"Processing {languages[code]}...")
    
    # Shuffle and sample
    sampled = ds.shuffle(seed=42).select(range(min_count))
    added_count = 0
    
    for row in sampled:
        en_text = row["translation"]["en"].strip()
        tgt_text = row["translation"][code].strip()
        
        # Double-check quality (in case some slipped through)
        if is_valid_translation_pair(en_text, tgt_text):
            combined_data["translation"].append({
                "en": en_text,
                code: tgt_text
            })
            combined_data["language"].append(code)
            added_count += 1
    
    print(f"  ✅ Added {added_count} pairs")
    total_added += added_count

print(f"\n📈 Total dataset size: {total_added} translation pairs")

# ✅ Create Hugging Face Dataset and split
full_dataset = Dataset.from_dict(combined_data)
split = full_dataset.train_test_split(test_size=0.1, seed=42)
final_dataset = DatasetDict({
    "train": split["train"],
    "test": split["test"]
})

# ✅ Save to disk
final_dataset.save_to_disk("balanced_mt_dataset")
print("✅ Final dataset saved to 'balanced_mt_dataset'")

# ✅ Quality check - show first 10 examples
print("\n🔍 Quality Check - First 10 examples:")
for i in range(min(10, len(final_dataset["train"]))):
    example = final_dataset["train"][i]
    lang = example["language"]
    en_text = example["translation"]["en"]
    target_text = example["translation"][lang]
    
    # Truncate long texts for display
    en_display = en_text[:60] + "..." if len(en_text) > 60 else en_text
    target_display = target_text[:60] + "..." if len(target_text) > 60 else target_text
    
    print(f"{i+1:2d}. {lang.upper():3s} | EN: {en_display}")
    print(f"    {' '*3} | {lang.upper():2s}: {target_display}")
    print()

# ✅ Dataset statistics
print("📊 Dataset Statistics:")
print(f"Train size: {len(final_dataset['train'])}")
print(f"Test size: {len(final_dataset['test'])}")

lang_counts = collections.Counter(combined_data["language"])
for code, count in lang_counts.items():
    print(f"{languages[code]:10s}: {count:,} pairs")

# ✅ Visualize balanced language distribution
plt.figure(figsize=(12, 6))
lang_names = [languages[k] for k in lang_counts.keys()]
counts = list(lang_counts.values())

sns.barplot(x=lang_names, y=counts, palette="viridis")
plt.title("Balanced Sentence Distribution Across Languages\n(After Quality Filtering)")
plt.xlabel("Language")
plt.ylabel("Translation Pairs")
plt.xticks(rotation=45)

# Add count labels on bars
for i, count in enumerate(counts):
    plt.text(i, count + max(counts)*0.01, f'{count:,}', 
             ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

# ✅ Additional analysis: Average sentence lengths
print("\n📏 Average Sentence Lengths:")
lang_lengths = {code: {'en': [], 'target': []} for code in languages.keys()}

for example in final_dataset["train"]:
    lang = example["language"]
    if lang in lang_lengths:
        en_len = len(example["translation"]["en"].split())
        target_len = len(example["translation"][lang].split())
        lang_lengths[lang]['en'].append(en_len)
        lang_lengths[lang]['target'].append(target_len)

for code, lengths in lang_lengths.items():
    if lengths['en']:  # Only if we have data for this language
        avg_en = sum(lengths['en']) / len(lengths['en'])
        avg_target = sum(lengths['target']) / len(lengths['target'])
        print(f"{languages[code]:10s}: EN={avg_en:.1f} words, {code.upper()}={avg_target:.1f} words")

print("\n✅ Dataset creation complete!")
print("💡 Usage example:")
print("from datasets import load_from_disk")
print("dataset = load_from_disk('balanced_mt_dataset')")
print("print(dataset['train'][0])")

ModuleNotFoundError: No module named 'matplotlib'

In [7]:
# Save train set
final_dataset["train"].save_to_disk("balanced_mt_dataset/train")

# Save test set
final_dataset["test"].save_to_disk("balanced_mt_dataset/test")


Saving the dataset (1/1 shards): 100%|██████████| 34376/34376 [00:00<00:00, 144270.24 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 3820/3820 [00:00<00:00, 177533.73 examples/s]


In [1]:
from datasets import Dataset, DatasetDict, load_dataset
from datasets import load_from_disk

final_dataset = DatasetDict({
    "train": load_from_disk("balanced_mt_dataset/train"),
    "test": load_from_disk("balanced_mt_dataset/test")
})

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# Check what's actually in your Yoruba dataset
yo_dataset = loaded_datasets["yo"]
print("First 5 Yoruba examples:")
for i in range(5):
    print(f"EN: {yo_dataset[i]['translation']['en']}")
    print(f"YO: {yo_dataset[i]['translation']['yo']}")
    print("---")

# Check other languages too
for code in ['ar', 'zh', 'hi']:
    ds = loaded_datasets[code]
    print(f"\nFirst example from {code}:")
    print(f"EN: {ds[0]['translation']['en']}")
    print(f"{code.upper()}: {ds[0]['translation'][code]}")

NameError: name 'loaded_datasets' is not defined

In [2]:
# Check first 10 examples to see the variety
for i in range(10):
    example = final_dataset["train"][i]
    lang = example["language"]
    en_text = example["translation"]["en"]
    target_text = example["translation"][lang]
    print(f"{i}: {lang.upper()} | EN: {en_text[:50]}... | {lang.upper()}: {target_text[:50]}...")

0: ZH | EN: In paragraph 1.50:... | ZH: 3. 在第1.50段:...
1: BN | EN: Would you like to send meeting invitations to part... | BN: আপনি কি অংশগ্রহণকারীদেরকে সভার আমন্ত্রনপত্র প্রেরণ...
2: RU | EN: As commanding officer, it's my job to interpret th... | RU: Капитан приказал нам вернуться за ним....
3: AR | EN: Would you turn back?... | AR: -هل لنا أن نعود؟...
4: JA | EN: I promise them to you if you will do me the favor ... | JA: 願いを聞いてもらえるなら 約束する...
5: SW | EN: I was no different.... | SW: Sikuwa tofauti...
6: ZH | EN: Girl, you could die at 40 from ajax poisoning.... | ZH: 姑娘 你可能40岁就死于清洁剂中毒...
7: SW | EN: In two generations, those had produced 3,800 grand... | SW: Ndani ya vizazi viwili, hao walizalisha wajukuu 3,...
8: RU | EN: UNIFEM has not devised an adequate tracking system... | RU: Необходимые системы контроля для этого ЮНИФЕМ пока...
9: AR | EN: I'm still analyzing this, but for the most part, i... | AR: مازلت أحلل هذا لكن الجزء الأكبر يبدو لي كمعلومات ح...


In [3]:
print("Train size:", len(final_dataset["train"]))
print("Test size:", len(final_dataset["test"]))


Train size: 34376
Test size: 3820


In [4]:
final_dataset["train"][1]

{'translation': {'ar': None,
  'bn': 'আপনি কি অংশগ্রহণকারীদেরকে সভার আমন্ত্রনপত্র প্রেরণ করতে চান?',
  'en': 'Would you like to send meeting invitations to participants?',
  'hi': None,
  'ja': None,
  'ru': None,
  'sw': None,
  'tr': None,
  'yo': None,
  'zh': None},
 'language': 'bn'}

In [None]:
# import os
# import csv
# import torch
# import warnings
# import gc
# from datasets import load_from_disk
# from transformers import (
#     EncoderDecoderModel,
#     AutoTokenizer,
#     AutoModelForSeq2SeqLM,
#     DataCollatorForSeq2Seq,
#     Seq2SeqTrainer,
#     Seq2SeqTrainingArguments,
#     logging as hf_logging,
#     MBartForConditionalGeneration,
#     MBart50TokenizerFast,
#     T5ForConditionalGeneration,
#     T5TokenizerFast
# )
# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
# import numpy as np

# # Suppress warnings
# hf_logging.set_verbosity_error()
# warnings.filterwarnings("ignore")

# # Custom tokenizer settings - 3 sizes × 3 types = 9 combinations
# tokenizer_sizes = ["small", "medium", "large"]
# tokenizer_types = ["hf_bpe_hf", "hf_wordpiece_hf", "sp_unigram_hf"]

# # Single multilingual model configuration
# MODEL_CONFIG = {
#     "model_name": "facebook/mbart-large-50-many-to-many-mmt",
#     "type": "mbart",
#     "description": "Multilingual mBART model with custom tokenizers"
# }

# # Languages for evaluation (model handles all simultaneously)
# languages = {
#     "yo": "Yoruba",
#     "ar": "Arabic", 
#     "zh": "Chinese",
#     "ru": "Russian",
#     "hi": "Hindi",
#     "ja": "Japanese",
#     "swa": "Swahili",
#     "bn": "Bengali",
#     "tr": "Turkish"
# }
# SRC_LANG = "en"

# # Load complete multilingual dataset
# dataset_path = "balanced_mt_dataset"
# print("📦 Loading complete multilingual dataset...")
# full_dataset = load_from_disk(dataset_path)
# print(f"✅ Dataset loaded: {len(full_dataset['train'])} train, {len(full_dataset['test'])} test")
# print(f"🌍 All languages included: {list(languages.keys())}")

# # Enhanced BLEU computation for multilingual evaluation
# def compute_multilingual_bleu(eval_pred):
#     """Multilingual BLEU computation across all target languages"""
#     predictions, labels = eval_pred
    
#     if len(predictions.shape) == 3:
#         predictions = np.argmax(predictions, axis=-1)
    
#     decoded_preds = []
#     decoded_labels = []
    
#     for pred, label in zip(predictions, labels):
#         # Replace -100 with pad token for decoding
#         label = np.where(label != -100, label, tokenizer.pad_token_id)
        
#         decoded_pred = tokenizer.decode(pred, skip_special_tokens=True).strip()
#         decoded_label = tokenizer.decode(label, skip_special_tokens=True).strip()
        
#         decoded_preds.append(decoded_pred)
#         decoded_labels.append(decoded_label)
    
#     # Compute BLEU with smoothing
#     smoothing = SmoothingFunction().method1
#     bleu_scores = []
#     exact_matches = 0
    
#     for pred, label in zip(decoded_preds, decoded_labels):
#         if not pred.strip() or not label.strip():
#             bleu_scores.append(0.0)
#             continue
            
#         pred_tokens = pred.split()
#         label_tokens = label.split()
        
#         # Check exact match
#         if pred.lower().strip() == label.lower().strip():
#             exact_matches += 1
        
#         if len(pred_tokens) == 0 or len(label_tokens) == 0:
#             bleu_scores.append(0.0)
#             continue
        
#         try:
#             bleu = sentence_bleu(
#                 [label_tokens], 
#                 pred_tokens,
#                 smoothing_function=smoothing,
#                 weights=(0.25, 0.25, 0.25, 0.25)
#             )
#             bleu_scores.append(bleu)
#         except:
#             bleu_scores.append(0.0)
    
#     avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0.0
#     exact_match_ratio = exact_matches / len(decoded_preds) if decoded_preds else 0.0
    
#     return {
#         "bleu": avg_bleu,
#         "exact_match": exact_match_ratio,
#         "avg_pred_length": np.mean([len(p.split()) for p in decoded_preds if p.strip()]) if decoded_preds else 0.0,
#         "avg_label_length": np.mean([len(l.split()) for l in decoded_labels if l.strip()]) if decoded_labels else 0.0,
#         "empty_predictions": sum(1 for p in decoded_preds if not p.strip()) / len(decoded_preds) if decoded_preds else 0.0
#     }

# def setup_custom_tokenizer_for_mbart(tokenizer_path):
#     """Setup custom tokenizer with proper mBART special tokens"""
#     # Load custom tokenizer
#     tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
    
#     # Define mBART special tokens
#     special_tokens = {
#         'bos_token': '<s>',
#         'eos_token': '</s>',
#         'sep_token': '</s>',
#         'pad_token': '<pad>',
#         'unk_token': '<unk>',
#         'mask_token': '<mask>'
#     }
    
#     # Add missing special tokens
#     tokens_to_add = {}
#     for token_name, token_value in special_tokens.items():
#         if getattr(tokenizer, token_name, None) is None:
#             tokens_to_add[token_name] = token_value
    
#     if tokens_to_add:
#         tokenizer.add_special_tokens(tokens_to_add)
    
#     # Ensure we have the required tokens
#     assert tokenizer.bos_token is not None, "BOS token is required"
#     assert tokenizer.eos_token is not None, "EOS token is required"
#     assert tokenizer.pad_token is not None, "PAD token is required"
    
#     return tokenizer

# def configure_model_for_custom_tokenizer(model, tokenizer):
#     """Configure mBART model for custom tokenizer"""
#     # Resize embeddings
#     print("🔄 Resizing model embeddings to match custom tokenizer...")
#     model.resize_token_embeddings(len(tokenizer))
    
#     # Set token IDs
#     model.config.decoder_start_token_id = tokenizer.bos_token_id
#     model.config.pad_token_id = tokenizer.pad_token_id
#     model.config.bos_token_id = tokenizer.bos_token_id
#     model.config.eos_token_id = tokenizer.eos_token_id
#     model.config.sep_token_id = tokenizer.eos_token_id  # mBART uses EOS as SEP
#     model.config.forced_eos_token_id = tokenizer.eos_token_id
    
#     # Update generation config if it exists
#     if hasattr(model, 'generation_config') and model.generation_config is not None:
#         model.generation_config.decoder_start_token_id = tokenizer.bos_token_id
#         model.generation_config.pad_token_id = tokenizer.pad_token_id
#         model.generation_config.bos_token_id = tokenizer.bos_token_id
#         model.generation_config.eos_token_id = tokenizer.eos_token_id
#         model.generation_config.forced_eos_token_id = tokenizer.eos_token_id
    
#     return model

# # Setup logging
# log_dir = "./MT_models_multilingual_custom_tokenizers"
# os.makedirs(log_dir, exist_ok=True)
# results_log_file = os.path.join(log_dir, "multilingual_custom_tokenizer_results.csv")

# if not os.path.exists(results_log_file):
#     with open(results_log_file, 'w', newline='', encoding='utf-8') as csvfile:
#         writer = csv.writer(csvfile)
#         writer.writerow(["Model_ID", "Base_Model", "Tokenizer_Size", "Tokenizer_Type", 
#                         "Total_Languages", "Total_Train_Samples", "Total_Test_Samples", 
#                         "Custom_Vocab_Size", "Overall_BLEU", "Overall_Exact_Match", 
#                         "Avg_Pred_Length", "Avg_Label_Length", "Empty_Predictions", 
#                         "Training_Status", "Notes"])

# print("🚀 Starting multilingual training with 9 CUSTOM TOKENIZERS...")
# print(f"📊 Training approach: ONE model per tokenizer handling ALL {len(languages)} languages")

# # Main training loop: 3 sizes × 3 types = 9 models total
# for size in tokenizer_sizes:
#     for tok_type in tokenizer_types:
#         # Path to custom tokenizer
#         tokenizer_path = f"vocab_final/vocab_final{size}/{tok_type}"
#         model_id = f"multilingual_{size}_{tok_type}"
        
#         print(f"\n{'='*100}")
#         print(f"🚀 Training Model {tokenizer_sizes.index(size)*3 + tokenizer_types.index(tok_type) + 1}/9")
#         print(f"🔧 Custom Tokenizer: {size}_{tok_type}")
#         print(f"📦 Base Model: {MODEL_CONFIG['model_name']}")
#         print(f"🌍 Target Languages: {list(languages.keys())} (ALL SIMULTANEOUSLY)")
        
#         # Initialize variables
#         model = None
#         tokenizer = None
#         trainer = None
        
#         try:
#             # Load and setup custom tokenizer
#             print(f"🔧 Loading custom tokenizer: {tokenizer_path}")
#             tokenizer = setup_custom_tokenizer_for_mbart(tokenizer_path)
            
#             print(f"✅ Custom tokenizer loaded successfully!")
#             print(f"📊 Custom vocab size: {len(tokenizer)}")
#             print(f"🔑 Special tokens - BOS: {tokenizer.bos_token_id}, EOS: {tokenizer.eos_token_id}, PAD: {tokenizer.pad_token_id}")
            
#             # Load base multilingual model
#             print("🤖 Loading base multilingual mBART model...")
#             model = MBartForConditionalGeneration.from_pretrained(MODEL_CONFIG["model_name"])
            
#             # Configure model for custom tokenizer
#             model = configure_model_for_custom_tokenizer(model, tokenizer)
            
#             print(f"✅ Model configured with custom tokenizer!")
            
#             # Use COMPLETE multilingual dataset (all languages together)
#             print(f"📊 Using complete multilingual dataset:")
#             print(f"   • Train samples: {len(full_dataset['train'])}")
#             print(f"   • Test samples: {len(full_dataset['test'])}")
#             print(f"   • Languages: {len(languages)} ({list(languages.keys())})")
            
#             # FIXED: Preprocessing function for multilingual data with custom tokenizer
#             def preprocess_multilingual_custom(examples):
#                 """Preprocess multilingual data with custom tokenizer - FIXED VERSION"""
#                 sources = []
#                 targets = []
                
#                 # Handle both single examples and batches
#                 if not isinstance(examples["translation"], list):
#                     examples = {
#                         "translation": [examples["translation"]], 
#                         "language": [examples["language"]]
#                     }
                
#                 for translation, lang in zip(examples["translation"], examples["language"]):
#                     if isinstance(translation, dict) and lang in languages:
#                         source = translation.get(SRC_LANG, "")
#                         target = translation.get(lang, "")
                        
#                         if source and target:
#                             # Add language information to source for multilingual training
#                             source_formatted = f"translate English to {languages[lang]}: {source}"
#                             sources.append(source_formatted)
#                             targets.append(target)
                
#                 if not sources or not targets:
#                     return {"input_ids": [], "attention_mask": [], "labels": []}
                
#                 # Tokenize with custom tokenizer
#                 max_length = 128
                
#                 model_inputs = tokenizer(
#                     sources,
#                     max_length=max_length,
#                     truncation=True,
#                     padding="max_length",
#                     return_tensors=None,
#                     return_token_type_ids=False
#                 )
#                 model_inputs.pop("token_type_ids", None)

#                 labels = tokenizer(
#                     targets,
#                     max_length=max_length,
#                     truncation=True,
#                     padding="max_length",
#                     return_tensors=None,
#                     return_token_type_ids=False
#                 )
#                 labels.pop("token_type_ids", None)

#                 # CRITICAL FIX: Ensure EOS token is present in labels
#                 processed_labels = []
#                 for label_seq in labels["input_ids"]:
#                     # Remove padding tokens first
#                     label_tokens = [token for token in label_seq if token != tokenizer.pad_token_id]
                    
#                     # Ensure EOS token is at the end (required for mBART)
#                     if not label_tokens or label_tokens[-1] != tokenizer.eos_token_id:
#                         label_tokens.append(tokenizer.eos_token_id)
                    
#                     # Pad to max_length and replace padding with -100
#                     while len(label_tokens) < max_length:
#                         label_tokens.append(-100)
                    
#                     # Truncate if too long
#                     label_tokens = label_tokens[:max_length]
#                     processed_labels.append(label_tokens)
                
#                 model_inputs["labels"] = processed_labels
#                 return model_inputs
            
#             # Preprocess complete multilingual dataset
#             print("⚙️  Preprocessing complete multilingual dataset with custom tokenizer...")
#             processed_dataset = full_dataset.map(
#                 preprocess_multilingual_custom,
#                 batched=True,
#                 remove_columns=full_dataset["train"].column_names,
#                 desc=f"Preprocessing with {size}_{tok_type}",
#                 batch_size=100
#             )
            
#             train_dataset = processed_dataset["train"]
#             eval_dataset = processed_dataset["test"]
            
#             # Filter out empty examples
#             def filter_empty(example):
#                 return (len(example["input_ids"]) > 0 and 
#                        len(example["labels"]) > 0 and
#                        any(label != -100 for label in example["labels"]))  # Ensure non-empty labels
            
#             train_dataset = train_dataset.filter(filter_empty)
#             eval_dataset = eval_dataset.filter(filter_empty)
            
#             print(f"✅ Preprocessed multilingual dataset:")
#             print(f"   • Train samples: {len(train_dataset)}")
#             print(f"   • Eval samples: {len(eval_dataset)}")
            
#             if len(train_dataset) == 0 or len(eval_dataset) == 0:
#                 print("❌ No valid samples after preprocessing")
#                 continue
            
#             # Setup training directory
#             output_dir = f"./MT_models_multilingual_custom_tokenizers/{model_id}"
#             os.makedirs(output_dir, exist_ok=True)
            
#             # Training arguments for multilingual model
#             training_args = Seq2SeqTrainingArguments(
#                 output_dir=output_dir,
#                 num_train_epochs=3,
#                 per_device_train_batch_size=4,  # Reduced batch size for stability
#                 per_device_eval_batch_size=4,
#                 gradient_accumulation_steps=4,  # Increased to maintain effective batch size
#                 learning_rate=3e-5,  # Slightly lower learning rate
#                 weight_decay=0.01,
#                 warmup_steps=500,
#                 eval_strategy="epoch",
#                 save_strategy="epoch",
#                 save_total_limit=2,
#                 logging_steps=100,
#                 report_to="none",
#                 predict_with_generate=True,
#                 generation_max_length=128,
#                 generation_num_beams=2,
#                 fp16=torch.cuda.is_available(),
#                 load_best_model_at_end=True,
#                 metric_for_best_model="bleu",
#                 greater_is_better=True,
#                 dataloader_num_workers=0,
#                 remove_unused_columns=False,
#                 ignore_data_skip=True,  # Skip problematic data points
#             )
            
#             # FIXED: Data collator with proper configuration
#             data_collator = DataCollatorForSeq2Seq(
#                 tokenizer=tokenizer,
#                 model=model,
#                 padding=True,
#                 pad_to_multiple_of=8 if training_args.fp16 else None,
#                 return_tensors="pt"
#             )
            
#             # Trainer
#             trainer = Seq2SeqTrainer(
#                 model=model,
#                 args=training_args,
#                 train_dataset=train_dataset,
#                 eval_dataset=eval_dataset,
#                 tokenizer=tokenizer,
#                 data_collator=data_collator,
#                 compute_metrics=compute_multilingual_bleu
#             )
            
#             # Train multilingual model
#             print("🏋️  Starting multilingual training with custom tokenizer...")
#             trainer.train()
            
#             # Evaluate
#             print("📊 Final multilingual evaluation...")
#             eval_results = trainer.evaluate()
            
#             # Save model and custom tokenizer
#             print("💾 Saving multilingual model and custom tokenizer...")
#             trainer.save_model()
#             tokenizer.save_pretrained(output_dir)
            
#             # Log results
#             overall_bleu = eval_results.get("eval_bleu", 0.0)
#             overall_exact_match = eval_results.get("eval_exact_match", 0.0)
#             avg_pred_len = eval_results.get("eval_avg_pred_length", 0.0)
#             avg_label_len = eval_results.get("eval_avg_label_length", 0.0)
#             empty_preds = eval_results.get("eval_empty_predictions", 0.0)
            
#             with open(results_log_file, 'a', newline='', encoding='utf-8') as csvfile:
#                 writer = csv.writer(csvfile)
#                 writer.writerow([
#                     model_id, MODEL_CONFIG["model_name"], size, tok_type,
#                     len(languages), len(train_dataset), len(eval_dataset), len(tokenizer),
#                     round(overall_bleu, 4), round(overall_exact_match, 4), 
#                     round(avg_pred_len, 2), round(avg_label_len, 2),
#                     round(empty_preds, 4), "SUCCESS", 
#                     f"Multilingual model with custom {size}_{tok_type} tokenizer"
#                 ])
            
#             print(f"✅ Completed: {model_id}")
#             print(f"📈 Overall BLEU Score: {overall_bleu:.4f}")
#             print(f"🎯 Overall Exact Match: {overall_exact_match:.4f}")
            
#             # Quick translation tests for different languages
#             print(f"\n🧪 Quick multilingual translation tests:")
#             test_input = "Hello, how are you today?"
            
#             for test_lang in ["ar", "zh", "hi"]:  # Test 3 languages
#                 formatted_input = f"translate English to {languages[test_lang]}: {test_input}"
#                 inputs = tokenizer(formatted_input, return_tensors="pt", padding=True, max_length=128, truncation=True)

#                 # Remove token_type_ids and move to same device
#                 inputs = {k: v for k, v in inputs.items() if k != "token_type_ids"}
#                 device = next(model.parameters()).device
#                 inputs = {k: v.to(device) for k, v in inputs.items()}

#                 with torch.no_grad():
#                     outputs = model.generate(
#                         **inputs,
#                         max_length=50,
#                         num_beams=2,
#                         early_stopping=True,
#                         do_sample=False,
#                         forced_eos_token_id=tokenizer.eos_token_id
#                     )

                
#                 translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
#                 print(f"  {test_lang} ({languages[test_lang]}): {translation}")
                
#         except Exception as e:
#             print(f"❌ Failed to train {model_id}: {str(e)}")
#             import traceback
#             traceback.print_exc()
            
#             with open(results_log_file, 'a', newline='', encoding='utf-8') as csvfile:
#                 writer = csv.writer(csvfile)
#                 writer.writerow([
#                     model_id, MODEL_CONFIG.get("model_name", ""), size, tok_type,
#                     len(languages), 0, 0, 0, 0, 0, 0, 0, 0, "TRAINING_FAILED", str(e)[:100]
#                 ])
        
#         finally:
#             # Cleanup
#             if model is not None:
#                 del model
#             if trainer is not None:
#                 del trainer
#             torch.cuda.empty_cache()
#             gc.collect()

# print("\n🎉 Multilingual training with custom tokenizers completed!")
# print(f"📋 Results saved to: {results_log_file}")
# print(f"🔢 Total models trained: 9 (3 sizes × 3 types)")
# print(f"🌍 Each model handles all {len(languages)} languages simultaneously")

📦 Loading complete multilingual dataset...
✅ Dataset loaded: 34376 train, 3820 test
🌍 All languages included: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr']
🚀 Starting multilingual training with 9 CUSTOM TOKENIZERS...
📊 Training approach: ONE model per tokenizer handling ALL 9 languages

🚀 Training Model 1/9
🔧 Custom Tokenizer: small_hf_bpe_hf
📦 Base Model: facebook/mbart-large-50-many-to-many-mmt
🌍 Target Languages: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr'] (ALL SIMULTANEOUSLY)
🔧 Loading custom tokenizer: vocab_final/vocab_finalsmall/hf_bpe_hf
✅ Custom tokenizer loaded successfully!
📊 Custom vocab size: 15002
🔑 Special tokens - BOS: 15000, EOS: 15001, PAD: 0
🤖 Loading base multilingual mBART model...
🔄 Resizing model embeddings to match custom tokenizer...
✅ Model configured with custom tokenizer!
📊 Using complete multilingual dataset:
   • Train samples: 34376
   • Test samples: 3820
   • Languages: 9 (['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr'])
⚙️

Preprocessing with large_hf_bpe_hf: 100%|██████████| 3820/3820 [00:01<00:00, 2154.90 examples/s]
Filter: 100%|██████████| 3386/3386 [00:00<00:00, 8621.32 examples/s]


✅ Preprocessed multilingual dataset:
   • Train samples: 30566
   • Eval samples: 3386
🏋️  Starting multilingual training with custom tokenizer...
{'loss': 11.1783, 'grad_norm': 5.359941005706787, 'learning_rate': 5.82e-06, 'epoch': 0.05234231876472128}
{'loss': 9.1278, 'grad_norm': 9.098174095153809, 'learning_rate': 1.182e-05, 'epoch': 0.10468463752944256}
{'loss': 8.1311, 'grad_norm': 6.191810607910156, 'learning_rate': 1.782e-05, 'epoch': 0.15702695629416383}
{'loss': 7.714, 'grad_norm': 5.863392353057861, 'learning_rate': 2.3820000000000002e-05, 'epoch': 0.2093692750588851}
{'loss': 7.3491, 'grad_norm': 5.360826015472412, 'learning_rate': 2.982e-05, 'epoch': 0.26171159382360637}
{'loss': 6.9923, 'grad_norm': 5.599084377288818, 'learning_rate': 2.9443913625071663e-05, 'epoch': 0.31405391258832765}
{'loss': 6.8377, 'grad_norm': 6.245006084442139, 'learning_rate': 2.8870628702465126e-05, 'epoch': 0.36639623135304894}
{'loss': 6.6141, 'grad_norm': 4.975226402282715, 'learning_rate': 2

Traceback (most recent call last):
  File "C:\Users\User 4\AppData\Local\Temp\ipykernel_26264\3434637004.py", line 409, in <module>
    with open(results_log_file, 'a', newline='', encoding='utf-8') as csvfile:
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\IPython\core\interactiveshell.py", line 324, in _modified_open
    return io_open(file, *args, **kwargs)
PermissionError: [Errno 13] Permission denied: './MT_models_multilingual_custom_tokenizers\\multilingual_custom_tokenizer_results.csv'


PermissionError: [Errno 13] Permission denied: './MT_models_multilingual_custom_tokenizers\\multilingual_custom_tokenizer_results.csv'

In [None]:
#-----Working but slow
# import os
# import csv
# import torch
# import warnings
# import gc
# from datasets import load_from_disk
# from transformers import (
#     EncoderDecoderModel,
#     AutoTokenizer,
#     AutoModelForSeq2SeqLM,
#     DataCollatorForSeq2Seq,
#     Seq2SeqTrainer,
#     Seq2SeqTrainingArguments,
#     logging as hf_logging,
#     MBartForConditionalGeneration,
#     MBart50TokenizerFast,
#     T5ForConditionalGeneration,
#     T5TokenizerFast,
#     BertTokenizerFast,
#     GPT2TokenizerFast,
#     BartForConditionalGeneration
# )
# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
# import numpy as np

# # Suppress warnings
# hf_logging.set_verbosity_error()
# warnings.filterwarnings("ignore")

# # FIXED: Model-Tokenizer Compatibility Mapping
# MODEL_TOKENIZER_CONFIGS = {
#     "hf_bpe_hf": {
#         "base_model": "facebook/bart-large",  # BART uses BPE
#         "model_class": BartForConditionalGeneration,
#         "description": "BART with BPE tokenization"
#     },
#     "hf_wordpiece_hf": {
#         "base_model": "google/mt5-base",  # mT5 can work with WordPiece
#         "model_class": T5ForConditionalGeneration,
#         "description": "mT5 with WordPiece tokenization"
#     },
#     "sp_unigram_hf": {
#         "base_model": "google/mt5-base",  # mT5 uses SentencePiece Unigram
#         "model_class": T5ForConditionalGeneration,
#         "description": "mT5 with SentencePiece Unigram"
#     }
# }

# # Custom tokenizer settings
# tokenizer_sizes = ["small", "medium", "large"]
# tokenizer_types = ["hf_bpe_hf", "hf_wordpiece_hf", "sp_unigram_hf"]

# # Languages for evaluation
# languages = {
#     "yo": "Yoruba",
#     "ar": "Arabic", 
#     "zh": "Chinese",
#     "ru": "Russian",
#     "hi": "Hindi",
#     "ja": "Japanese",
#     "swa": "Swahili",
#     "bn": "Bengali",
#     "tr": "Turkish"
# }
# SRC_LANG = "en"

# # Load complete multilingual dataset
# dataset_path = "balanced_mt_dataset"
# print("📦 Loading complete multilingual dataset...")
# full_dataset = load_from_disk(dataset_path)
# print(f"✅ Dataset loaded: {len(full_dataset['train'])} train, {len(full_dataset['test'])} test")
# print(f"🌍 All languages included: {list(languages.keys())}")

# # Enhanced BLEU computation with progress tracking
# def compute_multilingual_bleu(eval_pred):
#     """Multilingual BLEU computation with progress tracking"""
#     import time
#     start_time = time.time()
    
#     predictions, labels = eval_pred
    
#     if len(predictions.shape) == 3:
#         predictions = np.argmax(predictions, axis=-1)
    
#     decoded_preds = []
#     decoded_labels = []
    
#     total_samples = len(predictions)
#     print(f"🔄 Starting evaluation of {total_samples} samples...")
    
#     # Process in chunks with progress updates
#     chunk_size = 50
#     for i in range(0, total_samples, chunk_size):
#         end_idx = min(i + chunk_size, total_samples)
#         chunk_preds = predictions[i:end_idx]
#         chunk_labels = labels[i:end_idx]
        
#         try:
#             for pred, label in zip(chunk_preds, chunk_labels):
#                 # CRITICAL FIX: Handle negative token IDs properly
#                 # Filter out negative IDs and pad tokens before decoding
#                 pred_clean = [token for token in pred if token >= 0 and token < len(tokenizer)]
#                 label_clean = [token for token in label if token != -100 and token >= 0 and token < len(tokenizer)]
                
#                 try:
#                     decoded_pred = tokenizer.decode(pred_clean, skip_special_tokens=True).strip() if pred_clean else ""
#                     decoded_label = tokenizer.decode(label_clean, skip_special_tokens=True).strip() if label_clean else ""
#                 except Exception as e:
#                     # Fallback for any decode errors
#                     decoded_pred = ""
#                     decoded_label = ""
                
#                 decoded_preds.append(decoded_pred)
#                 decoded_labels.append(decoded_label)
            
#             # Progress update every chunk
#             if (i + chunk_size) % (chunk_size * 5) == 0 or end_idx == total_samples:
#                 elapsed = time.time() - start_time
#                 print(f"📊 Evaluation progress: {end_idx}/{total_samples} samples ({elapsed:.1f}s elapsed)")
                
#         except Exception as e:
#             print(f"⚠️  Batch decode failed for chunk {i}-{end_idx}: {str(e)}")
#             # Add empty strings for failed batch
#             for _ in range(end_idx - i):
#                 decoded_preds.append("")
#                 decoded_labels.append("")
    
#     # Compute BLEU with smoothing
#     print("🧮 Computing BLEU scores...")
#     smoothing = SmoothingFunction().method1
#     bleu_scores = []
#     exact_matches = 0
    
#     for idx, (pred, label) in enumerate(zip(decoded_preds, decoded_labels)):
#         if idx % 1000 == 0 and idx > 0:
#             print(f"   BLEU calculation: {idx}/{len(decoded_preds)} samples")
            
#         if not pred.strip() or not label.strip():
#             bleu_scores.append(0.0)
#             continue
            
#         pred_tokens = pred.split()
#         label_tokens = label.split()
        
#         # Check exact match
#         if pred.lower().strip() == label.lower().strip():
#             exact_matches += 1
        
#         if len(pred_tokens) == 0 or len(label_tokens) == 0:
#             bleu_scores.append(0.0)
#             continue
        
#         try:
#             bleu = sentence_bleu(
#                 [label_tokens], 
#                 pred_tokens,
#                 smoothing_function=smoothing,
#                 weights=(0.25, 0.25, 0.25, 0.25)
#             )
#             bleu_scores.append(bleu)
#         except:
#             bleu_scores.append(0.0)
    
#     avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0.0
#     exact_match_ratio = exact_matches / len(decoded_preds) if decoded_preds else 0.0
    
#     total_time = time.time() - start_time
#     print(f"✅ Evaluation complete: BLEU={avg_bleu:.4f}, Exact Match={exact_match_ratio:.4f} ({total_time:.1f}s total)")
    
#     return {
#         "bleu": avg_bleu,
#         "exact_match": exact_match_ratio,
#         "avg_pred_length": np.mean([len(p.split()) for p in decoded_preds if p.strip()]) if decoded_preds else 0.0,
#         "avg_label_length": np.mean([len(l.split()) for l in decoded_labels if l.strip()]) if decoded_labels else 0.0,
#         "empty_predictions": sum(1 for p in decoded_preds if not p.strip()) / len(decoded_preds) if decoded_preds else 0.0
#     }

# def setup_custom_tokenizer_for_model(tokenizer_path, model_type):
#     """Setup custom tokenizer with proper special tokens for specific model"""
#     # Load custom tokenizer
#     tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
    
#     # Model-specific special token configuration
#     if model_type == "bart":
#         special_tokens = {
#             'bos_token': '<s>',
#             'eos_token': '</s>',
#             'sep_token': '</s>',
#             'pad_token': '<pad>',
#             'unk_token': '<unk>',
#             'mask_token': '<mask>'
#         }
#     elif model_type == "t5":
#         special_tokens = {
#             'pad_token': '<pad>',
#             'eos_token': '</s>',
#             'unk_token': '<unk>',
#             'bos_token': '<pad>',  # T5 doesn't use BOS
#             'sep_token': '</s>',
#             'mask_token': '<extra_id_0>'
#         }
#     else:  # Default fallback
#         special_tokens = {
#             'bos_token': '<s>',
#             'eos_token': '</s>',
#             'sep_token': '</s>',
#             'pad_token': '<pad>',
#             'unk_token': '<unk>',
#             'mask_token': '<mask>'
#         }
    
#     # Add missing special tokens
#     tokens_to_add = {}
#     for token_name, token_value in special_tokens.items():
#         if getattr(tokenizer, token_name, None) is None:
#             tokens_to_add[token_name] = token_value
    
#     if tokens_to_add:
#         tokenizer.add_special_tokens(tokens_to_add)
    
#     # Ensure we have the required tokens
#     assert tokenizer.eos_token is not None, "EOS token is required"
#     assert tokenizer.pad_token is not None, "PAD token is required"
    
#     return tokenizer

# def configure_model_for_custom_tokenizer(model, tokenizer, model_type):
#     """Configure model for custom tokenizer based on model type"""
#     # Resize embeddings
#     print("🔄 Resizing model embeddings to match custom tokenizer...")
#     old_vocab_size = model.config.vocab_size
#     model.resize_token_embeddings(len(tokenizer))
#     print(f"   Vocab size: {old_vocab_size} → {len(tokenizer)}")
    
#     # Model-specific configuration
#     if model_type == "bart":
#         model.config.decoder_start_token_id = tokenizer.bos_token_id
#         model.config.pad_token_id = tokenizer.pad_token_id
#         model.config.bos_token_id = tokenizer.bos_token_id
#         model.config.eos_token_id = tokenizer.eos_token_id
#         model.config.sep_token_id = tokenizer.eos_token_id
#         model.config.forced_eos_token_id = tokenizer.eos_token_id
        
#     elif model_type == "t5":
#         model.config.pad_token_id = tokenizer.pad_token_id
#         model.config.eos_token_id = tokenizer.eos_token_id
#         model.config.decoder_start_token_id = tokenizer.pad_token_id  # T5 uses pad_token_id as decoder start
#         model.config.forced_eos_token_id = tokenizer.eos_token_id
    
#     # Update generation config if it exists
#     if hasattr(model, 'generation_config') and model.generation_config is not None:
#         model.generation_config.pad_token_id = tokenizer.pad_token_id
#         model.generation_config.eos_token_id = tokenizer.eos_token_id
#         model.generation_config.forced_eos_token_id = tokenizer.eos_token_id
#         if model_type == "bart":
#             model.generation_config.decoder_start_token_id = tokenizer.bos_token_id
#             model.generation_config.bos_token_id = tokenizer.bos_token_id
#         elif model_type == "t5":
#             model.generation_config.decoder_start_token_id = tokenizer.pad_token_id
    
#     return model

# # Setup logging
# log_dir = "./MT_models_multilingual_custom_tokenizers"
# os.makedirs(log_dir, exist_ok=True)
# results_log_file = os.path.join(log_dir, "multilingual_custom_tokenizer_results.csv")

# if not os.path.exists(results_log_file):
#     with open(results_log_file, 'w', newline='', encoding='utf-8') as csvfile:
#         writer = csv.writer(csvfile)
#         writer.writerow(["Model_ID", "Base_Model", "Tokenizer_Size", "Tokenizer_Type", 
#                         "Total_Languages", "Total_Train_Samples", "Total_Test_Samples", 
#                         "Custom_Vocab_Size", "Overall_BLEU", "Overall_Exact_Match", 
#                         "Avg_Pred_Length", "Avg_Label_Length", "Empty_Predictions", 
#                         "Training_Status", "Notes"])

# print("🚀 Starting multilingual training with COMPATIBLE model-tokenizer pairs...")
# print(f"📊 Training approach: ONE model per tokenizer handling ALL {len(languages)} languages")

# # Main training loop: 3 sizes × 3 types = 9 models total
# for size in tokenizer_sizes:
#     for tok_type in tokenizer_types:
#         # Get compatible model configuration
#         model_config = MODEL_TOKENIZER_CONFIGS[tok_type]
        
#         # Path to custom tokenizer
#         tokenizer_path = f"vocab_final/vocab_final{size}/{tok_type}"
#         model_id = f"multilingual_{size}_{tok_type}"
        
#         print(f"\n{'='*100}")
#         print(f"🚀 Training Model {tokenizer_sizes.index(size)*3 + tokenizer_types.index(tok_type) + 1}/9")
#         print(f"🔧 Custom Tokenizer: {size}_{tok_type}")
#         print(f"📦 Compatible Base Model: {model_config['base_model']}")
#         print(f"🌍 Target Languages: {list(languages.keys())} (ALL SIMULTANEOUSLY)")
        
#         # Initialize variables
#         model = None
#         tokenizer = None
#         trainer = None
        
#         try:
#             # Load and setup custom tokenizer
#             print(f"🔧 Loading custom tokenizer: {tokenizer_path}")
#             model_type = "bart" if "bart" in model_config['base_model'] else "t5"
#             tokenizer = setup_custom_tokenizer_for_model(tokenizer_path, model_type)
            
#             print(f"✅ Custom tokenizer loaded successfully!")
#             print(f"📊 Custom vocab size: {len(tokenizer)}")
#             print(f"🔑 Special tokens - EOS: {tokenizer.eos_token_id}, PAD: {tokenizer.pad_token_id}")
            
#             # Load compatible base model
#             print(f"🤖 Loading compatible base model: {model_config['base_model']}")
#             model = model_config['model_class'].from_pretrained(model_config['base_model'])
            
#             # Configure model for custom tokenizer
#             model = configure_model_for_custom_tokenizer(model, tokenizer, model_type)
            
#             print(f"✅ Model configured with custom tokenizer!")
            
#             # Use COMPLETE multilingual dataset (all languages together)
#             print(f"📊 Using complete multilingual dataset:")
#             print(f"   • Train samples: {len(full_dataset['train'])}")
#             print(f"   • Test samples: {len(full_dataset['test'])}")
#             print(f"   • Languages: {len(languages)} ({list(languages.keys())})")
            
#             # IMPROVED: Preprocessing function for multilingual data
#             def preprocess_multilingual_improved(examples):
#                 """Improved preprocessing for multilingual data with model-specific formatting"""
#                 sources = []
#                 targets = []
                
#                 # Handle both single examples and batches
#                 if not isinstance(examples["translation"], list):
#                     examples = {
#                         "translation": [examples["translation"]], 
#                         "language": [examples["language"]]
#                     }
                
#                 for translation, lang in zip(examples["translation"], examples["language"]):
#                     if isinstance(translation, dict) and lang in languages:
#                         source = translation.get(SRC_LANG, "")
#                         target = translation.get(lang, "")
                        
#                         if source.strip() and target.strip():  # Ensure non-empty
#                             # Model-specific formatting
#                             if model_type == "t5":
#                                 # T5 style: "translate English to German: Hello"
#                                 source_formatted = f"translate English to {languages[lang]}: {source}"
#                             else:  # BART style
#                                 # BART style with language token
#                                 source_formatted = f"{source} </s> {lang}_XX"  # Add language code
                            
#                             sources.append(source_formatted)
#                             targets.append(target)
                
#                 if not sources or not targets:
#                     return {"input_ids": [], "attention_mask": [], "labels": []}
                
#                 # Tokenize with custom tokenizer
#                 max_length = 256  # INCREASED for better performance
                
#                 # Input tokenization - REMOVE token_type_ids
#                 model_inputs = tokenizer(
#                     sources,
#                     max_length=max_length,
#                     truncation=True,
#                     padding="max_length",
#                     return_tensors=None,
#                     return_token_type_ids=False  # CRITICAL FIX
#                 )
#                 # Ensure token_type_ids is removed
#                 model_inputs.pop("token_type_ids", None)
                
#                 # Target tokenization  
#                 with tokenizer.as_target_tokenizer():
#                     labels = tokenizer(
#                         targets,
#                         max_length=max_length,
#                         truncation=True,
#                         padding="max_length",
#                         return_tensors=None,
#                         return_token_type_ids=False  # CRITICAL FIX
#                     )
#                 # Ensure token_type_ids is removed
#                 labels.pop("token_type_ids", None)
                
#                 # IMPROVED: Label processing with proper EOS handling
#                 processed_labels = []
#                 for label_seq in labels["input_ids"]:
#                     # Find actual end of sequence (before padding)
#                     try:
#                         pad_start = label_seq.index(tokenizer.pad_token_id)
#                         actual_tokens = label_seq[:pad_start]
#                     except ValueError:
#                         actual_tokens = label_seq
                    
#                     # Ensure EOS token at end
#                     if actual_tokens and actual_tokens[-1] != tokenizer.eos_token_id:
#                         actual_tokens.append(tokenizer.eos_token_id)
#                     elif not actual_tokens:
#                         actual_tokens = [tokenizer.eos_token_id]
                    
#                     # Create final sequence with -100 for padding
#                     final_labels = actual_tokens + [-100] * (max_length - len(actual_tokens))
#                     final_labels = final_labels[:max_length]  # Ensure correct length
                    
#                     processed_labels.append(final_labels)
                
#                 model_inputs["labels"] = processed_labels
#                 return model_inputs
            
#             # Preprocess complete multilingual dataset
#             print("⚙️  Preprocessing complete multilingual dataset with custom tokenizer...")
#             processed_dataset = full_dataset.map(
#                 preprocess_multilingual_improved,
#                 batched=True,
#                 remove_columns=full_dataset["train"].column_names,
#                 desc=f"Preprocessing with {size}_{tok_type}",
#                 batch_size=50,  # Smaller batch for stability
#                 num_proc=1  # Single process to avoid issues
#             )
            
#             train_dataset = processed_dataset["train"]
#             eval_dataset = processed_dataset["test"]
            
#             # Filter out empty examples
#             def filter_valid_examples(example):
#                 return (
#                     len(example["input_ids"]) > 0 and 
#                     len(example["labels"]) > 0 and
#                     any(label != -100 for label in example["labels"]) and
#                     sum(1 for token in example["input_ids"] if token != tokenizer.pad_token_id) > 0
#                 )
            
#             train_dataset = train_dataset.filter(filter_valid_examples)
#             eval_dataset = eval_dataset.filter(filter_valid_examples)
            
#             # Use smaller eval dataset to avoid memory issues
#             print(f"✅ Preprocessed multilingual dataset:")
#             print(f"   • Train samples: {len(train_dataset)}")
#             print(f"   • Eval samples: {len(eval_dataset)}")
            
#             # REDUCE eval dataset size to avoid hanging
#             if len(eval_dataset) > 1000:
#                 eval_dataset = eval_dataset.select(range(1000))
#                 print(f"   • Reduced eval samples to: {len(eval_dataset)} (to avoid memory issues)")
            
#             if len(train_dataset) == 0 or len(eval_dataset) == 0:
#                 print("❌ No valid samples after preprocessing")
#                 continue
            
#             # Setup training directory
#             output_dir = f"./MT_models_multilingual_custom_tokenizers/{model_id}"
#             os.makedirs(output_dir, exist_ok=True)
            
#             # IMPROVED: Training arguments with simpler evaluation
#             training_args = Seq2SeqTrainingArguments(
#                 output_dir=output_dir,
#                 num_train_epochs=5,  # More epochs for custom tokenizers
#                 per_device_train_batch_size=2,  # Smaller batch size
#                 per_device_eval_batch_size=1,  # REDUCED eval batch size
#                 gradient_accumulation_steps=8,  # Higher accumulation
#                 learning_rate=1e-4,  # Lower learning rate for stability
#                 weight_decay=0.01,
#                 warmup_ratio=0.1,  # Warmup as ratio
#                 eval_strategy="steps",  # Change to steps-based evaluation
#                 eval_steps=500,  # Evaluate every 500 steps instead of epoch end
#                 save_strategy="epoch",
#                 save_total_limit=2,
#                 logging_steps=50,
#                 report_to="none",
#                 predict_with_generate=True,
#                 generation_max_length=128,  # REDUCED generation length
#                 generation_num_beams=2,  # REDUCED beams
#                 fp16=torch.cuda.is_available(),
#                 load_best_model_at_end=False,  # DISABLED to avoid issues
#                 dataloader_num_workers=0,
#                 remove_unused_columns=False,  # CRITICAL: Keep all columns
#                 ignore_data_skip=True,
#                 label_smoothing_factor=0.1,  # Label smoothing
#                 max_grad_norm=1.0,  # Gradient clipping
#                 dataloader_pin_memory=False,  # Disable pin memory
#                 skip_memory_metrics=True,  # Skip memory tracking
                
                
#             )
            
#             # Data collator with explicit token_type_ids handling
#             data_collator = DataCollatorForSeq2Seq(
#                 tokenizer=tokenizer,
#                 model=model,
#                 padding=True,
#                 pad_to_multiple_of=8 if training_args.fp16 else None,
#                 return_tensors="pt",
#                 label_pad_token_id=-100
#             )
            
#             # CRITICAL FIX: Custom data collator that removes token_type_ids
#             class CustomDataCollator(DataCollatorForSeq2Seq):
#                 def __call__(self, features):
#                     batch = super().__call__(features)
#                     # Remove token_type_ids if present
#                     batch.pop("token_type_ids", None)
#                     return batch
            
#             data_collator = CustomDataCollator(
#                 tokenizer=tokenizer,
#                 model=model,
#                 padding=True,
#                 pad_to_multiple_of=8 if training_args.fp16 else None,
#                 return_tensors="pt",
#                 label_pad_token_id=-100
#             )
            
#             # Trainer
#             trainer = Seq2SeqTrainer(
#                 model=model,
#                 args=training_args,
#                 train_dataset=train_dataset,
#                 eval_dataset=eval_dataset,
#                 tokenizer=tokenizer,
#                 data_collator=data_collator,
#                 compute_metrics=compute_multilingual_bleu
#             )
            
#             # Train multilingual model
#             print("🏋️  Starting multilingual training with compatible model-tokenizer pair...")
#             trainer.train()
            
#             # Evaluate
#             print("📊 Final multilingual evaluation...")
#             eval_results = trainer.evaluate()
            
#             # Save model and custom tokenizer
#             print("💾 Saving multilingual model and custom tokenizer...")
#             trainer.save_model()
#             tokenizer.save_pretrained(output_dir)
            
#             # Log results
#             overall_bleu = eval_results.get("eval_bleu", 0.0)
#             overall_exact_match = eval_results.get("eval_exact_match", 0.0)
#             avg_pred_len = eval_results.get("eval_avg_pred_length", 0.0)
#             avg_label_len = eval_results.get("eval_avg_label_length", 0.0)
#             empty_preds = eval_results.get("eval_empty_predictions", 0.0)
            
#             with open(results_log_file, 'a', newline='', encoding='utf-8') as csvfile:
#                 writer = csv.writer(csvfile)
#                 writer.writerow([
#                     model_id, model_config["base_model"], size, tok_type,
#                     len(languages), len(train_dataset), len(eval_dataset), len(tokenizer),
#                     round(overall_bleu, 4), round(overall_exact_match, 4), 
#                     round(avg_pred_len, 2), round(avg_label_len, 2),
#                     round(empty_preds, 4), "SUCCESS", 
#                     f"Compatible {model_config['description']} with {size}_{tok_type}"
#                 ])
            
#             print(f"✅ Completed: {model_id}")
#             print(f"📈 Overall BLEU Score: {overall_bleu:.4f}")
#             print(f"🎯 Overall Exact Match: {overall_exact_match:.4f}")
            
#             # Quick translation tests for different languages
#             print(f"\n🧪 Quick multilingual translation tests:")
#             test_input = "Hello, how are you today?"
            
#             for test_lang in ["ar", "zh", "hi"]:  # Test 3 languages
#                 if model_type == "t5":
#                     formatted_input = f"translate English to {languages[test_lang]}: {test_input}"
#                 else:  # BART
#                     formatted_input = f"{test_input} </s> {test_lang}_XX"
                    
#                 inputs = tokenizer(formatted_input, return_tensors="pt", padding=True, max_length=256, truncation=True)
                
#                 # Move to device
#                 device = next(model.parameters()).device
#                 inputs = {k: v.to(device) for k, v in inputs.items() if k != "token_type_ids"}
                
#                 with torch.no_grad():
#                     outputs = model.generate(
#                         **inputs,
#                         max_length=128,
#                         num_beams=4,
#                         early_stopping=True,
#                         do_sample=False,
#                         forced_eos_token_id=tokenizer.eos_token_id
#                     )
                
#                 translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
#                 print(f"  {test_lang} ({languages[test_lang]}): {translation}")
                
#         except Exception as e:
#             print(f"❌ Failed to train {model_id}: {str(e)}")
#             import traceback
#             traceback.print_exc()
            
#             with open(results_log_file, 'a', newline='', encoding='utf-8') as csvfile:
#                 writer = csv.writer(csvfile)
#                 writer.writerow([
#                     model_id, model_config.get("base_model", ""), size, tok_type,
#                     len(languages), 0, 0, 0, 0, 0, 0, 0, 0, "TRAINING_FAILED", str(e)[:100]
#                 ])
        
#         finally:
#             # Cleanup
#             if model is not None:
#                 del model
#             if trainer is not None:
#                 del trainer
#             torch.cuda.empty_cache()
#             gc.collect()

# print("\n🎉 Multilingual training with compatible model-tokenizer pairs completed!")
# print(f"📋 Results saved to: {results_log_file}")
# print(f"🔢 Total models trained: 9 (3 sizes × 3 types)")
# print(f"🌍 Each model handles all {len(languages)} languages simultaneously")
# print("\n📊 Expected improvements:")
# print("• BPE tokenizers → BART models (proper compatibility)")
# print("• WordPiece tokenizers → mT5 models (better multilingual support)")
# print("• Unigram tokenizers → mT5 models (native compatibility)")
# print("• Increased sequence length (256 vs 128)")
# print("• Better hyperparameters and training setup")

📦 Loading complete multilingual dataset...
✅ Dataset loaded: 34376 train, 3820 test
🌍 All languages included: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr']
🚀 Starting multilingual training with COMPATIBLE model-tokenizer pairs...
📊 Training approach: ONE model per tokenizer handling ALL 9 languages

🚀 Training Model 1/9
🔧 Custom Tokenizer: small_hf_bpe_hf
📦 Compatible Base Model: facebook/bart-large
🌍 Target Languages: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr'] (ALL SIMULTANEOUSLY)
🔧 Loading custom tokenizer: vocab_final/vocab_finalsmall/hf_bpe_hf
✅ Custom tokenizer loaded successfully!
📊 Custom vocab size: 15002
🔑 Special tokens - EOS: 15001, PAD: 0
🤖 Loading compatible base model: facebook/bart-large
🔄 Resizing model embeddings to match custom tokenizer...
   Vocab size: 50265 → 15002
✅ Model configured with custom tokenizer!
📊 Using complete multilingual dataset:
   • Train samples: 34376
   • Test samples: 3820
   • Languages: 9 (['yo', 'ar', 'zh', 'ru', 'hi'

KeyboardInterrupt: 

In [5]:
##working ... but slow
import os
import csv
import torch
import warnings
import gc
from datasets import load_from_disk
from transformers import (
    EncoderDecoderModel,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    logging as hf_logging,
    MBartForConditionalGeneration,
    MBart50TokenizerFast,
    T5ForConditionalGeneration,
    T5TokenizerFast,
    BertTokenizerFast,
    GPT2TokenizerFast,
    BartForConditionalGeneration
)
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import numpy as np

# Suppress warnings
hf_logging.set_verbosity_error()
warnings.filterwarnings("ignore")

# FIXED: Model-Tokenizer Compatibility Mapping
MODEL_TOKENIZER_CONFIGS = {
    "hf_bpe_hf": {
        "base_model": "facebook/bart-large",  # BART uses BPE
        "model_class": BartForConditionalGeneration,
        "description": "BART with BPE tokenization"
    },
    "hf_wordpiece_hf": {
        "base_model": "google/mt5-base",  # mT5 can work with WordPiece
        "model_class": T5ForConditionalGeneration,
        "description": "mT5 with WordPiece tokenization"
    },
    "sp_unigram_hf": {
        "base_model": "google/mt5-base",  # mT5 uses SentencePiece Unigram
        "model_class": T5ForConditionalGeneration,
        "description": "mT5 with SentencePiece Unigram"
    }
}

# Custom tokenizer settings
tokenizer_sizes = ["small", "medium", "large"]
tokenizer_types = ["hf_bpe_hf", "hf_wordpiece_hf", "sp_unigram_hf"]

# Languages for evaluation
languages = {
    "yo": "Yoruba",
    "ar": "Arabic", 
    "zh": "Chinese",
    "ru": "Russian",
    "hi": "Hindi",
    "ja": "Japanese",
    "swa": "Swahili",
    "bn": "Bengali",
    "tr": "Turkish"
}
SRC_LANG = "en"

# Load complete multilingual dataset
dataset_path = "balanced_mt_dataset"
print("📦 Loading complete multilingual dataset...")
full_dataset = load_from_disk(dataset_path)
print(f"✅ Dataset loaded: {len(full_dataset['train'])} train, {len(full_dataset['test'])} test")
print(f"🌍 All languages included: {list(languages.keys())}")

# Enhanced BLEU computation with progress tracking
def compute_multilingual_bleu(eval_pred):
    """Multilingual BLEU computation with progress tracking"""
    import time
    start_time = time.time()
    
    predictions, labels = eval_pred
    
    if len(predictions.shape) == 3:
        predictions = np.argmax(predictions, axis=-1)
    
    decoded_preds = []
    decoded_labels = []
    
    total_samples = len(predictions)
    print(f"🔄 Starting evaluation of {total_samples} samples...")
    
    # Process in chunks with progress updates
    chunk_size = 50
    for i in range(0, total_samples, chunk_size):
        end_idx = min(i + chunk_size, total_samples)
        chunk_preds = predictions[i:end_idx]
        chunk_labels = labels[i:end_idx]
        
        try:
            for pred, label in zip(chunk_preds, chunk_labels):
                # CRITICAL FIX: Handle negative token IDs properly
                # Filter out negative IDs and pad tokens before decoding
                pred_clean = [token for token in pred if token >= 0 and token < len(tokenizer)]
                label_clean = [token for token in label if token != -100 and token >= 0 and token < len(tokenizer)]
                
                try:
                    decoded_pred = tokenizer.decode(pred_clean, skip_special_tokens=True).strip() if pred_clean else ""
                    decoded_label = tokenizer.decode(label_clean, skip_special_tokens=True).strip() if label_clean else ""
                except Exception as e:
                    # Fallback for any decode errors
                    decoded_pred = ""
                    decoded_label = ""
                
                decoded_preds.append(decoded_pred)
                decoded_labels.append(decoded_label)
            
            # Progress update every chunk
            if (i + chunk_size) % (chunk_size * 5) == 0 or end_idx == total_samples:
                elapsed = time.time() - start_time
                print(f"📊 Evaluation progress: {end_idx}/{total_samples} samples ({elapsed:.1f}s elapsed)")
                
        except Exception as e:
            print(f"⚠️  Batch decode failed for chunk {i}-{end_idx}: {str(e)}")
            # Add empty strings for failed batch
            for _ in range(end_idx - i):
                decoded_preds.append("")
                decoded_labels.append("")
    
    # Compute BLEU with smoothing
    print("🧮 Computing BLEU scores...")
    smoothing = SmoothingFunction().method1
    bleu_scores = []
    exact_matches = 0
    
    for idx, (pred, label) in enumerate(zip(decoded_preds, decoded_labels)):
        if idx % 1000 == 0 and idx > 0:
            print(f"   BLEU calculation: {idx}/{len(decoded_preds)} samples")
            
        if not pred.strip() or not label.strip():
            bleu_scores.append(0.0)
            continue
            
        pred_tokens = pred.split()
        label_tokens = label.split()
        
        # Check exact match
        if pred.lower().strip() == label.lower().strip():
            exact_matches += 1
        
        if len(pred_tokens) == 0 or len(label_tokens) == 0:
            bleu_scores.append(0.0)
            continue
        
        try:
            bleu = sentence_bleu(
                [label_tokens], 
                pred_tokens,
                smoothing_function=smoothing,
                weights=(0.25, 0.25, 0.25, 0.25)
            )
            bleu_scores.append(bleu)
        except:
            bleu_scores.append(0.0)
    
    avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0.0
    exact_match_ratio = exact_matches / len(decoded_preds) if decoded_preds else 0.0
    
    total_time = time.time() - start_time
    print(f"✅ Evaluation complete: BLEU={avg_bleu:.4f}, Exact Match={exact_match_ratio:.4f} ({total_time:.1f}s total)")
    
    return {
        "bleu": avg_bleu,
        "exact_match": exact_match_ratio,
        "avg_pred_length": np.mean([len(p.split()) for p in decoded_preds if p.strip()]) if decoded_preds else 0.0,
        "avg_label_length": np.mean([len(l.split()) for l in decoded_labels if l.strip()]) if decoded_labels else 0.0,
        "empty_predictions": sum(1 for p in decoded_preds if not p.strip()) / len(decoded_preds) if decoded_preds else 0.0
    }

def setup_custom_tokenizer_for_model(tokenizer_path, model_type):
    """Setup custom tokenizer with proper special tokens for specific model"""
    # Load custom tokenizer
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
    
    # Model-specific special token configuration
    if model_type == "bart":
        special_tokens = {
            'bos_token': '<s>',
            'eos_token': '</s>',
            'sep_token': '</s>',
            'pad_token': '<pad>',
            'unk_token': '<unk>',
            'mask_token': '<mask>'
        }
    elif model_type == "t5":
        special_tokens = {
            'pad_token': '<pad>',
            'eos_token': '</s>',
            'unk_token': '<unk>',
            'bos_token': '<pad>',  # T5 doesn't use BOS
            'sep_token': '</s>',
            'mask_token': '<extra_id_0>'
        }
    else:  # Default fallback
        special_tokens = {
            'bos_token': '<s>',
            'eos_token': '</s>',
            'sep_token': '</s>',
            'pad_token': '<pad>',
            'unk_token': '<unk>',
            'mask_token': '<mask>'
        }
    
    # Add missing special tokens
    tokens_to_add = {}
    for token_name, token_value in special_tokens.items():
        if getattr(tokenizer, token_name, None) is None:
            tokens_to_add[token_name] = token_value
    
    if tokens_to_add:
        tokenizer.add_special_tokens(tokens_to_add)
    
    # Ensure we have the required tokens
    assert tokenizer.eos_token is not None, "EOS token is required"
    assert tokenizer.pad_token is not None, "PAD token is required"
    
    return tokenizer

def configure_model_for_custom_tokenizer(model, tokenizer, model_type):
    """Configure model for custom tokenizer based on model type"""
    # Resize embeddings
    print("🔄 Resizing model embeddings to match custom tokenizer...")
    old_vocab_size = model.config.vocab_size
    model.resize_token_embeddings(len(tokenizer))
    print(f"   Vocab size: {old_vocab_size} → {len(tokenizer)}")
    
    # Model-specific configuration
    if model_type == "bart":
        model.config.decoder_start_token_id = tokenizer.bos_token_id
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.bos_token_id = tokenizer.bos_token_id
        model.config.eos_token_id = tokenizer.eos_token_id
        model.config.sep_token_id = tokenizer.eos_token_id
        model.config.forced_eos_token_id = tokenizer.eos_token_id
        
    elif model_type == "t5":
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.eos_token_id = tokenizer.eos_token_id
        model.config.decoder_start_token_id = tokenizer.pad_token_id  # T5 uses pad_token_id as decoder start
        model.config.forced_eos_token_id = tokenizer.eos_token_id
    
    # Update generation config if it exists
    if hasattr(model, 'generation_config') and model.generation_config is not None:
        model.generation_config.pad_token_id = tokenizer.pad_token_id
        model.generation_config.eos_token_id = tokenizer.eos_token_id
        model.generation_config.forced_eos_token_id = tokenizer.eos_token_id
        if model_type == "bart":
            model.generation_config.decoder_start_token_id = tokenizer.bos_token_id
            model.generation_config.bos_token_id = tokenizer.bos_token_id
        elif model_type == "t5":
            model.generation_config.decoder_start_token_id = tokenizer.pad_token_id
    
    return model

# Setup logging
log_dir = "./MT_models_multilingual_custom_tokenizers"
os.makedirs(log_dir, exist_ok=True)
results_log_file = os.path.join(log_dir, "multilingual_custom_tokenizer_results.csv")

if not os.path.exists(results_log_file):
    with open(results_log_file, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["Model_ID", "Base_Model", "Tokenizer_Size", "Tokenizer_Type", 
                        "Total_Languages", "Total_Train_Samples", "Total_Test_Samples", 
                        "Custom_Vocab_Size", "Overall_BLEU", "Overall_Exact_Match", 
                        "Avg_Pred_Length", "Avg_Label_Length", "Empty_Predictions", 
                        "Training_Status", "Notes"])

print("🚀 Starting multilingual training with COMPATIBLE model-tokenizer pairs...")
print(f"📊 Training approach: ONE model per tokenizer handling ALL {len(languages)} languages")

# Main training loop: 3 sizes × 3 types = 9 models total
for size in tokenizer_sizes:
    for tok_type in tokenizer_types:
        # Get compatible model configuration
        model_config = MODEL_TOKENIZER_CONFIGS[tok_type]
        
        # Path to custom tokenizer
        tokenizer_path = f"vocab_final/vocab_final{size}/{tok_type}"
        model_id = f"multilingual_{size}_{tok_type}"
        
        print(f"\n{'='*100}")
        print(f"🚀 Training Model {tokenizer_sizes.index(size)*3 + tokenizer_types.index(tok_type) + 1}/9")
        print(f"🔧 Custom Tokenizer: {size}_{tok_type}")
        print(f"📦 Compatible Base Model: {model_config['base_model']}")
        print(f"🌍 Target Languages: {list(languages.keys())} (ALL SIMULTANEOUSLY)")
        
        # Initialize variables
        model = None
        tokenizer = None
        trainer = None
        
        try:
            # Load and setup custom tokenizer
            print(f"🔧 Loading custom tokenizer: {tokenizer_path}")
            model_type = "bart" if "bart" in model_config['base_model'] else "t5"
            tokenizer = setup_custom_tokenizer_for_model(tokenizer_path, model_type)
            
            print(f"✅ Custom tokenizer loaded successfully!")
            print(f"📊 Custom vocab size: {len(tokenizer)}")
            print(f"🔑 Special tokens - EOS: {tokenizer.eos_token_id}, PAD: {tokenizer.pad_token_id}")
            
            # Load compatible base model
            print(f"🤖 Loading compatible base model: {model_config['base_model']}")
            model = model_config['model_class'].from_pretrained(model_config['base_model'])
            
            # Configure model for custom tokenizer
            model = configure_model_for_custom_tokenizer(model, tokenizer, model_type)
            
            print(f"✅ Model configured with custom tokenizer!")
            
            # Use COMPLETE multilingual dataset (all languages together)
            print(f"📊 Using complete multilingual dataset:")
            print(f"   • Train samples: {len(full_dataset['train'])}")
            print(f"   • Test samples: {len(full_dataset['test'])}")
            print(f"   • Languages: {len(languages)} ({list(languages.keys())})")
            
            # IMPROVED: Preprocessing function for multilingual data
            def preprocess_multilingual_improved(examples):
                """Improved preprocessing for multilingual data with model-specific formatting"""
                sources = []
                targets = []
                
                # Handle both single examples and batches
                if not isinstance(examples["translation"], list):
                    examples = {
                        "translation": [examples["translation"]], 
                        "language": [examples["language"]]
                    }
                
                for translation, lang in zip(examples["translation"], examples["language"]):
                    if isinstance(translation, dict) and lang in languages:
                        source = translation.get(SRC_LANG, "")
                        target = translation.get(lang, "")
                        
                        if source.strip() and target.strip():  # Ensure non-empty
                            # Model-specific formatting
                            if model_type == "t5":
                                # T5 style: "translate English to German: Hello"
                                source_formatted = f"translate English to {languages[lang]}: {source}"
                            else:  # BART style
                                # BART style with language token
                                source_formatted = f"{source} </s> {lang}_XX"  # Add language code
                            
                            sources.append(source_formatted)
                            targets.append(target)
                
                if not sources or not targets:
                    return {"input_ids": [], "attention_mask": [], "labels": []}
                
                # Tokenize with custom tokenizer
                max_length = 256  # INCREASED for better performance
                
                # Input tokenization - REMOVE token_type_ids
                model_inputs = tokenizer(
                    sources,
                    max_length=max_length,
                    truncation=True,
                    padding="max_length",
                    return_tensors=None,
                    return_token_type_ids=False  # CRITICAL FIX
                )
                # Ensure token_type_ids is removed
                model_inputs.pop("token_type_ids", None)
                
                # Target tokenization  
                with tokenizer.as_target_tokenizer():
                    labels = tokenizer(
                        targets,
                        max_length=max_length,
                        truncation=True,
                        padding="max_length",
                        return_tensors=None,
                        return_token_type_ids=False  # CRITICAL FIX
                    )
                # Ensure token_type_ids is removed
                labels.pop("token_type_ids", None)
                
                # IMPROVED: Label processing with proper EOS handling
                processed_labels = []
                for label_seq in labels["input_ids"]:
                    # Find actual end of sequence (before padding)
                    try:
                        pad_start = label_seq.index(tokenizer.pad_token_id)
                        actual_tokens = label_seq[:pad_start]
                    except ValueError:
                        actual_tokens = label_seq
                    
                    # Ensure EOS token at end
                    if actual_tokens and actual_tokens[-1] != tokenizer.eos_token_id:
                        actual_tokens.append(tokenizer.eos_token_id)
                    elif not actual_tokens:
                        actual_tokens = [tokenizer.eos_token_id]
                    
                    # Create final sequence with -100 for padding
                    final_labels = actual_tokens + [-100] * (max_length - len(actual_tokens))
                    final_labels = final_labels[:max_length]  # Ensure correct length
                    
                    processed_labels.append(final_labels)
                
                model_inputs["labels"] = processed_labels
                return model_inputs
            
            # Preprocess complete multilingual dataset
            print("⚙️  Preprocessing complete multilingual dataset with custom tokenizer...")
            processed_dataset = full_dataset.map(
                preprocess_multilingual_improved,
                batched=True,
                remove_columns=full_dataset["train"].column_names,
                desc=f"Preprocessing with {size}_{tok_type}",
                batch_size=50,  # Smaller batch for stability
                num_proc=1  # Single process to avoid issues
            )
            
            train_dataset = processed_dataset["train"]
            eval_dataset = processed_dataset["test"]
            
            # Filter out empty examples
            def filter_valid_examples(example):
                return (
                    len(example["input_ids"]) > 0 and 
                    len(example["labels"]) > 0 and
                    any(label != -100 for label in example["labels"]) and
                    sum(1 for token in example["input_ids"] if token != tokenizer.pad_token_id) > 0
                )
            
            train_dataset = train_dataset.filter(filter_valid_examples)
            eval_dataset = eval_dataset.filter(filter_valid_examples)
            
            # Use smaller eval dataset to avoid memory issues
            print(f"✅ Preprocessed multilingual dataset:")
            print(f"   • Train samples: {len(train_dataset)}")
            print(f"   • Eval samples: {len(eval_dataset)}")
            
            # REDUCE eval dataset size to avoid hanging
            if len(eval_dataset) > 1000:
                eval_dataset = eval_dataset.select(range(1000))
                print(f"   • Reduced eval samples to: {len(eval_dataset)} (to avoid memory issues)")
            
            if len(train_dataset) == 0 or len(eval_dataset) == 0:
                print("❌ No valid samples after preprocessing")
                continue
            
            # Setup training directory
            output_dir = f"./MT_models_multilingual_custom_tokenizers/{model_id}"
            os.makedirs(output_dir, exist_ok=True)
            
            # IMPROVED: Training arguments with simpler evaluation
            training_args = Seq2SeqTrainingArguments(
                output_dir=output_dir,
                num_train_epochs=1,  # More epochs for custom tokenizers
                per_device_train_batch_size=2,  # Smaller batch size
                per_device_eval_batch_size=1,  # REDUCED eval batch size
                gradient_accumulation_steps=8,  # Higher accumulation
                learning_rate=1e-4,  # Lower learning rate for stability
                weight_decay=0.01,
                warmup_ratio=0.1,  # Warmup as ratio
                eval_strategy="steps",  # Change to steps-based evaluation
                eval_steps=500,  # Evaluate every 500 steps instead of epoch end
                save_strategy="epoch",
                save_total_limit=2,
                logging_steps=50,
                report_to="none",
                predict_with_generate=True,
                generation_max_length=128,  # REDUCED generation length
                generation_num_beams=2,  # REDUCED beams
                fp16=torch.cuda.is_available(),
                load_best_model_at_end=False,  # DISABLED to avoid issues
                dataloader_num_workers=0,
                remove_unused_columns=False,  # CRITICAL: Keep all columns
                ignore_data_skip=True,
                label_smoothing_factor=0.1,  # Label smoothing
                max_grad_norm=1.0,  # Gradient clipping
                dataloader_pin_memory=False,  # Disable pin memory
                skip_memory_metrics=True,  # Skip memory tracking
                
                
            )
            
            # Data collator with explicit token_type_ids handling
            data_collator = DataCollatorForSeq2Seq(
                tokenizer=tokenizer,
                model=model,
                padding=True,
                pad_to_multiple_of=8 if training_args.fp16 else None,
                return_tensors="pt",
                label_pad_token_id=-100
            )
            
            # CRITICAL FIX: Custom data collator that removes token_type_ids
            class CustomDataCollator(DataCollatorForSeq2Seq):
                def __call__(self, features):
                    batch = super().__call__(features)
                    # Remove token_type_ids if present
                    batch.pop("token_type_ids", None)
                    return batch
            
            data_collator = CustomDataCollator(
                tokenizer=tokenizer,
                model=model,
                padding=True,
                pad_to_multiple_of=8 if training_args.fp16 else None,
                return_tensors="pt",
                label_pad_token_id=-100
            )
            
            # Trainer
            trainer = Seq2SeqTrainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                tokenizer=tokenizer,
                data_collator=data_collator,
                compute_metrics=compute_multilingual_bleu
            )
            
            # Train multilingual model
            print("🏋️  Starting multilingual training with compatible model-tokenizer pair...")
            trainer.train()
            
            # Evaluate
            print("📊 Final multilingual evaluation...")
            eval_results = trainer.evaluate()
            
            # Save model and custom tokenizer
            print("💾 Saving multilingual model and custom tokenizer...")
            trainer.save_model()
            tokenizer.save_pretrained(output_dir)
            
            # Log results
            overall_bleu = eval_results.get("eval_bleu", 0.0)
            overall_exact_match = eval_results.get("eval_exact_match", 0.0)
            avg_pred_len = eval_results.get("eval_avg_pred_length", 0.0)
            avg_label_len = eval_results.get("eval_avg_label_length", 0.0)
            empty_preds = eval_results.get("eval_empty_predictions", 0.0)
            
            with open(results_log_file, 'a', newline='', encoding='utf-8') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([
                    model_id, model_config["base_model"], size, tok_type,
                    len(languages), len(train_dataset), len(eval_dataset), len(tokenizer),
                    round(overall_bleu, 4), round(overall_exact_match, 4), 
                    round(avg_pred_len, 2), round(avg_label_len, 2),
                    round(empty_preds, 4), "SUCCESS", 
                    f"Compatible {model_config['description']} with {size}_{tok_type}"
                ])
            
            print(f"✅ Completed: {model_id}")
            print(f"📈 Overall BLEU Score: {overall_bleu:.4f}")
            print(f"🎯 Overall Exact Match: {overall_exact_match:.4f}")
            
            # Quick translation tests for different languages
            print(f"\n🧪 Quick multilingual translation tests:")
            test_input = "Hello, how are you today?"
            
            for test_lang in ["ar", "zh", "hi"]:  # Test 3 languages
                if model_type == "t5":
                    formatted_input = f"translate English to {languages[test_lang]}: {test_input}"
                else:  # BART
                    formatted_input = f"{test_input} </s> {test_lang}_XX"
                    
                inputs = tokenizer(formatted_input, return_tensors="pt", padding=True, max_length=256, truncation=True)
                
                # Move to device
                device = next(model.parameters()).device
                inputs = {k: v.to(device) for k, v in inputs.items() if k != "token_type_ids"}
                
                with torch.no_grad():
                    outputs = model.generate(
                        **inputs,
                        max_length=128,
                        num_beams=4,
                        early_stopping=True,
                        do_sample=False,
                        forced_eos_token_id=tokenizer.eos_token_id
                    )
                
                translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
                print(f"  {test_lang} ({languages[test_lang]}): {translation}")
                
        except Exception as e:
            print(f"❌ Failed to train {model_id}: {str(e)}")
            import traceback
            traceback.print_exc()
            
            with open(results_log_file, 'a', newline='', encoding='utf-8') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([
                    model_id, model_config.get("base_model", ""), size, tok_type,
                    len(languages), 0, 0, 0, 0, 0, 0, 0, 0, "TRAINING_FAILED", str(e)[:100]
                ])
        
        finally:
            # Cleanup
            if model is not None:
                del model
            if trainer is not None:
                del trainer
            torch.cuda.empty_cache()
            gc.collect()

print("\n🎉 Multilingual training with compatible model-tokenizer pairs completed!")
print(f"📋 Results saved to: {results_log_file}")
print(f"🔢 Total models trained: 9 (3 sizes × 3 types)")
print(f"🌍 Each model handles all {len(languages)} languages simultaneously")
print("\n📊 Expected improvements:")
print("• BPE tokenizers → BART models (proper compatibility)")
print("• WordPiece tokenizers → mT5 models (better multilingual support)")
print("• Unigram tokenizers → mT5 models (native compatibility)")
print("• Increased sequence length (256 vs 128)")
print("• Better hyperparameters and training setup")



📦 Loading complete multilingual dataset...
✅ Dataset loaded: 34376 train, 3820 test
🌍 All languages included: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr']
🚀 Starting multilingual training with COMPATIBLE model-tokenizer pairs...
📊 Training approach: ONE model per tokenizer handling ALL 9 languages

🚀 Training Model 1/9
🔧 Custom Tokenizer: small_hf_bpe_hf
📦 Compatible Base Model: facebook/bart-large
🌍 Target Languages: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr'] (ALL SIMULTANEOUSLY)
🔧 Loading custom tokenizer: vocab_final/vocab_finalsmall/hf_bpe_hf
✅ Custom tokenizer loaded successfully!
📊 Custom vocab size: 15002
🔑 Special tokens - EOS: 15001, PAD: 0
🤖 Loading compatible base model: facebook/bart-large
🔄 Resizing model embeddings to match custom tokenizer...
   Vocab size: 50265 → 15002
✅ Model configured with custom tokenizer!
📊 Using complete multilingual dataset:
   • Train samples: 34376
   • Test samples: 3820
   • Languages: 9 (['yo', 'ar', 'zh', 'ru', 'hi'

KeyboardInterrupt: 

In [None]:
##working ... but slow
import os
import csv
import torch
import warnings
import gc
from datasets import load_from_disk
from transformers import (
    EncoderDecoderModel,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    logging as hf_logging,
    MBartForConditionalGeneration,
    MBart50TokenizerFast,
    T5ForConditionalGeneration,
    T5TokenizerFast,
    BertTokenizerFast,
    GPT2TokenizerFast,
    BartForConditionalGeneration
)
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import numpy as np

# Suppress warnings
hf_logging.set_verbosity_error()
warnings.filterwarnings("ignore")

# FIXED: Model-Tokenizer Compatibility Mapping
MODEL_TOKENIZER_CONFIGS = {
    "hf_bpe_hf": {
        "base_model": "facebook/bart-large",  # BART uses BPE
        "model_class": BartForConditionalGeneration,
        "description": "BART with BPE tokenization"
    },
    "hf_wordpiece_hf": {
        "base_model": "google/mt5-base",  # mT5 can work with WordPiece
        "model_class": T5ForConditionalGeneration,
        "description": "mT5 with WordPiece tokenization"
    },
    "sp_unigram_hf": {
        "base_model": "google/mt5-base",  # mT5 uses SentencePiece Unigram
        "model_class": T5ForConditionalGeneration,
        "description": "mT5 with SentencePiece Unigram"
    }
}

# Custom tokenizer settings
tokenizer_sizes = ["small", "medium", "large"]
tokenizer_types = ["hf_bpe_hf", "hf_wordpiece_hf", "sp_unigram_hf"]

# Languages for evaluation
languages = {
    "yo": "Yoruba",
    "ar": "Arabic", 
    "zh": "Chinese",
    "ru": "Russian",
    "hi": "Hindi",
    "ja": "Japanese",
    "swa": "Swahili",
    "bn": "Bengali",
    "tr": "Turkish"
}
SRC_LANG = "en"

# Load complete multilingual dataset
dataset_path = "balanced_mt_dataset"
print("📦 Loading complete multilingual dataset...")
full_dataset = load_from_disk(dataset_path)
print(f"✅ Dataset loaded: {len(full_dataset['train'])} train, {len(full_dataset['test'])} test")
print(f"🌍 All languages included: {list(languages.keys())}")

# Enhanced BLEU computation with progress tracking
def compute_multilingual_bleu(eval_pred):
    """Multilingual BLEU computation with progress tracking"""
    import time
    start_time = time.time()
    
    predictions, labels = eval_pred
    
    if len(predictions.shape) == 3:
        predictions = np.argmax(predictions, axis=-1)
    
    decoded_preds = []
    decoded_labels = []
    
    total_samples = len(predictions)
    print(f"🔄 Starting evaluation of {total_samples} samples...")
    
    # Process in chunks with progress updates
    chunk_size = 50
    for i in range(0, total_samples, chunk_size):
        end_idx = min(i + chunk_size, total_samples)
        chunk_preds = predictions[i:end_idx]
        chunk_labels = labels[i:end_idx]
        
        try:
            for pred, label in zip(chunk_preds, chunk_labels):
                # CRITICAL FIX: Handle negative token IDs properly
                # Filter out negative IDs and pad tokens before decoding
                pred_clean = [token for token in pred if token >= 0 and token < len(tokenizer)]
                label_clean = [token for token in label if token != -100 and token >= 0 and token < len(tokenizer)]
                
                try:
                    decoded_pred = tokenizer.decode(pred_clean, skip_special_tokens=True).strip() if pred_clean else ""
                    decoded_label = tokenizer.decode(label_clean, skip_special_tokens=True).strip() if label_clean else ""
                except Exception as e:
                    # Fallback for any decode errors
                    decoded_pred = ""
                    decoded_label = ""
                
                decoded_preds.append(decoded_pred)
                decoded_labels.append(decoded_label)
            
            # Progress update every chunk
            if (i + chunk_size) % (chunk_size * 5) == 0 or end_idx == total_samples:
                elapsed = time.time() - start_time
                print(f"📊 Evaluation progress: {end_idx}/{total_samples} samples ({elapsed:.1f}s elapsed)")
                
        except Exception as e:
            print(f"⚠️  Batch decode failed for chunk {i}-{end_idx}: {str(e)}")
            # Add empty strings for failed batch
            for _ in range(end_idx - i):
                decoded_preds.append("")
                decoded_labels.append("")
    
    # Compute BLEU with smoothing
    print("🧮 Computing BLEU scores...")
    smoothing = SmoothingFunction().method1
    bleu_scores = []
    exact_matches = 0
    
    for idx, (pred, label) in enumerate(zip(decoded_preds, decoded_labels)):
        if idx % 1000 == 0 and idx > 0:
            print(f"   BLEU calculation: {idx}/{len(decoded_preds)} samples")
            
        if not pred.strip() or not label.strip():
            bleu_scores.append(0.0)
            continue
            
        pred_tokens = pred.split()
        label_tokens = label.split()
        
        # Check exact match
        if pred.lower().strip() == label.lower().strip():
            exact_matches += 1
        
        if len(pred_tokens) == 0 or len(label_tokens) == 0:
            bleu_scores.append(0.0)
            continue
        
        try:
            bleu = sentence_bleu(
                [label_tokens], 
                pred_tokens,
                smoothing_function=smoothing,
                weights=(0.25, 0.25, 0.25, 0.25)
            )
            bleu_scores.append(bleu)
        except:
            bleu_scores.append(0.0)
    
    avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0.0
    exact_match_ratio = exact_matches / len(decoded_preds) if decoded_preds else 0.0
    
    total_time = time.time() - start_time
    print(f"✅ Evaluation complete: BLEU={avg_bleu:.4f}, Exact Match={exact_match_ratio:.4f} ({total_time:.1f}s total)")
    
    return {
        "bleu": avg_bleu,
        "exact_match": exact_match_ratio,
        "avg_pred_length": np.mean([len(p.split()) for p in decoded_preds if p.strip()]) if decoded_preds else 0.0,
        "avg_label_length": np.mean([len(l.split()) for l in decoded_labels if l.strip()]) if decoded_labels else 0.0,
        "empty_predictions": sum(1 for p in decoded_preds if not p.strip()) / len(decoded_preds) if decoded_preds else 0.0
    }

def setup_custom_tokenizer_for_model(tokenizer_path, model_type):
    """Setup custom tokenizer with proper special tokens for specific model"""
    # Load custom tokenizer
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
    
    # Model-specific special token configuration
    if model_type == "bart":
        special_tokens = {
            'bos_token': '<s>',
            'eos_token': '</s>',
            'sep_token': '</s>',
            'pad_token': '<pad>',
            'unk_token': '<unk>',
            'mask_token': '<mask>'
        }
    elif model_type == "t5":
        special_tokens = {
            'pad_token': '<pad>',
            'eos_token': '</s>',
            'unk_token': '<unk>',
            'bos_token': '<pad>',  # T5 doesn't use BOS
            'sep_token': '</s>',
            'mask_token': '<extra_id_0>'
        }
    else:  # Default fallback
        special_tokens = {
            'bos_token': '<s>',
            'eos_token': '</s>',
            'sep_token': '</s>',
            'pad_token': '<pad>',
            'unk_token': '<unk>',
            'mask_token': '<mask>'
        }
    
    # Add missing special tokens
    tokens_to_add = {}
    for token_name, token_value in special_tokens.items():
        if getattr(tokenizer, token_name, None) is None:
            tokens_to_add[token_name] = token_value
    
    if tokens_to_add:
        tokenizer.add_special_tokens(tokens_to_add)
    
    # Ensure we have the required tokens
    assert tokenizer.eos_token is not None, "EOS token is required"
    assert tokenizer.pad_token is not None, "PAD token is required"
    
    return tokenizer

def configure_model_for_custom_tokenizer(model, tokenizer, model_type):
    """Configure model for custom tokenizer based on model type"""
    # Resize embeddings
    print("🔄 Resizing model embeddings to match custom tokenizer...")
    old_vocab_size = model.config.vocab_size
    model.resize_token_embeddings(len(tokenizer))
    print(f"   Vocab size: {old_vocab_size} → {len(tokenizer)}")
    
    # Model-specific configuration
    if model_type == "bart":
        model.config.decoder_start_token_id = tokenizer.bos_token_id
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.bos_token_id = tokenizer.bos_token_id
        model.config.eos_token_id = tokenizer.eos_token_id
        model.config.sep_token_id = tokenizer.eos_token_id
        model.config.forced_eos_token_id = tokenizer.eos_token_id
        
    elif model_type == "t5":
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.eos_token_id = tokenizer.eos_token_id
        model.config.decoder_start_token_id = tokenizer.pad_token_id  # T5 uses pad_token_id as decoder start
        model.config.forced_eos_token_id = tokenizer.eos_token_id
    
    # Update generation config if it exists
    if hasattr(model, 'generation_config') and model.generation_config is not None:
        model.generation_config.pad_token_id = tokenizer.pad_token_id
        model.generation_config.eos_token_id = tokenizer.eos_token_id
        model.generation_config.forced_eos_token_id = tokenizer.eos_token_id
        if model_type == "bart":
            model.generation_config.decoder_start_token_id = tokenizer.bos_token_id
            model.generation_config.bos_token_id = tokenizer.bos_token_id
        elif model_type == "t5":
            model.generation_config.decoder_start_token_id = tokenizer.pad_token_id
    
    return model

# Setup logging
log_dir = "./MT_models_multilingual_custom_tokenizers"
os.makedirs(log_dir, exist_ok=True)
results_log_file = os.path.join(log_dir, "multilingual_custom_tokenizer_results.csv")

if not os.path.exists(results_log_file):
    with open(results_log_file, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["Model_ID", "Base_Model", "Tokenizer_Size", "Tokenizer_Type", 
                        "Total_Languages", "Total_Train_Samples", "Total_Test_Samples", 
                        "Custom_Vocab_Size", "Overall_BLEU", "Overall_Exact_Match", 
                        "Avg_Pred_Length", "Avg_Label_Length", "Empty_Predictions", 
                        "Training_Status", "Notes"])

print("🚀 Starting multilingual training with COMPATIBLE model-tokenizer pairs...")
print(f"📊 Training approach: ONE model per tokenizer handling ALL {len(languages)} languages")

# Main training loop: 3 sizes × 3 types = 9 models total
for size in tokenizer_sizes:
    for tok_type in tokenizer_types:
        # Get compatible model configuration
        model_config = MODEL_TOKENIZER_CONFIGS[tok_type]
        
        # Path to custom tokenizer
        tokenizer_path = f"vocab_final/vocab_final{size}/{tok_type}"
        model_id = f"multilingual_{size}_{tok_type}"
        
        print(f"\n{'='*100}")
        print(f"🚀 Training Model {tokenizer_sizes.index(size)*3 + tokenizer_types.index(tok_type) + 1}/9")
        print(f"🔧 Custom Tokenizer: {size}_{tok_type}")
        print(f"📦 Compatible Base Model: {model_config['base_model']}")
        print(f"🌍 Target Languages: {list(languages.keys())} (ALL SIMULTANEOUSLY)")
        
        # Initialize variables
        model = None
        tokenizer = None
        trainer = None
        
        try:
            # Load and setup custom tokenizer
            print(f"🔧 Loading custom tokenizer: {tokenizer_path}")
            model_type = "bart" if "bart" in model_config['base_model'] else "t5"
            tokenizer = setup_custom_tokenizer_for_model(tokenizer_path, model_type)
            
            print(f"✅ Custom tokenizer loaded successfully!")
            print(f"📊 Custom vocab size: {len(tokenizer)}")
            print(f"🔑 Special tokens - EOS: {tokenizer.eos_token_id}, PAD: {tokenizer.pad_token_id}")
            
            # Load compatible base model
            print(f"🤖 Loading compatible base model: {model_config['base_model']}")
            model = model_config['model_class'].from_pretrained(model_config['base_model'])
            
            # Configure model for custom tokenizer
            model = configure_model_for_custom_tokenizer(model, tokenizer, model_type)
            
            print(f"✅ Model configured with custom tokenizer!")
            
            # Use COMPLETE multilingual dataset (all languages together)
            print(f"📊 Using complete multilingual dataset:")
            print(f"   • Train samples: {len(full_dataset['train'])}")
            print(f"   • Test samples: {len(full_dataset['test'])}")
            print(f"   • Languages: {len(languages)} ({list(languages.keys())})")
            
            # IMPROVED: Preprocessing function for multilingual data
            def preprocess_multilingual_improved(examples):
                """Improved preprocessing for multilingual data with model-specific formatting"""
                sources = []
                targets = []
                
                # Handle both single examples and batches
                if not isinstance(examples["translation"], list):
                    examples = {
                        "translation": [examples["translation"]], 
                        "language": [examples["language"]]
                    }
                
                for translation, lang in zip(examples["translation"], examples["language"]):
                    if isinstance(translation, dict) and lang in languages:
                        source = translation.get(SRC_LANG, "")
                        target = translation.get(lang, "")
                        
                        if source.strip() and target.strip():  # Ensure non-empty
                            # Model-specific formatting
                            if model_type == "t5":
                                # T5 style: "translate English to German: Hello"
                                source_formatted = f"translate English to {languages[lang]}: {source}"
                            else:  # BART style
                                # BART style with language token
                                source_formatted = f"{source} </s> {lang}_XX"  # Add language code
                            
                            sources.append(source_formatted)
                            targets.append(target)
                
                if not sources or not targets:
                    return {"input_ids": [], "attention_mask": [], "labels": []}
                
                # Tokenize with custom tokenizer
                max_length = 256  # INCREASED for better performance
                
                # Input tokenization - REMOVE token_type_ids
                model_inputs = tokenizer(
                    sources,
                    max_length=max_length,
                    truncation=True,
                    padding="max_length",
                    return_tensors=None,
                    return_token_type_ids=False  # CRITICAL FIX
                )
                # Ensure token_type_ids is removed
                model_inputs.pop("token_type_ids", None)
                
                # Target tokenization  
                with tokenizer.as_target_tokenizer():
                    labels = tokenizer(
                        targets,
                        max_length=max_length,
                        truncation=True,
                        padding="max_length",
                        return_tensors=None,
                        return_token_type_ids=False  # CRITICAL FIX
                    )
                # Ensure token_type_ids is removed
                labels.pop("token_type_ids", None)
                
                # IMPROVED: Label processing with proper EOS handling
                processed_labels = []
                for label_seq in labels["input_ids"]:
                    # Find actual end of sequence (before padding)
                    try:
                        pad_start = label_seq.index(tokenizer.pad_token_id)
                        actual_tokens = label_seq[:pad_start]
                    except ValueError:
                        actual_tokens = label_seq
                    
                    # Ensure EOS token at end
                    if actual_tokens and actual_tokens[-1] != tokenizer.eos_token_id:
                        actual_tokens.append(tokenizer.eos_token_id)
                    elif not actual_tokens:
                        actual_tokens = [tokenizer.eos_token_id]
                    
                    # Create final sequence with -100 for padding
                    final_labels = actual_tokens + [-100] * (max_length - len(actual_tokens))
                    final_labels = final_labels[:max_length]  # Ensure correct length
                    
                    processed_labels.append(final_labels)
                
                model_inputs["labels"] = processed_labels
                return model_inputs
            
            # Preprocess complete multilingual dataset
            print("⚙️  Preprocessing complete multilingual dataset with custom tokenizer...")
            processed_dataset = full_dataset.map(
                preprocess_multilingual_improved,
                batched=True,
                remove_columns=full_dataset["train"].column_names,
                desc=f"Preprocessing with {size}_{tok_type}",
                batch_size=50,  # Smaller batch for stability
                num_proc=1  # Single process to avoid issues
            )
            
            train_dataset = processed_dataset["train"]
            eval_dataset = processed_dataset["test"]
            
            # Filter out empty examples
            def filter_valid_examples(example):
                return (
                    len(example["input_ids"]) > 0 and 
                    len(example["labels"]) > 0 and
                    any(label != -100 for label in example["labels"]) and
                    sum(1 for token in example["input_ids"] if token != tokenizer.pad_token_id) > 0
                )
            
            train_dataset = train_dataset.filter(filter_valid_examples)
            eval_dataset = eval_dataset.filter(filter_valid_examples)
            
            # Use smaller eval dataset to avoid memory issues
            print(f"✅ Preprocessed multilingual dataset:")
            print(f"   • Train samples: {len(train_dataset)}")
            print(f"   • Eval samples: {len(eval_dataset)}")
            
            # REDUCE eval dataset size to avoid hanging
            if len(eval_dataset) > 1000:
                eval_dataset = eval_dataset.select(range(1000))
                print(f"   • Reduced eval samples to: {len(eval_dataset)} (to avoid memory issues)")
            
            if len(train_dataset) == 0 or len(eval_dataset) == 0:
                print("❌ No valid samples after preprocessing")
                continue
            
            # Setup training directory
            output_dir = f"./MT_models_multilingual_custom_tokenizers/{model_id}"
            os.makedirs(output_dir, exist_ok=True)
            
            # IMPROVED: Training arguments with simpler evaluation
            training_args = Seq2SeqTrainingArguments(
                output_dir=output_dir,
                num_train_epochs=1,  # More epochs for custom tokenizers
                per_device_train_batch_size=2,  # Smaller batch size
                per_device_eval_batch_size=1,  # REDUCED eval batch size
                gradient_accumulation_steps=8,  # Higher accumulation
                learning_rate=1e-4,  # Lower learning rate for stability
                weight_decay=0.01,
                warmup_ratio=0.1,  # Warmup as ratio
                eval_strategy="steps",  # Change to steps-based evaluation
                eval_steps=500,  # Evaluate every 500 steps instead of epoch end
                save_strategy="epoch",
                save_total_limit=2,
                logging_steps=50,
                report_to="none",
                predict_with_generate=True,
                generation_max_length=128,  # REDUCED generation length
                generation_num_beams=2,  # REDUCED beams
                fp16=torch.cuda.is_available(),
                load_best_model_at_end=False,  # DISABLED to avoid issues
                dataloader_num_workers=0,
                remove_unused_columns=False,  # CRITICAL: Keep all columns
                ignore_data_skip=True,
                label_smoothing_factor=0.1,  # Label smoothing
                max_grad_norm=1.0,  # Gradient clipping
                dataloader_pin_memory=False,  # Disable pin memory
                skip_memory_metrics=True,  # Skip memory tracking
                
                
            )
            
            # Data collator with explicit token_type_ids handling
            data_collator = DataCollatorForSeq2Seq(
                tokenizer=tokenizer,
                model=model,
                padding=True,
                pad_to_multiple_of=8 if training_args.fp16 else None,
                return_tensors="pt",
                label_pad_token_id=-100
            )
            
            # CRITICAL FIX: Custom data collator that removes token_type_ids
            class CustomDataCollator(DataCollatorForSeq2Seq):
                def __call__(self, features):
                    batch = super().__call__(features)
                    # Remove token_type_ids if present
                    batch.pop("token_type_ids", None)
                    return batch
            
            data_collator = CustomDataCollator(
                tokenizer=tokenizer,
                model=model,
                padding=True,
                pad_to_multiple_of=8 if training_args.fp16 else None,
                return_tensors="pt",
                label_pad_token_id=-100
            )
            
            # Trainer
            trainer = Seq2SeqTrainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                tokenizer=tokenizer,
                data_collator=data_collator,
                compute_metrics=compute_multilingual_bleu
            )
            
            # Train multilingual model
            print("🏋️  Starting multilingual training with compatible model-tokenizer pair...")
            trainer.train()
            
            # Evaluate
            print("📊 Final multilingual evaluation...")
            eval_results = trainer.evaluate()
            
            # Save model and custom tokenizer
            print("💾 Saving multilingual model and custom tokenizer...")
            trainer.save_model()
            tokenizer.save_pretrained(output_dir)
            
            # Log results
            overall_bleu = eval_results.get("eval_bleu", 0.0)
            overall_exact_match = eval_results.get("eval_exact_match", 0.0)
            avg_pred_len = eval_results.get("eval_avg_pred_length", 0.0)
            avg_label_len = eval_results.get("eval_avg_label_length", 0.0)
            empty_preds = eval_results.get("eval_empty_predictions", 0.0)
            
            with open(results_log_file, 'a', newline='', encoding='utf-8') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([
                    model_id, model_config["base_model"], size, tok_type,
                    len(languages), len(train_dataset), len(eval_dataset), len(tokenizer),
                    round(overall_bleu, 4), round(overall_exact_match, 4), 
                    round(avg_pred_len, 2), round(avg_label_len, 2),
                    round(empty_preds, 4), "SUCCESS", 
                    f"Compatible {model_config['description']} with {size}_{tok_type}"
                ])
            
            print(f"✅ Completed: {model_id}")
            print(f"📈 Overall BLEU Score: {overall_bleu:.4f}")
            print(f"🎯 Overall Exact Match: {overall_exact_match:.4f}")
            
            # Quick translation tests for different languages
            print(f"\n🧪 Quick multilingual translation tests:")
            test_input = "Hello, how are you today?"
            
            for test_lang in ["ar", "zh", "hi"]:  # Test 3 languages
                if model_type == "t5":
                    formatted_input = f"translate English to {languages[test_lang]}: {test_input}"
                else:  # BART
                    formatted_input = f"{test_input} </s> {test_lang}_XX"
                    
                inputs = tokenizer(formatted_input, return_tensors="pt", padding=True, max_length=256, truncation=True)
                
                # Move to device
                device = next(model.parameters()).device
                inputs = {k: v.to(device) for k, v in inputs.items() if k != "token_type_ids"}
                
                with torch.no_grad():
                    outputs = model.generate(
                        **inputs,
                        max_length=128,
                        num_beams=4,
                        early_stopping=True,
                        do_sample=False,
                        forced_eos_token_id=tokenizer.eos_token_id
                    )
                
                translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
                print(f"  {test_lang} ({languages[test_lang]}): {translation}")
                
        except Exception as e:
            print(f"❌ Failed to train {model_id}: {str(e)}")
            import traceback
            traceback.print_exc()
            
            with open(results_log_file, 'a', newline='', encoding='utf-8') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([
                    model_id, model_config.get("base_model", ""), size, tok_type,
                    len(languages), 0, 0, 0, 0, 0, 0, 0, 0, "TRAINING_FAILED", str(e)[:100]
                ])
        
        finally:
            # Cleanup
            if model is not None:
                del model
            if trainer is not None:
                del trainer
            torch.cuda.empty_cache()
            gc.collect()

print("\n🎉 Multilingual training with compatible model-tokenizer pairs completed!")
print(f"📋 Results saved to: {results_log_file}")
print(f"🔢 Total models trained: 9 (3 sizes × 3 types)")
print(f"🌍 Each model handles all {len(languages)} languages simultaneously")
print("\n📊 Expected improvements:")
print("• BPE tokenizers → BART models (proper compatibility)")
print("• WordPiece tokenizers → mT5 models (better multilingual support)")
print("• Unigram tokenizers → mT5 models (native compatibility)")
print("• Increased sequence length (256 vs 128)")
print("• Better hyperparameters and training setup")



📦 Loading complete multilingual dataset...
✅ Dataset loaded: 34376 train, 3820 test
🌍 All languages included: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr']
🚀 Starting FAST multilingual training with COMPATIBLE model-tokenizer pairs...
📊 Training approach: ONE model per tokenizer handling ALL 9 languages
⚡ SPEED OPTIMIZATION: Pre-processing dataset for all tokenizer types...

🚀 Training Model 1/9
🔧 Custom Tokenizer: small_hf_bpe_hf
📦 Compatible Base Model: facebook/bart-large
🌍 Target Languages: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr'] (ALL SIMULTANEOUSLY)
🔧 Loading custom tokenizer: vocab_final/vocab_finalsmall/hf_bpe_hf
✅ Custom tokenizer loaded successfully!
📊 Custom vocab size: 15002
🔑 Special tokens - EOS: 15001, PAD: 0
🤖 Loading compatible base model: facebook/bart-large
🔄 Resizing model embeddings to match custom tokenizer...
   Vocab size: 50265 → 15002
✅ Model configured with custom tokenizer!
📊 Using complete multilingual dataset:
   • Train samples: 3

In [10]:
fp16=torch.cuda.is_available(),
bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),  # Better than fp16

In [12]:
print(fp16)
print(bf16)

(True,)
(True,)


In [2]:
import os
import csv
import torch
import warnings
import gc
from datasets import load_from_disk
from transformers import (
    EncoderDecoderModel,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    logging as hf_logging,
    MBartForConditionalGeneration,
    MBart50TokenizerFast,
    T5ForConditionalGeneration,
    T5TokenizerFast,
    BertTokenizerFast,
    GPT2TokenizerFast,
    BartForConditionalGeneration,
    TrainerCallback
)
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import numpy as np

# Suppress warnings
hf_logging.set_verbosity_error()
warnings.filterwarnings("ignore")

# -------------------------
# USER-TUNABLE SPEED HYPERPARAMS
# -------------------------
FREEZE_ENCODER_LAYERS = 6    # how many encoder layers to freeze initially
FREEZE_EPOCHS = 2            # unfreeze after this many epochs
TRAIN_BATCH_SIZE = 4         # per-device train batch
GRAD_ACCUM_STEPS = 4         # gradient accumulation -> effective batch = TRAIN_BATCH_SIZE * GRAD_ACCUM_STEPS
EVAL_BEAMS_TRAIN = 2         # beams used during training/eval cycle (faster)
MAX_TOKENS = 256             # tokenization max length (you used 256 already)
PROCESSED_CACHE_ROOT = "processed_cached"  # per-tokenizer cache dir root

# FIXED: Model-Tokenizer Compatibility Mapping
MODEL_TOKENIZER_CONFIGS = {
    "hf_bpe_hf": {
        "base_model": "facebook/bart-large",
        "model_class": BartForConditionalGeneration,
        "description": "BART with BPE tokenization"
    },
    "hf_wordpiece_hf": {
        "base_model": "google/mt5-base",
        "model_class": T5ForConditionalGeneration,
        "description": "mT5 with WordPiece tokenization"
    },
    "sp_unigram_hf": {
        "base_model": "google/mt5-base",
        "model_class": T5ForConditionalGeneration,
        "description": "mT5 with SentencePiece Unigram"
    }
}

# Custom tokenizer settings
tokenizer_sizes = ["small", "medium", "large"]
tokenizer_types = ["hf_bpe_hf", "hf_wordpiece_hf", "sp_unigram_hf"]

# Languages for evaluation
languages = {
    "yo": "Yoruba",
    "ar": "Arabic",
    "zh": "Chinese",
    "ru": "Russian",
    "hi": "Hindi",
    "ja": "Japanese",
    "swa": "Swahili",
    "bn": "Bengali",
    "tr": "Turkish"
}
SRC_LANG = "en"

# Load complete multilingual dataset
dataset_path = "balanced_mt_dataset"
print("📦 Loading complete multilingual dataset...")
full_dataset = load_from_disk(dataset_path)
print(f"✅ Dataset loaded: {len(full_dataset['train'])} train, {len(full_dataset['test'])} test")
print(f"🌍 All languages included: {list(languages.keys())}")

# -------------------------
# Metric (keeps using global tokenizer at runtime)
# -------------------------
def compute_multilingual_bleu(eval_pred):
    """Multilingual BLEU computation with progress tracking"""
    import time
    start_time = time.time()
    predictions, labels = eval_pred

    # If predictions are logits/probs, pick argmax
    try:
        if len(predictions.shape) == 3:
            predictions = np.argmax(predictions, axis=-1)
    except Exception:
        pass

    decoded_preds = []
    decoded_labels = []
    total_samples = len(predictions)
    print(f"🔄 Starting evaluation of {total_samples} samples...")

    chunk_size = 50
    for i in range(0, total_samples, chunk_size):
        end_idx = min(i + chunk_size, total_samples)
        chunk_preds = predictions[i:end_idx]
        chunk_labels = labels[i:end_idx]

        try:
            for pred, label in zip(chunk_preds, chunk_labels):
                # Handle negative token IDs properly
                pred_clean = [token for token in pred if token >= 0 and token < len(tokenizer)]
                label_clean = [token for token in label if token != -100 and token >= 0 and token < len(tokenizer)]

                try:
                    decoded_pred = tokenizer.decode(pred_clean, skip_special_tokens=True).strip() if pred_clean else ""
                    decoded_label = tokenizer.decode(label_clean, skip_special_tokens=True).strip() if label_clean else ""
                except Exception:
                    decoded_pred = ""
                    decoded_label = ""

                decoded_preds.append(decoded_pred)
                decoded_labels.append(decoded_label)

            if (i + chunk_size) % (chunk_size * 5) == 0 or end_idx == total_samples:
                elapsed = time.time() - start_time
                print(f"📊 Evaluation progress: {end_idx}/{total_samples} samples ({elapsed:.1f}s elapsed)")
        except Exception as e:
            print(f"⚠️  Batch decode failed for chunk {i}-{end_idx}: {str(e)}")
            for _ in range(end_idx - i):
                decoded_preds.append("")
                decoded_labels.append("")

    # Compute BLEU with smoothing
    print("🧮 Computing BLEU scores...")
    smoothing = SmoothingFunction().method1
    bleu_scores = []
    exact_matches = 0

    for idx, (pred, label) in enumerate(zip(decoded_preds, decoded_labels)):
        if idx % 1000 == 0 and idx > 0:
            print(f"   BLEU calculation: {idx}/{len(decoded_preds)} samples")

        if not pred.strip() or not label.strip():
            bleu_scores.append(0.0)
            continue

        pred_tokens = pred.split()
        label_tokens = label.split()

        if pred.lower().strip() == label.lower().strip():
            exact_matches += 1

        if len(pred_tokens) == 0 or len(label_tokens) == 0:
            bleu_scores.append(0.0)
            continue

        try:
            bleu = sentence_bleu(
                [label_tokens],
                pred_tokens,
                smoothing_function=smoothing,
                weights=(0.25, 0.25, 0.25, 0.25)
            )
            bleu_scores.append(bleu)
        except Exception:
            bleu_scores.append(0.0)

    avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0.0
    exact_match_ratio = exact_matches / len(decoded_preds) if decoded_preds else 0.0

    total_time = time.time() - start_time
    print(f"✅ Evaluation complete: BLEU={avg_bleu:.4f}, Exact Match={exact_match_ratio:.4f} ({total_time:.1f}s total)")

    return {
        "bleu": avg_bleu,
        "exact_match": exact_match_ratio,
        "avg_pred_length": np.mean([len(p.split()) for p in decoded_preds if p.strip()]) if decoded_preds else 0.0,
        "avg_label_length": np.mean([len(l.split()) for l in decoded_labels if l.strip()]) if decoded_labels else 0.0,
        "empty_predictions": sum(1 for p in decoded_preds if not p.strip()) / len(decoded_preds) if decoded_preds else 0.0
    }

# -------------------------
# Tokenizer helper + model config helpers (unchanged logic)
# -------------------------
def setup_custom_tokenizer_for_model(tokenizer_path, model_type):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)

    if model_type == "bart":
        special_tokens = {
            'bos_token': '<s>',
            'eos_token': '</s>',
            'sep_token': '</s>',
            'pad_token': '<pad>',
            'unk_token': '<unk>',
            'mask_token': '<mask>'
        }
    elif model_type == "t5":
        special_tokens = {
            'pad_token': '<pad>',
            'eos_token': '</s>',
            'unk_token': '<unk>',
            'bos_token': '<pad>',
            'sep_token': '</s>',
            'mask_token': '<extra_id_0>'
        }
    else:
        special_tokens = {
            'bos_token': '<s>',
            'eos_token': '</s>',
            'sep_token': '</s>',
            'pad_token': '<pad>',
            'unk_token': '<unk>',
            'mask_token': '<mask>'
        }

    tokens_to_add = {}
    for token_name, token_value in special_tokens.items():
        if getattr(tokenizer, token_name, None) is None:
            tokens_to_add[token_name] = token_value

    if tokens_to_add:
        tokenizer.add_special_tokens(tokens_to_add)

    assert tokenizer.eos_token is not None, "EOS token is required"
    assert tokenizer.pad_token is not None, "PAD token is required"

    return tokenizer

def configure_model_for_custom_tokenizer(model, tokenizer, model_type):
    print("🔄 Resizing model embeddings to match custom tokenizer...")
    old_vocab_size = model.config.vocab_size
    model.resize_token_embeddings(len(tokenizer))
    print(f"   Vocab size: {old_vocab_size} → {len(tokenizer)}")

    if model_type == "bart":
        model.config.decoder_start_token_id = tokenizer.bos_token_id
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.bos_token_id = tokenizer.bos_token_id
        model.config.eos_token_id = tokenizer.eos_token_id
        model.config.sep_token_id = tokenizer.eos_token_id
        model.config.forced_eos_token_id = tokenizer.eos_token_id
    elif model_type == "t5":
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.eos_token_id = tokenizer.eos_token_id
        model.config.decoder_start_token_id = tokenizer.pad_token_id
        model.config.forced_eos_token_id = tokenizer.eos_token_id

    if hasattr(model, 'generation_config') and model.generation_config is not None:
        model.generation_config.pad_token_id = tokenizer.pad_token_id
        model.generation_config.eos_token_id = tokenizer.eos_token_id
        model.generation_config.forced_eos_token_id = tokenizer.eos_token_id
        if model_type == "bart":
            model.generation_config.decoder_start_token_id = tokenizer.bos_token_id
            model.generation_config.bos_token_id = tokenizer.bos_token_id
        elif model_type == "t5":
            model.generation_config.decoder_start_token_id = tokenizer.pad_token_id

    return model

# Setup logging / results CSV
log_dir = "./MT_models_multilingual_custom_tokenizers"
os.makedirs(log_dir, exist_ok=True)
results_log_file = os.path.join(log_dir, "multilingual_custom_tokenizer_results.csv")

if not os.path.exists(results_log_file):
    with open(results_log_file, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["Model_ID", "Base_Model", "Tokenizer_Size", "Tokenizer_Type",
                         "Total_Languages", "Total_Train_Samples", "Total_Test_Samples",
                         "Custom_Vocab_Size", "Overall_BLEU", "Overall_Exact_Match",
                         "Avg_Pred_Length", "Avg_Label_Length", "Empty_Predictions",
                         "Training_Status", "Notes"])

print("🚀 Starting multilingual training with COMPATIBLE model-tokenizer pairs...")
print(f"📊 Training approach: ONE model per tokenizer handling ALL {len(languages)} languages")

# Callback to unfreeze after FREEZE_EPOCHS
class UnfreezeCallback(TrainerCallback):
    def on_epoch_end(self, args, state, control, **kwargs):
        # state.epoch is float. Compare to FREEZE_EPOCHS - 1 because epoch increments at end
        if int(state.epoch) == FREEZE_EPOCHS:
            model_ref = kwargs.get("model")
            if model_ref is None:
                return
            print(f"🔓 Unfreezing embeddings + first {FREEZE_ENCODER_LAYERS} encoder layers at epoch {state.epoch}")
            # Unfreeze embed tokens
            try:
                for p in model_ref.model.encoder.embed_tokens.parameters():
                    p.requires_grad = True
            except Exception:
                pass
            # Unfreeze first N encoder layers
            try:
                for layer in model_ref.model.encoder.layers[:FREEZE_ENCODER_LAYERS]:
                    for p in layer.parameters():
                        p.requires_grad = True
            except Exception:
                pass

# Main training loop
for size in tokenizer_sizes:
    for tok_type in tokenizer_types:
        model_config = MODEL_TOKENIZER_CONFIGS[tok_type]
        tokenizer_path = f"vocab_final/vocab_final{size}/{tok_type}"
        model_id = f"multilingual_{size}_{tok_type}"

        print(f"\n{'='*100}")
        print(f"🚀 Training Model {tokenizer_sizes.index(size)*3 + tokenizer_types.index(tok_type) + 1}/9")
        print(f"🔧 Custom Tokenizer: {size}_{tok_type}")
        print(f"📦 Compatible Base Model: {model_config['base_model']}")
        print(f"🌍 Target Languages: {list(languages.keys())} (ALL SIMULTANEOUSLY)")

        model = None
        tokenizer = None
        trainer = None

        try:
            # Load tokenizer and model
            print(f"🔧 Loading custom tokenizer: {tokenizer_path}")
            model_type = "bart" if "bart" in model_config['base_model'] else "t5"
            tokenizer = setup_custom_tokenizer_for_model(tokenizer_path, model_type)
            print(f"✅ Custom tokenizer loaded: vocab_size={len(tokenizer)} | EOS={tokenizer.eos_token_id} PAD={tokenizer.pad_token_id}")

            print(f"🤖 Loading compatible base model: {model_config['base_model']}")
            model = model_config['model_class'].from_pretrained(model_config['base_model'])

            # Configure model/tokenizer
            model = configure_model_for_custom_tokenizer(model, tokenizer, model_type)

            # Speed-ups: gradient checkpointing and optional torch.compile + bf16
            try:
                model.gradient_checkpointing_enable()
                print("⚡ Gradient checkpointing enabled")
            except Exception:
                pass

            # Freeze embeddings + first layers
            try:
                for p in model.model.encoder.embed_tokens.parameters():
                    p.requires_grad = False
                for layer in model.model.encoder.layers[:FREEZE_ENCODER_LAYERS]:
                    for p in layer.parameters():
                        p.requires_grad = False
                print(f"🔒 Frozen embeddings + first {FREEZE_ENCODER_LAYERS} encoder layers")
            except Exception:
                pass

            # Try to compile (PyTorch 2.x)
            if hasattr(torch, "compile"):
                try:
                    model = torch.compile(model)
                    print("⚡ torch.compile applied (PyTorch 2.x)")
                except Exception as e:
                    print("⚠️ torch.compile failed or unsupported: ", e)

            # Preprocess & cache per tokenizer to avoid re-tokenizing multiple times
            cache_dir = os.path.join(PROCESSED_CACHE_ROOT, f"{size}_{tok_type}")
            if not os.path.exists(PROCESSED_CACHE_ROOT):
                os.makedirs(PROCESSED_CACHE_ROOT, exist_ok=True)

            if os.path.exists(cache_dir):
                print(f"Loading cached tokenized dataset from {cache_dir} ...")
                processed_dataset = load_from_disk(cache_dir)
            else:
                print("⚙️  Preprocessing complete multilingual dataset with custom tokenizer...")

                def preprocess_multilingual_improved(examples):
                    sources = []
                    targets = []

                    if not isinstance(examples["translation"], list):
                        examples = {
                            "translation": [examples["translation"]],
                            "language": [examples["language"]]
                        }

                    for translation, lang in zip(examples["translation"], examples["language"]):
                        if isinstance(translation, dict) and lang in languages:
                            source = translation.get(SRC_LANG, "")
                            target = translation.get(lang, "")

                            if source.strip() and target.strip():
                                if model_type == "t5":
                                    source_formatted = f"translate English to {languages[lang]}: {source}"
                                else:
                                    source_formatted = f"{source} </s> {lang}_XX"
                                sources.append(source_formatted)
                                targets.append(target)

                    if not sources or not targets:
                        return {"input_ids": [], "attention_mask": [], "labels": []}

                    max_length = MAX_TOKENS

                    model_inputs = tokenizer(
                        sources,
                        max_length=max_length,
                        truncation=True,
                        padding="max_length",
                        return_tensors=None,
                        return_token_type_ids=False
                    )
                    model_inputs.pop("token_type_ids", None)

                    with tokenizer.as_target_tokenizer():
                        labels = tokenizer(
                            targets,
                            max_length=max_length,
                            truncation=True,
                            padding="max_length",
                            return_tensors=None,
                            return_token_type_ids=False
                        )
                    labels.pop("token_type_ids", None)

                    processed_labels = []
                    for label_seq in labels["input_ids"]:
                        try:
                            pad_start = label_seq.index(tokenizer.pad_token_id)
                            actual_tokens = label_seq[:pad_start]
                        except ValueError:
                            actual_tokens = label_seq

                        if actual_tokens and actual_tokens[-1] != tokenizer.eos_token_id:
                            actual_tokens.append(tokenizer.eos_token_id)
                        elif not actual_tokens:
                            actual_tokens = [tokenizer.eos_token_id]

                        final_labels = actual_tokens + [-100] * (max_length - len(actual_tokens))
                        final_labels = final_labels[:max_length]
                        processed_labels.append(final_labels)

                    model_inputs["labels"] = processed_labels
                    return model_inputs

                processed_dataset = full_dataset.map(
                    preprocess_multilingual_improved,
                    batched=True,
                    remove_columns=full_dataset["train"].column_names,
                    desc=f"Preprocessing with {size}_{tok_type}",
                    batch_size=50,
                    num_proc=1
                )

                print(f"Saving tokenized cache to {cache_dir} ...")
                processed_dataset.save_to_disk(cache_dir)

            train_dataset = processed_dataset["train"]
            eval_dataset = processed_dataset["test"]

            # Filter out empty examples
            def filter_valid_examples(example):
                return (
                    len(example["input_ids"]) > 0 and
                    len(example["labels"]) > 0 and
                    any(label != -100 for label in example["labels"]) and
                    sum(1 for token in example["input_ids"] if token != tokenizer.pad_token_id) > 0
                )

            train_dataset = train_dataset.filter(filter_valid_examples)
            eval_dataset = eval_dataset.filter(filter_valid_examples)

            print(f"✅ Preprocessed multilingual dataset: Train={len(train_dataset)}, Eval={len(eval_dataset)}")

            # Reduce eval dataset to avoid memory issues (you can change threshold)
            if len(eval_dataset) > 1000:
                eval_dataset = eval_dataset.select(range(1000))
                print(f"   • Reduced eval samples to: {len(eval_dataset)} (to avoid memory issues)")

            if len(train_dataset) == 0 or len(eval_dataset) == 0:
                print("❌ No valid samples after preprocessing - skipping this config")
                continue

            # Output dir for this model
            output_dir = f"./MT_models_multilingual_custom_tokenizers/{model_id}"
            os.makedirs(output_dir, exist_ok=True)

            # Training args focused on speed & stability
            training_args = Seq2SeqTrainingArguments(
                output_dir=output_dir,
                num_train_epochs=5,
                per_device_train_batch_size=TRAIN_BATCH_SIZE,
                per_device_eval_batch_size=max(1, TRAIN_BATCH_SIZE),
                gradient_accumulation_steps=GRAD_ACCUM_STEPS,
                learning_rate=1e-4,
                weight_decay=0.01,
                warmup_ratio=0.1,
                eval_strategy="epoch",          # evaluate once per epoch (faster)
                save_strategy="epoch",
                save_total_limit=2,
                logging_steps=50,
                report_to="none",
                predict_with_generate=True,
                generation_max_length=128,
                generation_num_beams=EVAL_BEAMS_TRAIN,  # small beam during training/eval
                fp16=False,
                bf16=torch.cuda.is_available(),    # use bf16 if GPU supports it
                optim="adafactor",                 # memory-efficient optimizer
                load_best_model_at_end=False,
                dataloader_num_workers=0,
                remove_unused_columns=False,
                ignore_data_skip=True,
                label_smoothing_factor=0.1,
                max_grad_norm=1.0,
                dataloader_pin_memory=False,
                skip_memory_metrics=True,
            )

            # Data collator that removes token_type_ids
            class CustomDataCollator(DataCollatorForSeq2Seq):
                def __call__(self, features):
                    batch = super().__call__(features)
                    batch.pop("token_type_ids", None)
                    return batch

            data_collator = CustomDataCollator(
                tokenizer=tokenizer,
                model=model,
                padding=True,
                pad_to_multiple_of=8 if training_args.fp16 or training_args.bf16 else None,
                return_tensors="pt",
                label_pad_token_id=-100
            )

            # Trainer
            trainer = Seq2SeqTrainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                tokenizer=tokenizer,
                data_collator=data_collator,
                compute_metrics=compute_multilingual_bleu,
                callbacks=[UnfreezeCallback()]
            )

            # Train
            print("🏋️  Starting multilingual training with compatible model-tokenizer pair...")
            trainer.train()

            # Final evaluation (you may want to run with larger beams here if desired)
            print("📊 Final multilingual evaluation...")
            eval_results = trainer.evaluate()

            # Save
            print("💾 Saving multilingual model and custom tokenizer...")
            trainer.save_model()
            tokenizer.save_pretrained(output_dir)

            # Log results
            overall_bleu = eval_results.get("eval_bleu", 0.0)
            overall_exact_match = eval_results.get("eval_exact_match", 0.0)
            avg_pred_len = eval_results.get("eval_avg_pred_length", 0.0)
            avg_label_len = eval_results.get("eval_avg_label_length", 0.0)
            empty_preds = eval_results.get("eval_empty_predictions", 0.0)

            with open(results_log_file, 'a', newline='', encoding='utf-8') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([
                    model_id, model_config["base_model"], size, tok_type,
                    len(languages), len(train_dataset), len(eval_dataset), len(tokenizer),
                    round(overall_bleu, 4), round(overall_exact_match, 4),
                    round(avg_pred_len, 2), round(avg_label_len, 2),
                    round(empty_preds, 4), "SUCCESS",
                    f"Compatible {model_config['description']} with {size}_{tok_type}"
                ])

            print(f"✅ Completed: {model_id}")
            print(f"📈 Overall BLEU Score: {overall_bleu:.4f}")
            print(f"🎯 Overall Exact Match: {overall_exact_match:.4f}")

            # Quick multilingual translation tests (use larger beams here for quality)
            print(f"\n🧪 Quick multilingual translation tests:")
            test_input = "Hello, how are you today?"
            for test_lang in ["ar", "zh", "hi"]:
                if model_type == "t5":
                    formatted_input = f"translate English to {languages[test_lang]}: {test_input}"
                else:
                    formatted_input = f"{test_input} </s> {test_lang}_XX"

                inputs = tokenizer(formatted_input, return_tensors="pt", padding=True, max_length=256, truncation=True)
                device = next(model.parameters()).device
                inputs = {k: v.to(device) for k, v in inputs.items() if k != "token_type_ids"}

                with torch.no_grad():
                    outputs = model.generate(
                        **inputs,
                        max_length=128,
                        num_beams=4,
                        early_stopping=True,
                        do_sample=False,
                        forced_eos_token_id=tokenizer.eos_token_id
                    )

                translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
                print(f"  {test_lang} ({languages[test_lang]}): {translation}")

        except Exception as e:
            print(f"❌ Failed to train {model_id}: {str(e)}")
            import traceback
            traceback.print_exc()
            with open(results_log_file, 'a', newline='', encoding='utf-8') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([
                    model_id, model_config.get("base_model", ""), size, tok_type,
                    len(languages), 0, 0, 0, 0, 0, 0, 0, 0, "TRAINING_FAILED", str(e)[:200]
                ])
        finally:
            if model is not None:
                del model
            if trainer is not None:
                del trainer
            torch.cuda.empty_cache()
            gc.collect()

print("\n🎉 Multilingual training with compatible model-tokenizer pairs completed!")
print(f"📋 Results saved to: {results_log_file}")
print(f"🔢 Total models trained: {len(tokenizer_sizes) * len(tokenizer_types)}")
print(f"🌍 Each model handles all {len(languages)} languages simultaneously")


📦 Loading complete multilingual dataset...
✅ Dataset loaded: 34376 train, 3820 test
🌍 All languages included: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr']
🚀 Starting multilingual training with COMPATIBLE model-tokenizer pairs...
📊 Training approach: ONE model per tokenizer handling ALL 9 languages

🚀 Training Model 1/9
🔧 Custom Tokenizer: small_hf_bpe_hf
📦 Compatible Base Model: facebook/bart-large
🌍 Target Languages: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr'] (ALL SIMULTANEOUSLY)
🔧 Loading custom tokenizer: vocab_final/vocab_finalsmall/hf_bpe_hf
✅ Custom tokenizer loaded: vocab_size=15002 | EOS=15001 PAD=0
🤖 Loading compatible base model: facebook/bart-large
🔄 Resizing model embeddings to match custom tokenizer...
   Vocab size: 50265 → 15002
⚡ Gradient checkpointing enabled
🔒 Frozen embeddings + first 6 encoder layers
⚡ torch.compile applied (PyTorch 2.x)
⚙️  Preprocessing complete multilingual dataset with custom tokenizer...


Preprocessing with small_hf_bpe_hf: 100%|██████████| 34376/34376 [00:04<00:00, 7625.72 examples/s]
Preprocessing with small_hf_bpe_hf: 100%|██████████| 3820/3820 [00:00<00:00, 7224.78 examples/s]


Saving tokenized cache to processed_cached\small_hf_bpe_hf ...


Saving the dataset (1/1 shards): 100%|██████████| 30566/30566 [00:00<00:00, 429450.83 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 3386/3386 [00:00<00:00, 338608.40 examples/s]
Filter: 100%|██████████| 30566/30566 [00:30<00:00, 1003.21 examples/s]
Filter: 100%|██████████| 3386/3386 [00:03<00:00, 1010.21 examples/s]


✅ Preprocessed multilingual dataset: Train=30566, Eval=3386
   • Reduced eval samples to: 1000 (to avoid memory issues)
🏋️  Starting multilingual training with compatible model-tokenizer pair...
❌ Failed to train multilingual_small_hf_bpe_hf: BartForConditionalGeneration.forward() got an unexpected keyword argument 'num_items_in_batch'


Traceback (most recent call last):
  File "C:\Users\User 4\AppData\Local\Temp\ipykernel_2508\769564752.py", line 544, in <module>
    trainer.train()
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 2237, in train
    return inner_training_loop(
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 2578, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 3792, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 3879, in compute_loss
    outputs = model(**inputs)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl


🚀 Training Model 2/9
🔧 Custom Tokenizer: small_hf_wordpiece_hf
📦 Compatible Base Model: google/mt5-base
🌍 Target Languages: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr'] (ALL SIMULTANEOUSLY)
🔧 Loading custom tokenizer: vocab_final/vocab_finalsmall/hf_wordpiece_hf
✅ Custom tokenizer loaded: vocab_size=23634 | EOS=23632 PAD=0
🤖 Loading compatible base model: google/mt5-base
🔄 Resizing model embeddings to match custom tokenizer...
   Vocab size: 250112 → 23634
⚡ Gradient checkpointing enabled
⚡ torch.compile applied (PyTorch 2.x)
⚙️  Preprocessing complete multilingual dataset with custom tokenizer...


Preprocessing with small_hf_wordpiece_hf: 100%|██████████| 34376/34376 [00:05<00:00, 5775.57 examples/s]
Preprocessing with small_hf_wordpiece_hf: 100%|██████████| 3820/3820 [00:00<00:00, 6416.41 examples/s]


Saving tokenized cache to processed_cached\small_hf_wordpiece_hf ...


Saving the dataset (1/1 shards): 100%|██████████| 30566/30566 [00:00<00:00, 436078.30 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 3386/3386 [00:00<00:00, 338608.40 examples/s]
Filter: 100%|██████████| 30566/30566 [00:42<00:00, 714.26 examples/s]
Filter: 100%|██████████| 3386/3386 [00:03<00:00, 871.01 examples/s]


✅ Preprocessed multilingual dataset: Train=30566, Eval=3386
   • Reduced eval samples to: 1000 (to avoid memory issues)
🏋️  Starting multilingual training with compatible model-tokenizer pair...
❌ Failed to train multilingual_small_hf_wordpiece_hf: T5ForConditionalGeneration.forward() got an unexpected keyword argument 'num_items_in_batch'

🚀 Training Model 3/9
🔧 Custom Tokenizer: small_sp_unigram_hf
📦 Compatible Base Model: google/mt5-base
🌍 Target Languages: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr'] (ALL SIMULTANEOUSLY)
🔧 Loading custom tokenizer: vocab_final/vocab_finalsmall/sp_unigram_hf
✅ Custom tokenizer loaded: vocab_size=15002 | EOS=2 PAD=15000
🤖 Loading compatible base model: google/mt5-base


Traceback (most recent call last):
  File "C:\Users\User 4\AppData\Local\Temp\ipykernel_2508\769564752.py", line 544, in <module>
    trainer.train()
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 2237, in train
    return inner_training_loop(
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 2578, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 3792, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 3879, in compute_loss
    outputs = model(**inputs)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl

🔄 Resizing model embeddings to match custom tokenizer...
   Vocab size: 250112 → 15002
⚡ Gradient checkpointing enabled
⚡ torch.compile applied (PyTorch 2.x)
⚙️  Preprocessing complete multilingual dataset with custom tokenizer...


Preprocessing with small_sp_unigram_hf: 100%|██████████| 34376/34376 [00:05<00:00, 6356.57 examples/s]
Preprocessing with small_sp_unigram_hf: 100%|██████████| 3820/3820 [00:00<00:00, 6204.83 examples/s]


Saving tokenized cache to processed_cached\small_sp_unigram_hf ...


Saving the dataset (1/1 shards): 100%|██████████| 30566/30566 [00:00<00:00, 446145.72 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 3386/3386 [00:00<00:00, 338632.62 examples/s]
Filter: 100%|██████████| 30566/30566 [00:35<00:00, 871.39 examples/s]
Filter: 100%|██████████| 3386/3386 [00:03<00:00, 876.84 examples/s]


✅ Preprocessed multilingual dataset: Train=30566, Eval=3386
   • Reduced eval samples to: 1000 (to avoid memory issues)
🏋️  Starting multilingual training with compatible model-tokenizer pair...
❌ Failed to train multilingual_small_sp_unigram_hf: T5ForConditionalGeneration.forward() got an unexpected keyword argument 'num_items_in_batch'

🚀 Training Model 4/9
🔧 Custom Tokenizer: medium_hf_bpe_hf
📦 Compatible Base Model: facebook/bart-large
🌍 Target Languages: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr'] (ALL SIMULTANEOUSLY)
🔧 Loading custom tokenizer: vocab_final/vocab_finalmedium/hf_bpe_hf
✅ Custom tokenizer loaded: vocab_size=30002 | EOS=30001 PAD=0
🤖 Loading compatible base model: facebook/bart-large


Traceback (most recent call last):
  File "C:\Users\User 4\AppData\Local\Temp\ipykernel_2508\769564752.py", line 544, in <module>
    trainer.train()
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 2237, in train
    return inner_training_loop(
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 2578, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 3792, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 3879, in compute_loss
    outputs = model(**inputs)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl

🔄 Resizing model embeddings to match custom tokenizer...
   Vocab size: 50265 → 30002
⚡ Gradient checkpointing enabled
🔒 Frozen embeddings + first 6 encoder layers
⚡ torch.compile applied (PyTorch 2.x)
⚙️  Preprocessing complete multilingual dataset with custom tokenizer...


Preprocessing with medium_hf_bpe_hf: 100%|██████████| 34376/34376 [00:05<00:00, 6654.87 examples/s]
Preprocessing with medium_hf_bpe_hf: 100%|██████████| 3820/3820 [00:00<00:00, 6558.21 examples/s]


Saving tokenized cache to processed_cached\medium_hf_bpe_hf ...


Saving the dataset (1/1 shards): 100%|██████████| 30566/30566 [00:00<00:00, 470238.84 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 3386/3386 [00:00<00:00, 338600.32 examples/s]
Filter: 100%|██████████| 30566/30566 [00:34<00:00, 874.98 examples/s]
Filter: 100%|██████████| 3386/3386 [00:03<00:00, 881.27 examples/s]


✅ Preprocessed multilingual dataset: Train=30566, Eval=3386
   • Reduced eval samples to: 1000 (to avoid memory issues)
🏋️  Starting multilingual training with compatible model-tokenizer pair...
❌ Failed to train multilingual_medium_hf_bpe_hf: BartForConditionalGeneration.forward() got an unexpected keyword argument 'num_items_in_batch'

🚀 Training Model 5/9
🔧 Custom Tokenizer: medium_hf_wordpiece_hf
📦 Compatible Base Model: google/mt5-base
🌍 Target Languages: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr'] (ALL SIMULTANEOUSLY)
🔧 Loading custom tokenizer: vocab_final/vocab_finalmedium/hf_wordpiece_hf
✅ Custom tokenizer loaded: vocab_size=30002 | EOS=30000 PAD=0
🤖 Loading compatible base model: google/mt5-base


Traceback (most recent call last):
  File "C:\Users\User 4\AppData\Local\Temp\ipykernel_2508\769564752.py", line 544, in <module>
    trainer.train()
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 2237, in train
    return inner_training_loop(
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 2578, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 3792, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 3879, in compute_loss
    outputs = model(**inputs)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl

🔄 Resizing model embeddings to match custom tokenizer...
   Vocab size: 250112 → 30002
⚡ Gradient checkpointing enabled
⚡ torch.compile applied (PyTorch 2.x)
⚙️  Preprocessing complete multilingual dataset with custom tokenizer...


Preprocessing with medium_hf_wordpiece_hf: 100%|██████████| 34376/34376 [00:05<00:00, 6701.98 examples/s]
Preprocessing with medium_hf_wordpiece_hf: 100%|██████████| 3820/3820 [00:00<00:00, 6588.85 examples/s]


Saving tokenized cache to processed_cached\medium_hf_wordpiece_hf ...


Saving the dataset (1/1 shards): 100%|██████████| 30566/30566 [00:00<00:00, 446120.88 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 3386/3386 [00:00<00:00, 307740.44 examples/s]
Filter: 100%|██████████| 30566/30566 [00:34<00:00, 881.09 examples/s]
Filter: 100%|██████████| 3386/3386 [00:04<00:00, 842.33 examples/s]


✅ Preprocessed multilingual dataset: Train=30566, Eval=3386
   • Reduced eval samples to: 1000 (to avoid memory issues)
🏋️  Starting multilingual training with compatible model-tokenizer pair...
❌ Failed to train multilingual_medium_hf_wordpiece_hf: T5ForConditionalGeneration.forward() got an unexpected keyword argument 'num_items_in_batch'

🚀 Training Model 6/9
🔧 Custom Tokenizer: medium_sp_unigram_hf
📦 Compatible Base Model: google/mt5-base
🌍 Target Languages: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr'] (ALL SIMULTANEOUSLY)
🔧 Loading custom tokenizer: vocab_final/vocab_finalmedium/sp_unigram_hf


Traceback (most recent call last):
  File "C:\Users\User 4\AppData\Local\Temp\ipykernel_2508\769564752.py", line 544, in <module>
    trainer.train()
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 2237, in train
    return inner_training_loop(
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 2578, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 3792, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 3879, in compute_loss
    outputs = model(**inputs)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl

✅ Custom tokenizer loaded: vocab_size=30002 | EOS=2 PAD=30000
🤖 Loading compatible base model: google/mt5-base
🔄 Resizing model embeddings to match custom tokenizer...
   Vocab size: 250112 → 30002
⚡ Gradient checkpointing enabled
⚡ torch.compile applied (PyTorch 2.x)
⚙️  Preprocessing complete multilingual dataset with custom tokenizer...


Preprocessing with medium_sp_unigram_hf: 100%|██████████| 34376/34376 [00:05<00:00, 6325.41 examples/s]
Preprocessing with medium_sp_unigram_hf: 100%|██████████| 3820/3820 [00:00<00:00, 6183.96 examples/s]


Saving tokenized cache to processed_cached\medium_sp_unigram_hf ...


Saving the dataset (1/1 shards): 100%|██████████| 30566/30566 [00:00<00:00, 459581.78 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 3386/3386 [00:00<00:00, 338600.32 examples/s]
Filter: 100%|██████████| 30566/30566 [00:35<00:00, 856.78 examples/s]
Filter: 100%|██████████| 3386/3386 [00:04<00:00, 844.55 examples/s]


✅ Preprocessed multilingual dataset: Train=30566, Eval=3386
   • Reduced eval samples to: 1000 (to avoid memory issues)
🏋️  Starting multilingual training with compatible model-tokenizer pair...
❌ Failed to train multilingual_medium_sp_unigram_hf: T5ForConditionalGeneration.forward() got an unexpected keyword argument 'num_items_in_batch'

🚀 Training Model 7/9
🔧 Custom Tokenizer: large_hf_bpe_hf
📦 Compatible Base Model: facebook/bart-large
🌍 Target Languages: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr'] (ALL SIMULTANEOUSLY)
🔧 Loading custom tokenizer: vocab_final/vocab_finallarge/hf_bpe_hf


Traceback (most recent call last):
  File "C:\Users\User 4\AppData\Local\Temp\ipykernel_2508\769564752.py", line 544, in <module>
    trainer.train()
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 2237, in train
    return inner_training_loop(
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 2578, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 3792, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 3879, in compute_loss
    outputs = model(**inputs)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl

✅ Custom tokenizer loaded: vocab_size=50002 | EOS=50001 PAD=0
🤖 Loading compatible base model: facebook/bart-large
🔄 Resizing model embeddings to match custom tokenizer...
   Vocab size: 50265 → 50002
⚡ Gradient checkpointing enabled
🔒 Frozen embeddings + first 6 encoder layers
⚡ torch.compile applied (PyTorch 2.x)
⚙️  Preprocessing complete multilingual dataset with custom tokenizer...


Preprocessing with large_hf_bpe_hf: 100%|██████████| 34376/34376 [00:05<00:00, 6445.92 examples/s]
Preprocessing with large_hf_bpe_hf: 100%|██████████| 3820/3820 [00:00<00:00, 6313.32 examples/s]


Saving tokenized cache to processed_cached\large_hf_bpe_hf ...


Saving the dataset (1/1 shards): 100%|██████████| 30566/30566 [00:00<00:00, 452752.10 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 3386/3386 [00:00<00:00, 307820.48 examples/s]
Filter: 100%|██████████| 30566/30566 [00:36<00:00, 848.79 examples/s]
Filter: 100%|██████████| 3386/3386 [00:04<00:00, 690.38 examples/s]


✅ Preprocessed multilingual dataset: Train=30566, Eval=3386
   • Reduced eval samples to: 1000 (to avoid memory issues)
🏋️  Starting multilingual training with compatible model-tokenizer pair...
❌ Failed to train multilingual_large_hf_bpe_hf: BartForConditionalGeneration.forward() got an unexpected keyword argument 'num_items_in_batch'


Traceback (most recent call last):
  File "C:\Users\User 4\AppData\Local\Temp\ipykernel_2508\769564752.py", line 544, in <module>
    trainer.train()
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 2237, in train
    return inner_training_loop(
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 2578, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 3792, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 3879, in compute_loss
    outputs = model(**inputs)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl


🚀 Training Model 8/9
🔧 Custom Tokenizer: large_hf_wordpiece_hf
📦 Compatible Base Model: google/mt5-base
🌍 Target Languages: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr'] (ALL SIMULTANEOUSLY)
🔧 Loading custom tokenizer: vocab_final/vocab_finallarge/hf_wordpiece_hf
✅ Custom tokenizer loaded: vocab_size=50002 | EOS=50000 PAD=0
🤖 Loading compatible base model: google/mt5-base
🔄 Resizing model embeddings to match custom tokenizer...
   Vocab size: 250112 → 50002
⚡ Gradient checkpointing enabled
⚡ torch.compile applied (PyTorch 2.x)
⚙️  Preprocessing complete multilingual dataset with custom tokenizer...


Preprocessing with large_hf_wordpiece_hf: 100%|██████████| 34376/34376 [00:07<00:00, 4540.14 examples/s]
Preprocessing with large_hf_wordpiece_hf: 100%|██████████| 3820/3820 [00:00<00:00, 6455.45 examples/s]


Saving tokenized cache to processed_cached\large_hf_wordpiece_hf ...


Saving the dataset (1/1 shards): 100%|██████████| 30566/30566 [00:00<00:00, 394325.47 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 3386/3386 [00:00<00:00, 307800.46 examples/s]
Filter: 100%|██████████| 30566/30566 [00:36<00:00, 846.23 examples/s]
Filter: 100%|██████████| 3386/3386 [00:03<00:00, 996.17 examples/s] 


✅ Preprocessed multilingual dataset: Train=30566, Eval=3386
   • Reduced eval samples to: 1000 (to avoid memory issues)
🏋️  Starting multilingual training with compatible model-tokenizer pair...
❌ Failed to train multilingual_large_hf_wordpiece_hf: T5ForConditionalGeneration.forward() got an unexpected keyword argument 'num_items_in_batch'

🚀 Training Model 9/9
🔧 Custom Tokenizer: large_sp_unigram_hf
📦 Compatible Base Model: google/mt5-base
🌍 Target Languages: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr'] (ALL SIMULTANEOUSLY)
🔧 Loading custom tokenizer: vocab_final/vocab_finallarge/sp_unigram_hf


Traceback (most recent call last):
  File "C:\Users\User 4\AppData\Local\Temp\ipykernel_2508\769564752.py", line 544, in <module>
    trainer.train()
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 2237, in train
    return inner_training_loop(
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 2578, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 3792, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\transformers\trainer.py", line 3879, in compute_loss
    outputs = model(**inputs)
  File "c:\Users\User 4\.conda\envs\cuda_mt_env\lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl

✅ Custom tokenizer loaded: vocab_size=50002 | EOS=2 PAD=50000
🤖 Loading compatible base model: google/mt5-base
🔄 Resizing model embeddings to match custom tokenizer...
   Vocab size: 250112 → 50002
⚡ Gradient checkpointing enabled
⚡ torch.compile applied (PyTorch 2.x)
⚙️  Preprocessing complete multilingual dataset with custom tokenizer...


Preprocessing with large_sp_unigram_hf: 100%|██████████| 34376/34376 [00:04<00:00, 7775.44 examples/s]
Preprocessing with large_sp_unigram_hf: 100%|██████████| 3820/3820 [00:00<00:00, 7571.12 examples/s]


Saving tokenized cache to processed_cached\large_sp_unigram_hf ...


Saving the dataset (1/1 shards): 100%|██████████| 30566/30566 [00:00<00:00, 488992.58 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 3386/3386 [00:00<00:00, 376229.56 examples/s]
Filter:  26%|██▌       | 8000/30566 [00:08<00:23, 967.10 examples/s]


KeyboardInterrupt: 

In [None]:
# import os
# import csv
# import torch
# import warnings
# import gc
# from datasets import load_from_disk
# from transformers import (
#     EncoderDecoderModel,
#     AutoTokenizer,
#     AutoModelForSeq2SeqLM,
#     DataCollatorForSeq2Seq,
#     Seq2SeqTrainer,
#     Seq2SeqTrainingArguments,
#     logging as hf_logging,
#     MBartForConditionalGeneration,
#     MBart50TokenizerFast,
#     T5ForConditionalGeneration,
#     T5TokenizerFast
# )
# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
# import numpy as np

# # Suppress warnings
# hf_logging.set_verbosity_error()
# warnings.filterwarnings("ignore")

# # Configuration with BETTER base models
# BETTER_MODELS = {
#     "mbart_small": {
#         "model_name": "facebook/mbart-large-50-many-to-many-mmt",
#         "tokenizer_name": "facebook/mbart-large-50-many-to-many-mmt",
#         "type": "mbart",
#         "description": "Pre-trained multilingual MT model"
#     },
#     # "mt5_small": {
#     #     "model_name": "google/mt5-small",
#     #     "tokenizer_name": "google/mt5-small", 
#     #     "type": "t5",
#     #     "description": "Multilingual T5 for translation"
#     # },
#     # "opus_mt": {
#     #     "model_name": "Helsinki-NLP/opus-mt-en-mul", 
#     #     "tokenizer_name": "Helsinki-NLP/opus-mt-en-mul",
#     #     "type": "marian",
#     #     "description": "OPUS multilingual translation"
#     # }
# }

# # Languages with proper mBART language codes
# MBART_LANG_CODES = {
#     "yo": "yo_NG",  # Yoruba
#     "ar": "ar_AR",  # Arabic  
#     "zh": "zh_CN",  # Chinese
#     "ru": "ru_RU",  # Russian
#     "hi": "hi_IN",  # Hindi
#     "ja": "ja_XX",  # Japanese
#     "sw": "sw_KE",  # Swahili (if available, else skip)
#     "bn": "bn_IN",  # Bengali
#     "tr": "tr_TR",  # Turkish
#     "en": "en_XX"   # English
# }

# # Load dataset
# dataset_path = "balanced_mt_dataset"
# print("📦 Loading multilingual dataset...")
# full_dataset = load_from_disk(dataset_path)
# print(f"✅ Dataset loaded: {len(full_dataset['train'])} train, {len(full_dataset['test'])} test")

# # Check language distribution
# train_lang_dist = {}
# for example in full_dataset['train']:
#     lang = example['language']
#     train_lang_dist[lang] = train_lang_dist.get(lang, 0) + 1

# print("\n📊 Language distribution:")
# available_langs = []
# for lang_code, count in train_lang_dist.items():
#     if count > 0:
#         available_langs.append(lang_code)
#         mbart_code = MBART_LANG_CODES.get(lang_code, "UNKNOWN")
#         print(f"  {lang_code}: {count:,} pairs (mBART: {mbart_code})")

# # Enhanced BLEU computation
# def compute_bleu_advanced(eval_pred):
#     """Advanced BLEU computation with better handling"""
#     predictions, labels = eval_pred
    
#     if len(predictions.shape) == 3:
#         predictions = np.argmax(predictions, axis=-1)
    
#     decoded_preds = []
#     decoded_labels = []
    
#     for pred, label in zip(predictions, labels):
#         # Replace -100 with pad token for decoding
#         label = np.where(label != -100, label, tokenizer.pad_token_id)
        
#         decoded_pred = tokenizer.decode(pred, skip_special_tokens=True).strip()
#         decoded_label = tokenizer.decode(label, skip_special_tokens=True).strip()
        
#         decoded_preds.append(decoded_pred)
#         decoded_labels.append(decoded_label)
    
#     # Compute BLEU with smoothing
#     smoothing = SmoothingFunction().method1
#     bleu_scores = []
#     exact_matches = 0
    
#     for pred, label in zip(decoded_preds, decoded_labels):
#         if not pred.strip() or not label.strip():
#             bleu_scores.append(0.0)
#             continue
            
#         pred_tokens = pred.split()
#         label_tokens = label.split()
        
#         # Check exact match
#         if pred.lower().strip() == label.lower().strip():
#             exact_matches += 1
        
#         if len(pred_tokens) == 0 or len(label_tokens) == 0:
#             bleu_scores.append(0.0)
#             continue
        
#         try:
#             bleu = sentence_bleu(
#                 [label_tokens], 
#                 pred_tokens,
#                 smoothing_function=smoothing,
#                 weights=(0.25, 0.25, 0.25, 0.25)
#             )
#             bleu_scores.append(bleu)
#         except:
#             bleu_scores.append(0.0)
    
#     avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0.0
#     exact_match_ratio = exact_matches / len(decoded_preds) if decoded_preds else 0.0
    
#     return {
#         "bleu": avg_bleu,
#         "exact_match": exact_match_ratio,
#         "avg_pred_length": np.mean([len(p.split()) for p in decoded_preds if p.strip()]) if decoded_preds else 0.0,
#         "avg_label_length": np.mean([len(l.split()) for l in decoded_labels if l.strip()]) if decoded_labels else 0.0,
#         "empty_predictions": sum(1 for p in decoded_preds if not p.strip()) / len(decoded_preds) if decoded_preds else 0.0
#     }

# # Setup logging
# log_dir = "./MT_models_better"
# os.makedirs(log_dir, exist_ok=True)
# results_log_file = os.path.join(log_dir, "better_multilingual_results.csv")

# if not os.path.exists(results_log_file):
#     with open(results_log_file, 'w', newline='', encoding='utf-8') as csvfile:
#         writer = csv.writer(csvfile)
#         writer.writerow(["Model_ID", "Base_Model", "Model_Type", "Total_Train_Samples", "Total_Test_Samples", 
#                         "Vocab_Size", "BLEU_Score", "Exact_Match", "Avg_Pred_Length", "Avg_Label_Length", 
#                         "Empty_Predictions", "Training_Status", "Notes"])

# print("🚀 Starting training with BETTER pre-trained models...")

# for model_id, model_config in BETTER_MODELS.items():
#     print(f"\n{'='*100}")
#     print(f"🔧 Training with {model_config['description']}")
#     print(f"📦 Model: {model_config['model_name']}")
    
#     # Initialize these variables at the beginning of each iteration
#     model = None
#     tokenizer = None
#     trainer = None
    
#     try:
#         # Load model and tokenizer based on type
#         if model_config["type"] == "mbart":
#             print("🤖 Loading mBART for multilingual translation...")
#             model = MBartForConditionalGeneration.from_pretrained(model_config["model_name"])
#             tokenizer = MBart50TokenizerFast.from_pretrained(model_config["tokenizer_name"])
            
#             # Set source language for mBART
#             tokenizer.src_lang = "en_XX"
            
#         elif model_config["type"] == "t5":
#             print("🤖 Loading mT5 for multilingual translation...")
#             model = T5ForConditionalGeneration.from_pretrained(model_config["model_name"])
#             tokenizer = T5TokenizerFast.from_pretrained(model_config["tokenizer_name"])
            
#         elif model_config["type"] == "marian":
#             print("🤖 Loading Marian OPUS-MT model...")
#             model = AutoModelForSeq2SeqLM.from_pretrained(model_config["model_name"])
#             tokenizer = AutoTokenizer.from_pretrained(model_config["tokenizer_name"])
            
#         else:
#             print(f"❌ Unknown model type: {model_config['type']}")
#             continue
            
#         print(f"✅ Model loaded successfully!")
#         print(f"📊 Vocab size: {len(tokenizer)}")
        
#         # Preprocessing function for different model types
#         def preprocess_for_better_models(examples):
#             """Preprocess for pre-trained multilingual models"""
#             sources = []
#             targets = []
            
#             # Handle both single examples and batches
#             if not isinstance(examples["translation"], list):
#                 # Single example
#                 examples = {
#                     "translation": [examples["translation"]], 
#                     "language": [examples["language"]]
#                 }
            
#             for translation, language in zip(examples["translation"], examples["language"]):
#                 # Check if translation is dict or needs parsing
#                 if isinstance(translation, dict):
#                     source = translation.get("en", "")
#                     target = translation.get(language, "")
#                 else:
#                     # Handle string format or other formats
#                     print(f"⚠️ Unexpected translation format: {type(translation)}")
#                     continue
                    
#                 if not source or not target:
#                     continue
                
#                 if model_config["type"] == "mbart":
#                     # mBART format: no special prefixes needed, handled by tokenizer
#                     sources.append(source)
#                     targets.append(target)
                    
#                 elif model_config["type"] == "t5":
#                     # T5 format: "translate English to {language}: {text}"
#                     lang_name = {
#                         "yo": "Yoruba", "ar": "Arabic", "zh": "Chinese", 
#                         "ru": "Russian", "hi": "Hindi", "ja": "Japanese",
#                         "sw": "Swahili", "bn": "Bengali", "tr": "Turkish"
#                     }.get(language, language)
#                     sources.append(f"translate English to {lang_name}: {source}")
#                     targets.append(target)
                    
#                 else:  # Marian
#                     # Marian format: source as-is
#                     sources.append(source)
#                     targets.append(target)
            
#             if not sources or not targets:
#                 print("⚠️ No valid source-target pairs found")
#                 return {"input_ids": [], "attention_mask": [], "labels": []}
            
#             # Dynamic length calculation on sample
#             sample_size = min(100, len(sources))
#             sample_sources = sources[:sample_size]
#             sample_targets = targets[:sample_size]
                
#             source_lengths = [len(tokenizer.encode(s, add_special_tokens=True)) for s in sample_sources]
#             target_lengths = [len(tokenizer.encode(t, add_special_tokens=True)) for t in sample_targets]
            
#             max_source_length = min(256, int(np.percentile(source_lengths, 95)) + 20)
#             max_target_length = min(256, int(np.percentile(target_lengths, 95)) + 20)
            
#             print(f"  📏 Max lengths - Source: {max_source_length}, Target: {max_target_length}")
            
#             # Tokenize inputs
#             model_inputs = tokenizer(
#                 sources,
#                 max_length=max_source_length,
#                 truncation=True,
#                 padding="max_length",
#                 return_tensors=None
#             )
            
#             # Handle target tokenization based on model type
#             if model_config["type"] == "mbart":
#                 # For mBART, need to handle target language properly
#                 # Get target language for first example (assuming batch has same target lang)
#                 first_lang = None
#                 for translation, language in zip(examples["translation"], examples["language"]):
#                     if isinstance(translation, dict):
#                         first_lang = language
#                         break
                
#                 if first_lang and first_lang in MBART_LANG_CODES:
#                     # Set target language
#                     target_lang = MBART_LANG_CODES[first_lang]
#                     tokenizer.tgt_lang = target_lang
                    
#                     # Tokenize with target language context
#                     with tokenizer.as_target_tokenizer():
#                         labels = tokenizer(
#                             targets,
#                             max_length=max_target_length,
#                             truncation=True,
#                             padding="max_length",
#                             return_tensors=None
#                         )
#                 else:
#                     # Fallback without target language setting
#                     labels = tokenizer(
#                         targets,
#                         max_length=max_target_length,
#                         truncation=True,
#                         padding="max_length",
#                         return_tensors=None
#                     )
#             else:
#                 # For T5 and Marian
#                 labels = tokenizer(
#                     targets,
#                     max_length=max_target_length,
#                     truncation=True,
#                     padding="max_length",
#                     return_tensors=None
#                 )
            
#             # Replace padding with -100 for loss computation
#             labels["input_ids"] = [
#                 [(label if label != tokenizer.pad_token_id else -100) for label in label_seq]
#                 for label_seq in labels["input_ids"]
#             ]
            
#             model_inputs["labels"] = labels["input_ids"]
#             return model_inputs
        
#         # Preprocess dataset
#         print("⚙️  Preprocessing data for better models...")
        
#         # Process in smaller batches to avoid memory issues
#         processed_dataset = full_dataset.map(
#             preprocess_for_better_models,
#             batched=True,
#             remove_columns=full_dataset["train"].column_names,
#             desc="Preprocessing",
#             batch_size=100  # Reduced batch size
#         )
        
#         train_dataset = processed_dataset["train"]
#         eval_dataset = processed_dataset["test"]
        
#         # Filter out empty examples
#         def filter_empty(example):
#             return len(example["input_ids"]) > 0 and len(example["labels"]) > 0
        
#         train_dataset = train_dataset.filter(filter_empty)
#         eval_dataset = eval_dataset.filter(filter_empty)
        
#         print(f"✅ Preprocessed: {len(train_dataset)} train, {len(eval_dataset)} eval samples")
        
#         if len(train_dataset) == 0 or len(eval_dataset) == 0:
#             print("❌ No valid samples after preprocessing")
#             continue
        
#         # Setup training
#         output_dir = f"./MT_models_better/{model_id}"
#         os.makedirs(output_dir, exist_ok=True)
        
#         # Calculate steps for evaluation and saving
#         steps_per_epoch = len(train_dataset) // (8 * 4)  # batch_size * gradient_accumulation_steps
#         eval_steps = max(200, steps_per_epoch // 2)
#         save_steps = eval_steps
        
#         # Optimized training arguments for pre-trained models
#         training_args = Seq2SeqTrainingArguments(
#             output_dir=output_dir,
#             num_train_epochs=1,  # Reduced epochs for testing
#             per_device_train_batch_size=4,  # Reduced batch size
#             per_device_eval_batch_size=4,
#             gradient_accumulation_steps=8,  # Increased to maintain effective batch size
#             learning_rate=3e-5,  # Slightly higher LR
#             weight_decay=0.01,
#             warmup_steps=min(500, steps_per_epoch),
#             eval_strategy="steps",
#             eval_steps=eval_steps,
#             save_strategy="steps", 
#             save_steps=save_steps,
#             save_total_limit=2,
#             logging_steps=100,
#             report_to="none",
#             predict_with_generate=True,
#             generation_max_length=128,
#             generation_num_beams=2,  # Reduced beams for speed
#             fp16=torch.cuda.is_available(),
#             load_best_model_at_end=True,
#             metric_for_best_model="bleu",
#             greater_is_better=True,
#             dataloader_num_workers=0,  # Reduced workers
#             remove_unused_columns=False,
#             label_smoothing_factor=0.1,
#             dataloader_pin_memory=False,  # Disable pin memory to save GPU memory
#         )
        
#         # Data collator
#         data_collator = DataCollatorForSeq2Seq(
#             tokenizer=tokenizer,
#             model=model,
#             padding=True,
#             pad_to_multiple_of=8 if training_args.fp16 else None
#         )
        
#         # Trainer
#         trainer = Seq2SeqTrainer(
#             model=model,
#             args=training_args,
#             train_dataset=train_dataset,
#             eval_dataset=eval_dataset,
#             tokenizer=tokenizer,
#             data_collator=data_collator,
#             compute_metrics=compute_bleu_advanced
#         )
        
#         # Train
#         print("🏋️  Starting training with better pre-trained model...")
#         trainer.train()
        
#         # Evaluate
#         print("📊 Final evaluation...")
#         eval_results = trainer.evaluate()
        
#         # Save
#         print("💾 Saving model...")
#         trainer.save_model()
#         tokenizer.save_pretrained(output_dir)
        
#         # Log results
#         bleu_score = eval_results.get("eval_bleu", 0.0)
#         exact_match = eval_results.get("eval_exact_match", 0.0)
#         avg_pred_len = eval_results.get("eval_avg_pred_length", 0.0)
#         avg_label_len = eval_results.get("eval_avg_label_length", 0.0)
#         empty_preds = eval_results.get("eval_empty_predictions", 0.0)
        
#         with open(results_log_file, 'a', newline='', encoding='utf-8') as csvfile:
#             writer = csv.writer(csvfile)
#             writer.writerow([
#                 model_id, model_config["model_name"], model_config["type"],
#                 len(train_dataset), len(eval_dataset), len(tokenizer),
#                 round(bleu_score, 4), round(exact_match, 4), 
#                 round(avg_pred_len, 2), round(avg_label_len, 2),
#                 round(empty_preds, 4), "SUCCESS", model_config["description"]
#             ])
        
#         print(f"✅ Completed: {model_id}")
#         print(f"📈 BLEU Score: {bleu_score:.4f}")
#         print(f"🎯 Exact Match: {exact_match:.4f}")
#         print(f"📏 Lengths - Pred: {avg_pred_len:.1f}, Label: {avg_label_len:.1f}")
        
#         # Quick translation test
#         print("\n🧪 Quick translation test:")
#         test_input = "Hello, how are you today?"
        
#         if model_config["type"] == "t5":
#             test_input = "translate English to Arabic: " + test_input
#         elif model_config["type"] == "mbart":
#             tokenizer.src_lang = "en_XX"
#             tokenizer.tgt_lang = "ar_AR"  # Test with Arabic
        
#         # inputs = tokenizer(test_input, return_tensors="pt", padding=True)
#         inputs = tokenizer(test_input, return_tensors="pt", padding=True).to(model.device)
        
#         with torch.no_grad():
#             outputs = model.generate(
#                 **inputs,
#                 max_length=50,
#                 num_beams=2,
#                 early_stopping=True,
#                 do_sample=False
#             )
        
#         translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
#         print(f"  Input: {test_input}")
#         print(f"  Output: {translation}")
        
#     except Exception as e:
#         print(f"❌ Failed to train {model_id}: {str(e)}")
#         import traceback
#         traceback.print_exc()
        
#         with open(results_log_file, 'a', newline='', encoding='utf-8') as csvfile:
#             writer = csv.writer(csvfile)
#             writer.writerow([
#                 model_id, model_config.get("model_name", ""), model_config.get("type", ""),
#                 0, 0, 0, 0, 0, 0, 0, 0, "TRAINING_FAILED", str(e)[:100]
#             ])
    
#     finally:
#         # Cleanup - ensure variables exist before deleting
#         if 'model' in locals() and model is not None:
#             del model
#         if 'trainer' in locals() and trainer is not None:
#             del trainer
#         if 'tokenizer' in locals() and tokenizer is not None:
#             del tokenizer
#         torch.cuda.empty_cache()
#         gc.collect()

# print("\n🎉 Better model training completed!")
# print(f"📋 Results: {results_log_file}")

# # Usage examples
# print(f"\n💡 Usage examples:")
# print(f"# mBART:")
# print(f"from transformers import MBartForConditionalGeneration, MBart50TokenizerFast")
# print(f"model = MBartForConditionalGeneration.from_pretrained('./MT_models_better/mbart_small')")
# print(f"tokenizer = MBart50TokenizerFast.from_pretrained('./MT_models_better/mbart_small')")

# print(f"\n# Expected improvements:")
# print(f"  • BLEU: 0.15-0.35 (instead of 0.0044)")
# print(f"  • Better multilingual understanding")
# print(f"  • More natural translations")

📦 Loading multilingual dataset...
✅ Dataset loaded: 34376 train, 3820 test

📊 Language distribution:
  zh: 3,807 pairs (mBART: zh_CN)
  bn: 3,829 pairs (mBART: bn_IN)
  ru: 3,824 pairs (mBART: ru_RU)
  ar: 3,838 pairs (mBART: ar_AR)
  ja: 3,808 pairs (mBART: ja_XX)
  sw: 3,810 pairs (mBART: sw_KE)
  yo: 3,820 pairs (mBART: yo_NG)
  tr: 3,814 pairs (mBART: tr_TR)
  hi: 3,826 pairs (mBART: hi_IN)
🚀 Starting training with BETTER pre-trained models...

🔧 Training with Pre-trained multilingual MT model
📦 Model: facebook/mbart-large-50-many-to-many-mmt
🤖 Loading mBART for multilingual translation...
✅ Model loaded successfully!
📊 Vocab size: 250054
⚙️  Preprocessing data for better models...
✅ Preprocessed: 34376 train, 3820 eval samples
🏋️  Starting training with better pre-trained model...
{'loss': 7.3144, 'grad_norm': 5.2044358253479, 'learning_rate': 5.940000000000001e-06, 'epoch': 0.09308820107051431}
{'loss': 5.6617, 'grad_norm': 4.844237804412842, 'learning_rate': 1.1940000000000001e-

In [15]:
!python -m pip install sentencepiece




In [16]:
from transformers import MBart50Tokenizer

tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")


ImportError: 
MBart50Tokenizer requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.


In [17]:
!python -m pip uninstall -y sentencepiece
!python -m  pip install sentencepiece --no-cache-dir


Found existing installation: sentencepiece 0.2.0
Uninstalling sentencepiece-0.2.0:
  Successfully uninstalled sentencepiece-0.2.0
Collecting sentencepiece
  Downloading sentencepiece-0.2.0-cp310-cp310-win_amd64.whl.metadata (8.3 kB)
Downloading sentencepiece-0.2.0-cp310-cp310-win_amd64.whl (991 kB)
   ---------------------------------------- 0.0/991.5 kB ? eta -:--:--
   --------------------------------------- 991.5/991.5 kB 11.6 MB/s eta 0:00:00
Installing collected packages: sentencepiece
Successfully installed sentencepiece-0.2.0


In [18]:

import sentencepiece
print(sentencepiece.__version__)


0.2.0


In [19]:
!which python  # or use %system in Windows


'which' is not recognized as an internal or external command,
operable program or batch file.


In [20]:
!where python


c:\Users\User 4\.conda\envs\multilingual_mt\python.exe
C:\Users\User 4\AppData\Local\Programs\Python\Python310\python.exe
C:\Users\User 4\AppData\Local\Programs\Python\Python311\python.exe
C:\Users\User 4\AppData\Local\Microsoft\WindowsApps\python.exe


In [5]:
from transformers import MBart50Tokenizer
tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")


In [None]:
##working ... but slow (running final code)
import os
import csv
import torch
import warnings
import gc
from datasets import load_from_disk
from transformers import (
    EncoderDecoderModel,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    logging as hf_logging,
    MBartForConditionalGeneration,
    MBart50TokenizerFast,
    T5ForConditionalGeneration,
    T5TokenizerFast,
    BertTokenizerFast,
    GPT2TokenizerFast,
    BartForConditionalGeneration
)
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import numpy as np

# Suppress warnings
hf_logging.set_verbosity_error()
warnings.filterwarnings("ignore")

# FIXED: Model-Tokenizer Compatibility Mapping
MODEL_TOKENIZER_CONFIGS = {
    "hf_bpe_hf": {
        "base_model": "facebook/bart-large",  # BART uses BPE
        "model_class": BartForConditionalGeneration,
        "description": "BART with BPE tokenization"
    },
    "hf_wordpiece_hf": {
        "base_model": "facebook/bart-large",  # BART uses BPE
        "model_class": BartForConditionalGeneration,
        "description": "BART with BPE tokenization"
    },
    "sp_unigram_hf": {
        "base_model": "facebook/bart-large",  # BART uses BPE
        "model_class": BartForConditionalGeneration,
        "description": "BART with BPE tokenization"
    }
    # "hf_wordpiece_hf": {
    #     "base_model": "google/mt5-base",  # mT5 can work with WordPiece
    #     "model_class": T5ForConditionalGeneration,
    #     "description": "mT5 with WordPiece tokenization"
    # },
    # "sp_unigram_hf": {
    #     "base_model": "google/mt5-base",  # mT5 uses SentencePiece Unigram
    #     "model_class": T5ForConditionalGeneration,
    #     "description": "mT5 with SentencePiece Unigram"
    # }
}

# Custom tokenizer settings
# tokenizer_sizes = ["small", "medium", "large"]
tokenizer_sizes = [ "medium", "large"]
tokenizer_types = ["hf_bpe_hf", "hf_wordpiece_hf", "sp_unigram_hf"]
# tokenizer_types = ["hf_wordpiece_hf", "sp_unigram_hf"]

# Languages for evaluation
languages = {
    "yo": "Yoruba",
    "ar": "Arabic", 
    "zh": "Chinese",
    "ru": "Russian",
    "hi": "Hindi",
    "ja": "Japanese",
    "swa": "Swahili",
    "bn": "Bengali",
    "tr": "Turkish"
}
SRC_LANG = "en"

# Load complete multilingual dataset
dataset_path = "balanced_mt_dataset"
print("📦 Loading complete multilingual dataset...")
full_dataset = load_from_disk(dataset_path)
print(f"✅ Dataset loaded: {len(full_dataset['train'])} train, {len(full_dataset['test'])} test")
print(f"🌍 All languages included: {list(languages.keys())}")

# Enhanced BLEU computation with progress tracking
def compute_multilingual_bleu(eval_pred):
    """Multilingual BLEU computation with progress tracking"""
    import time
    start_time = time.time()
    
    predictions, labels = eval_pred
    
    if len(predictions.shape) == 3:
        predictions = np.argmax(predictions, axis=-1)
    
    decoded_preds = []
    decoded_labels = []
    
    total_samples = len(predictions)
    print(f"🔄 Starting evaluation of {total_samples} samples...")
    
    # Process in chunks with progress updates
    chunk_size = 50
    for i in range(0, total_samples, chunk_size):
        end_idx = min(i + chunk_size, total_samples)
        chunk_preds = predictions[i:end_idx]
        chunk_labels = labels[i:end_idx]
        
        try:
            for pred, label in zip(chunk_preds, chunk_labels):
                # CRITICAL FIX: Handle negative token IDs properly
                # Filter out negative IDs and pad tokens before decoding
                pred_clean = [token for token in pred if token >= 0 and token < len(tokenizer)]
                label_clean = [token for token in label if token != -100 and token >= 0 and token < len(tokenizer)]
                
                try:
                    decoded_pred = tokenizer.decode(pred_clean, skip_special_tokens=True).strip() if pred_clean else ""
                    decoded_label = tokenizer.decode(label_clean, skip_special_tokens=True).strip() if label_clean else ""
                except Exception as e:
                    # Fallback for any decode errors
                    decoded_pred = ""
                    decoded_label = ""
                
                decoded_preds.append(decoded_pred)
                decoded_labels.append(decoded_label)
            
            # Progress update every chunk
            if (i + chunk_size) % (chunk_size * 5) == 0 or end_idx == total_samples:
                elapsed = time.time() - start_time
                print(f"📊 Evaluation progress: {end_idx}/{total_samples} samples ({elapsed:.1f}s elapsed)")
                
        except Exception as e:
            print(f"⚠️  Batch decode failed for chunk {i}-{end_idx}: {str(e)}")
            # Add empty strings for failed batch
            for _ in range(end_idx - i):
                decoded_preds.append("")
                decoded_labels.append("")
    
    # Compute BLEU with smoothing
    print("🧮 Computing BLEU scores...")
    smoothing = SmoothingFunction().method1
    bleu_scores = []
    exact_matches = 0
    
    for idx, (pred, label) in enumerate(zip(decoded_preds, decoded_labels)):
        if idx % 1000 == 0 and idx > 0:
            print(f"   BLEU calculation: {idx}/{len(decoded_preds)} samples")
            
        if not pred.strip() or not label.strip():
            bleu_scores.append(0.0)
            continue
            
        pred_tokens = pred.split()
        label_tokens = label.split()
        
        # Check exact match
        if pred.lower().strip() == label.lower().strip():
            exact_matches += 1
        
        if len(pred_tokens) == 0 or len(label_tokens) == 0:
            bleu_scores.append(0.0)
            continue
        
        try:
            bleu = sentence_bleu(
                [label_tokens], 
                pred_tokens,
                smoothing_function=smoothing,
                weights=(0.25, 0.25, 0.25, 0.25)
            )
            bleu_scores.append(bleu)
        except:
            bleu_scores.append(0.0)
    
    avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0.0
    exact_match_ratio = exact_matches / len(decoded_preds) if decoded_preds else 0.0
    
    total_time = time.time() - start_time
    print(f"✅ Evaluation complete: BLEU={avg_bleu:.4f}, Exact Match={exact_match_ratio:.4f} ({total_time:.1f}s total)")
    
    return {
        "bleu": avg_bleu,
        "exact_match": exact_match_ratio,
        "avg_pred_length": np.mean([len(p.split()) for p in decoded_preds if p.strip()]) if decoded_preds else 0.0,
        "avg_label_length": np.mean([len(l.split()) for l in decoded_labels if l.strip()]) if decoded_labels else 0.0,
        "empty_predictions": sum(1 for p in decoded_preds if not p.strip()) / len(decoded_preds) if decoded_preds else 0.0
    }

def setup_custom_tokenizer_for_model(tokenizer_path, model_type):
    """Setup custom tokenizer with proper special tokens for specific model"""
    # Load custom tokenizer
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
    
    # Model-specific special token configuration
    if model_type == "bart":
        special_tokens = {
            'bos_token': '<s>',
            'eos_token': '</s>',
            'sep_token': '</s>',
            'pad_token': '<pad>',
            'unk_token': '<unk>',
            'mask_token': '<mask>'
        }
    elif model_type == "t5":
        special_tokens = {
            'pad_token': '<pad>',
            'eos_token': '</s>',
            'unk_token': '<unk>',
            'bos_token': '<pad>',  # T5 doesn't use BOS
            'sep_token': '</s>',
            'mask_token': '<extra_id_0>'
        }
    else:  # Default fallback
        special_tokens = {
            'bos_token': '<s>',
            'eos_token': '</s>',
            'sep_token': '</s>',
            'pad_token': '<pad>',
            'unk_token': '<unk>',
            'mask_token': '<mask>'
        }
    
    # Add missing special tokens
    tokens_to_add = {}
    for token_name, token_value in special_tokens.items():
        if getattr(tokenizer, token_name, None) is None:
            tokens_to_add[token_name] = token_value
    
    if tokens_to_add:
        tokenizer.add_special_tokens(tokens_to_add)
    
    # Ensure we have the required tokens
    assert tokenizer.eos_token is not None, "EOS token is required"
    assert tokenizer.pad_token is not None, "PAD token is required"
    
    return tokenizer

def configure_model_for_custom_tokenizer(model, tokenizer, model_type):
    """Configure model for custom tokenizer based on model type"""
    # Resize embeddings
    print("🔄 Resizing model embeddings to match custom tokenizer...")
    old_vocab_size = model.config.vocab_size
    model.resize_token_embeddings(len(tokenizer))
    print(f"   Vocab size: {old_vocab_size} → {len(tokenizer)}")
    
    # Model-specific configuration
    if model_type == "bart":
        model.config.decoder_start_token_id = tokenizer.bos_token_id
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.bos_token_id = tokenizer.bos_token_id
        model.config.eos_token_id = tokenizer.eos_token_id
        model.config.sep_token_id = tokenizer.eos_token_id
        model.config.forced_eos_token_id = tokenizer.eos_token_id
        
    elif model_type == "t5":
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.eos_token_id = tokenizer.eos_token_id
        model.config.decoder_start_token_id = tokenizer.pad_token_id  # T5 uses pad_token_id as decoder start
        model.config.forced_eos_token_id = tokenizer.eos_token_id
    
    # Update generation config if it exists
    if hasattr(model, 'generation_config') and model.generation_config is not None:
        model.generation_config.pad_token_id = tokenizer.pad_token_id
        model.generation_config.eos_token_id = tokenizer.eos_token_id
        model.generation_config.forced_eos_token_id = tokenizer.eos_token_id
        if model_type == "bart":
            model.generation_config.decoder_start_token_id = tokenizer.bos_token_id
            model.generation_config.bos_token_id = tokenizer.bos_token_id
        elif model_type == "t5":
            model.generation_config.decoder_start_token_id = tokenizer.pad_token_id
    
    return model

# Setup logging
log_dir = "./MT_models_multilingual_custom_tokenizers"
os.makedirs(log_dir, exist_ok=True)
results_log_file = os.path.join(log_dir, "multilingual_custom_tokenizer_results.csv")

if not os.path.exists(results_log_file):
    with open(results_log_file, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["Model_ID", "Base_Model", "Tokenizer_Size", "Tokenizer_Type", 
                        "Total_Languages", "Total_Train_Samples", "Total_Test_Samples", 
                        "Custom_Vocab_Size", "Overall_BLEU", "Overall_Exact_Match", 
                        "Avg_Pred_Length", "Avg_Label_Length", "Empty_Predictions", 
                        "Training_Status", "Notes"])

print("🚀 Starting multilingual training with COMPATIBLE model-tokenizer pairs...")
print(f"📊 Training approach: ONE model per tokenizer handling ALL {len(languages)} languages")

# Main training loop: 3 sizes × 3 types = 9 models total
for size in tokenizer_sizes:
    for tok_type in tokenizer_types:
        # Get compatible model configuration
        model_config = MODEL_TOKENIZER_CONFIGS[tok_type]
        
        # Path to custom tokenizer
        tokenizer_path = f"vocab_final/vocab_final{size}/{tok_type}"
        model_id = f"multilingual_{size}_{tok_type}"
        
        print(f"\n{'='*100}")
        print(f"🚀 Training Model {tokenizer_sizes.index(size)*3 + tokenizer_types.index(tok_type) + 1}/9")
        print(f"🔧 Custom Tokenizer: {size}_{tok_type}")
        print(f"📦 Compatible Base Model: {model_config['base_model']}")
        print(f"🌍 Target Languages: {list(languages.keys())} (ALL SIMULTANEOUSLY)")
        
        # Initialize variables
        model = None
        tokenizer = None
        trainer = None
        
        try:
            # Load and setup custom tokenizer
            print(f"🔧 Loading custom tokenizer: {tokenizer_path}")
            model_type = "bart" if "bart" in model_config['base_model'] else "t5"
            tokenizer = setup_custom_tokenizer_for_model(tokenizer_path, model_type)
            
            print(f"✅ Custom tokenizer loaded successfully!")
            print(f"📊 Custom vocab size: {len(tokenizer)}")
            print(f"🔑 Special tokens - EOS: {tokenizer.eos_token_id}, PAD: {tokenizer.pad_token_id}")
            
            # Load compatible base model
            print(f"🤖 Loading compatible base model: {model_config['base_model']}")
            model = model_config['model_class'].from_pretrained(model_config['base_model'])
            
            # Configure model for custom tokenizer
            model = configure_model_for_custom_tokenizer(model, tokenizer, model_type)
            
            print(f"✅ Model configured with custom tokenizer!")
            
            # Use COMPLETE multilingual dataset (all languages together)
            print(f"📊 Using complete multilingual dataset:")
            print(f"   • Train samples: {len(full_dataset['train'])}")
            print(f"   • Test samples: {len(full_dataset['test'])}")
            print(f"   • Languages: {len(languages)} ({list(languages.keys())})")
            
            # IMPROVED: Preprocessing function for multilingual data
            def preprocess_multilingual_improved(examples):
                """Improved preprocessing for multilingual data with model-specific formatting"""
                sources = []
                targets = []
                
                # Handle both single examples and batches
                if not isinstance(examples["translation"], list):
                    examples = {
                        "translation": [examples["translation"]], 
                        "language": [examples["language"]]
                    }
                
                for translation, lang in zip(examples["translation"], examples["language"]):
                    if isinstance(translation, dict) and lang in languages:
                        source = translation.get(SRC_LANG, "")
                        target = translation.get(lang, "")
                        
                        if source.strip() and target.strip():  # Ensure non-empty
                            # Model-specific formatting
                            if model_type == "t5":
                                # T5 style: "translate English to German: Hello"
                                source_formatted = f"translate English to {languages[lang]}: {source}"
                            else:  # BART style
                                # BART style with language token
                                source_formatted = f"{source} </s> {lang}_XX"  # Add language code
                            
                            sources.append(source_formatted)
                            targets.append(target)
                
                if not sources or not targets:
                    return {"input_ids": [], "attention_mask": [], "labels": []}
                
                # Tokenize with custom tokenizer
                max_length = 256  # INCREASED for better performance
                
                # Input tokenization - REMOVE token_type_ids
                model_inputs = tokenizer(
                    sources,
                    max_length=max_length,
                    truncation=True,
                    padding="max_length",
                    return_tensors=None,
                    return_token_type_ids=False  # CRITICAL FIX
                )
                # Ensure token_type_ids is removed
                model_inputs.pop("token_type_ids", None)
                
                # Target tokenization  
                with tokenizer.as_target_tokenizer():
                    labels = tokenizer(
                        targets,
                        max_length=max_length,
                        truncation=True,
                        padding="max_length",
                        return_tensors=None,
                        return_token_type_ids=False  # CRITICAL FIX
                    )
                # Ensure token_type_ids is removed
                labels.pop("token_type_ids", None)
                
                # IMPROVED: Label processing with proper EOS handling
                processed_labels = []
                for label_seq in labels["input_ids"]:
                    # Find actual end of sequence (before padding)
                    try:
                        pad_start = label_seq.index(tokenizer.pad_token_id)
                        actual_tokens = label_seq[:pad_start]
                    except ValueError:
                        actual_tokens = label_seq
                    
                    # Ensure EOS token at end
                    if actual_tokens and actual_tokens[-1] != tokenizer.eos_token_id:
                        actual_tokens.append(tokenizer.eos_token_id)
                    elif not actual_tokens:
                        actual_tokens = [tokenizer.eos_token_id]
                    
                    # Create final sequence with -100 for padding
                    final_labels = actual_tokens + [-100] * (max_length - len(actual_tokens))
                    final_labels = final_labels[:max_length]  # Ensure correct length
                    
                    processed_labels.append(final_labels)
                
                model_inputs["labels"] = processed_labels
                return model_inputs
            
            # Preprocess complete multilingual dataset
            print("⚙️  Preprocessing complete multilingual dataset with custom tokenizer...")
            processed_dataset = full_dataset.map(
                preprocess_multilingual_improved,
                batched=True,
                remove_columns=full_dataset["train"].column_names,
                desc=f"Preprocessing with {size}_{tok_type}",
                batch_size=50,  # Smaller batch for stability
                num_proc=1  # Single process to avoid issues
            )
            
            train_dataset = processed_dataset["train"]
            eval_dataset = processed_dataset["test"]
            
            # Filter out empty examples
            def filter_valid_examples(example):
                return (
                    len(example["input_ids"]) > 0 and 
                    len(example["labels"]) > 0 and
                    any(label != -100 for label in example["labels"]) and
                    sum(1 for token in example["input_ids"] if token != tokenizer.pad_token_id) > 0
                )
            
            train_dataset = train_dataset.filter(filter_valid_examples)
            eval_dataset = eval_dataset.filter(filter_valid_examples)
            
            # Use smaller eval dataset to avoid memory issues
            print(f"✅ Preprocessed multilingual dataset:")
            print(f"   • Train samples: {len(train_dataset)}")
            print(f"   • Eval samples: {len(eval_dataset)}")
            
            # REDUCE eval dataset size to avoid hanging
            if len(eval_dataset) > 1000:
                eval_dataset = eval_dataset.select(range(1000))
                print(f"   • Reduced eval samples to: {len(eval_dataset)} (to avoid memory issues)")
            
            if len(train_dataset) == 0 or len(eval_dataset) == 0:
                print("❌ No valid samples after preprocessing")
                continue
            
            # Setup training directory
            output_dir = f"./MT_models_multilingual_custom_tokenizers/{model_id}"
            os.makedirs(output_dir, exist_ok=True)
            
            # IMPROVED: Training arguments with simpler evaluation
            training_args = Seq2SeqTrainingArguments(
                output_dir=output_dir,
                num_train_epochs=5,  # More epochs for custom tokenizers
                per_device_train_batch_size=8,  # Smaller batch size
                per_device_eval_batch_size=8,  # REDUCED eval batch size
                gradient_accumulation_steps=8,  # Higher accumulation
                learning_rate=1e-4,  # Lower learning rate for stability
                weight_decay=0.01,
                warmup_ratio=0.1,  # Warmup as ratio
                eval_strategy="steps",  # Change to steps-based evaluation
                eval_steps=500,  # Evaluate every 500 steps instead of epoch end
                save_strategy="epoch",
                save_total_limit=2,
                logging_steps=50,
                report_to="none",
                predict_with_generate=True,
                generation_max_length=128,  # REDUCED generation length
                generation_num_beams=2,  # REDUCED beams
                fp16=torch.cuda.is_available(),
                load_best_model_at_end=False,  # DISABLED to avoid issues
                dataloader_num_workers=0,
                remove_unused_columns=False,  # CRITICAL: Keep all columns
                ignore_data_skip=True,
                label_smoothing_factor=0.1,  # Label smoothing
                max_grad_norm=1.0,  # Gradient clipping
                dataloader_pin_memory=False,  # Disable pin memory
                skip_memory_metrics=True,  # Skip memory tracking
                
                
            )
            
            # Data collator with explicit token_type_ids handling
            data_collator = DataCollatorForSeq2Seq(
                tokenizer=tokenizer,
                model=model,
                padding=True,
                pad_to_multiple_of=8 if training_args.fp16 else None,
                return_tensors="pt",
                label_pad_token_id=-100
            )
            
            # CRITICAL FIX: Custom data collator that removes token_type_ids
            class CustomDataCollator(DataCollatorForSeq2Seq):
                def __call__(self, features):
                    batch = super().__call__(features)
                    # Remove token_type_ids if present
                    batch.pop("token_type_ids", None)
                    return batch
            
            data_collator = CustomDataCollator(
                tokenizer=tokenizer,
                model=model,
                padding=True,
                pad_to_multiple_of=8 if training_args.fp16 else None,
                return_tensors="pt",
                label_pad_token_id=-100
            )
            
            # Trainer
            trainer = Seq2SeqTrainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                tokenizer=tokenizer,
                data_collator=data_collator,
                compute_metrics=compute_multilingual_bleu
            )
            
            # Train multilingual model
            print("🏋️  Starting multilingual training with compatible model-tokenizer pair...")
            trainer.train()
            
            # Evaluate
            print("📊 Final multilingual evaluation...")
            eval_results = trainer.evaluate()
            
            # Save model and custom tokenizer
            print("💾 Saving multilingual model and custom tokenizer...")
            trainer.save_model()
            tokenizer.save_pretrained(output_dir)
            
            # Log results
            overall_bleu = eval_results.get("eval_bleu", 0.0)
            overall_exact_match = eval_results.get("eval_exact_match", 0.0)
            avg_pred_len = eval_results.get("eval_avg_pred_length", 0.0)
            avg_label_len = eval_results.get("eval_avg_label_length", 0.0)
            empty_preds = eval_results.get("eval_empty_predictions", 0.0)
            
            with open(results_log_file, 'a', newline='', encoding='utf-8') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([
                    model_id, model_config["base_model"], size, tok_type,
                    len(languages), len(train_dataset), len(eval_dataset), len(tokenizer),
                    round(overall_bleu, 4), round(overall_exact_match, 4), 
                    round(avg_pred_len, 2), round(avg_label_len, 2),
                    round(empty_preds, 4), "SUCCESS", 
                    f"Compatible {model_config['description']} with {size}_{tok_type}"
                ])
            
            print(f"✅ Completed: {model_id}")
            print(f"📈 Overall BLEU Score: {overall_bleu:.4f}")
            print(f"🎯 Overall Exact Match: {overall_exact_match:.4f}")
            
            # Quick translation tests for different languages
            print(f"\n🧪 Quick multilingual translation tests:")
            test_input = "Hello, how are you today?"
            
            for test_lang in ["ar", "zh", "hi"]:  # Test 3 languages
                if model_type == "t5":
                    formatted_input = f"translate English to {languages[test_lang]}: {test_input}"
                else:  # BART
                    formatted_input = f"{test_input} </s> {test_lang}_XX"
                    
                inputs = tokenizer(formatted_input, return_tensors="pt", padding=True, max_length=256, truncation=True)
                
                # Move to device
                device = next(model.parameters()).device
                inputs = {k: v.to(device) for k, v in inputs.items() if k != "token_type_ids"}
                
                with torch.no_grad():
                    outputs = model.generate(
                        **inputs,
                        max_length=128,
                        num_beams=4,
                        early_stopping=True,
                        do_sample=False,
                        forced_eos_token_id=tokenizer.eos_token_id
                    )
                
                translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
                print(f"  {test_lang} ({languages[test_lang]}): {translation}")
                
        except Exception as e:
            print(f"❌ Failed to train {model_id}: {str(e)}")
            import traceback
            traceback.print_exc()
            
            with open(results_log_file, 'a', newline='', encoding='utf-8') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([
                    model_id, model_config.get("base_model", ""), size, tok_type,
                    len(languages), 0, 0, 0, 0, 0, 0, 0, 0, "TRAINING_FAILED", str(e)[:100]
                ])
        
        finally:
            # Cleanup
            if model is not None:
                del model
            if trainer is not None:
                del trainer
            torch.cuda.empty_cache()
            gc.collect()

print("\n🎉 Multilingual training with compatible model-tokenizer pairs completed!")
print(f"📋 Results saved to: {results_log_file}")
print(f"🔢 Total models trained: 9 (3 sizes × 3 types)")
print(f"🌍 Each model handles all {len(languages)} languages simultaneously")
print("\n📊 Expected improvements:")
print("• BPE tokenizers → BART models (proper compatibility)")
print("• WordPiece tokenizers → mT5 models (better multilingual support)")
print("• Unigram tokenizers → mT5 models (native compatibility)")
print("• Increased sequence length (256 vs 128)")
print("• Better hyperparameters and training setup")



📦 Loading complete multilingual dataset...
✅ Dataset loaded: 34376 train, 3820 test
🌍 All languages included: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr']
🚀 Starting multilingual training with COMPATIBLE model-tokenizer pairs...
📊 Training approach: ONE model per tokenizer handling ALL 9 languages

🚀 Training Model 1/9
🔧 Custom Tokenizer: medium_hf_bpe_hf
📦 Compatible Base Model: facebook/bart-large
🌍 Target Languages: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr'] (ALL SIMULTANEOUSLY)
🔧 Loading custom tokenizer: vocab_final/vocab_finalmedium/hf_bpe_hf
✅ Custom tokenizer loaded successfully!
📊 Custom vocab size: 30002
🔑 Special tokens - EOS: 30001, PAD: 0
🤖 Loading compatible base model: facebook/bart-large
🔄 Resizing model embeddings to match custom tokenizer...
   Vocab size: 50265 → 30002
✅ Model configured with custom tokenizer!
📊 Using complete multilingual dataset:
   • Train samples: 34376
   • Test samples: 3820
   • Languages: 9 (['yo', 'ar', 'zh', 'ru', 'h

Preprocessing with medium_hf_wordpiece_hf: 100%|██████████| 34376/34376 [00:05<00:00, 6544.68 examples/s]
Preprocessing with medium_hf_wordpiece_hf: 100%|██████████| 3820/3820 [00:00<00:00, 6431.39 examples/s]
Filter: 100%|██████████| 30566/30566 [00:34<00:00, 889.34 examples/s]
Filter: 100%|██████████| 3386/3386 [00:03<00:00, 872.15 examples/s]


✅ Preprocessed multilingual dataset:
   • Train samples: 30566
   • Eval samples: 3386
   • Reduced eval samples to: 1000 (to avoid memory issues)
🏋️  Starting multilingual training with compatible model-tokenizer pair...
{'loss': 12.2691, 'grad_norm': 13.021721839904785, 'learning_rate': 1.6317991631799166e-05, 'epoch': 0.10468463752944256}
{'loss': 9.1893, 'grad_norm': 6.743889331817627, 'learning_rate': 3.723849372384937e-05, 'epoch': 0.2093692750588851}
{'loss': 8.4695, 'grad_norm': 84.4217529296875, 'learning_rate': 5.8158995815899583e-05, 'epoch': 0.31405391258832765}
{'loss': 7.9399, 'grad_norm': 3.676532745361328, 'learning_rate': 7.90794979079498e-05, 'epoch': 0.4187385501177702}
{'loss': 7.2826, 'grad_norm': 3.3325424194335938, 'learning_rate': 0.0001, 'epoch': 0.5234231876472127}
{'loss': 6.8487, 'grad_norm': 305.4402160644531, 'learning_rate': 9.767549976754998e-05, 'epoch': 0.6281078251766553}
{'loss': 7.3415, 'grad_norm': 2.721489429473877, 'learning_rate': 9.535099953509

Preprocessing with medium_sp_unigram_hf: 100%|██████████| 34376/34376 [00:04<00:00, 7796.38 examples/s]
Preprocessing with medium_sp_unigram_hf: 100%|██████████| 3820/3820 [00:00<00:00, 7538.05 examples/s]
Filter: 100%|██████████| 30566/30566 [00:30<00:00, 1012.19 examples/s]
Filter: 100%|██████████| 3386/3386 [00:03<00:00, 1011.39 examples/s]


✅ Preprocessed multilingual dataset:
   • Train samples: 30566
   • Eval samples: 3386
   • Reduced eval samples to: 1000 (to avoid memory issues)
🏋️  Starting multilingual training with compatible model-tokenizer pair...
{'loss': 11.3399, 'grad_norm': 3.4217100143432617, 'learning_rate': 1.7573221757322174e-05, 'epoch': 0.10468463752944256}
{'loss': 8.7204, 'grad_norm': 4.216411590576172, 'learning_rate': 3.849372384937239e-05, 'epoch': 0.2093692750588851}
{'loss': 8.2838, 'grad_norm': 3.684976577758789, 'learning_rate': 5.94142259414226e-05, 'epoch': 0.31405391258832765}
{'loss': 7.8448, 'grad_norm': 2.833085536956787, 'learning_rate': 8.03347280334728e-05, 'epoch': 0.4187385501177702}
{'loss': 7.2812, 'grad_norm': 3.3034207820892334, 'learning_rate': 9.9860529986053e-05, 'epoch': 0.5234231876472127}
{'loss': 6.8879, 'grad_norm': 3.2613837718963623, 'learning_rate': 9.753602975360298e-05, 'epoch': 0.6281078251766553}
{'loss': 6.5371, 'grad_norm': 2.1150360107421875, 'learning_rate': 

Preprocessing with large_hf_bpe_hf: 100%|██████████| 34376/34376 [00:04<00:00, 7873.34 examples/s]
Preprocessing with large_hf_bpe_hf: 100%|██████████| 3820/3820 [00:00<00:00, 7625.90 examples/s]
Filter: 100%|██████████| 30566/30566 [00:29<00:00, 1038.61 examples/s]
Filter: 100%|██████████| 3386/3386 [00:03<00:00, 1033.72 examples/s]


✅ Preprocessed multilingual dataset:
   • Train samples: 30566
   • Eval samples: 3386
   • Reduced eval samples to: 1000 (to avoid memory issues)
🏋️  Starting multilingual training with compatible model-tokenizer pair...
{'loss': 12.3333, 'grad_norm': 39.86501693725586, 'learning_rate': 1.6317991631799166e-05, 'epoch': 0.10468463752944256}
{'loss': 9.5082, 'grad_norm': 4.491843223571777, 'learning_rate': 3.723849372384937e-05, 'epoch': 0.2093692750588851}
{'loss': 8.5616, 'grad_norm': 2.9539096355438232, 'learning_rate': 5.8158995815899583e-05, 'epoch': 0.31405391258832765}
{'loss': 7.9838, 'grad_norm': 3.2830698490142822, 'learning_rate': 7.90794979079498e-05, 'epoch': 0.4187385501177702}
{'loss': 7.5525, 'grad_norm': 3.104830741882324, 'learning_rate': 0.0001, 'epoch': 0.5234231876472127}
{'loss': 7.3278, 'grad_norm': 3.121720552444458, 'learning_rate': 9.767549976754998e-05, 'epoch': 0.6281078251766553}
{'loss': 7.1855, 'grad_norm': 4.011468410491943, 'learning_rate': 9.53509995350

Preprocessing with large_hf_wordpiece_hf: 100%|██████████| 34376/34376 [00:04<00:00, 7985.20 examples/s]
Preprocessing with large_hf_wordpiece_hf: 100%|██████████| 3820/3820 [00:00<00:00, 7898.79 examples/s]
Filter: 100%|██████████| 30566/30566 [00:29<00:00, 1042.53 examples/s]
Filter: 100%|██████████| 3386/3386 [00:03<00:00, 1041.63 examples/s]


✅ Preprocessed multilingual dataset:
   • Train samples: 30566
   • Eval samples: 3386
   • Reduced eval samples to: 1000 (to avoid memory issues)
🏋️  Starting multilingual training with compatible model-tokenizer pair...
{'loss': 12.4722, 'grad_norm': 8.382189750671387, 'learning_rate': 1.6736401673640167e-05, 'epoch': 0.10468463752944256}
{'loss': 9.6524, 'grad_norm': 8.992897987365723, 'learning_rate': 3.765690376569038e-05, 'epoch': 0.2093692750588851}
{'loss': 9.0752, 'grad_norm': 4.165016174316406, 'learning_rate': 5.857740585774059e-05, 'epoch': 0.31405391258832765}
{'loss': 8.3755, 'grad_norm': 12.15662956237793, 'learning_rate': 7.949790794979079e-05, 'epoch': 0.4187385501177702}
{'loss': 8.0045, 'grad_norm': 2.9229161739349365, 'learning_rate': 9.9953509995351e-05, 'epoch': 0.5234231876472127}
{'loss': 7.7123, 'grad_norm': 24.226835250854492, 'learning_rate': 9.762900976290097e-05, 'epoch': 0.6281078251766553}
{'loss': 7.4933, 'grad_norm': 3.233417510986328, 'learning_rate': 

Preprocessing with large_sp_unigram_hf: 100%|██████████| 34376/34376 [00:04<00:00, 7965.07 examples/s]
Preprocessing with large_sp_unigram_hf: 100%|██████████| 3820/3820 [00:00<00:00, 7660.81 examples/s]
Filter: 100%|██████████| 30566/30566 [00:29<00:00, 1032.83 examples/s]
Filter: 100%|██████████| 3386/3386 [00:03<00:00, 1024.02 examples/s]


✅ Preprocessed multilingual dataset:
   • Train samples: 30566
   • Eval samples: 3386
   • Reduced eval samples to: 1000 (to avoid memory issues)
🏋️  Starting multilingual training with compatible model-tokenizer pair...
{'loss': 11.5044, 'grad_norm': 3.63551664352417, 'learning_rate': 1.7573221757322174e-05, 'epoch': 0.10468463752944256}
{'loss': 8.9993, 'grad_norm': 3.296987295150757, 'learning_rate': 3.849372384937239e-05, 'epoch': 0.2093692750588851}
{'loss': 8.2906, 'grad_norm': 2.983065605163574, 'learning_rate': 5.94142259414226e-05, 'epoch': 0.31405391258832765}
{'loss': 7.8263, 'grad_norm': 3.1163668632507324, 'learning_rate': 8.03347280334728e-05, 'epoch': 0.4187385501177702}
{'loss': 7.4074, 'grad_norm': 2.5006611347198486, 'learning_rate': 9.9860529986053e-05, 'epoch': 0.5234231876472127}
{'loss': 7.3326, 'grad_norm': 3.549058437347412, 'learning_rate': 9.753602975360298e-05, 'epoch': 0.6281078251766553}
{'loss': 7.029, 'grad_norm': 2.484065532684326, 'learning_rate': 9.52

In [3]:
#Only mt5

import os
import csv
import torch
import warnings
import gc
from datasets import load_from_disk
from transformers import (
    MT5ForConditionalGeneration,  # CRITICAL: Use MT5, not T5
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    logging as hf_logging
)
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import numpy as np

# Suppress warnings
hf_logging.set_verbosity_error()
warnings.filterwarnings("ignore")

# MT5-ONLY Model Configuration
MT5_TOKENIZER_CONFIGS = {
    "hf_wordpiece_hf": {
        "base_model": "google/mt5-base",
        "description": "mT5 with WordPiece tokenization"
    },
    "sp_unigram_hf": {
        "base_model": "google/mt5-base",
        "description": "mT5 with SentencePiece Unigram"
    }
}

# Settings
tokenizer_sizes = ["small", "medium", "large"]
mt5_tokenizer_types = ["hf_wordpiece_hf", "sp_unigram_hf"]  # Only MT5 compatible

# Languages
languages = {
    "yo": "Yoruba", "ar": "Arabic", "zh": "Chinese", "ru": "Russian",
    "hi": "Hindi", "ja": "Japanese", "swa": "Swahili", "bn": "Bengali", "tr": "Turkish"
}
SRC_LANG = "en"

# Load dataset
dataset_path = "balanced_mt_dataset"
print("📦 Loading complete multilingual dataset...")
full_dataset = load_from_disk(dataset_path)
print(f"✅ Dataset loaded: {len(full_dataset['train'])} train, {len(full_dataset['test'])} test")
print(f"🌍 Languages: {list(languages.keys())}")

def setup_mt5_tokenizer(tokenizer_path):
    """Setup custom tokenizer specifically for mT5"""
    print(f"🔧 Loading mT5-compatible tokenizer: {tokenizer_path}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
    
    # mT5-specific special tokens
    mt5_special_tokens = {
        'pad_token': '<pad>',
        'eos_token': '</s>',
        'unk_token': '<unk>',
        'bos_token': '<pad>',  # mT5 uses <pad> as decoder start
        'sep_token': '</s>',
    }
    
    # Add missing special tokens
    tokens_to_add = {}
    for token_name, token_value in mt5_special_tokens.items():
        if getattr(tokenizer, token_name, None) is None:
            tokens_to_add[token_name] = token_value
    
    if tokens_to_add:
        num_added = tokenizer.add_special_tokens(tokens_to_add)
        print(f"   ✅ Added {num_added} special tokens")
    
    print(f"   📊 Vocab size: {len(tokenizer)}")
    print(f"   🔑 EOS: {tokenizer.eos_token_id}, PAD: {tokenizer.pad_token_id}")
    
    return tokenizer

def configure_mt5_model(model, tokenizer):
    """Configure mT5 model for custom tokenizer"""
    print("🔄 Configuring mT5 model...")
    
    # Resize embeddings
    old_vocab_size = model.config.vocab_size
    model.resize_token_embeddings(len(tokenizer))
    print(f"   Resized embeddings: {old_vocab_size} → {len(tokenizer)}")
    
    # Configure model parameters
    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.eos_token_id = tokenizer.eos_token_id
    model.config.decoder_start_token_id = tokenizer.pad_token_id
    model.config.forced_eos_token_id = tokenizer.eos_token_id
    model.config.vocab_size = len(tokenizer)
    
    # Initialize new embeddings properly
    with torch.no_grad():
        if old_vocab_size < len(tokenizer):
            # Input embeddings
            embed_layer = model.get_input_embeddings()
            new_embeddings = embed_layer.weight[old_vocab_size:].data
            new_embeddings.normal_(mean=0.0, std=0.02)
            
            # Output embeddings (lm_head)
            if hasattr(model, 'lm_head') and hasattr(model.lm_head, 'weight'):
                new_out_embeddings = model.lm_head.weight[old_vocab_size:].data
                new_out_embeddings.normal_(mean=0.0, std=0.02)
    
    # Update generation config
    if hasattr(model, 'generation_config') and model.generation_config is not None:
        model.generation_config.pad_token_id = tokenizer.pad_token_id
        model.generation_config.eos_token_id = tokenizer.eos_token_id
        model.generation_config.decoder_start_token_id = tokenizer.pad_token_id
    
    print("   ✅ mT5 model configured successfully!")
    return model

def preprocess_mt5_data(examples, tokenizer):
    """Preprocessing specifically optimized for mT5"""
    sources = []
    targets = []
    
    # Handle batching
    if not isinstance(examples["translation"], list):
        examples = {
            "translation": [examples["translation"]], 
            "language": [examples["language"]]
        }
    
    for translation, lang in zip(examples["translation"], examples["language"]):
        if isinstance(translation, dict) and lang in languages:
            source = translation.get(SRC_LANG, "")
            target = translation.get(lang, "")
            
            if source.strip() and target.strip():
                # mT5 task format
                source_formatted = f"translate English to {languages[lang]}: {source}"
                sources.append(source_formatted)
                targets.append(target)
    
    if not sources:
        return {"input_ids": [], "attention_mask": [], "labels": []}
    
    # Tokenization settings
    max_length = 128
    
    # Tokenize sources
    model_inputs = tokenizer(
        sources,
        max_length=max_length,
        truncation=True,
        padding="max_length",
        return_tensors=None,
        add_special_tokens=True,
        return_token_type_ids=False
    )
    
    # Tokenize targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            targets,
            max_length=max_length,
            truncation=True,
            padding="max_length", 
            return_tensors=None,
            add_special_tokens=True,
            return_token_type_ids=False
        )
    
    # Process labels (convert pad tokens to -100)
    processed_labels = []
    for label_seq in labels["input_ids"]:
        processed_seq = [
            -100 if token == tokenizer.pad_token_id else token 
            for token in label_seq
        ]
        processed_labels.append(processed_seq)
    
    model_inputs["labels"] = processed_labels
    return model_inputs

class MT5DataCollator(DataCollatorForSeq2Seq):
    """Custom data collator for mT5"""
    def __call__(self, features):
        # Clean features
        cleaned_features = []
        for feature in features:
            cleaned_feature = {k: v for k, v in feature.items() if k != "token_type_ids"}
            cleaned_features.append(cleaned_feature)
        
        batch = super().__call__(cleaned_features)
        batch.pop("token_type_ids", None)
        
        # Fix labels
        if "labels" in batch:
            labels = batch["labels"]
            labels = torch.where(
                labels == self.tokenizer.pad_token_id,
                torch.tensor(-100, dtype=labels.dtype, device=labels.device),
                labels
            )
            batch["labels"] = labels
        
        return batch

def compute_bleu_metrics(eval_pred):
    """Compute BLEU scores with progress tracking"""
    import time
    start_time = time.time()
    
    predictions, labels = eval_pred
    if len(predictions.shape) == 3:
        predictions = np.argmax(predictions, axis=-1)
    
    decoded_preds = []
    decoded_labels = []
    
    total_samples = len(predictions)
    print(f"🔄 Evaluating {total_samples} samples...")
    
    for i, (pred, label) in enumerate(zip(predictions, labels)):
        if i % 250 == 0 and i > 0:
            elapsed = time.time() - start_time
            print(f"   Progress: {i}/{total_samples} ({elapsed:.1f}s)")
        
        # Clean tokens
        pred_clean = [token for token in pred if token >= 0 and token < len(tokenizer)]
        label_clean = [token for token in label if token != -100 and token >= 0 and token < len(tokenizer)]
        
        try:
            decoded_pred = tokenizer.decode(pred_clean, skip_special_tokens=True).strip() if pred_clean else ""
            decoded_label = tokenizer.decode(label_clean, skip_special_tokens=True).strip() if label_clean else ""
        except:
            decoded_pred = ""
            decoded_label = ""
        
        decoded_preds.append(decoded_pred)
        decoded_labels.append(decoded_label)
    
    # Compute BLEU
    print("🧮 Computing BLEU scores...")
    smoothing = SmoothingFunction().method1
    bleu_scores = []
    exact_matches = 0
    
    for pred, label in zip(decoded_preds, decoded_labels):
        if not pred.strip() or not label.strip():
            bleu_scores.append(0.0)
            continue
        
        if pred.lower().strip() == label.lower().strip():
            exact_matches += 1
        
        pred_tokens = pred.split()
        label_tokens = label.split()
        
        if len(pred_tokens) == 0 or len(label_tokens) == 0:
            bleu_scores.append(0.0)
            continue
        
        try:
            bleu = sentence_bleu([label_tokens], pred_tokens, smoothing_function=smoothing)
            bleu_scores.append(bleu)
        except:
            bleu_scores.append(0.0)
    
    avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0.0
    exact_match = exact_matches / len(decoded_preds) if decoded_preds else 0.0
    
    total_time = time.time() - start_time
    print(f"✅ Evaluation complete: BLEU={avg_bleu:.4f}, EM={exact_match:.4f} ({total_time:.1f}s)")
    
    return {
        "bleu": avg_bleu,
        "exact_match": exact_match,
        "avg_pred_length": np.mean([len(p.split()) for p in decoded_preds if p.strip()]) if decoded_preds else 0.0,
        "avg_label_length": np.mean([len(l.split()) for l in decoded_labels if l.strip()]) if decoded_labels else 0.0,
        "empty_predictions": sum(1 for p in decoded_preds if not p.strip()) / len(decoded_preds) if decoded_preds else 0.0
    }

# Setup logging
log_dir = "./MT5_models_only"
os.makedirs(log_dir, exist_ok=True)
results_file = os.path.join(log_dir, "mt5_results.csv")

if not os.path.exists(results_file):
    with open(results_file, 'w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(["Model_ID", "Tokenizer_Size", "Tokenizer_Type", "Custom_Vocab_Size", 
                        "Train_Samples", "Eval_Samples", "BLEU_Score", "Exact_Match", 
                        "Training_Status", "Notes"])

print("🚀 Starting MT5-ONLY multilingual training...")
print(f"📊 Will train {len(tokenizer_sizes)} sizes × {len(mt5_tokenizer_types)} types = {len(tokenizer_sizes) * len(mt5_tokenizer_types)} MT5 models")

# Main MT5 training loop
model_count = 0
for size in tokenizer_sizes:
    for tok_type in mt5_tokenizer_types:
        model_count += 1
        tokenizer_path = f"vocab_final/vocab_final{size}/{tok_type}"
        model_id = f"mt5_{size}_{tok_type}"
        
        print(f"\n{'='*80}")
        print(f"🚀 Training MT5 Model {model_count}/{len(tokenizer_sizes) * len(mt5_tokenizer_types)}")
        print(f"🔧 Custom Tokenizer: {size}_{tok_type}")
        print(f"🌍 Target Languages: ALL {len(languages)} languages")
        
        model = None
        tokenizer = None
        trainer = None
        
        try:
            # Setup MT5 tokenizer
            tokenizer = setup_mt5_tokenizer(tokenizer_path)
            
            # Load MT5 model
            print("🤖 Loading google/mt5-base...")
            model = MT5ForConditionalGeneration.from_pretrained("google/mt5-base")
            
            # Configure model
            model = configure_mt5_model(model, tokenizer)
            
            # Preprocess dataset
            print("⚙️  Preprocessing multilingual dataset for MT5...")
            processed_dataset = full_dataset.map(
                lambda x: preprocess_mt5_data(x, tokenizer),
                batched=True,
                remove_columns=full_dataset["train"].column_names,
                desc=f"MT5 preprocessing {size}_{tok_type}",
                batch_size=32,
                num_proc=1
            )
            
            # Filter valid examples
            def is_valid(example):
                return (
                    len(example["input_ids"]) > 0 and 
                    len(example["labels"]) > 0 and
                    any(label != -100 for label in example["labels"])
                )
            
            train_dataset = processed_dataset["train"].filter(is_valid)
            eval_dataset = processed_dataset["test"].filter(is_valid)
            
            # Limit eval dataset size
            if len(eval_dataset) > 500:
                eval_dataset = eval_dataset.select(range(500))
                print(f"   Reduced eval to {len(eval_dataset)} samples")
            
            print(f"   ✅ Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
            
            if len(train_dataset) == 0:
                print("❌ No valid training samples!")
                continue
            
            # Setup training
            output_dir = f"./MT5_models_only/{model_id}"
            os.makedirs(output_dir, exist_ok=True)
            
            # MT5-optimized training arguments
            training_args = Seq2SeqTrainingArguments(
                output_dir=output_dir,
                num_train_epochs=1,
                per_device_train_batch_size=4,
                per_device_eval_batch_size=4,
                gradient_accumulation_steps=16,
                learning_rate=3e-5,  # Lower for MT5
                weight_decay=0.01,
                warmup_steps=100,
                eval_strategy="steps",
                eval_steps=1000,
                save_strategy="epoch",
                save_total_limit=1,
                logging_steps=100,
                report_to="none",
                predict_with_generate=True,
                generation_max_length=64,
                generation_num_beams=2,
                fp16=False,  # Disabled for stability
                load_best_model_at_end=False,
                dataloader_num_workers=0,
                remove_unused_columns=False,
                ignore_data_skip=True,
                max_grad_norm=0.5,
                dataloader_pin_memory=False,
                skip_memory_metrics=True,
                optim="adamw_torch",
                lr_scheduler_type="cosine"
            )
            
            # Data collator
            data_collator = MT5DataCollator(
                tokenizer=tokenizer,
                model=model,
                padding=True,
                return_tensors="pt",
                label_pad_token_id=-100
            )
            
            # Trainer
            trainer = Seq2SeqTrainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                tokenizer=tokenizer,
                data_collator=data_collator,
                compute_metrics=compute_bleu_metrics
            )
            
            # Train
            print("🏋️  Starting MT5 training...")
            trainer.train()
            
            # Evaluate
            print("📊 Final evaluation...")
            eval_results = trainer.evaluate()
            
            # Save
            print("💾 Saving MT5 model and tokenizer...")
            trainer.save_model()
            tokenizer.save_pretrained(output_dir)
            
            # Log results
            bleu_score = eval_results.get("eval_bleu", 0.0)
            exact_match = eval_results.get("eval_exact_match", 0.0)
            
            with open(results_file, 'a', newline='', encoding='utf-8') as f:
                writer = csv.writer(f)
                writer.writerow([
                    model_id, size, tok_type, len(tokenizer),
                    len(train_dataset), len(eval_dataset), 
                    round(bleu_score, 4), round(exact_match, 4),
                    "SUCCESS", f"MT5 with {size}_{tok_type}"
                ])
            
            print(f"✅ Completed: {model_id}")
            print(f"📈 BLEU: {bleu_score:.4f}, Exact Match: {exact_match:.4f}")
            
            # Quick test
            print("\n🧪 Quick translation test:")
            test_input = "translate English to Hindi: Hello, how are you?"
            inputs = tokenizer(test_input, return_tensors="pt", max_length=128, truncation=True)
            
            device = next(model.parameters()).device
            inputs = {k: v.to(device) for k, v in inputs.items() if k != "token_type_ids"}
            
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_length=64,
                    num_beams=4,
                    early_stopping=True,
                    forced_eos_token_id=tokenizer.eos_token_id
                )
            
            translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
            print(f"   Translation: {translation}")
            
        except Exception as e:
            print(f"❌ Failed: {model_id} - {str(e)}")
            import traceback
            traceback.print_exc()
            
            with open(results_file, 'a', newline='', encoding='utf-8') as f:
                writer = csv.writer(f)
                writer.writerow([
                    model_id, size, tok_type, 0, 0, 0, 0, 0,
                    "FAILED", str(e)[:100]
                ])
        
        finally:
            # Cleanup
            if model is not None:
                del model
            if trainer is not None:
                del trainer
            torch.cuda.empty_cache()
            gc.collect()

print(f"\n🎉 MT5-only training completed!")
print(f"📋 Results saved to: {results_file}")
print(f"🔢 Total MT5 models: {len(tokenizer_sizes) * len(mt5_tokenizer_types)}")

📦 Loading complete multilingual dataset...
✅ Dataset loaded: 34376 train, 3820 test
🌍 Languages: ['yo', 'ar', 'zh', 'ru', 'hi', 'ja', 'swa', 'bn', 'tr']
🚀 Starting MT5-ONLY multilingual training...
📊 Will train 3 sizes × 2 types = 6 MT5 models

🚀 Training MT5 Model 1/6
🔧 Custom Tokenizer: small_hf_wordpiece_hf
🌍 Target Languages: ALL 9 languages
🔧 Loading mT5-compatible tokenizer: vocab_final/vocab_finalsmall/hf_wordpiece_hf
   ✅ Added 2 special tokens
   📊 Vocab size: 23634
   🔑 EOS: 23632, PAD: 0
🤖 Loading google/mt5-base...
🔄 Configuring mT5 model...
   Resized embeddings: 250112 → 23634
   ✅ mT5 model configured successfully!
⚙️  Preprocessing multilingual dataset for MT5...


MT5 preprocessing small_hf_wordpiece_hf: 100%|██████████| 34376/34376 [00:16<00:00, 2136.99 examples/s]
MT5 preprocessing small_hf_wordpiece_hf: 100%|██████████| 3820/3820 [00:01<00:00, 2127.55 examples/s]
Filter: 100%|██████████| 30566/30566 [00:03<00:00, 8754.06 examples/s]
Filter: 100%|██████████| 3386/3386 [00:00<00:00, 8423.02 examples/s]


   Reduced eval to 500 samples
   ✅ Train: 30566, Eval: 500
🏋️  Starting MT5 training...
{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.0, 'epoch': 0.2093692750588851}
{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.0, 'epoch': 0.4187385501177702}


KeyboardInterrupt: 