## 1. Installation des d√©pendances

In [None]:
# Installation des packages n√©cessaires
!pip install gymnasium stable-baselines3[extra] torch numpy matplotlib tensorboard -q
print("‚úÖ Installation termin√©e !")

## 2. D√©finition de l'environnement Snake (CNN)

In [None]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import random
import pygame

# M√™mes constantes qu'avant
WINDOW_WIDTH = 600
WINDOW_HEIGHT = 600
BLOCK_SIZE = 20
SPEED = 20

# Couleurs modernes
WHITE = (255, 255, 255)
BLACK = (15, 15, 25)
DARK_GRAY = (30, 30, 40)
RED = (255, 80, 80)
ORANGE = (255, 165, 0)
GREEN = (76, 175, 80)
BLUE1 = (66, 165, 245)
BLUE2 = (33, 150, 243)
CYAN = (0, 188, 212)
YELLOW = (255, 235, 59)

class SnakeEnvCnn(gym.Env):
    metadata = {'render_modes': ['human'], 'render_fps': SPEED}

    def __init__(self, render_mode=None):
        super(SnakeEnvCnn, self).__init__()
        self.w = WINDOW_WIDTH
        self.h = WINDOW_HEIGHT
        self.render_mode = render_mode
        self.window = None
        self.clock = None
        self.font = None
        self.small_font = None
        
        # Calcul du nombre de cases (ex: 30x30)
        self.grid_w = int(self.w / BLOCK_SIZE)
        self.grid_h = int(self.h / BLOCK_SIZE)

        # ACTION : inchang√©
        self.action_space = spaces.Discrete(4)
        
        # OBSERVATION : C'est l√† que tout change !
        # On renvoie une "Image" de taille (1, 30, 30) (1 canal, Hauteur, Largeur)
        # Valeurs : 0=Vide, 80=Corps, 180=T√™te, 255=Pomme (Nuances de gris)
        self.observation_space = spaces.Box(
            low=0, high=255, 
            shape=(1, self.grid_h, self.grid_w), 
            dtype=np.uint8
        )
        
        # Pour le reward shaping
        self.prev_distance = None

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.direction = 1
        self.head = [self.w/2, self.h/2]
        self.snake = [self.head, 
                      [self.head[0]-BLOCK_SIZE, self.head[1]],
                      [self.head[0]-(2*BLOCK_SIZE), self.head[1]]]
        self.score = 0
        self.frame_iteration = 0
        self._place_food()
        self.prev_distance = self._get_distance()
        return self._get_observation(), {}

    def _get_distance(self):
        """Distance de Manhattan entre la t√™te et la pomme"""
        return abs(self.head[0] - self.food[0]) + abs(self.head[1] - self.food[1])

    def step(self, action):
        self.frame_iteration += 1
        self._move(action)
        
        game_over = False
        reward = 0
        
        # Collision = Game Over
        if self._is_collision() or self.frame_iteration > 100*len(self.snake):
            game_over = True
            reward = -10
            return self._get_observation(), reward, game_over, False, {}
        
        # Calculer la nouvelle distance
        new_distance = self._get_distance()
        
        # Manger la pomme = grosse r√©compense
        if self.head == self.food:
            self.score += 1
            reward = 20  # Augment√© de 10 √† 20
            self._place_food()
            self.prev_distance = self._get_distance()
        else:
            self.snake.pop()
            
            # REWARD SHAPING : R√©compense/punition bas√©e sur la distance
            # Se rapprocher = +1, s'√©loigner = -1
            if new_distance < self.prev_distance:
                reward = 1  # Se rapproche de la pomme
            elif new_distance > self.prev_distance:
                reward = -1  # S'√©loigne de la pomme
            # Sinon reward = 0 (m√™me distance)
            
            self.prev_distance = new_distance
            
        if self.render_mode == "human":
            self._render_frame()
            
        return self._get_observation(), reward, game_over, False, {}

    def _get_observation(self):
        # On cr√©e une grille vide (Fond noir = 0)
        grid = np.zeros((self.grid_h, self.grid_w), dtype=np.uint8)
        
        # On dessine le corps (Gris fonc√© = 80)
        for pt in self.snake:
            x = int(pt[0] / BLOCK_SIZE)
            y = int(pt[1] / BLOCK_SIZE)
            if 0 <= x < self.grid_w and 0 <= y < self.grid_h:
                grid[y, x] = 80
        
        # On dessine la t√™te (Gris clair = 180) pour qu'il sache o√π il est
        hx = int(self.head[0] / BLOCK_SIZE)
        hy = int(self.head[1] / BLOCK_SIZE)
        if 0 <= hx < self.grid_w and 0 <= hy < self.grid_h:
            grid[hy, hx] = 180
            
        # On dessine la pomme (Blanc = 255)
        fx = int(self.food[0] / BLOCK_SIZE)
        fy = int(self.food[1] / BLOCK_SIZE)
        grid[fy, fx] = 255
        
        # On ajoute la dimension du canal (1, 30, 30) exig√©e par PyTorch CNN
        return np.expand_dims(grid, axis=0)

    # ... Les m√©thodes _place_food, _is_collision, _move, _render_frame sont identiques √† V1 ...
    # (Copiez-les depuis snake_env.py, elles ne changent pas)
    def _place_food(self):
        x = random.randint(0, (self.w-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE 
        y = random.randint(0, (self.h-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE
        self.food = [x, y]
        if self.food in self.snake: self._place_food()

    def _is_collision(self, pt=None):
        if pt is None: pt = self.head
        if pt[0] > self.w - BLOCK_SIZE or pt[0] < 0 or pt[1] > self.h - BLOCK_SIZE or pt[1] < 0: return True
        if pt in self.snake[1:]: return True
        return False

    def _move(self, action):
        clock_wise = [0, 1, 2, 3]
        if action == 0 and self.direction != 1: self.direction = 0
        elif action == 1 and self.direction != 0: self.direction = 1
        elif action == 2 and self.direction != 3: self.direction = 2
        elif action == 3 and self.direction != 2: self.direction = 3
        x = self.head[0]
        y = self.head[1]
        if self.direction == 1: x += BLOCK_SIZE
        elif self.direction == 0: x -= BLOCK_SIZE
        elif self.direction == 3: y += BLOCK_SIZE
        elif self.direction == 2: y -= BLOCK_SIZE
        self.head = [x, y]
        self.snake.insert(0, self.head)

    def _render_frame(self):
        if self.window is None:
            pygame.init()
            self.window = pygame.display.set_mode((self.w, self.h))
            self.clock = pygame.time.Clock()
            self.font = pygame.font.Font(None, 48)
            self.small_font = pygame.font.Font(None, 32)
            pygame.display.set_caption("üêç Snake AI CNN Training üêç")
        
        # Fond
        self.window.fill(BLACK)
        
        # Grille l√©g√®re en arri√®re-plan
        grid_color = (50, 50, 70)
        for x in range(0, self.w, BLOCK_SIZE):
            pygame.draw.line(self.window, grid_color, (x, 0), (x, self.h), 1)
        for y in range(0, self.h, BLOCK_SIZE):
            pygame.draw.line(self.window, grid_color, (0, y), (self.w, y), 1)
            
        # Dessiner la Pomme
        self._draw_apple()
        
        # Dessiner le Serpent
        self._draw_snake()
        
        # Afficher le Score
        self._draw_score()
        
        pygame.display.flip()
        self.clock.tick(self.metadata["render_fps"])

    def _draw_apple(self):
        """Dessiner la pomme avec un effet visuel am√©lior√©"""
        x, y = int(self.food[0]), int(self.food[1])
        
        # Lueur autour de la pomme
        glow_radius = BLOCK_SIZE // 2 + 3
        pygame.draw.circle(self.window, (255, 100, 0, 50), (x + BLOCK_SIZE//2, y + BLOCK_SIZE//2), glow_radius)
        
        # Pomme principale (d√©grad√© simul√©)
        pygame.draw.rect(self.window, RED, pygame.Rect(x+2, y+2, BLOCK_SIZE-4, BLOCK_SIZE-4), border_radius=4)
        pygame.draw.rect(self.window, ORANGE, pygame.Rect(x+3, y+3, BLOCK_SIZE-6, BLOCK_SIZE-6), border_radius=3)
        
        # Brillance
        pygame.draw.circle(self.window, YELLOW, (x + 7, y + 7), 3)

    def _draw_snake(self):
        """Dessiner le serpent avec d√©grad√© de couleur"""
        snake_length = len(self.snake)
        
        for i, pt in enumerate(self.snake):
            x, y = int(pt[0]), int(pt[1])
            
            # Couleur d√©grad√©e : cyan pour la t√™te, bleu pour la queue
            ratio = i / max(snake_length - 1, 1)
            color = (
                int(BLUE1[0] + (CYAN[0] - BLUE1[0]) * (1 - ratio)),
                int(BLUE1[1] + (CYAN[1] - BLUE1[1]) * (1 - ratio)),
                int(BLUE1[2] + (CYAN[2] - BLUE1[2]) * (1 - ratio))
            )
            
            # Corps du serpent (arrondi pour plus joli)
            pygame.draw.rect(self.window, color, pygame.Rect(x+1, y+1, BLOCK_SIZE-2, BLOCK_SIZE-2), border_radius=3)
            
            # T√™te du serpent (plus grande et brillante)
            if i == 0:
                pygame.draw.rect(self.window, CYAN, pygame.Rect(x, y, BLOCK_SIZE, BLOCK_SIZE), border_radius=4)
                pygame.draw.circle(self.window, WHITE, (x + 6, y + 6), 2)
                pygame.draw.circle(self.window, WHITE, (x + 14, y + 6), 2)

    def _draw_score(self):
        """Afficher le score et les informations en haut √† droite"""
        score_text = self.font.render(f"Score: {self.score}", True, GREEN)
        length_text = self.small_font.render(f"Length: {len(self.snake)}", True, CYAN)
        frame_text = self.small_font.render(f"Frame: {self.frame_iteration}", True, WHITE)
        
        # Position en haut √† droite
        panel_width = 180
        panel_x = self.w - panel_width - 10
        
        # Fond semi-transparent pour la lisibilit√©
        pygame.draw.rect(self.window, DARK_GRAY, (panel_x, 5, panel_width, 90), border_radius=5)
        pygame.draw.rect(self.window, CYAN, (panel_x, 5, panel_width, 90), 2, border_radius=5)
        
        self.window.blit(score_text, (panel_x + 10, 10))
        self.window.blit(length_text, (panel_x + 10, 45))
        self.window.blit(frame_text, (panel_x + 10, 70))

    def close(self):
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()

## 3. üß† D√©finition du r√©seau CNN personnalis√©

In [None]:
import torch as th
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

class CustomCNN(BaseFeaturesExtractor):
    """
    R√©seau de neurones convolutif pour traiter la grille 30x30.
    Architecture :
    - Conv2D (32 filtres, 4x4) -> ReLU
    - Conv2D (64 filtres, 4x4) -> ReLU
    - Flatten -> Linear (256) -> ReLU
    """
    def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
        super(CustomCNN, self).__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]
        
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=4, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )
        
        # Calcul automatique de la taille apr√®s convolutions
        with th.no_grad():
            n_flatten = self.cnn(
                th.as_tensor(observation_space.sample()[None]).float()
            ).shape[1]
        
        self.linear = nn.Sequential(
            nn.Linear(n_flatten, features_dim), 
            nn.ReLU()
        )

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations))

print("‚úÖ CustomCNN d√©fini !")

## 4. ‚öôÔ∏è Configuration de l'entra√Ænement

In [None]:
import os
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv

# === CONFIGURATION ===
N_ENVS = 8           # Nombre d'environnements parall√®les (Colab a ~2 CPUs, mais √ßa marche)
TIMESTEPS = 2_000_000  # Nombre total de steps (augmenter pour de meilleurs r√©sultats)
SAVE_FREQ = 100_000   # Sauvegarder tous les X steps

# Dossiers
MODELS_DIR = "checkpoints/PPO_CNN_COLAB"
LOG_DIR = "logs"

os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

print(f"üìä Configuration :")
print(f"   - Environnements parall√®les : {N_ENVS}")
print(f"   - Steps totaux : {TIMESTEPS:,}")
print(f"   - Sauvegarde tous les : {SAVE_FREQ:,} steps")
print(f"   - Dossier mod√®les : {MODELS_DIR}")
print(f"   - Dossier logs : {LOG_DIR}")

## 5. üöÄ Lancement de l'entra√Ænement

In [None]:
# V√©rifier le GPU
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"üñ•Ô∏è Device utilis√© : {device}")
if device == "cuda":
    print(f"   GPU : {torch.cuda.get_device_name(0)}")

# Cr√©ation des environnements vectoris√©s
print(f"\nüîÑ Cr√©ation de {N_ENVS} environnements parall√®les...")
env = make_vec_env(
    SnakeEnvCnn, 
    n_envs=N_ENVS,
    vec_env_cls=SubprocVecEnv
)

# Callback pour sauvegarder r√©guli√®rement
checkpoint_callback = CheckpointCallback(
    save_freq=max(SAVE_FREQ // N_ENVS, 1),
    save_path=MODELS_DIR,
    name_prefix="snake_cnn"
)

# Configuration du CNN
policy_kwargs = dict(
    features_extractor_class=CustomCNN,
    features_extractor_kwargs=dict(features_dim=256),
)

# Cr√©ation du mod√®le PPO
print("üß† Cr√©ation du mod√®le PPO avec CNN...")
model = PPO(
    "CnnPolicy", 
    env, 
    verbose=1, 
    tensorboard_log=LOG_DIR,
    learning_rate=0.0003,
    policy_kwargs=policy_kwargs,
    batch_size=256,
    n_steps=1024,
    gamma=0.99,
    device=device
)

print("\n" + "="*50)
print("üéÆ D√âMARRAGE DE L'ENTRA√éNEMENT")
print("="*50)
print(f"L'IA va jouer {TIMESTEPS:,} coups...")
print("Cela peut prendre 30min √† 2h selon la configuration.")
print("="*50 + "\n")

In [None]:
# Lancement de l'entra√Ænement
model.learn(
    total_timesteps=TIMESTEPS, 
    callback=checkpoint_callback,
    progress_bar=True  # Barre de progression
)

# Sauvegarde finale
final_path = f"{MODELS_DIR}/snake_cnn_final"
model.save(final_path)

print("\n" + "="*50)
print("‚úÖ ENTRA√éNEMENT TERMIN√â !")
print("="*50)
print(f"Mod√®le final sauvegard√© : {final_path}.zip")

# Fermer les environnements
env.close()

## 6. üìà Visualisation des logs TensorBoard

In [None]:
# Charger TensorBoard dans le notebook
%load_ext tensorboard
%tensorboard --logdir logs

## 7. üì• T√©l√©charger le mod√®le entra√Æn√©

In [None]:
# Lister les mod√®les sauvegard√©s
import glob

models = glob.glob(f"{MODELS_DIR}/*.zip")
models.sort()

print("üìÅ Mod√®les disponibles :")
for i, m in enumerate(models):
    print(f"   [{i}] {m}")

In [None]:
# T√©l√©charger le mod√®le final
from google.colab import files

# T√©l√©charge le dernier mod√®le
if models:
    files.download(models[-1])
    print(f"\nüì• T√©l√©chargement de : {models[-1]}")
else:
    print("‚ùå Aucun mod√®le trouv√© !")

In [None]:
# Optionnel : Cr√©er une archive de tous les checkpoints
import shutil

shutil.make_archive("snake_models", 'zip', MODELS_DIR)
files.download("snake_models.zip")
print("üì• Archive de tous les mod√®les t√©l√©charg√©e !")

## 8. üß™ Test rapide du mod√®le (sans rendu)

In [None]:
# Charger et tester le mod√®le
from stable_baselines3 import PPO

# Charger le meilleur mod√®le
test_model = PPO.load(f"{MODELS_DIR}/snake_cnn_final")

# Cr√©er un environnement de test
test_env = SnakeEnvCnn()

# Jouer 10 parties
scores = []
for episode in range(10):
    obs, _ = test_env.reset()
    done = False
    score = 0
    
    while not done:
        action, _ = test_model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = test_env.step(action)
        if reward > 0:
            score += 1
    
    scores.append(score)
    print(f"Partie {episode+1}/10 : Score = {score}")

print(f"\nüìä Score moyen sur 10 parties : {sum(scores)/len(scores):.1f}")
print(f"   Meilleur score : {max(scores)}")
print(f"   Pire score : {min(scores)}")

---

## üìù Notes

### Pour utiliser le mod√®le en local :

1. T√©l√©chargez le fichier `.zip` du mod√®le
2. Placez-le dans `checkpoints/PPO_CNN/` de votre projet local
3. Lancez `python test_play_cnn.py`

### Pour am√©liorer les r√©sultats :

- Augmentez `TIMESTEPS` (5M, 10M...)
- Ajustez `learning_rate` (0.0001, 0.00003...)
- Modifiez les r√©compenses dans l'environnement
- Ajoutez des r√©compenses interm√©diaires (se rapprocher de la pomme)

---

**Auteur** : Samy EH - Projet SY23 - Janvier 2026