## 1. Installation des d√©pendances

In [1]:
# Installation des packages n√©cessaires
!pip install gymnasium stable-baselines3[extra] torch numpy matplotlib tensorboard -q
print("Installation termin√©e !")

[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/188.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[90m‚ï∫[0m [32m184.3/188.0 kB[0m [31m9.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m188.0/188.0 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstallation termin√©e !


## 2. D√©finition de l'environnement Snake (CNN)

In [2]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import random
import pygame

# --- Constantes ---
WINDOW_WIDTH = 600
WINDOW_HEIGHT = 600
BLOCK_SIZE = 20
SPEED = 20

# --- Couleurs (render uniquement) ---
WHITE = (255, 255, 255)
BLACK = (15, 15, 25)
RED = (255, 80, 80)
GREEN = (76, 175, 80)
CYAN = (0, 188, 212)

class SnakeEnvCnn(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": SPEED}

    def __init__(self, render_mode=None):
        super().__init__()

        self.render_mode = render_mode
        self.window = None
        self.clock = None

        self.grid_w = WINDOW_WIDTH // BLOCK_SIZE
        self.grid_h = WINDOW_HEIGHT // BLOCK_SIZE

        # ACTIONS : 0 = tout droit, 1 = gauche, 2 = droite
        self.action_space = spaces.Discrete(3)

        # OBSERVATION : image normalis√©e (1, H, W)
        self.observation_space = spaces.Box(
            low=0.0, high=1.0,
            shape=(1, self.grid_h, self.grid_w),
            dtype=np.float32
        )

        self.reset()

    # ---------------- RESET ----------------
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)

        self.direction = 1  # 0=LEFT, 1=RIGHT, 2=UP, 3=DOWN

        cx = self.grid_w // 2
        cy = self.grid_h // 2

        self.head = [cx, cy]
        self.snake = [
            self.head,
            [cx - 1, cy],
            [cx - 2, cy]
        ]

        self.score = 0
        self.frame_iteration = 0
        self._place_food()
        self.prev_distance = self._distance_to_food()

        return self._get_observation(), {}

    # ---------------- STEP ----------------
    def step(self, action):
        self.frame_iteration += 1
        self._move(action)

        reward = -0.01
        terminated = False
        truncated = False

        if self._is_collision() or self.frame_iteration > 100 * len(self.snake):
            reward = -10
            terminated = True
            return self._get_observation(), reward, terminated, truncated, {}

        distance = self._distance_to_food()

        if self.head == self.food:
            reward = 10
            self.score += 1
            self._place_food()
            self.prev_distance = self._distance_to_food()
        else:
            self.snake.pop()

            if distance < self.prev_distance:
                reward += 1
            else:
                reward -= 1

            self.prev_distance = distance

        if self.render_mode == "human":
            self._render_frame()

        return self._get_observation(), reward, terminated, truncated, {}

    # ---------------- OBSERVATION ----------------
    def _get_observation(self):
      grid = np.zeros((self.grid_h, self.grid_w), dtype=np.float32)

      for pt in self.snake:
        x, y = pt
        if 0 <= x < self.grid_w and 0 <= y < self.grid_h:
            grid[y, x] = 0.3

      hx, hy = self.head
      if 0 <= hx < self.grid_w and 0 <= hy < self.grid_h:
        grid[hy, hx] = 0.7

      fx, fy = self.food
      grid[fy, fx] = 1.0

      return np.expand_dims(grid, axis=0)

    # ---------------- UTILS ----------------
    def _distance_to_food(self):
        return abs(self.head[0] - self.food[0]) + abs(self.head[1] - self.food[1])

    def _place_food(self):
        while True:
            pos = [
                random.randint(0, self.grid_w - 1),
                random.randint(0, self.grid_h - 1)
            ]
            if pos not in self.snake:
                self.food = pos
                break

    def _is_collision(self, pt=None):
        if pt is None:
            pt = self.head

        if pt[0] < 0 or pt[0] >= self.grid_w or pt[1] < 0 or pt[1] >= self.grid_h:
            return True

        if pt in self.snake[1:]:
            return True

        return False

    # ---------------- MOUVEMENT RELATIF ----------------
    def _move(self, action):
        # directions : LEFT, RIGHT, UP, DOWN
        dirs = [(-1, 0), (1, 0), (0, -1), (0, 1)]

        if action == 1:   # gauche
            self.direction = [2, 3, 1, 0][self.direction]
        elif action == 2: # droite
            self.direction = [3, 2, 0, 1][self.direction]

        dx, dy = dirs[self.direction]
        new_head = [self.head[0] + dx, self.head[1] + dy]

        self.head = new_head
        self.snake.insert(0, self.head)

    # ---------------- RENDER ----------------
    def _render_frame(self):
        if self.window is None:
            pygame.init()
            self.window = pygame.display.set_mode((WINDOW_WIDTH, WINDOW_HEIGHT))
            self.clock = pygame.time.Clock()

        self.window.fill(BLACK)

        for pt in self.snake:
            pygame.draw.rect(
                self.window,
                CYAN,
                pygame.Rect(pt[0]*BLOCK_SIZE, pt[1]*BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE)
            )

        pygame.draw.rect(
            self.window,
            RED,
            pygame.Rect(self.food[0]*BLOCK_SIZE, self.food[1]*BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE)
        )

        pygame.display.flip()
        self.clock.tick(self.metadata["render_fps"])

    def close(self):
        if self.window:
            pygame.quit()


## 3. üß† D√©finition du r√©seau CNN personnalis√©

In [3]:
import torch as th
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

class CustomCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
        super().__init__(observation_space, features_dim)

        n_input_channels = observation_space.shape[0]

        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Flatten()
        )

        with th.no_grad():
            n_flatten = self.cnn(
                th.zeros(1, *observation_space.shape)
            ).shape[1]

        self.linear = nn.Sequential(
            nn.Linear(n_flatten, features_dim),
            nn.ReLU()
        )

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations))


Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  return datetime.utcnow().replace(tzinfo=utc)


## 4. ‚öôÔ∏è Configuration de l'entra√Ænement

In [4]:
import os
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback, CallbackList
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor


# === CONFIGURATION ===
N_ENVS = 2          # Nombre d'environnements parall√®les (Colab a ~2 CPUs, mais √ßa marche)
TIMESTEPS = 3_000_000  # Nombre total de steps (augmenter pour de meilleurs r√©sultats)
SAVE_FREQ = 300_000   # Sauvegarder tous les X steps

# Dossiers
MODELS_DIR = "checkpoints/PPO_CNN_COLAB"
LOG_DIR = "logs"

os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

print(f"üìä Configuration :")
print(f"   - Environnements parall√®les : {N_ENVS}")
print(f"   - Steps totaux : {TIMESTEPS:,}")
print(f"   - Sauvegarde tous les : {SAVE_FREQ:,} steps")
print(f"   - Dossier mod√®les : {MODELS_DIR}")
print(f"   - Dossier logs : {LOG_DIR}")

üìä Configuration :
   - Environnements parall√®les : 2
   - Steps totaux : 3,000,000
   - Sauvegarde tous les : 300,000 steps
   - Dossier mod√®les : checkpoints/PPO_CNN_COLAB
   - Dossier logs : logs


## 5. üöÄ Lancement de l'entra√Ænement

In [5]:
# V√©rifier le GPU
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"üñ•Ô∏è Device utilis√© : {device}")
if device == "cuda":
    print(f"   GPU : {torch.cuda.get_device_name(0)}")

# Cr√©ation des environnements vectoris√©s
print(f"\nüîÑ Cr√©ation de {N_ENVS} environnements parall√®les...")
env = make_vec_env(
    lambda: Monitor(SnakeEnvCnn()),
    n_envs=N_ENVS,
    vec_env_cls=DummyVecEnv
)

eval_env = DummyVecEnv([
    lambda: Monitor(SnakeEnvCnn())
])

# Callback pour sauvegarder r√©guli√®rement
checkpoint_callback = CheckpointCallback(
    save_freq=max(SAVE_FREQ // N_ENVS, 1),
    save_path=MODELS_DIR,
    name_prefix="snake_cnn"
)

eval_callback = EvalCallback(
    eval_env,
    best_model_save_path=MODELS_DIR,
    log_path=LOG_DIR,
    eval_freq=50_000,        # tous les 50k steps
    n_eval_episodes=10,      # moyenne sur 10 parties
    deterministic=True,
    render=False
)

# Configuration du CNN
policy_kwargs = dict(
    features_extractor_class=CustomCNN,
    features_extractor_kwargs=dict(features_dim=256),
)

# Cr√©ation du mod√®le PPO
print("üß† Cr√©ation du mod√®le PPO avec CNN...")
model = PPO(
    "CnnPolicy",
    env,
    verbose=1,
    tensorboard_log=LOG_DIR,
    learning_rate=0.0003,
    policy_kwargs=policy_kwargs,
    batch_size=256,
    n_steps=1024,
    gamma=0.99,
    device=device
)

print("\n" + "="*50)
print("üéÆ D√âMARRAGE DE L'ENTRA√éNEMENT")
print("="*50)
print(f"L'IA va jouer {TIMESTEPS:,} coups...")
print("Cela peut prendre 30min √† 2h selon la configuration.")
print("="*50 + "\n")

üñ•Ô∏è Device utilis√© : cpu

üîÑ Cr√©ation de 2 environnements parall√®les...
üß† Cr√©ation du mod√®le PPO avec CNN...
Using cpu device

üéÆ D√âMARRAGE DE L'ENTRA√éNEMENT
L'IA va jouer 3,000,000 coups...
Cela peut prendre 30min √† 2h selon la configuration.



  return datetime.utcnow().replace(tzinfo=utc)


In [6]:
# Lancement de l'entra√Ænement
model.learn(
    total_timesteps=TIMESTEPS,
    callback=CallbackList([checkpoint_callback, eval_callback]),
    progress_bar=True  # Barre de progression
)

# Sauvegarde finale
final_path = f"{MODELS_DIR}/snake_cnn_final"
model.save(final_path)

print("\n" + "="*50)
print("‚úÖ ENTRA√éNEMENT TERMIN√â !")
print("="*50)
print(f"Mod√®le final sauvegard√© : {final_path}.zip")

# Fermer les environnements
env.close()

Logging to logs/PPO_1


Output()

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 146      |
|    ep_rew_mean     | -12.5    |
| time/              |          |
|    fps             | 927      |
|    iterations      | 1        |
|    time_elapsed    | 2        |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 128          |
|    ep_rew_mean          | -16          |
| time/                   |              |
|    fps                  | 334          |
|    iterations           | 2            |
|    time_elapsed         | 12           |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0012274908 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.1         |
|    explained_variance   | -0.000278    |
|    learning_r

## 6. üìà Visualisation des logs TensorBoard

In [None]:
# Charger TensorBoard dans le notebook
%load_ext tensorboard
%tensorboard --logdir logs

## 7. üì• T√©l√©charger le mod√®le entra√Æn√©

In [None]:
# Lister les mod√®les sauvegard√©s
import glob

models = glob.glob(f"{MODELS_DIR}/*.zip")
models.sort()

print("üìÅ Mod√®les disponibles :")
for i, m in enumerate(models):
    print(f"   [{i}] {m}")

In [None]:
# T√©l√©charger le mod√®le final
from google.colab import files

# T√©l√©charge le dernier mod√®le
if models:
    files.download(models[-1])
    print(f"\nüì• T√©l√©chargement de : {models[-1]}")
else:
    print("‚ùå Aucun mod√®le trouv√© !")

In [None]:
# Optionnel : Cr√©er une archive de tous les checkpoints
import shutil

shutil.make_archive("snake_models", 'zip', MODELS_DIR)
files.download("snake_models.zip")
print("üì• Archive de tous les mod√®les t√©l√©charg√©e !")

## 8. üß™ Test rapide du mod√®le (sans rendu)

In [None]:
# Charger et tester le mod√®le
from stable_baselines3 import PPO

# Charger le meilleur mod√®le
test_model = PPO.load(f"{MODELS_DIR}/snake_cnn_final")

# Cr√©er un environnement de test
test_env = SnakeEnvCnn()

# Jouer 10 parties
scores = []
for episode in range(10):
    obs, _ = test_env.reset()
    done = False
    score = 0

    while not done:
        action, _ = test_model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = test_env.step(action)
        if reward > 0:
            score += 1

    scores.append(score)
    print(f"Partie {episode+1}/10 : Score = {score}")

print(f"\nüìä Score moyen sur 10 parties : {sum(scores)/len(scores):.1f}")
print(f"   Meilleur score : {max(scores)}")
print(f"   Pire score : {min(scores)}")

---

## üìù Notes

### Pour utiliser le mod√®le en local :

1. T√©l√©chargez le fichier `.zip` du mod√®le
2. Placez-le dans `checkpoints/PPO_CNN/` de votre projet local
3. Lancez `python test_play_cnn.py`

### Pour am√©liorer les r√©sultats :

- Augmentez `TIMESTEPS` (5M, 10M...)
- Ajustez `learning_rate` (0.0001, 0.00003...)
- Modifiez les r√©compenses dans l'environnement
- Ajoutez des r√©compenses interm√©diaires (se rapprocher de la pomme)

---

**Auteur** : Samy EH - Projet SY23 - Janvier 2026