In [1]:
# Run setup from config notebook
%run 0_config_setup.ipynb

GPU Available: True
Number of GPUs: 2
  GPU 0: NVIDIA GeForce RTX 5090
    Memory: 33.67 GB
  GPU 1: NVIDIA GeForce RTX 5090
    Memory: 33.67 GB

Total VRAM: 67.34 GB
Hugging Face token loaded from cache.
‚úÖ Detected project directory: .
üìÅ Data directory: data
üìÅ Models directory: models
‚ö†Ô∏è Translation data not found at: Data/english-arabic
‚úÖ Directory structure and model paths are ready!
HPC Config: 2 GPUs, Flash Attention: False
COMET Config: model=Unbabel/wmt22-cometkiwi-da, batch_size=64, gpu=1
‚úÖ Configuration loaded successfully!
   - COMET enabled for scoring
   - LoRA enabled: True
   - KL penalty: 0.15
   - Sample size: 100,000
   - Batch sizes: Generation=64, RM=8, PPO=32
Utility functions loaded!
Configuration saved to config.json
GPU Available: True
Number of GPUs: 2
  GPU 0: NVIDIA GeForce RTX 5090
    Memory: 33.67 GB
  GPU 1: NVIDIA GeForce RTX 5090
    Memory: 33.67 GB

Total VRAM: 67.34 GB
Hugging Face token loaded from cache.
‚úÖ Detected project directo

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import json
import random
import gc
import os
import time
from pathlib import Path

# ===========================
# CONFIGURATION: EVALUATION METHOD
# ===========================
USE_COMET = False  # Set to True to use COMET model, False for fast heuristic-based scoring

set_seed(SEED)

print("Synthetic Data Generation Pipeline")
print("=" * 80)
print(f"Scoring method: {'COMET-based' if USE_COMET else 'Heuristic-based'}")
print(f"Heuristic metrics: Length ratio, Punctuation presence, Non-empty validation")
print("=" * 80)


Synthetic Data Generation Pipeline
Scoring method: Heuristic-based
Heuristic metrics: Length ratio, Punctuation presence, Non-empty validation


## Load SFT Model

In [3]:
# Clear GPU cache and set CUDA environment variables
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    gc.collect()
    print("GPU cache cleared")

# CUDA environment variables for optimized memory management
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Unload COMET model if loaded to free memory for SFT model
if USE_COMET and 'comet_model' in globals() and comet_model is not None:
    try:
        del comet_model
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        print("COMET unloaded from GPU memory")
    except Exception as e:
        print(f"Warning: Failed to unload COMET: {e}")

# ===========================
# MODEL LOADING CONFIGURATION
# ===========================
FORCE_CPU = False
USE_BFLOAT16 = True

print("\nModel Loading Configuration:")
print(f"Total GPUs: {NUM_GPUS}")
print(f"Total VRAM: {NUM_GPUS * 31.36:.2f}GB")
print(f"SFT Model: 28.9B parameters")
print(f"Quantization: 8-bit")
print(f"Precision: {'bfloat16' if USE_BFLOAT16 else 'float32'}")


GPU cache cleared

Model Loading Configuration:
Total GPUs: 2
Total VRAM: 62.72GB
SFT Model: 28.9B parameters
Quantization: 8-bit
Precision: bfloat16


In [4]:
print("\nLoading SFT model from Hugging Face...")
model_name = "ModelSpace/GemmaX2-28-9B-v0.1"

if FORCE_CPU:
    print("Loading model on CPU")
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map={"": "cpu"},
        torch_dtype=torch.float32,
        trust_remote_code=True,
        low_cpu_mem_usage=True
    )
else:
    print(f"Loading model with 8-bit quantization across {NUM_GPUS} GPUs")
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    
    max_memory = {
        0: "31GB",
        1: "31GB",
        "cpu": "64GB"
    }
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        load_in_8bit=True,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        max_memory=max_memory
    )

# Set padding token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.eos_token_id

tokenizer.padding_side = "left"

# Report device placement
try:
    if hasattr(model, 'hf_device_map'):
        devices_used = set(str(v) for v in model.hf_device_map.values())
        print(f"Model split across: {devices_used}")
except Exception as e:
    print(f"Device info: {e}")

total_params = sum(p.numel() for p in model.parameters()) / 1e9
print(f"Model size: {total_params:.2f}B parameters")
print("Model ready for inference")


'(MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /ModelSpace/GemmaX2-28-9B-v0.1/resolve/main/tokenizer_config.json (Caused by NameResolutionError("HTTPSConnection(host=\'huggingface.co\', port=443): Failed to resolve \'huggingface.co\' ([Errno -3] Temporary failure in name resolution)"))'), '(Request ID: 5635c47c-525e-4453-af7f-559d92dee5d7)')' thrown while requesting HEAD https://huggingface.co/ModelSpace/GemmaX2-28-9B-v0.1/resolve/main/tokenizer_config.json
Retrying in 1s [Retry 1/5].
Retrying in 1s [Retry 1/5].



Loading SFT model from Hugging Face...
Loading model with 8-bit quantization across 2 GPUs


`torch_dtype` is deprecated! Use `dtype` instead!
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Model split across: {'1', '0'}
Model size: 9.24B parameters
Model ready for inference


## Load Training Data

In [5]:
# ===========================
# LOAD TRAINING DATA (EN + FR)
# ===========================
USE_SAMPLES = False  # Set False for full dataset, True for samples

print("\nLoading source data for synthetic translation generation...")
print(f"Data source: {'SAMPLES' if USE_SAMPLES else 'FULL'}\n")

all_data = []

# Load English data
english_inputs_path = PROJECT_DIR / ("data/english_inputs_samples.json" if USE_SAMPLES else "data/english_inputs.json")

if english_inputs_path.exists():
    with open(english_inputs_path, 'r', encoding='utf-8') as f:
        english_data = json.load(f)
    
    if isinstance(english_data, list):
        for item in english_data:
            if isinstance(item, str):
                all_data.append({'source': item, 'source_lang': 'en'})
            elif isinstance(item, dict):
                text = item.get('text', item.get('source', item.get('sentence', '')))
                if text:
                    all_data.append({'source': text, 'source_lang': 'en'})
    print(f"Loaded {len(english_data)} English samples")
else:
    print(f"Warning: {english_inputs_path.name} not found")

# Load French data
french_inputs_path = PROJECT_DIR / "data/french_inputs.json"

if french_inputs_path.exists():
    with open(french_inputs_path, 'r', encoding='utf-8') as f:
        french_data = json.load(f)
    
    if isinstance(french_data, list):
        for item in french_data:
            if isinstance(item, str):
                all_data.append({'source': item, 'source_lang': 'fr'})
            elif isinstance(item, dict):
                text = item.get('text', item.get('source', item.get('sentence', '')))
                if text:
                    all_data.append({'source': text, 'source_lang': 'fr'})
    print(f"Loaded {len(french_data)} French samples")
else:
    print(f"Warning: french_inputs.json not found")

print(f"\nTotal available data: {len(all_data):,} samples")

# ===========================
# SAMPLE DATA FOR BALANCED TRAINING
# ===========================
SAMPLE_SIZE_PER_LANG = 10_000  # 10K per language = 20K total

en_data = [s for s in all_data if s['source_lang'] == 'en']
fr_data = [s for s in all_data if s['source_lang'] == 'fr']

print(f"Available by language:")
print(f"  English: {len(en_data):,} samples")
print(f"  French: {len(fr_data):,} samples")

random.shuffle(en_data)
random.shuffle(fr_data)

en_samples = en_data[:min(SAMPLE_SIZE_PER_LANG, len(en_data))]
fr_samples = fr_data[:min(SAMPLE_SIZE_PER_LANG, len(fr_data))]

training_samples = en_samples + fr_samples
random.shuffle(training_samples)

total_samples = len(training_samples)
en_pct = 100 * len(en_samples) / total_samples if total_samples > 0 else 0
fr_pct = 100 * len(fr_samples) / total_samples if total_samples > 0 else 0

print(f"\nSampled {total_samples:,} samples for generation:")
print(f"  English to Arabic: {len(en_samples):,} ({en_pct:.1f}%)")
print(f"  French to Arabic: {len(fr_samples):,} ({fr_pct:.1f}%)")



Loading source data for synthetic translation generation...
Data source: FULL

Loaded 3294856 English samples
Loaded 3294856 English samples
Loaded 484003 French samples

Total available data: 3,778,859 samples
Available by language:
  English: 3,294,856 samples
  French: 484,003 samples
Loaded 484003 French samples

Total available data: 3,778,859 samples
Available by language:
  English: 3,294,856 samples
  French: 484,003 samples

Sampled 20,000 samples for generation:
  English to Arabic: 10,000 (50.0%)
  French to Arabic: 10,000 (50.0%)

Sampled 20,000 samples for generation:
  English to Arabic: 10,000 (50.0%)
  French to Arabic: 10,000 (50.0%)


## Generate Translation Candidates

In [6]:
# ===========================
# GENERATION CONFIGURATION
# ===========================
MEGA_BATCH_SIZE = 64  # Batch size for parallel generation
NUM_CANDIDATES = 4
MAX_NEW_TOKENS = 128

print("\nGeneration Configuration:")
print(f"  Batch size: {MEGA_BATCH_SIZE}")
print(f"  Candidates per source: {NUM_CANDIDATES}")
print(f"  Max tokens: {MAX_NEW_TOKENS}")
print(f"  Methods: 4 (temperature, top-k, nucleus, greedy)")
print(f"  Note: Batch size reduced to prevent CUDA OOM errors")



Generation Configuration:
  Batch size: 64
  Candidates per source: 4
  Max tokens: 128
  Methods: 4 (temperature, top-k, nucleus, greedy)
  Note: Batch size reduced to prevent CUDA OOM errors


In [7]:
# ===========================
# TRANSLATION GENERATION METHODS
# ===========================
# Four different sampling strategies for diverse translation candidates
# Each method generates exactly ONE candidate per source

def generate_with_temperature(sources, langs):
    """High temperature sampling for diverse outputs"""
    prompts = [format_translation_prompt(src, lang) for src, lang in zip(sources, langs)]
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)
    input_ids = inputs["input_ids"].to(model.device)
    attention_mask = inputs["attention_mask"].to(model.device)
    
    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=True,
        temperature=1.2,
        top_p=0.95,
        top_k=50,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id
    )
    
    generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    candidates = []
    
    for i in range(len(sources)):
        text = generated_texts[i]
        translation = text.split("Arabic translation:")[-1].strip() if "Arabic translation:" in text else text.strip()
        candidates.append({
            'translation': translation,
            'method': 'temperature',
            'config': {'temperature': 1.2, 'top_p': 0.95, 'top_k': 50}
        })
    return candidates


def generate_with_topk(sources, langs):
    """Conservative top-k sampling"""
    prompts = [format_translation_prompt(src, lang) for src, lang in zip(sources, langs)]
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)
    input_ids = inputs["input_ids"].to(model.device)
    attention_mask = inputs["attention_mask"].to(model.device)
    
    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=True,
        temperature=0.7,
        top_k=30,
        top_p=0.9,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id
    )
    
    generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    candidates = []
    
    for i in range(len(sources)):
        text = generated_texts[i]
        translation = text.split("Arabic translation:")[-1].strip() if "Arabic translation:" in text else text.strip()
        candidates.append({
            'translation': translation,
            'method': 'top_k',
            'config': {'temperature': 0.7, 'top_k': 30, 'top_p': 0.9}
        })
    return candidates


def generate_with_nucleus(sources, langs):
    """Nucleus (top-p) sampling for balanced diversity"""
    prompts = [format_translation_prompt(src, lang) for src, lang in zip(sources, langs)]
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)
    input_ids = inputs["input_ids"].to(model.device)
    attention_mask = inputs["attention_mask"].to(model.device)
    
    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=True,
        temperature=0.9,
        top_p=0.95,
        top_k=0,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id
    )
    
    generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    candidates = []
    
    for i in range(len(sources)):
        text = generated_texts[i]
        translation = text.split("Arabic translation:")[-1].strip() if "Arabic translation:" in text else text.strip()
        candidates.append({
            'translation': translation,
            'method': 'nucleus',
            'config': {'temperature': 0.9, 'top_p': 0.95}
        })
    return candidates


def generate_with_greedy(sources, langs):
    """Greedy decoding for consistent outputs"""
    prompts = [format_translation_prompt(src, lang) for src, lang in zip(sources, langs)]
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)
    input_ids = inputs["input_ids"].to(model.device)
    attention_mask = inputs["attention_mask"].to(model.device)
    
    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=False,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id
    )
    
    generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    candidates = []
    
    for i in range(len(sources)):
        text = generated_texts[i]
        translation = text.split("Arabic translation:")[-1].strip() if "Arabic translation:" in text else text.strip()
        candidates.append({
            'translation': translation,
            'method': 'greedy',
            'config': {'do_sample': False}
        })
    return candidates


GENERATION_METHODS = {
    'temperature': generate_with_temperature,
    'top_k': generate_with_topk,
    'nucleus': generate_with_nucleus,
    'greedy': generate_with_greedy
}

print("\nGeneration methods configured:")
print("  1. Temperature Sampling (high randomness)")
print("  2. Top-K Sampling (conservative)")
print("  3. Nucleus Sampling (balanced)")
print("  4. Greedy Decoding (deterministic)")


Generation methods configured:
  1. Temperature Sampling (high randomness)
  2. Top-K Sampling (conservative)
  3. Nucleus Sampling (balanced)
  4. Greedy Decoding (deterministic)


## Alternative: vLLM for Maximum Speed (Optional)

If you have vLLM installed (`pip install vllm`), uncomment and run the cell below for 3-10x faster generation.

In [8]:
# Optional: Use vLLM for 3-10x faster generation
# Install: pip install vllm

USE_VLLM = False

if USE_VLLM:
    try:
        from vllm import LLM, SamplingParams
        
        print("Loading model with vLLM (tensor parallel)...")
        del model
        torch.cuda.empty_cache()
        
        vllm_model = LLM(
            model="ModelSpace/GemmaX2-28-9B-v0.1",
            tensor_parallel_size=NUM_GPUS,
            dtype="bfloat16",
            max_model_len=512,
            trust_remote_code=True,
        )
        
        vllm_sampling = SamplingParams(
            temperature=0.9,
            top_p=0.95,
            max_tokens=MAX_NEW_TOKENS,
            n=NUM_CANDIDATES
        )
        
        def generate_with_vllm(batch_sources, batch_langs, num_candidates=4):
            """Fast generation with vLLM"""
            prompts = [format_translation_prompt(src, lang) 
                       for src, lang in zip(batch_sources, batch_langs)]
            outputs = vllm_model.generate(prompts, vllm_sampling)
            
            all_candidates = []
            for output in outputs:
                candidates = []
                for completion in output.outputs:
                    text = completion.text
                    translation = text.split("Arabic translation:")[-1].strip() if "Arabic translation:" in text else text.strip()
                    candidates.append({
                        'translation': translation,
                        'config': {'temperature': 0.9, 'top_p': 0.95}
                    })
                all_candidates.append(candidates)
            return all_candidates
        
        print("vLLM loaded successfully (3-10x faster)")
        
    except ImportError:
        print("vLLM not installed. Using transformers instead.")
        print("Install with: pip install vllm")
    except Exception as e:
        print(f"vLLM failed to load: {e}. Using transformers instead.")
else:
    print("Using transformers for generation (set USE_VLLM=True for faster inference)")


Using transformers for generation (set USE_VLLM=True for faster inference)


## Score Translations

In [9]:
# ===========================
# TRANSLATION QUALITY SCORING
# ===========================

def score_candidates(translation_text: str, source_text: str) -> float:
    """Score a single translation using heuristics.
    
    Metrics: length ratio, punctuation presence, non-empty validation
    Returns: Quality score between 0 and 1
    """
    src_len = len(source_text.split())
    src_punct = sum(1 for c in source_text if c in '.!?,;:')
    
    tgt = translation_text
    tgt_len = len(tgt.split())
    
    # Length ratio score
    if src_len > 0:
        length_ratio = min(tgt_len, src_len) / max(tgt_len, src_len)
    else:
        length_ratio = 0.5 if tgt_len == 0 else 0.0
    
    # Punctuation presence score
    tgt_punct = sum(1 for c in tgt if c in '.!?,;:')
    punct_score = 1.0 if (src_punct > 0 and tgt_punct > 0) or (src_punct == 0 and tgt_punct == 0) else 0.7
    
    # Non-empty score
    non_empty_score = 1.0 if len(tgt.strip()) > 0 and '[ERROR]' not in tgt else 0.0
    
    # Combined score
    quality_score = (length_ratio * 0.5 + punct_score * 0.3 + non_empty_score * 0.2)
    quality_score = max(0.0, min(1.0, quality_score))
    
    return quality_score


def create_all_preference_pairs(source_text: str, candidates_with_methods: list) -> list:
    """Create preference pairs from 4 candidates using all pairwise comparisons.
    
    Args:
        source_text: Source text
        candidates_with_methods: List of dicts with 'translation' and 'method' keys
    
    Returns:
        List of preference pairs
    """
    pairs = []
    
    if len(candidates_with_methods) < 2:
        return pairs
    
    # Score each candidate
    scored_candidates = []
    for cand in candidates_with_methods:
        score = score_candidates(cand['translation'], source_text)
        scored_candidates.append({
            'translation': cand['translation'],
            'method': cand['method'],
            'score': score
        })
    
    # Create all pairwise comparisons
    for i in range(len(scored_candidates)):
        for j in range(i + 1, len(scored_candidates)):
            cand_i = scored_candidates[i]
            cand_j = scored_candidates[j]
            
            # Ensure we have a meaningful preference (at least 0.01 difference)
            score_diff = abs(cand_i['score'] - cand_j['score'])
            
            if score_diff >= 0.01:
                if cand_i['score'] > cand_j['score']:
                    chosen = cand_i
                    rejected = cand_j
                else:
                    chosen = cand_j
                    rejected = cand_i
                
                pairs.append({
                    'chosen': chosen['translation'],
                    'rejected': rejected['translation'],
                    'chosen_score': chosen['score'],
                    'rejected_score': rejected['score'],
                    'score_margin': abs(chosen['score'] - rejected['score']),
                    'chosen_method': chosen['method'],
                    'rejected_method': rejected['method']
                })
    
    return pairs

print("Scoring functions configured (heuristic-based)")

Scoring functions configured (heuristic-based)


## Generate Synthetic Preference Dataset

In [10]:
# Test scoring functions
print("Testing scoring and preference pair functions...\n")

test_source = "Hello world"
test_candidates = [
    {'translation': 'Hello world', 'method': 'greedy'},
    {'translation': 'Hi there', 'method': 'temperature'},
    {'translation': '[ERROR]', 'method': 'nucleus'},
    {'translation': 'Bonjour le monde', 'method': 'top_k'}
]

print("Scored candidates:")
for cand in test_candidates:
    score = score_candidates(cand['translation'], test_source)
    print(f"  '{cand['translation'][:30]}' ({cand['method']}): {score:.3f}")

pairs = create_all_preference_pairs(test_source, test_candidates)
print(f"\nPreference pairs created: {len(pairs)}")
for pair in pairs:
    print(f"  Chosen: '{pair['chosen'][:30]}' ({pair['chosen_score']:.3f})")
    print(f"  Rejected: '{pair['rejected'][:30]}' ({pair['rejected_score']:.3f})")
    print(f"  Margin: {pair['score_margin']:.3f}")

print("\nFunctions working correctly!")

Testing scoring and preference pair functions...

Scored candidates:
  'Hello world' (greedy): 1.000
  'Hi there' (temperature): 1.000
  '[ERROR]' (nucleus): 0.550
  'Bonjour le monde' (top_k): 0.833

Preference pairs created: 5
  Chosen: 'Hello world' (1.000)
  Rejected: '[ERROR]' (0.550)
  Margin: 0.450
  Chosen: 'Hello world' (1.000)
  Rejected: 'Bonjour le monde' (0.833)
  Margin: 0.167
  Chosen: 'Hi there' (1.000)
  Rejected: '[ERROR]' (0.550)
  Margin: 0.450
  Chosen: 'Hi there' (1.000)
  Rejected: 'Bonjour le monde' (0.833)
  Margin: 0.167
  Chosen: 'Bonjour le monde' (0.833)
  Rejected: '[ERROR]' (0.550)
  Margin: 0.283

Functions working correctly!


In [11]:
# ===========================
# MAIN GENERATION AND SCORING LOOP WITH CHECKPOINTING
# ===========================

# Clear GPU memory before starting
print("Clearing GPU memory...")
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    gc.collect()
    print("‚úì GPU memory cleared")

print("\n" + "=" * 80)
print("MULTI-METHOD SYNTHETIC DATA GENERATION")
print(f"Scoring: {'COMET-based' if USE_COMET else 'Heuristic-based'}")
print("=" * 80)

print(f"\nConfiguration:")
print(f"  Total samples: {len(training_samples):,}")
print(f"  Batch size: {MEGA_BATCH_SIZE}")
print(f"  Candidates per sample: 4 (one per method)")
print(f"  Generation methods: 4")
print(f"  GPUs: {NUM_GPUS}")

en_count = sum(1 for s in training_samples if s['source_lang'] == 'en')
fr_count = sum(1 for s in training_samples if s['source_lang'] == 'fr')
print(f"\nLanguage distribution:")
print(f"  English to Arabic: {en_count:,} ({100*en_count/len(training_samples):.1f}%)")
print(f"  French to Arabic: {fr_count:,} ({100*fr_count/len(training_samples):.1f}%)")
print("=" * 80)

# Output file paths
en_ar_candidates_file = OUTPUTS_DIR / "english_arabic_candidates.jsonl"
fr_ar_candidates_file = OUTPUTS_DIR / "french_arabic_candidates.jsonl"
en_ar_preferences_file = OUTPUTS_DIR / "en-ar-preferences.jsonl"
fr_ar_preferences_file = OUTPUTS_DIR / "fr-ar-preferences.jsonl"

# Checkpoint file paths
checkpoint_dir = DATA_DIR / "checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint_file = checkpoint_dir / "generation_checkpoint.json"
en_candidates_checkpoint = checkpoint_dir / "en_candidates_checkpoint.jsonl"
fr_candidates_checkpoint = checkpoint_dir / "fr_candidates_checkpoint.jsonl"
en_preferences_checkpoint = checkpoint_dir / "en_preferences_checkpoint.jsonl"
fr_preferences_checkpoint = checkpoint_dir / "fr_preferences_checkpoint.jsonl"
stats_checkpoint = checkpoint_dir / "stats_checkpoint.json"

# Load checkpoint if exists
resume_from_batch = 0
en_candidates_data = []
fr_candidates_data = []
en_preferences_data = []
fr_preferences_data = []
method_stats = {method_name: {'count': 0, 'avg_score': 0, 'scores': []} for method_name in GENERATION_METHODS.keys()}
quality_scores_collected = []
en_pairs_count = 0
fr_pairs_count = 0
errors_count = 0

if checkpoint_file.exists():
    print("\n‚è≥ Loading checkpoint...")
    try:
        with open(checkpoint_file, 'r') as f:
            checkpoint = json.load(f)
        resume_from_batch = checkpoint['last_completed_batch'] + 1
        en_pairs_count = checkpoint['en_pairs_count']
        fr_pairs_count = checkpoint['fr_pairs_count']
        errors_count = checkpoint['errors_count']
        
        # Load checkpoint data files
        if en_candidates_checkpoint.exists():
            with open(en_candidates_checkpoint, 'r', encoding='utf-8') as f:
                en_candidates_data = [json.loads(line) for line in f if line.strip()]
        
        if fr_candidates_checkpoint.exists():
            with open(fr_candidates_checkpoint, 'r', encoding='utf-8') as f:
                fr_candidates_data = [json.loads(line) for line in f if line.strip()]
        
        if en_preferences_checkpoint.exists():
            with open(en_preferences_checkpoint, 'r', encoding='utf-8') as f:
                en_preferences_data = [json.loads(line) for line in f if line.strip()]
        
        if fr_preferences_checkpoint.exists():
            with open(fr_preferences_checkpoint, 'r', encoding='utf-8') as f:
                fr_preferences_data = [json.loads(line) for line in f if line.strip()]
        
        if stats_checkpoint.exists():
            with open(stats_checkpoint, 'r') as f:
                checkpoint_stats = json.load(f)
            method_stats = checkpoint_stats.get('method_stats', method_stats)
            quality_scores_collected = checkpoint_stats.get('quality_scores_collected', [])
        
        print(f"‚úì Checkpoint loaded!")
        print(f"  Resuming from batch {resume_from_batch}")
        print(f"  EN candidates: {len(en_candidates_data):,}")
        print(f"  FR candidates: {len(fr_candidates_data):,}")
        print(f"  EN preference pairs: {len(en_preferences_data):,}")
        print(f"  FR preference pairs: {len(fr_preferences_data):,}")
    except Exception as e:
        print(f"‚úó Error loading checkpoint: {e}")
        print("  Starting from beginning...")
        resume_from_batch = 0
else:
    print("\nüìù No checkpoint found. Starting from beginning...")

num_batches = (len(training_samples) + MEGA_BATCH_SIZE - 1) // MEGA_BATCH_SIZE
start_time = time.time()
samples_processed = resume_from_batch * MEGA_BATCH_SIZE
checkpoint_interval = 20  # Save checkpoint more frequently (every 20 batches)

for batch_idx in tqdm(range(resume_from_batch, num_batches), desc="Processing batches", initial=resume_from_batch, total=num_batches):
    start_idx = batch_idx * MEGA_BATCH_SIZE
    end_idx = min(start_idx + MEGA_BATCH_SIZE, len(training_samples))
    batch_samples = training_samples[start_idx:end_idx]
    
    batch_sources = [s['source'] for s in batch_samples]
    batch_langs = [s['source_lang'] for s in batch_samples]
    
    try:
        # Generate with all 4 methods
        all_method_results = {}
        for method_name, method_func in GENERATION_METHODS.items():
            try:
                # Clear GPU memory before each method
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    gc.collect()
                
                method_candidates = method_func(batch_sources, batch_langs)
                all_method_results[method_name] = method_candidates
            except RuntimeError as e:
                if "out of memory" in str(e).lower():
                    print(f"\n‚ö†Ô∏è  OOM Error in method {method_name}: Clearing memory and retrying...")
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                        torch.cuda.reset_peak_memory_stats()
                    gc.collect()
                    time.sleep(2)  # Wait a moment before retrying
                    try:
                        method_candidates = method_func(batch_sources, batch_langs)
                        all_method_results[method_name] = method_candidates
                        print(f"‚úì Retry successful for {method_name}")
                    except Exception as retry_err:
                        print(f"‚úó Retry failed: {retry_err}")
                        all_method_results[method_name] = [
                            {'translation': '[ERROR]', 'method': method_name, 'config': {}} 
                            for _ in batch_samples
                        ]
                        errors_count += 1
                else:
                    print(f"Error in method {method_name}: {e}")
                    all_method_results[method_name] = [
                        {'translation': '[ERROR]', 'method': method_name, 'config': {}} 
                        for _ in batch_samples
                    ]
                    errors_count += 1
        
        # Process each sample
        for sample_idx, sample in enumerate(batch_samples):
            source_text = sample['source']
            source_lang = sample['source_lang']
            
            # Collect 4 candidates (one per method)
            four_candidates = []
            for method_name in GENERATION_METHODS.keys():
                if sample_idx < len(all_method_results[method_name]):
                    cand = all_method_results[method_name][sample_idx]
                    four_candidates.append({
                        'translation': cand['translation'],
                        'method': cand['method']
                    })
                    
                    # Track statistics
                    score = score_candidates(cand['translation'], source_text)
                    quality_scores_collected.append(score)
                    method_stats[method_name]['count'] += 1
                    method_stats[method_name]['scores'].append(score)
            
            # Create candidate record
            candidate_record = {
                'source': source_text,
                'source_lang': source_lang,
                'candidates': [
                    {
                        'translation': c['translation'],
                        'method': c['method']
                    }
                    for c in four_candidates
                ]
            }
            
            # Save to appropriate candidates file
            if source_lang == 'en':
                en_candidates_data.append(candidate_record)
            elif source_lang == 'fr':
                fr_candidates_data.append(candidate_record)
            
            # Create preference pairs from 4 candidates
            preference_pairs = create_all_preference_pairs(source_text, four_candidates)
            
            for pair in preference_pairs:
                preference_record = {
                    'source': source_text,
                    'source_lang': source_lang,
                    'chosen': pair['chosen'],
                    'rejected': pair['rejected'],
                    'chosen_score': pair['chosen_score'],
                    'rejected_score': pair['rejected_score'],
                    'margin': pair['score_margin'],
                    'chosen_method': pair['chosen_method'],
                    'rejected_method': pair['rejected_method']
                }
                
                if source_lang == 'en':
                    en_preferences_data.append(preference_record)
                    en_pairs_count += 1
                elif source_lang == 'fr':
                    fr_preferences_data.append(preference_record)
                    fr_pairs_count += 1
        
        samples_processed += len(batch_samples)
        
        # Aggressive memory management
        if batch_idx % 10 == 0 and torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()
    
    except Exception as e:
        errors_count += 1
        if errors_count <= 5:
            print(f"Error in batch {batch_idx}: {e}")
        continue
    
    # Save checkpoint periodically
    if (batch_idx + 1) % checkpoint_interval == 0 or batch_idx == num_batches - 1:
        checkpoint_data = {
            'last_completed_batch': batch_idx,
            'samples_processed': samples_processed,
            'en_pairs_count': en_pairs_count,
            'fr_pairs_count': fr_pairs_count,
            'errors_count': errors_count,
            'timestamp': time.time()
        }
        
        with open(checkpoint_file, 'w') as f:
            json.dump(checkpoint_data, f, indent=2)
        
        # Save checkpoint data files
        with open(en_candidates_checkpoint, 'w', encoding='utf-8') as f:
            for item in en_candidates_data:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
        
        with open(fr_candidates_checkpoint, 'w', encoding='utf-8') as f:
            for item in fr_candidates_data:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
        
        with open(en_preferences_checkpoint, 'w', encoding='utf-8') as f:
            for item in en_preferences_data:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
        
        with open(fr_preferences_checkpoint, 'w', encoding='utf-8') as f:
            for item in fr_preferences_data:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
        
        checkpoint_stats = {
            'method_stats': method_stats,
            'quality_scores_collected': quality_scores_collected
        }
        with open(stats_checkpoint, 'w') as f:
            json.dump(checkpoint_stats, f, indent=2)
    
    # Progress update
    if (batch_idx + 1) % 20 == 0:
        elapsed = time.time() - start_time
        rate = samples_processed / elapsed if elapsed > 0 else 0
        remaining = (len(training_samples) - samples_processed) / rate if rate > 0 else 0
        
        if quality_scores_collected:
            avg_score = sum(quality_scores_collected) / len(quality_scores_collected)
        else:
            avg_score = 0
        
        print(f"\nProgress (Batch {batch_idx + 1}/{num_batches}):")
        print(f"  Samples: {samples_processed:,}/{len(training_samples):,} ({100*samples_processed/len(training_samples):.1f}%)")
        print(f"  Preference pairs: {en_pairs_count + fr_pairs_count:,}")
        print(f"    EN->AR: {en_pairs_count:,}, FR->AR: {fr_pairs_count:,}")
        print(f"  Avg quality score: {avg_score:.4f}")
        print(f"  Rate: {rate:.1f} samples/sec")
        print(f"  ETA: {remaining/3600:.2f} hours")
        print(f"  Errors: {errors_count}")
        print(f"  üíæ Checkpoint saved")

# Finalize statistics
for method_name in GENERATION_METHODS.keys():
    if method_stats[method_name]['scores']:
        method_stats[method_name]['avg_score'] = sum(method_stats[method_name]['scores']) / len(method_stats[method_name]['scores'])

total_time = time.time() - start_time

print("\n" + "=" * 80)
print("GENERATION COMPLETE")
print("=" * 80)
print(f"  Samples processed: {samples_processed:,}")
print(f"  EN-AR candidates: {len(en_candidates_data):,}")
print(f"  FR-AR candidates: {len(fr_candidates_data):,}")
print(f"  Total preference pairs: {en_pairs_count + fr_pairs_count:,}")
print(f"    English to Arabic: {en_pairs_count:,}")
print(f"    French to Arabic: {fr_pairs_count:,}")
print(f"  Total time: {total_time/3600:.2f} hours")
print(f"  Average rate: {samples_processed/total_time:.1f} samples/sec")
print(f"  Errors: {errors_count}")

if quality_scores_collected:
    overall_avg = sum(quality_scores_collected) / len(quality_scores_collected)
    print(f"\nOverall avg quality score: {overall_avg:.4f}")

print("\nMethod statistics:")
for method_name in GENERATION_METHODS.keys():
    print(f"  {method_name}: avg={method_stats[method_name]['avg_score']:.4f}")

Clearing GPU memory...
‚úì GPU memory cleared

MULTI-METHOD SYNTHETIC DATA GENERATION
Scoring: Heuristic-based

Configuration:
  Total samples: 20,000
  Batch size: 64
  Candidates per sample: 4 (one per method)
  Generation methods: 4
  GPUs: 2

Language distribution:
  English to Arabic: 10,000 (50.0%)
  French to Arabic: 10,000 (50.0%)

üìù No checkpoint found. Starting from beginning...
‚úì GPU memory cleared

MULTI-METHOD SYNTHETIC DATA GENERATION
Scoring: Heuristic-based

Configuration:
  Total samples: 20,000
  Batch size: 64
  Candidates per sample: 4 (one per method)
  Generation methods: 4
  GPUs: 2

Language distribution:
  English to Arabic: 10,000 (50.0%)
  French to Arabic: 10,000 (50.0%)

üìù No checkpoint found. Starting from beginning...


Processing batches:   0%|          | 0/313 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Processing batches:   6%|‚ñã         | 20/313 [21:48<5:23:07, 66.17s/it]


Progress (Batch 20/313):
  Samples: 1,280/20,000 (6.4%)
  Preference pairs: 5,804
    EN->AR: 2,801, FR->AR: 3,003
  Avg quality score: 0.8857
  Rate: 1.0 samples/sec
  ETA: 5.32 hours
  Errors: 0
  üíæ Checkpoint saved


Processing batches:  13%|‚ñà‚ñé        | 40/313 [42:32<4:39:00, 61.32s/it]


Progress (Batch 40/313):
  Samples: 2,560/20,000 (12.8%)
  Preference pairs: 11,620
    EN->AR: 5,665, FR->AR: 5,955
  Avg quality score: 0.8843
  Rate: 1.0 samples/sec
  ETA: 4.83 hours
  Errors: 0
  üíæ Checkpoint saved


Processing batches:  19%|‚ñà‚ñâ        | 60/313 [1:02:49<4:08:52, 59.02s/it]


Progress (Batch 60/313):
  Samples: 3,840/20,000 (19.2%)
  Preference pairs: 17,464
    EN->AR: 8,372, FR->AR: 9,092
  Avg quality score: 0.8839
  Rate: 1.0 samples/sec
  ETA: 4.41 hours
  Errors: 0
  üíæ Checkpoint saved


Processing batches:  26%|‚ñà‚ñà‚ñå       | 80/313 [1:22:56<4:13:04, 65.17s/it]


Progress (Batch 80/313):
  Samples: 5,120/20,000 (25.6%)
  Preference pairs: 23,232
    EN->AR: 11,068, FR->AR: 12,164
  Avg quality score: 0.8838
  Rate: 1.0 samples/sec
  ETA: 4.02 hours
  Errors: 0
  üíæ Checkpoint saved


Processing batches:  32%|‚ñà‚ñà‚ñà‚ñè      | 100/313 [1:44:13<3:35:40, 60.75s/it]


Progress (Batch 100/313):
  Samples: 6,400/20,000 (32.0%)
  Preference pairs: 29,067
    EN->AR: 13,808, FR->AR: 15,259
  Avg quality score: 0.8841
  Rate: 1.0 samples/sec
  ETA: 3.69 hours
  Errors: 0
  üíæ Checkpoint saved


Processing batches:  38%|‚ñà‚ñà‚ñà‚ñä      | 120/313 [2:04:39<3:31:12, 65.66s/it]


Progress (Batch 120/313):
  Samples: 7,680/20,000 (38.4%)
  Preference pairs: 34,946
    EN->AR: 16,650, FR->AR: 18,296
  Avg quality score: 0.8840
  Rate: 1.0 samples/sec
  ETA: 3.33 hours
  Errors: 0
  üíæ Checkpoint saved


Processing batches:  45%|‚ñà‚ñà‚ñà‚ñà‚ñç     | 140/313 [2:24:47<2:49:59, 58.96s/it]


Progress (Batch 140/313):
  Samples: 8,960/20,000 (44.8%)
  Preference pairs: 40,808
    EN->AR: 19,530, FR->AR: 21,278
  Avg quality score: 0.8842
  Rate: 1.0 samples/sec
  ETA: 2.97 hours
  Errors: 0
  üíæ Checkpoint saved


Processing batches:  51%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 160/313 [2:45:21<2:48:28, 66.07s/it]


Progress (Batch 160/313):
  Samples: 10,240/20,000 (51.2%)
  Preference pairs: 46,639
    EN->AR: 22,407, FR->AR: 24,232
  Avg quality score: 0.8844
  Rate: 1.0 samples/sec
  ETA: 2.63 hours
  Errors: 0
  üíæ Checkpoint saved


Processing batches:  58%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä    | 180/313 [3:05:53<2:19:11, 62.79s/it]


Progress (Batch 180/313):
  Samples: 11,520/20,000 (57.6%)
  Preference pairs: 52,542
    EN->AR: 25,161, FR->AR: 27,381
  Avg quality score: 0.8846
  Rate: 1.0 samples/sec
  ETA: 2.28 hours
  Errors: 0
  üíæ Checkpoint saved


Processing batches:  64%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç   | 200/313 [3:26:01<1:51:35, 59.25s/it]


Progress (Batch 200/313):
  Samples: 12,800/20,000 (64.0%)
  Preference pairs: 58,377
    EN->AR: 28,138, FR->AR: 30,239
  Avg quality score: 0.8849
  Rate: 1.0 samples/sec
  ETA: 1.93 hours
  Errors: 0
  üíæ Checkpoint saved


Processing batches:  70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 220/313 [3:46:09<1:37:03, 62.62s/it]


Progress (Batch 220/313):
  Samples: 14,080/20,000 (70.4%)
  Preference pairs: 64,217
    EN->AR: 31,133, FR->AR: 33,084
  Avg quality score: 0.8847
  Rate: 1.0 samples/sec
  ETA: 1.58 hours
  Errors: 0
  üíæ Checkpoint saved


Processing batches:  77%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã  | 240/313 [4:06:55<1:16:06, 62.55s/it]


Progress (Batch 240/313):
  Samples: 15,360/20,000 (76.8%)
  Preference pairs: 70,068
    EN->AR: 33,960, FR->AR: 36,108
  Avg quality score: 0.8847
  Rate: 1.0 samples/sec
  ETA: 1.24 hours
  Errors: 0
  üíæ Checkpoint saved


Processing batches:  83%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé | 260/313 [4:28:24<56:47, 64.30s/it]  


Progress (Batch 260/313):
  Samples: 16,640/20,000 (83.2%)
  Preference pairs: 75,866
    EN->AR: 36,883, FR->AR: 38,983
  Avg quality score: 0.8849
  Rate: 1.0 samples/sec
  ETA: 0.90 hours
  Errors: 0
  üíæ Checkpoint saved


Processing batches:  89%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ | 280/313 [4:49:52<37:50, 68.80s/it]


Progress (Batch 280/313):
  Samples: 17,920/20,000 (89.6%)
  Preference pairs: 81,743
    EN->AR: 39,807, FR->AR: 41,936
  Avg quality score: 0.8846
  Rate: 1.0 samples/sec
  ETA: 0.56 hours
  Errors: 0
  üíæ Checkpoint saved


Processing batches:  96%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå| 300/313 [5:11:30<12:48, 59.14s/it]


Progress (Batch 300/313):
  Samples: 19,200/20,000 (96.0%)
  Preference pairs: 87,534
    EN->AR: 42,463, FR->AR: 45,071
  Avg quality score: 0.8845
  Rate: 1.0 samples/sec
  ETA: 0.22 hours
  Errors: 0
  üíæ Checkpoint saved


Processing batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 313/313 [5:24:11<00:00, 62.15s/it]


GENERATION COMPLETE
  Samples processed: 20,000
  EN-AR candidates: 10,000
  FR-AR candidates: 10,000
  Total preference pairs: 91,232
    English to Arabic: 44,324
    French to Arabic: 46,908
  Total time: 5.40 hours
  Average rate: 1.0 samples/sec
  Errors: 0

Overall avg quality score: 0.8847

Method statistics:
  temperature: avg=0.8820
  top_k: avg=0.8857
  nucleus: avg=0.8856
  greedy: avg=0.8855





## Save Synthetic Dataset

In [12]:
# Save datasets to separate files
print(f"Saving datasets...\n")

# Save EN-AR candidates
print(f"Saving {len(en_candidates_data):,} EN-AR candidates to {en_ar_candidates_file.name}...")
with open(en_ar_candidates_file, 'w', encoding='utf-8') as f:
    for item in en_candidates_data:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')
print(f"  ‚úì Saved")

# Save FR-AR candidates
print(f"Saving {len(fr_candidates_data):,} FR-AR candidates to {fr_ar_candidates_file.name}...")
with open(fr_ar_candidates_file, 'w', encoding='utf-8') as f:
    for item in fr_candidates_data:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')
print(f"  ‚úì Saved")

# Save EN-AR preferences
print(f"Saving {len(en_preferences_data):,} EN-AR preference pairs to {en_ar_preferences_file.name}...")
with open(en_ar_preferences_file, 'w', encoding='utf-8') as f:
    for item in en_preferences_data:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')
print(f"  ‚úì Saved")

# Save FR-AR preferences
print(f"Saving {len(fr_preferences_data):,} FR-AR preference pairs to {fr_ar_preferences_file.name}...")
with open(fr_ar_preferences_file, 'w', encoding='utf-8') as f:
    for item in fr_preferences_data:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')
print(f"  ‚úì Saved")

# Calculate statistics
en_pairs = en_preferences_data
fr_pairs = fr_preferences_data

# Analyze which methods produced best pairs
method_pair_stats = {}
all_pairs = en_pairs + fr_pairs
for method_name in GENERATION_METHODS.keys():
    chosen_count = sum(1 for item in all_pairs if item['chosen_method'] == method_name)
    rejected_count = sum(1 for item in all_pairs if item['rejected_method'] == method_name)
    method_pair_stats[method_name] = {
        'chosen_count': chosen_count,
        'rejected_count': rejected_count
    }

# Prepare statistics
stats = {
    'en_candidates': len(en_candidates_data),
    'fr_candidates': len(fr_candidates_data),
    'total_preference_pairs': len(all_pairs),
    'en_pairs': len(en_pairs),
    'fr_pairs': len(fr_pairs),
    'avg_margin': sum(item['margin'] for item in all_pairs) / len(all_pairs) if all_pairs else 0,
    'avg_chosen_score': sum(item['chosen_score'] for item in all_pairs) / len(all_pairs) if all_pairs else 0,
    'avg_rejected_score': sum(item['rejected_score'] for item in all_pairs) / len(all_pairs) if all_pairs else 0,
    'language_breakdown': {
        'english': {
            'candidates': len(en_candidates_data),
            'pairs': len(en_pairs),
            'avg_margin': sum(item['margin'] for item in en_pairs) / len(en_pairs) if en_pairs else 0,
            'avg_chosen_score': sum(item['chosen_score'] for item in en_pairs) / len(en_pairs) if en_pairs else 0,
        },
        'french': {
            'candidates': len(fr_candidates_data),
            'pairs': len(fr_pairs),
            'avg_margin': sum(item['margin'] for item in fr_pairs) / len(fr_pairs) if fr_pairs else 0,
            'avg_chosen_score': sum(item['chosen_score'] for item in fr_pairs) / len(fr_pairs) if fr_pairs else 0,
        }
    },
    'method_breakdown': {
        method_name: {
            'pairs_as_chosen': method_pair_stats[method_name]['chosen_count'],
            'pairs_as_rejected': method_pair_stats[method_name]['rejected_count'],
            'avg_score': method_stats[method_name]['avg_score'],
            'total_candidates': method_stats[method_name]['count']
        }
        for method_name in GENERATION_METHODS.keys()
    },
    'scoring_method': 'comet' if USE_COMET else 'heuristic'
}

stats_path = OUTPUTS_DIR / "generation_stats.json"
with open(stats_path, 'w') as f:
    json.dump(stats, f, indent=2)

# Print summary
print("\n" + "=" * 80)
print("DATASET STATISTICS")
print("=" * 80)
print(f"\nCandidates:")
print(f"  English to Arabic: {stats['language_breakdown']['english']['candidates']:,}")
print(f"  French to Arabic: {stats['language_breakdown']['french']['candidates']:,}")
print(f"  Total: {stats['en_candidates'] + stats['fr_candidates']:,}")

print(f"\nPreference Pairs:")
print(f"  Total pairs: {stats['total_preference_pairs']:,}")
print(f"  Average margin: {stats['avg_margin']:.4f}")
print(f"  Average chosen score: {stats['avg_chosen_score']:.4f}")

print(f"\nLanguage Breakdown:")
print(f"  English to Arabic:")
print(f"    Candidates: {stats['language_breakdown']['english']['candidates']:,}")
print(f"    Preference pairs: {stats['language_breakdown']['english']['pairs']:,}")
print(f"    Avg margin: {stats['language_breakdown']['english']['avg_margin']:.4f}")
print(f"    Avg score: {stats['language_breakdown']['english']['avg_chosen_score']:.4f}")

print(f"  French to Arabic:")
print(f"    Candidates: {stats['language_breakdown']['french']['candidates']:,}")
print(f"    Preference pairs: {stats['language_breakdown']['french']['pairs']:,}")
print(f"    Avg margin: {stats['language_breakdown']['french']['avg_margin']:.4f}")
print(f"    Avg score: {stats['language_breakdown']['french']['avg_chosen_score']:.4f}")

print(f"\nMethod Breakdown (Preference Pairs):")
for method_name, info in stats['method_breakdown'].items():
    print(f"  {method_name}:")
    print(f"    Chosen: {info['pairs_as_chosen']:,}")
    print(f"    Rejected: {info['pairs_as_rejected']:,}")
    print(f"    Avg score: {info['avg_score']:.4f}")

print(f"\nGenerated files:")
print(f"  - {en_ar_candidates_file.name}")
print(f"  - {fr_ar_candidates_file.name}")
print(f"  - {en_ar_preferences_file.name}")
print(f"  - {fr_ar_preferences_file.name}")
print(f"  - {stats_path.name}")

print(f"\nStatistics saved to {stats_path}")
print("=" * 80)

Saving datasets...

Saving 10,000 EN-AR candidates to english_arabic_candidates.jsonl...
  ‚úì Saved
Saving 10,000 FR-AR candidates to french_arabic_candidates.jsonl...
  ‚úì Saved
Saving 44,324 EN-AR preference pairs to en-ar-preferences.jsonl...
  ‚úì Saved
Saving 46,908 FR-AR preference pairs to fr-ar-preferences.jsonl...
  ‚úì Saved
Saving 46,908 FR-AR preference pairs to fr-ar-preferences.jsonl...
  ‚úì Saved

DATASET STATISTICS

Candidates:
  English to Arabic: 10,000
  French to Arabic: 10,000
  Total: 20,000

Preference Pairs:
  Total pairs: 91,232
  Average margin: 0.0633
  Average chosen score: 0.9166

Language Breakdown:
  English to Arabic:
    Candidates: 10,000
    Preference pairs: 44,324
    Avg margin: 0.0549
    Avg score: 0.9224
  French to Arabic:
    Candidates: 10,000
    Preference pairs: 46,908
    Avg margin: 0.0712
    Avg score: 0.9111

Method Breakdown (Preference Pairs):
  temperature:
    Chosen: 24,553
    Rejected: 24,405
    Avg score: 0.8820
  top_k:
 

## Sample Preference Pairs

In [13]:
# Display sample generated data
print("Sample Generated Data\n")
print("=" * 80)

# Show sample candidates
print("SAMPLE TRANSLATION CANDIDATES (4 methods per source)")
print("=" * 80)

en_examples = random.sample(en_candidates_data, min(2, len(en_candidates_data)))
fr_examples = random.sample(fr_candidates_data, min(1, len(fr_candidates_data)))

for i, item in enumerate(en_examples + fr_examples, 1):
    lang_label = 'EN‚ÜíAR' if item['source_lang'] == 'en' else 'FR‚ÜíAR'
    print(f"\nExample {i}: {lang_label}")
    print(f"Source: {item['source'][:100]}")
    print(f"Candidates:")
    for j, cand in enumerate(item['candidates'], 1):
        print(f"  {j}. [{cand['method']}]: {cand['translation'][:100]}")

# Show sample preferences
print("\n" + "=" * 80)
print("SAMPLE PREFERENCE PAIRS")
print("=" * 80)

en_pref_examples = random.sample(en_preferences_data, min(2, len(en_preferences_data)))
fr_pref_examples = random.sample(fr_preferences_data, min(1, len(fr_preferences_data)))

for i, item in enumerate(en_pref_examples + fr_pref_examples, 1):
    lang_label = 'EN‚ÜíAR' if item['source_lang'] == 'en' else 'FR‚ÜíAR'
    print(f"\nExample {i}: {lang_label}")
    print(f"Source: {item['source'][:100]}")
    print(f"\n‚úì Chosen ({item['chosen_method']}, score: {item['chosen_score']:.3f}):")
    print(f"  {item['chosen'][:100]}")
    print(f"\n‚úó Rejected ({item['rejected_method']}, score: {item['rejected_score']:.3f}):")
    print(f"  {item['rejected'][:100]}")
    print(f"\nMargin: {item['margin']:.3f}")

# Method performance summary
print("\n" + "=" * 80)
print("METHOD PERFORMANCE SUMMARY")
print("=" * 80)

all_pairs = en_preferences_data + fr_preferences_data
for method_name in GENERATION_METHODS.keys():
    chosen = sum(1 for item in all_pairs if item['chosen_method'] == method_name)
    rejected = sum(1 for item in all_pairs if item['rejected_method'] == method_name)
    total = chosen + rejected
    if total > 0:
        chosen_pct = 100 * chosen / total
        print(f"\n{method_name}:")
        print(f"  Chosen: {chosen:,} ({chosen_pct:.1f}%)")
        print(f"  Rejected: {rejected:,} ({100-chosen_pct:.1f}%)")
        print(f"  Total: {total:,}")

print("\n" + "=" * 80)

Sample Generated Data

SAMPLE TRANSLATION CANDIDATES (4 methods per source)

Example 1: EN‚ÜíAR
Source: Implementation. Under a mandate generally deriving from Article 98 of the Charter of the United Nati
Candidates:
  1. [temperature]: ÿßŸÑÿ™ŸÜŸÅŸäÿ∞ - ÿ®ŸÖŸàÿ¨ÿ® ŸàŸÑÿßŸäÿ© ŸÖÿ≥ÿ™ŸÖÿØÿ© ÿπŸÖŸàŸÖÿß ŸÖŸÜ ÿ£ÿ≠ŸÉÿßŸÖ ÿßŸÑŸÖÿßÿØÿ© 98 ŸÖŸÜ ŸÖŸäÿ´ÿßŸÇ ÿßŸÑÿ£ŸÖŸÖ ÿßŸÑŸÖÿ™ÿ≠ÿØÿ©ÿå ŸàÿßŸÑŸÖŸàÿßŸÅŸÇÿ© ÿßŸÑÿµÿ±Ÿäÿ≠ÿ© ÿ£Ÿà ÿß
  2. [top_k]: ÿßŸÑÿ™ŸÜŸÅŸäÿ∞ - ÿ®ŸÖŸàÿ¨ÿ® ŸàŸÑÿßŸäÿ© ŸÖÿ≥ÿ™ŸÖÿØÿ© ÿ®ÿ¥ŸÉŸÑ ÿπÿßŸÖ ŸÖŸÜ ÿßŸÑŸÖÿßÿØÿ© 98 ŸÖŸÜ ŸÖŸäÿ´ÿßŸÇ ÿßŸÑÿ£ŸÖŸÖ ÿßŸÑŸÖÿ™ÿ≠ÿØÿ©ÿå ŸàÿßŸÑŸÖŸàÿßŸÅŸÇÿ© ÿßŸÑÿµÿ±Ÿäÿ≠ÿ© ÿ£Ÿà ÿßŸÑÿ∂ŸÖ
  3. [nucleus]: ÿßŸÑÿ™ŸÜŸÅŸäÿ∞ - ŸÅŸä ÿ•ÿ∑ÿßÿ± ŸàŸÑÿßŸäÿ© ŸÖÿ≥ÿ™ŸÖÿØÿ© ÿπŸÖŸàŸÖÿß ŸÖŸÜ ÿßŸÑŸÖÿßÿØÿ© 98 ŸÖŸÜ ŸÖŸäÿ´ÿßŸÇ ÿßŸÑÿ£ŸÖŸÖ ÿßŸÑŸÖÿ™ÿ≠ÿØÿ©ÿå ŸàÿßŸÑŸÖŸàÿßŸÅŸÇÿ© ÿßŸÑÿµÿ±Ÿäÿ≠ÿ© ÿ£Ÿà ÿßŸÑÿ∂ŸÖŸÜ
  4. [greedy]: ÿßŸÑÿ™ŸÜŸÅŸäÿ∞ - ÿ®ŸÖŸàÿ¨ÿ® ŸàŸÑÿßŸäÿ© ŸÖÿ≥ÿ™ŸÖÿØÿ© ÿπŸÖŸàŸÖÿß ŸÖŸÜ ÿßŸÑŸÖÿßÿØÿ© 98 ŸÖŸÜ ŸÖŸäÿ´ÿßŸÇ ÿßŸÑÿ£ŸÖŸÖ ÿßŸÑŸÖÿ™ÿ≠ÿØÿ©ÿå ŸàÿßŸÑŸÖŸàÿßŸÅŸÇÿ© ÿßŸÑÿµÿ±Ÿäÿ≠ÿ© ÿ£Ÿà ÿßŸÑÿ∂ŸÖ

## Next Step

Proceed to **notebook 2** to train the reward model using this synthetic preference data.