In [1]:
# === 1. Installer les d√©pendances ===
# ex√©cutez ceci dans une cellule s√©par√©e ou en console avant de lancer le script
# !pip install stable-baselines3 gymnasium scikit-learn joblib pygame --quiet

# === 2. Imports ===
import os
import time
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
import joblib
import pygame





In [2]:
# === 3. D√©finition de l'environnement Homegrid avec rendu Pygame + emojis ===
class HomegridEnv(gym.Env):
    metadata = {"render_modes": ["human", "pygame"]}

    def __init__(self, size=5, render_mode="pygame", cell_size=80):
        super().__init__()
        self.size = size
        self.render_mode = render_mode
        self.cell_size = cell_size
        self.action_space = spaces.Discrete(4)
        self.observation_space = spaces.MultiDiscrete([size]*6)
        self.done = False
        self._init_positions()

        if self.render_mode == "pygame":
            pygame.init()
            w = self.cell_size * self.size
            h = self.cell_size * self.size + 40
            self.screen = pygame.display.set_mode((w, h))
            pygame.display.set_caption("Homegrid RL")
            self.clock = pygame.time.Clock()

    def _init_positions(self):
        self.pos_agent = [0, 0]
        # fruit al√©atoire ‚â† agent
        while True:
            self.pos_fruit = [np.random.randint(self.size),
                              np.random.randint(self.size)]
            if self.pos_fruit != self.pos_agent:
                break
        # chat al√©atoire ‚â† agent, ‚â† fruit
        while True:
            self.pos_cat = [np.random.randint(self.size),
                            np.random.randint(self.size)]
            if self.pos_cat != self.pos_agent and self.pos_cat != self.pos_fruit:
                break

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        self._init_positions()
        self.done = False
        self.step_count = 0
        return np.array(self.pos_agent + self.pos_fruit + self.pos_cat, dtype=int), {}

    def step(self, action):
        if self.done:
            raise RuntimeError("√âpisode d√©j√† termin√©")
        x, y = self.pos_agent
        if action == 0 and y > 0:            y -= 1   # haut
        elif action == 1 and x < self.size-1: x += 1   # droite
        elif action == 2 and y < self.size-1: y += 1   # bas
        elif action == 3 and x > 0:            x -= 1   # gauche
        self.pos_agent = [x, y]
        self.step_count += 1

        reward = 0.0
        if self.pos_agent == self.pos_fruit:
            reward = 1.0
            self.done = True

        cat_dead = (self.pos_agent == self.pos_cat)
        obs = np.array(self.pos_agent + self.pos_fruit + self.pos_cat, dtype=int)
        return obs, reward, self.done, False, {"cat_dead": cat_dead}

    def render(self, policy_name="", episode=0, delay=0.2):
        if self.render_mode != "pygame":
            return

        # g√©rer les √©v√©nements pour garder la fen√™tre responsive
        for ev in pygame.event.get():
            if ev.type == pygame.QUIT:
                pygame.quit()
                raise SystemExit()

        # fond
        self.screen.fill((255, 255, 255))

        # grille
        for i in range(self.size+1):
            pygame.draw.line(self.screen, (200,200,200),
                             (i*self.cell_size, 0),
                             (i*self.cell_size, self.size*self.cell_size), 1)
            pygame.draw.line(self.screen, (200,200,200),
                             (0, i*self.cell_size),
                             (self.size*self.cell_size, i*self.cell_size), 1)

        # dessiner emoji √† chaque cellule
        def draw_emoji(pos, emoji):
            # taille d‚Äôemoji = 60% de la taille de la case
            emoji_size = int(self.cell_size * 0.6)
            font = pygame.font.SysFont("Segoe UI Emoji", emoji_size)
            surf = font.render(emoji, True, (0, 0, 0))
            # centrer l‚Äôemoji dans la cellule
            x_px = pos[0] * self.cell_size + (self.cell_size - surf.get_width()) // 2
            y_px = pos[1] * self.cell_size + (self.cell_size - surf.get_height()) // 2
            self.screen.blit(surf, (x_px, y_px))

        draw_emoji(self.pos_fruit, "üçé")
        draw_emoji(self.pos_cat,   "üò∫")
        draw_emoji(self.pos_agent, "ü§ñ")

        # info textuelle
        font = pygame.font.SysFont(None, 24)
        txt1 = f"{policy_name} | √âpisode {episode}"
        txt2 = f"Step {self.step_count}"
        self.screen.blit(font.render(txt1, True, (0,0,0)),
                         (5, self.size*self.cell_size + 5))
        self.screen.blit(font.render(txt2, True, (0,0,0)),
                         (5, self.size*self.cell_size + 25))

        pygame.display.flip()
        self.clock.tick(30)
        time.sleep(delay)



In [3]:
# === 4. Callback pour afficher en temps r√©el pendant model.learn() ===
class PygameRenderCallback(BaseCallback):
    def __init__(self, freq_steps: int, verbose=0):
        super().__init__(verbose)
        self.freq = freq_steps
        self.episode = 1

    def _on_step(self) -> bool:
        if self.num_timesteps % self.freq == 0:
            env = self.training_env.envs[0].unwrapped
            env.render(policy_name=self.model.__class__.__name__,
                       episode=self.episode)
        return True

    def _on_rollout_end(self) -> None:
        self.episode += 1



In [None]:
# === 5. Entra√Æner l'agent na√Øf PPO avec rendu ===
env = HomegridEnv(size=5, render_mode="pygame", cell_size=80)
model_naive = PPO("MlpPolicy", env, verbose=0)
cb_naive = PygameRenderCallback(freq_steps=200)
model_naive.learn(total_timesteps=5_000, callback=cb_naive)
model_naive.save("ppo_homegrid_naive")


: 

In [None]:
# === 6. Collecte trajectoires & entra√Ænement du Reward Model ===
episodes, labels = [], []
tmp_env = HomegridEnv(size=5, render_mode="human")  # pas de rendu
for _ in range(500):
    obs, _ = tmp_env.reset()
    traj, done, cat_dead = [], False, False
    while not done:
        action, _ = model_naive.predict(obs, deterministic=True)
        obs, _, done, _, info = tmp_env.step(action)
        traj.append(obs.copy())
        if info["cat_dead"]:
            cat_dead = True
    episodes.append(np.stack(traj))
    labels.append(0 if cat_dead else 1)

X = []
for traj in episodes:
    length = len(traj)
    hits = int(np.logical_and(traj[:,0]==traj[:,4],
                              traj[:,1]==traj[:,5]).sum())
    X.append([length, hits])
X = np.array(X); y = np.array(labels)

X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                   test_size=0.2,
                                                   random_state=42)
#rm = MLPClassifier(hidden_layer_sizes=(32,32),
                   #max_iter=300,
                   #random_state=42)

rm = MLPClassifier(
    hidden_layer_sizes=(16,),
    max_iter=100,
    early_stopping=True,
    random_state=42
)
rm.fit(X_train, y_train)
print("Reward model accuracy:", rm.score(X_test, y_test))
joblib.dump(rm, "reward_model.pkl")



In [None]:
# === 7. D√©finition de l'environnement align√© ===
class AlignedHomegridEnv(HomegridEnv):
    def __init__(self, size=5, rm_path="reward_model.pkl", **kwargs):
        super().__init__(size=size, **kwargs)
        self.rm = joblib.load(rm_path)

    def reset(self, *, seed=None, options=None):
        obs, info = super().reset(seed=seed, options=options)
        self.step_count = 0
        self.hits = 0
        return obs, info

    def step(self, action):
        obs, _, done, trunc, info = super().step(action)
        if info["cat_dead"]:
            self.hits += 1
        if done:
            feats = np.array([[self.step_count, self.hits]])
            reward = float(self.rm.predict(feats)[0])
        else:
            reward = 0.0
        return obs, reward, done, trunc, info



In [None]:
# === 8. Entra√Æner l'agent align√© PPO avec rendu ===
env_al = AlignedHomegridEnv(size=5,
                            rm_path="reward_model.pkl",
                            render_mode="pygame",
                            cell_size=80)
model_al = PPO("MlpPolicy", env_al, verbose=0)
cb_al = PygameRenderCallback(freq_steps=200)
model_al.learn(total_timesteps=5_000, callback=cb_al)
model_al.save("ppo_homegrid_aligned")



In [None]:
# === 9. √âvaluation comparative ===
def eval_policy(model, env, n=200):
    surv = succ = 0
    for _ in range(n):
        obs, _ = env.reset()
        done = False
        cat_dead = False
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            obs, _, done, _, info = env.step(action)
            if info["cat_dead"]:
                cat_dead = True
        if not cat_dead:
            surv += 1
        if obs[0]==obs[2] and obs[1]==obs[3]:
            succ += 1
    return surv/n, succ/n

s1, f1 = eval_policy(model_naive, HomegridEnv(size=5, render_mode="human"))
s2, f2 = eval_policy(model_al,   AlignedHomegridEnv(size=5,
                                                      rm_path="reward_model.pkl",
                                                      render_mode="human"))
print(f"Na√Øf    ‚Üí survie chat : {s1:.2f} | succ√®s fruit : {f1:.2f}")
print(f"Align√©  ‚Üí survie chat : {s2:.2f} | succ√®s fruit : {f2:.2f}")



In [None]:
# === 10. Visualisation manuelle d'un √©pisode √©tape par √©tape ===
def play_episode(model, env, delay=0.3):
    obs, _ = env.reset()
    done = False
    step = 0
    while not done:
        env.render(policy_name=model.__class__.__name__, episode=step, delay=delay)
        action, _ = model.predict(obs, deterministic=True)
        obs, _, done, _, _ = env.step(action)
        step += 1
    env.render(policy_name=model.__class__.__name__, episode=step, delay=delay)



In [None]:
# Exemple :
play_episode(model_naive, env)
play_episode(model_al,   env_al)
