## Setup: Install Dependencies

In [None]:
# Installation pour Colab (Python 3.12+)
# Pin TRL to 0.26.0 for API compatibility
!pip install --upgrade pip setuptools wheel -q
!pip install transformers[torch] datasets trl==0.26.0 wandb accelerate -q

# V√©rification
import torch
import transformers
import trl
print("‚úÖ Installation r√©ussie!")
print(f"PyTorch: {torch.__version__}")
print(f"Transformers: {transformers.__version__}")
print(f"TRL: {trl.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 0. Mount Google Drive (Optional - for Colab/Kaggle)

In [None]:
# Mount Google Drive pour √©conomiser temps et quota
import os

try:
    from google.colab import drive
    drive.mount('/content/drive')
    SAVE_BASE_PATH = '/content/drive/MyDrive/dpo_ppo_training'
    os.makedirs(SAVE_BASE_PATH, exist_ok=True)
    print(f"‚úÖ Google Drive mont√©. Mod√®les sauvegard√©s sur: {SAVE_BASE_PATH}")
    USE_DRIVE = True
except ImportError:
    # Pas sur Colab
    SAVE_BASE_PATH = './results'
    USE_DRIVE = False
    print(f"‚ö†Ô∏è  Pas de Google Drive d√©tect√©. Stockage local: {SAVE_BASE_PATH}")

## 1. Configuration PPO

**Param√®tres configurables :**
- `USE_REWARD_MODEL` : False = classifier direct, True = reward model entra√Æn√©
- `target_kl` : Divergence KL cible (3, 6, 9, 12)
- `batch_size` : √Ä ajuster selon GPU T4
- `max_new_tokens` : Tokens g√©n√©r√©s par prompt (24 par d√©faut)

In [None]:
import torch
import numpy as np
import wandb
import json
from datetime import datetime
from datasets import load_dataset, Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    pipeline
)
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead

# =====================================
# CONFIGURATION PRINCIPALE
# =====================================

# ‚≠ê PPO Mode: Trained Reward Model only
USE_REWARD_MODEL = True  # Always use the trained reward model

# PPO Hyperparameters
TARGET_KL = 3.0  # Divergence KL cible
BATCH_SIZE = 128  # √Ä ajuster selon votre GPU (T4: 64-128)
MINI_BATCH_SIZE = 32  # batch_size / 4 g√©n√©ralement
LEARNING_RATE = 2e-5
NUM_EPOCHS = 1
MAX_NEW_TOKENS = 24  # Tokens g√©n√©r√©s par prompt

# Model paths
SFT_MODEL_PATH = f"{SAVE_BASE_PATH}/sft_model"
REWARD_MODEL_PATH = f"{SAVE_BASE_PATH}/reward_model"  # Trained reward model
PPO_MODEL_PATH = f"{SAVE_BASE_PATH}/ppo_model"

print(f"{'='*80}")
print(f"PPO Configuration (Reward Model)")
print(f"{'='*80}")
print(f"Target KL: {TARGET_KL}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Mini Batch Size: {MINI_BATCH_SIZE}")
print(f"Learning Rate: {LEARNING_RATE}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Max New Tokens: {MAX_NEW_TOKENS}")
print(f"SFT Model: {SFT_MODEL_PATH}")
print(f"Reward Model: {REWARD_MODEL_PATH}")
print(f"{'='*80}\n")

## 2. Load Dataset (Prompts from DPO)

In [None]:
print("="*80)
print("√âTAPE 1: Chargement du dataset de prompts")
print("="*80)

# Load preference pairs from DPO dataset
pairs_path = f"{SAVE_BASE_PATH}/datasets/preference_pairs.json"

if not os.path.exists(pairs_path):
    raise FileNotFoundError(
        f"‚ùå Dataset introuvable: {pairs_path}\n"
        "Veuillez d'abord ex√©cuter le notebook de g√©n√©ration des paires de pr√©f√©rences."
    )

# Load prompts
print(f"\nüì• Chargement des prompts depuis: {pairs_path}")
with open(pairs_path, 'r', encoding='utf-8') as f:
    preference_pairs = json.load(f)

# Extract unique prompts and their labels
prompts = []
prompt_labels = {}  # prompt -> sentiment_label (1=positive, 0=negative)

for pair in preference_pairs:
    prompt = pair["prompt"]
    if prompt not in prompts:
        prompts.append(prompt)
        # Try to infer label from chosen text (positive sentiment = 1)
        # For IMDB, we assume "chosen" is positive (sentiment=1)
        prompt_labels[prompt] = 1  # Ground truth: chosen is positive

print(f"‚úÖ {len(prompts)} prompts uniques charg√©s")
print(f"\nExemples de prompts:")
for i in range(min(3, len(prompts))):
    print(f"  {i+1}. {prompts[i][:50]}...")

# Create dataset
ppo_dataset = Dataset.from_dict({"query": prompts})
print(f"\n‚úÖ Dataset PPO cr√©√© avec {len(ppo_dataset)} prompts")
print(f"{'='*80}")
if USE_GT_REWARD:
    print(f"‚≠ê MODE: PPO-GT (Ground Truth Oracle)")
else:
    reward_type = "Learned Reward Model" if USE_REWARD_MODEL else "Siebert Classifier"
    print(f"üìö MODE: PPO Standard ({reward_type})")
print(f"{'='*80}\n")

## 3. Load Models (Policy, Reference, Reward)

In [None]:
print("="*80)
print("√âTAPE 2: Chargement des mod√®les")
print("="*80)

from transformers import GenerationConfig

# Load tokenizer
print(f"\nüì• Chargement du tokenizer depuis: {SFT_MODEL_PATH}")
tokenizer = AutoTokenizer.from_pretrained(SFT_MODEL_PATH)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
print("‚úÖ Tokenizer charg√©")

# Load policy model (with value head for PPO)
print(f"\nüì• Chargement du policy model (SFT + value head)...")
policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    SFT_MODEL_PATH,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)

if torch.cuda.is_available():
    policy_model = policy_model.to("cuda")
    print(f"‚úÖ Policy model charg√© sur GPU ({policy_model.pretrained_model.num_parameters() / 1e9:.2f}B params)")
else:
    print(f"‚ö†Ô∏è  Policy model charg√© sur CPU")

# Load reference model (frozen SFT)
print(f"\nüì• Chargement du reference model (SFT frozen)...")
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    SFT_MODEL_PATH,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)

if torch.cuda.is_available():
    ref_model = ref_model.to("cuda")
    print(f"‚úÖ Reference model charg√© sur GPU (frozen)")
else:
    print(f"‚ö†Ô∏è  Reference model charg√© sur CPU")

# Freeze reference model
for param in ref_model.parameters():
    param.requires_grad = False
ref_model.eval()

# Ensure generation_config exists and is attached to wrappers
for m in [policy_model, ref_model]:
    base = m.pretrained_model
    try:
        gen_cfg = base.generation_config
    except AttributeError:
        gen_cfg = None
    if gen_cfg is None:
        gen_cfg = GenerationConfig.from_model_config(base.config)
    if gen_cfg.pad_token_id is None:
        gen_cfg.pad_token_id = tokenizer.pad_token_id
    base.generation_config = gen_cfg
    m.generation_config = gen_cfg

# Load reward model (entra√Æn√©)
print(f"\nüì• Chargement du reward model (entra√Æn√©)...")
if not os.path.exists(REWARD_MODEL_PATH):
    raise FileNotFoundError(
        f"‚ùå Reward model introuvable: {REWARD_MODEL_PATH}\n"
        "Veuillez d'abord entra√Æner le reward model dans Train_Reward_Model_NVIDIA.ipynb"
    )

reward_model = AutoModelForSequenceClassification.from_pretrained(
    REWARD_MODEL_PATH,
    num_labels=2,  # Binaire: positive (1) / negative (0)
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
reward_tokenizer = AutoTokenizer.from_pretrained(REWARD_MODEL_PATH)

if torch.cuda.is_available() and reward_model.device.type != 'cuda':
    reward_model = reward_model.to("cuda")

reward_model.eval()
for param in reward_model.parameters():
    param.requires_grad = False

print(f"‚úÖ Reward model charg√© sur GPU (frozen)")

print(f"\n{'='*80}")
print(f"‚úÖ Tous les mod√®les charg√©s avec succ√®s")
print(f"{'='*80}\n")

In [None]:
print("\nüì• Chargement du value model (critic)...")
value_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    SFT_MODEL_PATH,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
if torch.cuda.is_available():
    value_model = value_model.to("cuda")
    print("‚úÖ Value model charg√© sur GPU")
else:
    print("‚ö†Ô∏è  Value model charg√© sur CPU")

# Attach generation_config to value model wrapper
from transformers import GenerationConfig
base = value_model.pretrained_model
try:
    gen_cfg = base.generation_config
except AttributeError:
    gen_cfg = None
if gen_cfg is None:
    gen_cfg = GenerationConfig.from_model_config(base.config)
if gen_cfg.pad_token_id is None:
    gen_cfg.pad_token_id = tokenizer.pad_token_id
base.generation_config = gen_cfg
value_model.generation_config = gen_cfg

## 4. Initialize W&B Logging

In [None]:
# Login to W&B
wandb.login()

# Initialize W&B run
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_name = f"ppo_imdb_reward_kl{TARGET_KL}_{timestamp}"

wandb.init(
    project="ppo",
    name=run_name,
    config={
        "model": "gpt2-large",
        "dataset": "imdb_prompts",
        "num_prompts": len(ppo_dataset),
        "batch_size": BATCH_SIZE,
        "mini_batch_size": MINI_BATCH_SIZE,
        "learning_rate": LEARNING_RATE,
        "num_epochs": NUM_EPOCHS,
        "target_kl": TARGET_KL,
        "max_new_tokens": MAX_NEW_TOKENS,
        "reward_model_type": "trained",
        "device": "cuda" if torch.cuda.is_available() else "cpu",
    }
)

print(f"‚úÖ W&B initialized: {run_name}")
print(f"   Project: ppo")
print(f"   Target KL: {TARGET_KL}")

## 5. Configure PPO Trainer

In [None]:
print("="*80)
print("√âTAPE 3: Configuration du PPO Trainer")
print("="*80)

# PPO Configuration
ppo_config = PPOConfig(
    learning_rate=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    mini_batch_size=MINI_BATCH_SIZE,
    gradient_accumulation_steps=1,
    kl_coef=TARGET_KL,  # Pond√©ration du terme KL
    num_ppo_epochs=4,
    seed=42,
    report_to="wandb",  # TRL 0.26.x utilise TrainingArguments.report_to
    run_name=run_name
)

# Initialize PPO Trainer
print(f"\nüöÄ Initialisation du PPO Trainer...")
ppo_trainer = PPOTrainer(
    args=ppo_config,
    model=policy_model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    train_dataset=ppo_dataset,
    reward_model=reward_model,
    value_model=value_model,
 )

print(f"‚úÖ PPO Trainer initialis√©")
print(f"\nüìä Configuration:")
print(f"   - Batch size: {BATCH_SIZE}")
print(f"   - Mini batch size: {MINI_BATCH_SIZE}")
print(f"   - Learning rate: {LEARNING_RATE}")
print(f"   - KL coef: {TARGET_KL}")
print(f"   - PPO epochs per batch: 4")
print(f"   - Max new tokens: {MAX_NEW_TOKENS}")
print(f"   - Total prompts: {len(ppo_dataset)}")
print(f"   - Estimated batches: {len(ppo_dataset) // BATCH_SIZE}")

## 6. Define Reward Function

In [None]:
def compute_rewards(texts, queries=None):
    """
    Compute rewards for generated texts.
    
    Two modes:
    - USE_GT_REWARD=True: PPO-GT Oracle - use ground truth label (1.0 for all samples)
    - USE_GT_REWARD=False: PPO Standard - use learned reward (classifier or model)
    
    Args:
        texts: List of generated text strings (prompt + response)
        queries: List of original prompts (for GT mode)
        
    Returns:
        List of reward scores (floats)
    """
    rewards = []
    
    if USE_GT_REWARD:
        # ‚≠ê PPO-GT MODE: Oracle with ground truth rewards
        # In IMDB sentiment task: positive sentiment = 1.0
        # For PPO-GT, all samples get reward = 1.0 (oracle knows they're trying to be positive)
        # More realistically: reward based on how positive the text is from the IMDB perspective
        
        # Simplified: use ground truth that all prompts should lead to positive sentiment
        rewards = [1.0] * len(texts)
        
        print(f"  [PPO-GT] Using ground truth rewards: {len(rewards)} samples with reward=1.0")
    else:
        # üìö PPO STANDARD MODE: Learned reward from classifier or model
        for text in texts:
            # Truncate to 512 tokens for classifier
            truncated = text[:512]
            
            if USE_REWARD_MODEL:
                # Option A: Trained reward model
                inputs = reward_tokenizer(
                    truncated,
                    return_tensors="pt",
                    truncation=True,
                    max_length=512,
                    padding=True
                ).to(reward_model.device)
                
                with torch.no_grad():
                    outputs = reward_model(**inputs)
                    reward = outputs.logits[0, 0].item()  # Scalar reward
            else:
                # Option B: Direct classifier (siebert)
                inputs = reward_tokenizer(
                    truncated,
                    return_tensors="pt",
                    truncation=True,
                    max_length=512,
                    padding=True
                ).to(reward_model.device)
                
                with torch.no_grad():
                    outputs = reward_model(**inputs)
                    logits = outputs.logits
                    probs = torch.softmax(logits, dim=-1)
                    # POSITIVE is label 1, reward = positive probability
                    reward = probs[0, 1].item()
            
            rewards.append(reward)
    
    return rewards

# Test reward function
print("Testing reward function...")
test_texts = [
    "This movie is amazing and wonderful!",
    "This movie is terrible and boring."
]
test_rewards = compute_rewards(test_texts)
print(f"\nTest rewards:")
for text, reward in zip(test_texts, test_rewards):
    print(f"  Text: {text[:50]}...")
    print(f"  Reward: {reward:.4f}\n")
print("‚úÖ Reward function working correctly")

## 7. PPO Training Loop

In [None]:
from pathlib import Path
from tqdm import tqdm

print("="*80)
print("√âTAPE 4: Entra√Ænement PPO")
print("="*80)

# Checkpointing config
RESUME_FROM_CHECKPOINT = True
CHECKPOINT_DIR = Path(PPO_MODEL_PATH) / "checkpoints"
CHECKPOINT_SAVE_STEPS = 200  # save optimizer/scheduler/model states every N steps
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

# Try to resume from latest checkpoint
latest_ckpt = None
if RESUME_FROM_CHECKPOINT:
    ckpts = sorted(
        CHECKPOINT_DIR.glob("step_*"),
        key=lambda p: int(p.name.split("_")[-1]) if p.name.split("_")[-1].isdigit() else -1,
    )
    if ckpts:
        latest_ckpt = str(ckpts[-1])
        print(f"üîÑ Reprise depuis le checkpoint: {latest_ckpt}")
        ppo_trainer.accelerator.load_state(latest_ckpt)
    else:
        print("‚ö†Ô∏è  Aucun checkpoint trouv√©, entra√Ænement from scratch")

# Generation kwargs
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": MAX_NEW_TOKENS,
}

print(f"\nüöÄ D√©marrage de l'entra√Ænement PPO...")
print(f"   - {len(ppo_dataset)} prompts")
print(f"   - {len(ppo_dataset) // BATCH_SIZE} batches")
print(f"   - {NUM_EPOCHS} epoch(s)\n")

global_step = 0

for epoch in range(NUM_EPOCHS):
    print(f"\n{'='*80}")
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
    print(f"{'='*80}\n")
    
    for batch_idx, batch in enumerate(tqdm(ppo_trainer.dataloader, desc=f"Epoch {epoch+1}")):
        query_tensors = batch["input_ids"]
        
        # Generate responses
        response_tensors = ppo_trainer.generate(
            query_tensors,
            return_prompt=False,
            **generation_kwargs
        )
        
        # Decode responses
        batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
        
        # Compute rewards
        texts = [q + r for q, r in zip(batch["query"], batch["response"])]
        rewards = compute_rewards(texts)
        rewards = [torch.tensor(r) for r in rewards]
        
        # Run PPO step
        stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
        
        # Log statistics
        ppo_trainer.log_stats(
            stats,
            batch,
            rewards,
            columns_to_log=["query", "response"]
        )
        
        global_step += 1
        
        # Print progress every 10 batches
        if global_step % 10 == 0:
            mean_reward = np.mean([r.item() for r in rewards])
            mean_kl = stats.get("objective/kl", 0.0)
            print(f"\nBatch {global_step}: Mean Reward = {mean_reward:.4f}, KL = {mean_kl:.4f}")
        
        # Periodic checkpoint save
        if global_step % CHECKPOINT_SAVE_STEPS == 0:
            ckpt_path = CHECKPOINT_DIR / f"step_{global_step}"
            ppo_trainer.accelerator.save_state(str(ckpt_path))
            print(f"üíæ Checkpoint sauvegard√©: {ckpt_path}")

print(f"\n{'='*80}")
print(f"‚úÖ Entra√Ænement PPO termin√©!")
print(f"{'='*80}")


## 8. Save PPO Model

In [None]:
print("="*80)
print("√âTAPE 5: Sauvegarde du mod√®le PPO")
print("="*80)

# Save model
os.makedirs(PPO_MODEL_PATH, exist_ok=True)
print(f"\nüíæ Sauvegarde du mod√®le PPO dans: {PPO_MODEL_PATH}")

# Save the pretrained model (without value head)
policy_model.save_pretrained(PPO_MODEL_PATH)
tokenizer.save_pretrained(PPO_MODEL_PATH)

print(f"‚úÖ Mod√®le PPO sauvegard√©!")
print(f"\nüìÅ Fichiers cr√©√©s:")
print(f"   - {PPO_MODEL_PATH}/pytorch_model.bin")
print(f"   - {PPO_MODEL_PATH}/config.json")
print(f"   - {PPO_MODEL_PATH}/tokenizer.json")

# Close W&B run
wandb.finish()
print(f"\n‚úÖ W&B run closed")

print(f"\n{'='*80}")
print(f"‚úÖ PPO TRAINING COMPLETE!")
print(f"{'='*80}")
print(f"\nüéØ Prochaines √©tapes:")
print(f"   1. √âvaluer le mod√®le PPO sur le test set")
print(f"   2. Comparer avec SFT et DPO")
print(f"   3. G√©n√©rer la courbe reward-KL (Figure 2)")

## 9. Test PPO Model (Optional)

In [None]:
print("="*80)
print("Test du mod√®le PPO sur quelques exemples")
print("="*80)

# Load PPO model for testing
test_model = AutoModelForCausalLM.from_pretrained(PPO_MODEL_PATH)
test_tokenizer = AutoTokenizer.from_pretrained(PPO_MODEL_PATH)

if torch.cuda.is_available():
    test_model = test_model.to("cuda")

test_model.eval()

# Test on 5 random prompts
import random
random.seed(42)
test_prompts = random.sample(prompts, 5)

print(f"\nG√©n√©ration sur 5 prompts al√©atoires:\n")

for i, prompt in enumerate(test_prompts, 1):
    print(f"{'‚îÄ'*80}")
    print(f"Prompt {i}: {prompt}")
    
    # Tokenize
    inputs = test_tokenizer(prompt, return_tensors="pt")
    if torch.cuda.is_available():
        inputs = inputs.to("cuda")
    
    # Generate
    with torch.no_grad():
        outputs = test_model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=test_tokenizer.eos_token_id
        )
    
    # Decode
    generated_text = test_tokenizer.decode(outputs[0], skip_special_tokens=True)
    continuation = generated_text[len(prompt):].strip()
    
    print(f"Generated: {continuation}")
    
    # Compute reward
    reward = compute_rewards([generated_text])[0]
    print(f"Reward: {reward:.4f}")
    print()

print(f"{'='*80}")
print(f"‚úÖ Test termin√©!")