## 1. Setup et Imports

In [1]:
import os
import sys
import torch
from torch.utils.data import DataLoader

# Ajouter le r√©pertoire racine au path
ROOT = os.path.abspath('..')
if ROOT not in sys.path:
    sys.path.append(ROOT)

from src.dpo.models import load_models
from src.dpo.data import PromptDataset, prompt_collate_fn
from src.ppo.ppo_trainer import PPOTrainer
from src.dpo.utils import load_yaml_config

print("‚úÖ Imports r√©ussis")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"MPS available: {torch.backends.mps.is_available()}")

ModuleNotFoundError: No module named 'src'

In [None]:

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
   

print(f"Device s√©lectionn√©: {device}")

## 3. V√©rification et Pr√©paration des Donn√©es

In [None]:
# V√©rifier si les donn√©es de prompts existent
prompts_path = "data/processed/sentiment/prompts.jsonl"

if not os.path.exists(prompts_path):
    print("‚ö†Ô∏è  Fichier prompts.jsonl non trouv√©!")
    print("üì• Ex√©cution de prepare_prompts.py...")
    !python scripts/prepare_prompts.py
    print("‚úÖ Prompts pr√©par√©s")
else:
    # Compter le nombre de prompts
    with open(prompts_path, 'r') as f:
        num_prompts = sum(1 for line in f if line.strip())
    print(f"‚úÖ {num_prompts} prompts trouv√©s dans {prompts_path}")

## 4. Configuration des Hyperparam√®tres

Vous pouvez modifier ces param√®tres selon vos besoins:

In [None]:
# Charger la config par d√©faut
config_path = "configs/ppo_sentiment.yaml"
config = load_yaml_config(config_path)

# Afficher les param√®tres principaux
print("üìä Configuration PPO:")
print(f"  Model: {config['model']['name']}")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Learning rate: {config['training']['learning_rate']}")
print(f"  Epochs: {config['training']['num_epochs']}")
print("\nüéØ Param√®tres PPO:")
print(f"  Clip epsilon: {config['ppo']['clip_epsilon']}")
print(f"  Value coef: {config['ppo']['value_coef']}")
print(f"  Entropy coef: {config['ppo']['entropy_coef']}")
print(f"  Target KL: {config['ppo']['target_kl']}")
print(f"  PPO epochs: {config['ppo']['num_ppo_epochs']}")
print("\nüéÆ G√©n√©ration:")
print(f"  Max length: {config['generation']['max_length']}")
print(f"  Temperature: {config['generation']['temperature']}")

### Modifier les param√®tres (optionnel)

D√©commentez et modifiez si vous voulez changer certains param√®tres:

In [None]:
# Exemple de modifications
# config['training']['batch_size'] = 1  # R√©duire si probl√®me de m√©moire
# config['training']['num_epochs'] = 2  # Plus d'epochs
# config['ppo']['num_ppo_epochs'] = 2   # Moins d'epochs PPO par batch
# config['generation']['max_length'] = 64  # R√©ponses plus courtes

print("‚öôÔ∏è  Configuration personnalis√©e (si modifi√©e)")

## 5. Chargement des Mod√®les

In [None]:
print("üì¶ Chargement des mod√®les...")
print("   Cela peut prendre quelques minutes...")

model_name = config["model"]["name"]
dtype = config["model"]["dtype"]

# Charger les mod√®les (policy et r√©f√©rence)
mb = load_models(model_name, dtype=dtype, device=device)
tokenizer = mb.tokenizer

print(f"‚úÖ Mod√®les charg√©s: {model_name}")
print(f"   Policy model: {mb.policy_model.num_parameters():,} param√®tres")
print(f"   Device: {mb.device}")

## 6. Pr√©paration du DataLoader

In [None]:
# Charger le dataset de prompts
prompt_dataset = PromptDataset(config["data"]["prompt_path"])
max_prompt_length = config["data"]["max_prompt_length"]

print(f"üìö Dataset charg√©: {len(prompt_dataset)} prompts")

# Fonction de collate
def collate(batch):
    return prompt_collate_fn(
        batch,
        tokenizer=tokenizer,
        max_prompt_length=max_prompt_length,
    )

# DataLoader
prompt_loader = DataLoader(
    prompt_dataset,
    batch_size=config["training"]["batch_size"],
    shuffle=True,
    collate_fn=collate,
)

print(f"‚úÖ DataLoader cr√©√©: {len(prompt_loader)} batches")

## 7. Initialisation du PPO Trainer

In [None]:
print("üèóÔ∏è  Initialisation du PPO Trainer...")

# Cr√©er le trainer
trainer = PPOTrainer(
    model_bundle=mb,
    prompt_loader=prompt_loader,
    config=config,
)

print("‚úÖ Trainer initialis√©")
print(f"   Reward model: {config['reward_model']['name']}")
print(f"   Save dir: {config['logging']['save_dir']}")

## 8. Entra√Ænement PPO üöÄ

**‚ö†Ô∏è Attention**: L'entra√Ænement peut prendre du temps, surtout sur CPU!

**Temps estim√©**:
- Sur GPU: ~30 min - 1h
- Sur CPU: 2-4h (voire plus selon le hardware)

In [None]:
import time

print("="*60)
print(f"üöÄ D√©marrage de l'entra√Ænement PPO")
print(f"   Exp√©rience: {config['experiment_name']}")
print(f"   Device: {device}")
print("="*60)

start_time = time.time()

try:
    # Lancer l'entra√Ænement
    trainer.train()
    
    elapsed = time.time() - start_time
    print("\n" + "="*60)
    print(f"‚úÖ Entra√Ænement termin√©!")
    print(f"   Temps total: {elapsed/60:.2f} minutes")
    print("="*60)
    
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è  Entra√Ænement interrompu par l'utilisateur")
    print("   Les checkpoints partiels ont √©t√© sauvegard√©s")
    
except Exception as e:
    print(f"\n‚ùå Erreur pendant l'entra√Ænement: {e}")
    import traceback
    traceback.print_exc()

## 9. V√©rification des Checkpoints

In [None]:
import glob

save_dir = config['logging']['save_dir']
checkpoints = glob.glob(os.path.join(save_dir, "*.pt"))

if checkpoints:
    print(f"‚úÖ {len(checkpoints)} checkpoint(s) sauvegard√©(s):")
    for ckpt in sorted(checkpoints):
        size_mb = os.path.getsize(ckpt) / (1024**2)
        print(f"   üìÅ {os.path.basename(ckpt)} ({size_mb:.1f} MB)")
else:
    print("‚ö†Ô∏è  Aucun checkpoint trouv√©")

## 10. Test de G√©n√©ration Rapide

In [None]:
# Tester la g√©n√©ration avec le mod√®le entra√Æn√©
test_prompts = [
    "The movie was",
    "I think this product is",
    "The customer service was",
]

print("üß™ Test de g√©n√©ration avec le mod√®le PPO:")
print("="*60)

mb.policy_model.eval()
with torch.no_grad():
    for prompt in test_prompts:
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        outputs = mb.policy_model.generate(
            **inputs,
            max_new_tokens=32,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
        )
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"\nüí¨ Prompt: {prompt}")
        print(f"‚ú® Response: {response}")

print("\n" + "="*60)

## üéØ Prochaines √âtapes

Maintenant que l'entra√Ænement est termin√©, vous pouvez:

1. **√âvaluer le mod√®le** avec `eval_sentiment2.py`:
   ```bash
   python scripts/eval_sentiment2.py --method ppo --num_samples 3
   ```

2. **Comparer avec DPO et GRPO**:
   ```bash
   python scripts/compare_methods.py
   ```

3. **Visualiser les r√©sultats**:
   ```bash
   python scripts/visualize_comparison.py
   ```

4. **Modifier les hyperparam√®tres** et r√©entra√Æner pour comparer les performances

## üìä Notes

- Les checkpoints sont sauvegard√©s dans `checkpoints/sentiment_ppo/`
- Le mod√®le PPO inclut un **value head** en plus de la policy
- L'entra√Ænement utilise un **reward model** bas√© sur le sentiment
- Le **target KL** permet l'early stopping pour la stabilit√©
- Sur CPU, l'entra√Ænement est plus lent mais plus stable que sur MPS (macOS)