## Setup: Install Dependencies

In [None]:
# Installation pour Colab (Python 3.12+)
!pip install --upgrade pip setuptools wheel -q
!pip install transformers[torch] datasets trl wandb accelerate -q

# V√©rification
import torch
import transformers
import trl
print("‚úÖ Installation r√©ussie!")
print(f"PyTorch: {torch.__version__}")
print(f"Transformers: {transformers.__version__}")
print(f"TRL: {trl.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 0. Mount Google Drive (Optional - for Colab/Kaggle)

In [None]:
# Mount Google Drive pour √©conomiser temps et quota
import os

try:
    from google.colab import drive
    drive.mount('/content/drive')
    SAVE_BASE_PATH = '/content/drive/MyDrive/dpo_ppo_training'
    os.makedirs(SAVE_BASE_PATH, exist_ok=True)
    print(f"‚úÖ Google Drive mont√©. Mod√®les sauvegard√©s sur: {SAVE_BASE_PATH}")
    USE_DRIVE = True
except ImportError:
    # Pas sur Colab
    SAVE_BASE_PATH = './results'
    USE_DRIVE = False
    print(f"‚ö†Ô∏è  Pas de Google Drive d√©tect√©. Stockage local: {SAVE_BASE_PATH}")

## 1. Configuration GRPO

**Param√®tres cl√©s GRPO :**
- `num_sample_generations` : Nombre de g√©n√©rations par prompt (4-8 typique)
- `USE_REWARD_MODEL` : False = classifier direct, True = reward model entra√Æn√©
- `batch_size` : Nombre de prompts par batch
- `learning_rate` : Taux d'apprentissage

In [None]:
import torch
import numpy as np
import wandb
import json
from datetime import datetime
from datasets import load_dataset, Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    pipeline,
)
from trl import GRPOConfig, GRPOTrainer

# =====================================
# CONFIGURATION PRINCIPALE
# =====================================

# GRPO Hyperparameters
NUM_SAMPLE_GENERATIONS = 4  # G√©n√©rations par prompt (4-8 recommand√©)
USE_REWARD_MODEL = False  # False: classifier direct, True: reward model entra√Æn√©

BATCH_SIZE = 32  # Nombre de prompts par batch (ajuster selon GPU T4)
MINI_BATCH_SIZE = 8  # Mini-batch pour GRPO optimization
LEARNING_RATE = 2e-5
NUM_EPOCHS = 1
MAX_NEW_TOKENS = 24  # Tokens √† g√©n√©rer par completion

# Model paths
SFT_MODEL_PATH = f"{SAVE_BASE_PATH}/sft_model"
REWARD_MODEL_PATH = f"{SAVE_BASE_PATH}/reward_model"  # Si USE_REWARD_MODEL=True
GRPO_MODEL_PATH = f"{SAVE_BASE_PATH}/grpo_model"

# Dataset
PREFERENCE_PAIRS_PATH = f"{SAVE_BASE_PATH}/datasets/preference_pairs.json"

print(f"{'='*80}")
print(f"GRPO Configuration")
print(f"{'='*80}")
print(f"Num Sample Generations: {NUM_SAMPLE_GENERATIONS}")
print(f"Use Reward Model: {USE_REWARD_MODEL}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Mini Batch Size: {MINI_BATCH_SIZE}")
print(f"Learning Rate: {LEARNING_RATE}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Max New Tokens: {MAX_NEW_TOKENS}")
print(f"SFT Model: {SFT_MODEL_PATH}")
print(f"GRPO Model Output: {GRPO_MODEL_PATH}")
print(f"{'='*80}\n")

## 2. Load Dataset (Prompts Only)

In [None]:
print("="*80)
print("√âTAPE 1: Chargement du dataset de prompts")
print("="*80)

# Load preference pairs to extract unique prompts
if not os.path.exists(PREFERENCE_PAIRS_PATH):
    raise FileNotFoundError(
        f"‚ùå Dataset introuvable: {PREFERENCE_PAIRS_PATH}\n"
        "Veuillez d'abord g√©n√©rer les paires de pr√©f√©rences (SFT notebook, cell 15)."
    )

print(f"\nüì• Chargement des prompts depuis: {PREFERENCE_PAIRS_PATH}")
with open(PREFERENCE_PAIRS_PATH, 'r', encoding='utf-8') as f:
    preference_pairs = json.load(f)

# Extract unique prompts
unique_prompts = list(set([pair["prompt"] for pair in preference_pairs]))
print(f"‚úÖ {len(unique_prompts)} prompts uniques extraits")

# Display samples
print(f"\nExemples de prompts:")
for i in range(min(3, len(unique_prompts))):
    print(f"  {i+1}. {unique_prompts[i][:80]}...")

# Create dataset
grpo_dataset = Dataset.from_dict({
    "query": unique_prompts,
})

print(f"\n‚úÖ Dataset GRPO cr√©√© avec {len(grpo_dataset)} prompts")
print(f"   Colonnes: {grpo_dataset.column_names}")

## 3. Load Models (Policy, Reference, Reward)

In [None]:
print("="*80)
print("√âTAPE 2: Chargement des mod√®les")
print("="*80)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# ===== 1. Load Policy Model (SFT - will be trained) =====
print(f"\nüì• Chargement du policy model (SFT)...")
if not os.path.exists(SFT_MODEL_PATH):
    raise FileNotFoundError(
        f"‚ùå Mod√®le SFT introuvable: {SFT_MODEL_PATH}\n"
        "Veuillez d'abord entra√Æner le mod√®le SFT."
    )

policy_model = AutoModelForCausalLM.from_pretrained(
    SFT_MODEL_PATH,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
policy_model.to(device)
tokenizer = AutoTokenizer.from_pretrained(SFT_MODEL_PATH)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

print(f"‚úÖ Policy model charg√© ({policy_model.num_parameters() / 1e9:.2f}B params)")

# ===== 2. Load Reference Model (Frozen SFT) =====
print(f"\nüì• Chargement du reference model (SFT frozen)...")
ref_model = AutoModelForCausalLM.from_pretrained(
    SFT_MODEL_PATH,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
ref_model.to(device)
for param in ref_model.parameters():
    param.requires_grad = False
ref_model.eval()
print(f"‚úÖ Reference model charg√© (frozen)")

# ===== 3. Load Reward Model or Classifier =====
print(f"\nüì• Chargement du reward model...")

if USE_REWARD_MODEL:
    # Option A: Trained reward model
    if not os.path.exists(REWARD_MODEL_PATH):
        print(f"‚ö†Ô∏è  USE_REWARD_MODEL=True mais mod√®le introuvable: {REWARD_MODEL_PATH}")
        print(f"‚ö†Ô∏è  Basculement vers le classifier direct (siebert)")
        USE_REWARD_MODEL = False
    else:
        reward_model = AutoModelForSequenceClassification.from_pretrained(
            REWARD_MODEL_PATH,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        )
        reward_model.to(device)
        reward_model.eval()
        reward_tokenizer = AutoTokenizer.from_pretrained(REWARD_MODEL_PATH)
        print(f"‚úÖ Reward model charg√© depuis: {REWARD_MODEL_PATH}")

if not USE_REWARD_MODEL:
    # Option B: Direct sentiment classifier (siebert)
    reward_model = AutoModelForSequenceClassification.from_pretrained(
        "siebert/sentiment-roberta-large-english",
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
    )
    reward_model.to(device)
    reward_model.eval()
    reward_tokenizer = AutoTokenizer.from_pretrained("siebert/sentiment-roberta-large-english")
    print(f"‚úÖ Classifier siebert charg√© (labels: NEGATIVE=0, POSITIVE=1)")

# Freeze reward model
for param in reward_model.parameters():
    param.requires_grad = False

print(f"\n{'='*80}")
print(f"‚úÖ Tous les mod√®les charg√©s avec succ√®s")
print(f"{'='*80}\n")

## 4. Initialize W&B Logging

In [None]:
# Login to W&B
wandb.login()

# Initialize W&B run
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_name = f"grpo_imdb_gen{NUM_SAMPLE_GENERATIONS}_{timestamp}"

wandb.init(
    project="grpo",
    name=run_name,
    config={
        "model": "gpt2-large",
        "dataset": "imdb",
        "num_prompts": len(grpo_dataset),
        "num_sample_generations": NUM_SAMPLE_GENERATIONS,
        "use_reward_model": USE_REWARD_MODEL,
        "batch_size": BATCH_SIZE,
        "mini_batch_size": MINI_BATCH_SIZE,
        "learning_rate": LEARNING_RATE,
        "num_epochs": NUM_EPOCHS,
        "max_new_tokens": MAX_NEW_TOKENS,
        "device": str(device),
    }
)

print(f"‚úÖ W&B initialized: {run_name}")
print(f"   Project: grpo")
print(f"   Generations per prompt: {NUM_SAMPLE_GENERATIONS}")

## 5. Configure GRPO Trainer

In [None]:
print("="*80)
print("√âTAPE 3: Configuration du GRPO Trainer")
print("="*80)

# GRPO Configuration
grpo_config = GRPOConfig(
    output_dir=GRPO_MODEL_PATH,
    
    # GRPO-specific
    num_sample_generations=NUM_SAMPLE_GENERATIONS,
    
    # Training
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=1,
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    
    # Optimization
    max_grad_norm=1.0,
    weight_decay=0.05,
    
    # Generation
    max_new_tokens=MAX_NEW_TOKENS,
    temperature=1.0,
    top_p=1.0,
    do_sample=True,
    
    # Logging
    logging_steps=10,
    logging_first_step=True,
    report_to=["wandb"],
    
    # Saving
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    
    # Hardware optimization
    fp16=torch.cuda.is_available(),
    bf16=False,
    gradient_checkpointing=True,
    
    # Other
    remove_unused_columns=False,
    seed=42,
)

print(f"‚úÖ GRPO Config cr√©√©")
print(f"\nüìä Configuration:")
print(f"   - Generations per prompt: {NUM_SAMPLE_GENERATIONS}")
print(f"   - Batch size: {BATCH_SIZE}")
print(f"   - Learning rate: {LEARNING_RATE}")
print(f"   - Max new tokens: {MAX_NEW_TOKENS}")
print(f"   - Total prompts: {len(grpo_dataset)}")
print(f"   - FP16: {grpo_config.fp16}")
print(f"   - Gradient checkpointing: {grpo_config.gradient_checkpointing}")

## 6. Reward Function

In [None]:
print("="*80)
print("√âTAPE 4: D√©finition de la fonction de reward")
print("="*80)

def compute_rewards(texts):
    """
    Compute sentiment rewards for generated texts.
    
    Args:
        texts: List of generated texts (prompt + response)
    
    Returns:
        rewards: List of reward scores (higher = more positive sentiment)
    """
    rewards = []
    
    for text in texts:
        # Truncate to max length for classifier
        truncated = text[:512]
        
        # Tokenize
        inputs = reward_tokenizer(
            truncated,
            return_tensors="pt",
            truncation=True,
            max_length=512,
            padding=True
        ).to(device)
        
        # Compute reward
        with torch.no_grad():
            if USE_REWARD_MODEL:
                # Trained reward model: scalar output
                outputs = reward_model(**inputs)
                reward = outputs.logits[0, 0].item()
            else:
                # Classifier: probability of POSITIVE class
                outputs = reward_model(**inputs)
                probs = torch.softmax(outputs.logits, dim=-1)
                reward = probs[0, 1].item()  # POSITIVE probability (0-1)
        
        rewards.append(reward)
    
    return rewards

# Test reward function
print("\nüß™ Test de la fonction de reward:\n")
test_texts = [
    "This movie is absolutely amazing and wonderful!",
    "This movie is terrible and boring."
]

test_rewards = compute_rewards(test_texts)
for text, reward in zip(test_texts, test_rewards):
    print(f"Text: {text}")
    print(f"Reward: {reward:.4f}\n")

print(f"‚úÖ Fonction de reward op√©rationnelle")
print(f"   Mode: {'Reward Model' if USE_REWARD_MODEL else 'Classifier (siebert)'}")

## 7. Initialize GRPO Trainer

In [None]:
print("="*80)
print("√âTAPE 5: Initialisation du GRPO Trainer")
print("="*80)

# Initialize GRPO Trainer
grpo_trainer = GRPOTrainer(
    model=policy_model,
    ref_model=ref_model,
    args=grpo_config,
    train_dataset=grpo_dataset,
    tokenizer=tokenizer,
    reward_function=compute_rewards,
)

print(f"\n‚úÖ GRPO Trainer initialis√©")
print(f"\nüìä R√©sum√©:")
print(f"   - Policy model: GPT-2-Large (SFT)")
print(f"   - Reference model: GPT-2-Large (SFT frozen)")
print(f"   - Reward: {'Trained model' if USE_REWARD_MODEL else 'Classifier direct'}")
print(f"   - Prompts: {len(grpo_dataset)}")
print(f"   - Generations/prompt: {NUM_SAMPLE_GENERATIONS}")
print(f"   - Total generations: {len(grpo_dataset) * NUM_SAMPLE_GENERATIONS}")
print(f"\nüöÄ Pr√™t pour l'entra√Ænement GRPO!")

## 8. GRPO Training

In [None]:
print("="*80)
print("√âTAPE 6: Entra√Ænement GRPO")
print("="*80)

print(f"\nüöÄ D√©marrage de l'entra√Ænement GRPO...")
print(f"   - {len(grpo_dataset)} prompts")
print(f"   - {NUM_SAMPLE_GENERATIONS} g√©n√©rations par prompt")
print(f"   - {NUM_EPOCHS} epoch(s)\n")

# Clear GPU cache before training
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("‚úÖ GPU cache cleared\n")

# Train
grpo_trainer.train()

print(f"\n{'='*80}")
print(f"‚úÖ Entra√Ænement GRPO termin√©!")
print(f"{'='*80}")

## 9. Save GRPO Model

In [None]:
print("="*80)
print("√âTAPE 7: Sauvegarde du mod√®le GRPO")
print("="*80)

# Save model
print(f"\nüíæ Sauvegarde du mod√®le GRPO dans: {GRPO_MODEL_PATH}")
policy_model.save_pretrained(GRPO_MODEL_PATH)
tokenizer.save_pretrained(GRPO_MODEL_PATH)

print(f"‚úÖ Mod√®le GRPO sauvegard√©!")
print(f"\nüìÅ Fichiers cr√©√©s:")
print(f"   - {GRPO_MODEL_PATH}/pytorch_model.bin")
print(f"   - {GRPO_MODEL_PATH}/config.json")
print(f"   - {GRPO_MODEL_PATH}/tokenizer.json")

# Close W&B run
wandb.finish()
print(f"\n‚úÖ W&B run closed")

print(f"\n{'='*80}")
print(f"‚úÖ GRPO TRAINING COMPLETE!")
print(f"{'='*80}")
print(f"\nüéØ Prochaines √©tapes:")
print(f"   1. √âvaluer le mod√®le GRPO sur le test set")
print(f"   2. Comparer avec SFT, DPO, PPO")
print(f"   3. Tester d'autres valeurs de num_sample_generations (2, 8, 16)")
print(f"   4. G√©n√©rer la courbe reward-KL (Figure 2)")

## 10. Test GRPO Model (Optional)

In [None]:
print("="*80)
print("Test du mod√®le GRPO sur quelques exemples")
print("="*80)

# Load GRPO model for testing
test_model = AutoModelForCausalLM.from_pretrained(GRPO_MODEL_PATH)
test_tokenizer = AutoTokenizer.from_pretrained(GRPO_MODEL_PATH)

if torch.cuda.is_available():
    test_model = test_model.to("cuda")

test_model.eval()

# Test on 5 random prompts from dataset
import random
random.seed(42)
test_indices = random.sample(range(len(grpo_dataset)), 5)
test_samples = grpo_dataset.select(test_indices)

print(f"\nG√©n√©ration sur 5 prompts al√©atoires du dataset:\n")

for i, sample in enumerate(test_samples, 1):
    prompt = sample["query"]
    
    print(f"{'‚îÄ'*80}")
    print(f"Prompt {i}: {prompt}")
    
    # Tokenize
    inputs = test_tokenizer(prompt, return_tensors="pt")
    if torch.cuda.is_available():
        inputs = inputs.to("cuda")
    
    # Generate
    with torch.no_grad():
        outputs = test_model.generate(
            **inputs,
            max_new_tokens=40,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=test_tokenizer.eos_token_id
        )
    
    # Decode
    generated_text = test_tokenizer.decode(outputs[0], skip_special_tokens=True)
    continuation = generated_text[len(prompt):].strip()
    
    print(f"\nGRPO Generated: {continuation}")
    
    # Compute reward
    reward = compute_rewards([generated_text])[0]
    print(f"Reward: {reward:.4f}")
    print()

print(f"{'='*80}")
print(f"‚úÖ Test termin√©!")
print(f"\nüí° Le mod√®le GRPO devrait g√©n√©rer du texte avec un sentiment positif √©lev√©")

## 11. Compare SFT vs GRPO (Side-by-side)

In [None]:
print("="*80)
print("Comparaison SFT vs GRPO sur les m√™mes prompts")
print("="*80)

# Load SFT model
print("\nüì• Chargement du mod√®le SFT...")
sft_model = AutoModelForCausalLM.from_pretrained(SFT_MODEL_PATH)
sft_tokenizer = AutoTokenizer.from_pretrained(SFT_MODEL_PATH)

if torch.cuda.is_available():
    sft_model = sft_model.to("cuda")

sft_model.eval()
print("‚úÖ Mod√®le SFT charg√©")

# Test on 3 prompts
test_prompts = [
    "This movie is",
    "I really enjoyed",
    "The acting was"
]

print(f"\nG√©n√©ration comparative sur {len(test_prompts)} prompts:\n")

for i, prompt in enumerate(test_prompts, 1):
    print(f"{'='*80}")
    print(f"Prompt {i}: {prompt}")
    print(f"{'='*80}")
    
    # SFT generation
    inputs_sft = sft_tokenizer(prompt, return_tensors="pt")
    if torch.cuda.is_available():
        inputs_sft = inputs_sft.to("cuda")
    
    with torch.no_grad():
        outputs_sft = sft_model.generate(
            **inputs_sft,
            max_new_tokens=30,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=sft_tokenizer.eos_token_id
        )
    
    sft_text = sft_tokenizer.decode(outputs_sft[0], skip_special_tokens=True)
    sft_continuation = sft_text[len(prompt):].strip()
    sft_reward = compute_rewards([sft_text])[0]
    
    # GRPO generation
    inputs_grpo = test_tokenizer(prompt, return_tensors="pt")
    if torch.cuda.is_available():
        inputs_grpo = inputs_grpo.to("cuda")
    
    with torch.no_grad():
        outputs_grpo = test_model.generate(
            **inputs_grpo,
            max_new_tokens=30,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=test_tokenizer.eos_token_id
        )
    
    grpo_text = test_tokenizer.decode(outputs_grpo[0], skip_special_tokens=True)
    grpo_continuation = grpo_text[len(prompt):].strip()
    grpo_reward = compute_rewards([grpo_text])[0]
    
    # Display comparison
    print(f"\nüìù SFT:   {sft_continuation}")
    print(f"   Reward: {sft_reward:.4f}")
    print(f"\nüéØ GRPO:  {grpo_continuation}")
    print(f"   Reward: {grpo_reward:.4f}")
    print(f"\nüìä Am√©lioration: {(grpo_reward - sft_reward):.4f} (+{((grpo_reward - sft_reward) / sft_reward * 100):.1f}%)")
    print()

print(f"{'='*80}")
print(f"‚úÖ Comparaison termin√©e!")
print(f"\nüí° GRPO devrait g√©n√©rer du texte avec un reward plus √©lev√© que SFT")