## 1. Setup et Imports

In [None]:
# Changer de r√©pertoire vers le projet (si n√©cessaire)
# %cd Direct-Preference-Optimization/

import os
print(f"üìÇ R√©pertoire actuel: {os.getcwd()}")

In [None]:
%%capture
# Installation silencieuse des packages
!pip install torch transformers datasets accelerate bitsandbytes sentencepiece protobuf tqdm pyyaml scikit-learn

In [None]:
import os
import sys
import torch
from torch.utils.data import DataLoader

# Ajouter le r√©pertoire racine au path
ROOT = os.path.abspath('.')
if ROOT not in sys.path:
    sys.path.append(ROOT)

from src.core.models import load_models
from src.core.data import PromptDataset, prompt_collate_fn
from src.ppo.ppo_trainer_no_vh import PPOTrainerNoValueHead
from src.core.utils import load_yaml_config

print("‚úÖ Imports r√©ussis")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. V√©rification des Donn√©es

In [None]:
prompts_path = "data/processed/sentiment/prompts.jsonl"

if not os.path.exists(prompts_path):
    print("‚ö†Ô∏è  Fichier prompts.jsonl non trouv√©!")
    print("Ex√©cution de prepare_prompts.py...")
    !python scripts/prepare_prompts.py
else:
    with open(prompts_path, 'r') as f:
        num_prompts = sum(1 for line in f if line.strip())
    print(f"‚úÖ {num_prompts} prompts trouv√©s")

## 3. Configuration

In [None]:
# Charger la config NO VALUE HEAD
config_path = "configs/ppo_sentiment_no_vh.yaml"
config = load_yaml_config(config_path)

print("‚öôÔ∏è  Configuration PPO (NO VALUE HEAD):")
print(f"  Model: {config['model']['name']}")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Learning rate: {config['training']['learning_rate']}")
print(f"  Epochs: {config['training']['num_epochs']}")
print("\nüìä Param√®tres PPO:")
print(f"  Clip epsilon: {config['ppo']['clip_epsilon']}")
print(f"  Entropy coef: {config['ppo']['entropy_coef']}")
print(f"  ‚ùå PAS de value_coef (pas de value head)")
print(f"  Target KL: {config['ppo']['target_kl']}")
print(f"  PPO epochs: {config['ppo']['num_ppo_epochs']}")

## 4. Chargement des Mod√®les

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üì± Device: {device}")

In [None]:
model_name = config["model"]["name"]
dtype = config["model"]["dtype"]

mb = load_models(model_name, dtype=dtype, device=device)
tokenizer = mb.tokenizer

print(f"‚úÖ Mod√®les charg√©s: {model_name}")
print(f"   Policy model: {mb.policy_model.num_parameters():,} param√®tres")
print(f"   Device: {mb.device}")
print(f"   ‚ùå PAS de value head - √âconomie de param√®tres!")

## 5. Pr√©paration du DataLoader

In [None]:
prompt_dataset = PromptDataset(config["data"]["prompt_path"])
max_prompt_length = config["data"]["max_prompt_length"]

print(f"Dataset: {len(prompt_dataset)} prompts")

def collate(batch):
    return prompt_collate_fn(
        batch,
        tokenizer=tokenizer,
        max_prompt_length=max_prompt_length,
    )

prompt_loader = DataLoader(
    prompt_dataset,
    batch_size=config["training"]["batch_size"],
    shuffle=True,
    collate_fn=collate,
)

print(f"‚úÖ DataLoader cr√©√©: {len(prompt_loader)} batches")

## 6. Initialisation du PPO Trainer (No VH)

In [None]:
print("‚öôÔ∏è  Initialisation du PPO Trainer (NO VALUE HEAD)...")

trainer = PPOTrainerNoValueHead(
    model_bundle=mb,
    prompt_loader=prompt_loader,
    config=config,
)

print("\n‚úÖ Trainer initialis√©")
print(f"   Reward model: {config['reward_model']['name']}")
print(f"   Save dir: {config['logging']['save_dir']}")
print(f"   ‚úì Mode: PPO sans Value Head")

## 7. Entra√Ænement

In [None]:
import time

print("="*60)
print(f"üöÄ D√âMARRAGE DE L'ENTRA√éNEMENT PPO (NO VALUE HEAD)")
print(f"   Exp√©rience: {config['experiment_name']}")
print(f"   Device: {device}")
print(f"   Epochs: {config['training']['num_epochs']}")
print(f"   Batch size: {config['training']['batch_size']}")
print(f"   Total batches: {len(prompt_loader)}")
print("="*60)

start_time = time.time()

try:
    trainer.train()
    
    elapsed = time.time() - start_time
    print("\n" + "="*60)
    print(f"‚úÖ Entra√Ænement termin√©!")
    print(f"   Temps total: {elapsed/60:.2f} minutes")
    print(f"   Temps moyen par epoch: {elapsed/config['training']['num_epochs']/60:.2f} minutes")
    print("="*60)
    
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è  Entra√Ænement interrompu par l'utilisateur")
    
except Exception as e:
    print("\n" + "="*60)
    print(f"‚ùå ERREUR: {type(e).__name__}")
    print(f"   Message: {e}")
    print("="*60)
    import traceback
    traceback.print_exc()

## 8. V√©rification des Checkpoints

In [None]:
import glob

save_dir = config['logging']['save_dir']
checkpoints = glob.glob(os.path.join(save_dir, "*.pt"))

if checkpoints:
    print(f"‚úÖ {len(checkpoints)} checkpoint(s) sauvegard√©(s):")
    for ckpt in sorted(checkpoints):
        size_mb = os.path.getsize(ckpt) / (1024**2)
        print(f"   üìÅ {os.path.basename(ckpt)} ({size_mb:.1f} MB)")
else:
    print("‚ö†Ô∏è  Aucun checkpoint trouv√©")

## 9. Test de G√©n√©ration

In [None]:
test_prompts = [
    "The movie was",
    "I think this product is",
    "The customer service was",
]

print("üß™ Test de g√©n√©ration (PPO No VH):")
print("="*60)

mb.policy_model.eval()
with torch.no_grad():
    for prompt in test_prompts:
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        outputs = mb.policy_model.generate(
            **inputs,
            max_new_tokens=32,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
        )
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"\nüí¨ Prompt: {prompt}")
        print(f"‚ú® Response: {response}")

print("\n" + "="*60)

## üìä Comparaison PPO vs PPO No VH

| Aspect | PPO (avec VH) | PPO (sans VH) |
|--------|---------------|---------------|
| **Param√®tres** | +1-5M (value head) | Aucun param√®tre suppl√©mentaire |
| **VRAM** | ~17-18 GB | ~15-16 GB |
| **Complexit√©** | Plus complexe (value loss) | Plus simple |
| **Variance** | Moins de variance | Peut avoir plus de variance |
| **Stabilit√©** | G√©n√©ralement plus stable | D√©pend du reward model |
| **Vitesse** | L√©g√®rement plus rapide | Appels suppl√©mentaires au RM |

## üí° Quand utiliser quelle version?

**PPO avec Value Head:**
- Probl√®mes complexes avec trajectoires longues
- Besoin de variance minimale
- Reward model co√ªteux en calcul

**PPO sans Value Head:**
- Contraintes de VRAM
- Reward model rapide et fiable
- Besoin de simplicit√©
- Exp√©rimentation rapide