## Setup: Install Dependencies

In [None]:
# Installation pour Colab (Python 3.12+)
!pip install --upgrade pip setuptools wheel -q
!pip install transformers[torch] datasets trl wandb accelerate -q

# V√©rification
import torch
import transformers
import trl
print("‚úÖ Installation r√©ussie!")
print(f"PyTorch: {torch.__version__}")
print(f"Transformers: {transformers.__version__}")
print(f"TRL: {trl.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 0. Mount Google Drive (Optional - for Colab/Kaggle)

In [None]:
# Mount Google Drive pour √©conomiser temps et quota
import os

try:
    from google.colab import drive
    drive.mount('/content/drive')
    SAVE_BASE_PATH = '/content/drive/MyDrive/dpo_ppo_training'
    os.makedirs(SAVE_BASE_PATH, exist_ok=True)
    print(f"‚úÖ Google Drive mont√©. Mod√®les sauvegard√©s sur: {SAVE_BASE_PATH}")
    USE_DRIVE = True
except ImportError:
    # Pas sur Colab
    SAVE_BASE_PATH = './results'
    USE_DRIVE = False
    print(f"‚ö†Ô∏è  Pas de Google Drive d√©tect√©. Stockage local: {SAVE_BASE_PATH}")

## 1. Configuration DPO

**Param√®tres configurables :**
- `beta` : Param√®tre de r√©gularisation KL (0.1, 0.5, 1.0 typiques)
- `batch_size` : √Ä ajuster selon GPU T4
- `learning_rate` : Taux d'apprentissage
- `num_epochs` : Nombre d'√©poques

In [None]:
import torch
import numpy as np
import wandb
import json
from datetime import datetime
from datasets import load_dataset, Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)
from trl import DPOTrainer, DPOConfig

# =====================================
# CONFIGURATION PRINCIPALE
# =====================================

# DPO Hyperparameters
BETA = 0.1  # Param√®tre de r√©gularisation KL (0.1, 0.5, 1.0 selon papier DPO)
BATCH_SIZE = 4  # √Ä ajuster selon votre GPU T4 (4-8 pour DPO)
GRADIENT_ACCUMULATION_STEPS = 4  # Batch effectif = 4 * 4 = 16
LEARNING_RATE = 1e-6  # LR pour DPO avec BF16 (1e-6 stable, 5e-7 tr√®s conservateur)
NUM_EPOCHS = 1
MAX_LENGTH = 512  # Longueur maximale des s√©quences
MAX_PROMPT_LENGTH = 128  # Longueur maximale du prompt

# Model paths
SFT_MODEL_PATH = f"{SAVE_BASE_PATH}/sft_model"
DPO_MODEL_PATH = f"{SAVE_BASE_PATH}/dpo_model_beta{BETA}"

print(f"{'='*80}")
print(f"DPO Configuration")
print(f"{'='*80}")
print(f"Beta (KL regularization): {BETA}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Gradient Accumulation: {GRADIENT_ACCUMULATION_STEPS}")
print(f"Effective Batch Size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"Learning Rate: {LEARNING_RATE}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Max Length: {MAX_LENGTH}")
print(f"Max Prompt Length: {MAX_PROMPT_LENGTH}")
print(f"SFT Model: {SFT_MODEL_PATH}")
print(f"DPO Model Output: {DPO_MODEL_PATH}")
print(f"{'='*80}\n")


## 2. Load Preference Pairs Dataset

In [None]:
print("="*80)
print("√âTAPE 1: Chargement du dataset de paires de pr√©f√©rences")
print("="*80)

# Load preference pairs
pairs_path = f"{SAVE_BASE_PATH}/datasets/preference_pairs.json"

if not os.path.exists(pairs_path):
    raise FileNotFoundError(
        f"‚ùå Dataset introuvable: {pairs_path}\n"
        "Veuillez d'abord ex√©cuter le notebook de g√©n√©ration des paires de pr√©f√©rences."
    )

# Load pairs
print(f"\nüì• Chargement des paires depuis: {pairs_path}")
with open(pairs_path, 'r', encoding='utf-8') as f:
    preference_pairs = json.load(f)

print(f"‚úÖ {len(preference_pairs)} paires de pr√©f√©rences charg√©es")

# Display sample
print(f"\nExemples de paires:")
for i in range(min(2, len(preference_pairs))):
    pair = preference_pairs[i]
    print(f"\nPaire {i+1}:")
    print(f"  Prompt:   {pair['prompt'][:60]}...")
    print(f"  Chosen:   {pair['chosen'][:60]}...")
    print(f"  Rejected: {pair['rejected'][:60]}...")

# Create HuggingFace dataset
dpo_dataset = Dataset.from_dict({
    "prompt": [pair["prompt"] for pair in preference_pairs],
    "chosen": [pair["chosen"] for pair in preference_pairs],
    "rejected": [pair["rejected"] for pair in preference_pairs],
})

print(f"\n‚úÖ Dataset DPO cr√©√© avec {len(dpo_dataset)} paires")
print(f"   Colonnes: {dpo_dataset.column_names}")

## 3. Load Models (Policy & Reference)

In [None]:
print("="*80)
print("√âTAPE 2: Chargement des mod√®les")
print("="*80)

# Check if SFT model exists
if not os.path.exists(SFT_MODEL_PATH):
    raise FileNotFoundError(
        f"‚ùå Mod√®le SFT introuvable: {SFT_MODEL_PATH}\n"
        "Veuillez d'abord entra√Æner le mod√®le SFT."
    )

# Load tokenizer
print(f"\nüì• Chargement du tokenizer depuis: {SFT_MODEL_PATH}")
tokenizer = AutoTokenizer.from_pretrained(SFT_MODEL_PATH)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
print("‚úÖ Tokenizer charg√©")

# Load policy model (will be trained)
print(f"\nüì• Chargement du policy model (SFT)...")
model = AutoModelForCausalLM.from_pretrained(
    SFT_MODEL_PATH,
    torch_dtype=torch.bfloat16  # BF16 optimal pour DPO (stable + rapide)
)

if torch.cuda.is_available():
    model = model.to("cuda")
    print(f"‚úÖ Policy model charg√© sur GPU ({model.num_parameters() / 1e9:.2f}B params)")
else:
    print(f"‚ö†Ô∏è  Policy model charg√© sur CPU")

# Load reference model (frozen SFT)
print(f"\nüì• Chargement du reference model (SFT frozen)...")
ref_model = AutoModelForCausalLM.from_pretrained(
    SFT_MODEL_PATH,
    torch_dtype=torch.bfloat16  # BF16 optimal pour DPO (stable + rapide)
)

if torch.cuda.is_available():
    ref_model = ref_model.to("cuda")
    print(f"‚úÖ Reference model charg√© sur GPU (frozen)")
else:
    print(f"‚ö†Ô∏è  Reference model charg√© sur CPU")

# Freeze reference model
for param in ref_model.parameters():
    param.requires_grad = False
ref_model.eval()

print(f"\n{'='*80}")
print(f"‚úÖ Mod√®les charg√©s avec succ√®s")
print(f"{'='*80}\n")

## 4. Initialize W&B Logging

In [None]:
# Login to W&B
wandb.login()

# Initialize W&B run
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_name = f"dpo_imdb_beta{BETA}_{timestamp}"

wandb.init(
    project="dpo_ppo",
    name=run_name,
    config={
        "model": "gpt2-large",
        "dataset": "imdb_preference_pairs",
        "num_pairs": len(dpo_dataset),
        "batch_size": BATCH_SIZE,
        "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
        "effective_batch_size": BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS,
        "learning_rate": LEARNING_RATE,
        "num_epochs": NUM_EPOCHS,
        "beta": BETA,
        "max_length": MAX_LENGTH,
        "max_prompt_length": MAX_PROMPT_LENGTH,
        "device": "cuda" if torch.cuda.is_available() else "cpu",
    }
)

print(f"‚úÖ W&B initialized: {run_name}")
print(f"   Project: dpo_ppo")
print(f"   Beta: {BETA}")

## 5. Configure DPO Trainer

In [None]:
from trl import DPOConfig

print("="*80)
print("√âTAPE 3: Configuration du DPO Trainer")
print("="*80)

# DPO Configuration (remplace TrainingArguments + beta)
dpo_config = DPOConfig(
    output_dir=DPO_MODEL_PATH,
    
    # DPO-specific
    beta=BETA,
    max_length=MAX_LENGTH,
    max_prompt_length=MAX_PROMPT_LENGTH,
    
    # Training
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    warmup_steps=50,  # Warmup plus long pour stabilit√©
    weight_decay=0.05,
    max_grad_norm=1.0,  # CRITIQUE: gradient clipping pour √©viter explosion
    
    # Logging
    logging_steps=10,
    logging_first_step=True,
    report_to=["wandb"],
    
    # Saving
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    
    # Optimization - BF16 activ√© (meilleur que FP16 pour DPO)
    fp16=False,  # Ne jamais utiliser FP16 avec DPO (instable)
    bf16=True,  # BF16 stable pour DPO (m√™me plage que FP32)
    gradient_checkpointing=True,
    
    # Other
    remove_unused_columns=False,
    seed=42,
)

# Initialize DPO Trainer
print(f"\nüöÄ Initialisation du DPO Trainer...")
dpo_trainer = DPOTrainer(
    model=model,
    ref_model=ref_model,
    args=dpo_config,
    train_dataset=dpo_dataset,
    processing_class=tokenizer,  # tokenizer pass√© via processing_class dans TRL r√©cent
)

print(f"‚úÖ DPO Trainer initialis√©")
print(f"\nüìä Configuration:")
print(f"   - Beta: {BETA}")
print(f"   - Batch size: {BATCH_SIZE}")
print(f"   - Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}")
print(f"   - Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"   - Learning rate: {LEARNING_RATE}")
print(f"   - Total pairs: {len(dpo_dataset)}")
print(f"   - Estimated steps: {len(dpo_dataset) // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)}")
print(f"   - FP16: {dpo_config.fp16}")
print(f"   - BF16: {dpo_config.bf16}")
print(f"   - Gradient checkpointing: {dpo_config.gradient_checkpointing}")
print(f"   - Max grad norm: {dpo_config.max_grad_norm}")

print(f"\n‚úÖ BF16 activ√© : stable pour DPO (m√™me plage dynamique que FP32, 2x plus rapide)")

In [None]:
print("="*80)
print("√âTAPE 4: Entra√Ænement DPO")
print("="*80)

print(f"\nüöÄ D√©marrage de l'entra√Ænement DPO...")
print(f"   - {len(dpo_dataset)} paires de pr√©f√©rences")
print(f"   - {NUM_EPOCHS} epoch(s)")
print(f"   - Beta = {BETA}\n")

# Clear GPU cache before training
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("‚úÖ GPU cache cleared\n")

# Resume from last checkpoint if available
from pathlib import Path
RESUME_FROM_CHECKPOINT = True
resume_checkpoint = None
if RESUME_FROM_CHECKPOINT:
    ckpts = sorted(
        Path(DPO_MODEL_PATH).glob("checkpoint-*/"),
        key=lambda p: int(p.name.split("-")[-1]) if p.name.split("-")[-1].isdigit() else -1,
    )
    if ckpts:
        resume_checkpoint = str(ckpts[-1])
        print(f"üîÑ Reprise depuis le checkpoint: {resume_checkpoint}")
    else:
        print("‚ö†Ô∏è  Aucun checkpoint trouv√©, entra√Ænement from scratch")

# Train
dpo_trainer.train(resume_from_checkpoint=resume_checkpoint)

print(f"\n{'='*80}")
print(f"‚úÖ Entra√Ænement DPO termin√©!")
print(f"{'='*80}")


## 6. DPO Training

## 7. Save DPO Model

In [None]:
print("="*80)
print("√âTAPE 5: Sauvegarde du mod√®le DPO")
print("="*80)

# Save model
print(f"\nüíæ Sauvegarde du mod√®le DPO dans: {DPO_MODEL_PATH}")
dpo_trainer.save_model(DPO_MODEL_PATH)
tokenizer.save_pretrained(DPO_MODEL_PATH)

print(f"‚úÖ Mod√®le DPO sauvegard√©!")
print(f"\nüìÅ Fichiers cr√©√©s:")
print(f"   - {DPO_MODEL_PATH}/pytorch_model.bin")
print(f"   - {DPO_MODEL_PATH}/config.json")
print(f"   - {DPO_MODEL_PATH}/tokenizer.json")

# Close W&B run
wandb.finish()
print(f"\n‚úÖ W&B run closed")

print(f"\n{'='*80}")
print(f"‚úÖ DPO TRAINING COMPLETE!")
print(f"{'='*80}")
print(f"\nüéØ Prochaines √©tapes:")
print(f"   1. √âvaluer le mod√®le DPO sur le test set")
print(f"   2. Comparer avec SFT et PPO")
print(f"   3. Tester d'autres valeurs de beta (0.5, 1.0)")
print(f"   4. G√©n√©rer la courbe reward-KL (Figure 2)")

## 8. Test DPO Model (Optional)

In [None]:
print("="*80)
print("Test du mod√®le DPO sur quelques exemples")
print("="*80)

# Load DPO model for testing
test_model = AutoModelForCausalLM.from_pretrained(DPO_MODEL_PATH)
test_tokenizer = AutoTokenizer.from_pretrained(DPO_MODEL_PATH)

if torch.cuda.is_available():
    test_model = test_model.to("cuda")

test_model.eval()

# Test on 5 random prompts from dataset
import random
random.seed(42)
test_indices = random.sample(range(len(dpo_dataset)), 5)
test_samples = dpo_dataset.select(test_indices)

print(f"\nG√©n√©ration sur 5 prompts al√©atoires du dataset:\n")

for i, sample in enumerate(test_samples, 1):
    prompt = sample["prompt"]
    chosen = sample["chosen"]
    rejected = sample["rejected"]
    
    print(f"{'‚îÄ'*80}")
    print(f"Prompt {i}: {prompt}")
    
    # Tokenize
    inputs = test_tokenizer(prompt, return_tensors="pt")
    if torch.cuda.is_available():
        inputs = inputs.to("cuda")
    
    # Generate
    with torch.no_grad():
        outputs = test_model.generate(
            **inputs,
            max_new_tokens=50,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=test_tokenizer.eos_token_id
        )
    
    # Decode
    generated_text = test_tokenizer.decode(outputs[0], skip_special_tokens=True)
    continuation = generated_text[len(prompt):].strip()
    
    print(f"\nDPO Generated: {continuation[:100]}...")
    print(f"\nOriginal Chosen:   {chosen[len(prompt):100]}...")
    print(f"Original Rejected: {rejected[len(prompt):100]}...")
    print()

print(f"{'='*80}")
print(f"‚úÖ Test termin√©!")
print(f"\nüí° Le mod√®le DPO devrait g√©n√©rer du texte plus proche de 'chosen' que de 'rejected'")

## 9. Compare SFT vs DPO (Side-by-side)

In [None]:
print("="*80)
print("Comparaison SFT vs DPO sur les m√™mes prompts")
print("="*80)

# Load SFT model
print("\nüì• Chargement du mod√®le SFT...")
sft_model = AutoModelForCausalLM.from_pretrained(SFT_MODEL_PATH)
sft_tokenizer = AutoTokenizer.from_pretrained(SFT_MODEL_PATH)

if torch.cuda.is_available():
    sft_model = sft_model.to("cuda")

sft_model.eval()
print("‚úÖ Mod√®le SFT charg√©")

# Test on 3 prompts
test_prompts = [
    "This movie is",
    "I really enjoyed",
    "The acting was"
]

print(f"\nG√©n√©ration comparative sur {len(test_prompts)} prompts:\n")

for i, prompt in enumerate(test_prompts, 1):
    print(f"{'='*80}")
    print(f"Prompt {i}: {prompt}")
    print(f"{'='*80}")
    
    # SFT generation
    inputs_sft = sft_tokenizer(prompt, return_tensors="pt")
    if torch.cuda.is_available():
        inputs_sft = inputs_sft.to("cuda")
    
    with torch.no_grad():
        outputs_sft = sft_model.generate(
            **inputs_sft,
            max_new_tokens=30,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=sft_tokenizer.eos_token_id
        )
    
    sft_text = sft_tokenizer.decode(outputs_sft[0], skip_special_tokens=True)
    sft_continuation = sft_text[len(prompt):].strip()
    
    # DPO generation
    inputs_dpo = test_tokenizer(prompt, return_tensors="pt")
    if torch.cuda.is_available():
        inputs_dpo = inputs_dpo.to("cuda")
    
    with torch.no_grad():
        outputs_dpo = test_model.generate(
            **inputs_dpo,
            max_new_tokens=30,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=test_tokenizer.eos_token_id
        )
    
    dpo_text = test_tokenizer.decode(outputs_dpo[0], skip_special_tokens=True)
    dpo_continuation = dpo_text[len(prompt):].strip()
    
    # Display comparison
    print(f"\nüìù SFT:  {sft_continuation}")
    print(f"üéØ DPO:  {dpo_continuation}")
    print()

print(f"{'='*80}")
print(f"‚úÖ Comparaison termin√©e!")
print(f"\nüí° DPO devrait g√©n√©rer du texte avec un sentiment plus positif (selon les pr√©f√©rences)")