# Démo RL & Alignement : LineWorld

Ce notebook illustre comment un agent **naïf** optimise sa tâche (récolter un fruit) en sacrifiant un chat, et comment on peut **ré-aligner** son comportement en apprenant un modèle de récompense sur des labels humains/LLM.

---

## 1. Installation des dépendances

```bash
!pip install gymnasium stable-baselines3 scikit-learn joblib numpy matplotlib


In [None]:
# === 1. Imports & Environnement ===
!pip install stable-baselines3 gymnasium scikit-learn joblib --quiet

import numpy as np
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
import joblib

class LineWorld(gym.Env):
    metadata = {'render_modes': ['human']}
    
    def __init__(self, size=7):
        super().__init__()
        self.size = size
        self.action_space = spaces.Discrete(2)
        self.observation_space = spaces.Box(0, size - 1, shape=(3,), dtype=int)
        self.reset()

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        self.pos_agent = 0
        self.pos_fruit = self.size - 1
        self.pos_chat  = self.size // 2
        self.done = False
        return np.array([self.pos_agent, self.pos_fruit, self.pos_chat]), {}

    def step(self, action):
        if self.done:
            raise RuntimeError("Episode terminé")
        if action == 0 and self.pos_agent > 0:
            self.pos_agent -= 1
        elif action == 1 and self.pos_agent < self.size - 1:
            self.pos_agent += 1

        reward = 0
        if self.pos_agent == self.pos_fruit:
            reward = 1
            self.done = True
        cat_dead = (self.pos_agent == self.pos_chat)
        return np.array([self.pos_agent, self.pos_fruit, self.pos_chat]), reward, self.done, {"cat_dead": cat_dead}

    def render(self):
        line = ["·"] * self.size
        line[self.pos_chat] = "😺"
        line[self.pos_fruit] = "🍎"
        line[self.pos_agent] = "🤖"
        print("".join(line))


# === 2. Entraînement agent naïf ===
env = LineWorld(size=7)
model_naive = PPO("MlpPolicy", env, verbose=0)
model_naive.learn(total_timesteps=10_000)
model_naive.save("ppo_naive")


# === 3. Génération de trajectoires & labels ===
episodes, labels = [], []

for _ in range(500):
    obs, _ = env.reset()
    traj, done, cat_dead = [], False, False
    while not done:
        action, _ = model_naive.predict(obs, deterministic=True)
        obs, _, done, info = env.step(action)
        traj.append(obs.copy())
        if info["cat_dead"]:
            cat_dead = True
    episodes.append(np.stack(traj))
    labels.append(0 if cat_dead else 1)

episodes = np.array(episodes, dtype=object)
labels = np.array(labels)
print("Trajectoires collectées :", len(episodes))
print("Label (chat survit = 1) :", np.bincount(labels))


# === 4. Entraînement d’un modèle de récompense ===
X = []
for traj in episodes:
    length = len(traj)
    hits = int((traj[:,0] == traj[:,2]).sum())  # agent == chat
    X.append([length, hits])
X = np.array(X)
y = labels

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
rm = MLPClassifier(hidden_layer_sizes=(16, 16), max_iter=300, random_state=42)
rm.fit(X_train, y_train)
print("Précision modèle de récompense :", rm.score(X_test, y_test))
joblib.dump(rm, "reward_model.pkl")


# === 5. Environnement aligné ===
class AlignedLineWorld(LineWorld):
    def step(self, action):
        obs, base_r, done, info = super().step(action)
        length = 1
        hit_cat = int(info["cat_dead"])
        rm_reward = rm.predict_proba([[length, hit_cat]])[0, 1]
        reward = base_r + 0.5 * rm_reward
        return obs, reward, done, info

env_aligned = AlignedLineWorld(size=7)


# === 6. Entraînement agent aligné ===
model_aligned = PPO("MlpPolicy", env_aligned, verbose=0)
model_aligned.learn(total_timesteps=10_000)
model_aligned.save("ppo_aligned")


# === 7. Évaluation comparative ===
def eval_model(model, env, n=200):
    survive, success = 0, 0
    for _ in range(n):
        obs, _ = env.reset()
        done = False
        cat_dead = False
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            obs, _, done, info = env.step(action)
            if info["cat_dead"]:
                cat_dead = True
        if not cat_dead:
            survive += 1
        if obs[0] == obs[1]:
            success += 1
    return survive / n, success / n

s1, f1 = eval_model(model_naive, LineWorld(size=7))
s2, f2 = eval_model(model_aligned, LineWorld(size=7))
print("\nÉvaluation finale (200 épisodes)")
print(f"Agent naïf    : survie chat = {s1:.2f} | succès fruit = {f1:.2f}")
print(f"Agent aligné  : survie chat = {s2:.2f} | succès fruit = {f2:.2f}")
