In [2]:
import random
import numpy as np
# import gym
import gymnasium as gym
import numpy as np
from pathlib import Path

# Snake game logic

In [None]:
from snake_game import FastSnakeGame, SnakeGame

In [1]:
import torch, sys
try:
    import torch_directml
    dml = torch_directml.device()
    x = torch.randn(2,2, device=dml)
    print("DML OK:", x.device)
except Exception as e:
    print("DML failed:", e, file=sys.stderr)
print("CUDA available:", torch.cuda.is_available())

DML OK: privateuseone:0
CUDA available: False


# Gym env 

In [None]:

OBS_SHAPE = 6
SHAPE = (OBS_SHAPE, )

class SnakeEnv(gym.Env):
    """
    step(action): This method takes an action as input, updates the game state based on that action, returns the new state, the reward gained (or lost), whether the game is over (done), and additional info if necessary.
    reset(): This method resets the environment to an initial state and returns this initial state. It's used at the beginning of a new episode.
    render(): This method is for visualizing the state of the environment. Depending on how you want to view the game, this could simply update the game window.
    close(): This method performs any necessary cleanup, like closing the game window.
    """
    
    def __init__(self, game_size:int=0, fast_game:bool=True):
        super(SnakeEnv, self).__init__()
        self.action_space = gym.spaces.Discrete(4) # Output
        self.observation_space = gym.spaces.Box(low=-4, high=5, shape=SHAPE, dtype=np.float64)
        self.render_mode = "human"  # Default render mode
        # StableBaselines throws error if these are not defined
        self.spec = None
        self.metadata = {"renders_mode":["human"]}
        self.game_size = game_size
        self.SnakeGameHandler = SnakeGame if not fast_game else FastSnakeGame
        self._init()    
    
    def _init(self):
        self.snake_game = self.SnakeGameHandler(self.game_size)
        food = np.array(self.snake_game.food)
        head = np.array(self.snake_game.snake[0])
        self._last_distance = self.euclidean_distance(head=head, food=food)
        self.previous_score = 0
        
        # define during run
        self.food_position = None
        self.snake_positions = None
        self.angle_to_food = None
        
    def seed(self, seed=42): # needed with make_vec_env
        np.random.seed(seed)
    
    def set_snake_and_food_position(self, raw_obs):
        self.food_position = np.argwhere(raw_obs == 2).flatten()
        self.snake_positions = np.argwhere(raw_obs == 1).flatten()
        
    def euclidean_distance(self, head=None, food=None):
        """ Calculate the Euclidean distance between the centroid of the snake and the food position."""
        head = head if head is not None else self.snake_positions.take((0, 1))
        food = food if food is not None else self.food_position
        new_distance = np.linalg.norm(head - food)
        return new_distance/self.game_size  # Normalize by the size of the game grid

    def get_neighbors(self, grid, head, out_of_bounds_value=3):
        """Utilise le padding pour gérer les bordures"""
        # Ajouter un padding de 1 avec la valeur out_of_bounds
        padded = np.pad(grid, 1, mode='constant', constant_values=out_of_bounds_value)
        # Ajuster les indices pour le tableau paddé
        i, j = head
        pi, pj = i + 1, j + 1
        # Récupérer les voisins
        neighbors = np.array([
            padded[pi-1, pj],  # haut
            padded[pi, pj-1],  # gauche
            padded[pi+1, pj],  # bas
            padded[pi, pj+1],  # droite
        ])
        return neighbors

    def angle_between_snake_head_and_food(self, head, food):
        """Calculate the angle between the snake's head and the food."""
        delta_x = food[0] - head[0]
        delta_y = food[1] - head[1]
        self.angle = np.array([np.arctan2(delta_y, delta_x)])
        return self.angle
    
    def feature_gen(self, raw_obs):
        self.set_snake_and_food_position(raw_obs)
        new_distance = self.euclidean_distance()
        distance = np.array([new_distance])
        
        head = self.snake_positions.take((0, 1))
        # recuper les voisins direct de la tete du serpent
        neighbors = self.get_neighbors(raw_obs, head) 
        angle = self.angle_between_snake_head_and_food(head, self.food_position) 

        obs = np.concatenate([neighbors, distance, angle])
        return obs
        
    def get_reward(self, score, done):
        # Calculate the Euclidean distance between the snake and the food
        new_distance = self.euclidean_distance()
        # Check if the snake has eaten food and update the reward
        if self.previous_score != score:
            reward = 10
            self.previous_score = score
        elif done:
            reward = -10
        else:
            reward =  1 if new_distance <= self._last_distance else -1
        self._last_distance = new_distance
        return reward
        
    @property
    def obs(self):
        raw_obs = self.snake_game.raw_obs
        return self.feature_gen(raw_obs)

    def step(self, action):
        raw_obs, score, done, _ = self.snake_game.step(action)
        obs = self.feature_gen(raw_obs)
        reward = self.get_reward(score, done)
        terminated = done  # done = self.snake_game.game_over
        truncated = False  # In this case, we don't have a time limit, so no
        return obs, reward, terminated, truncated, _

    def reset(self, seed=42):
        self.seed(seed)
        self._init()
        return self.obs, {}

    def render(self, mode="human"):
        if mode == "human":
            self.snake_game.render()
            
    def close(self):
        self.snake_game.quit()


In [33]:
from stable_baselines3.common.env_checker import check_env

env = SnakeEnv()
check_env(env)

# Reinforcement learning

In [None]:
import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


class CustomCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=64):
        super(CustomCNN, self).__init__(observation_space, features_dim)
        self.cnn = nn.Sequential(
                nn.Conv2d(1, features_dim//2, kernel_size=3, stride=1),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(features_dim//2, features_dim, kernel_size=3, stride=1),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(features_dim, features_dim, kernel_size=3, stride=1),
                nn.MaxPool2d(),
                nn.Conv2d(features_dim, features_dim, kernel_size=3, stride=1),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.AdaptiveMaxPool2d((4,4)),
                nn.Flatten(),
                nn.Linear(features_dim * 4 * 4, features_dim),  # Adjust input
        )

    def forward(self, X):
        out = self.cnn(X)
        return out

# Feature extractor for the Snake environment, no softmax or activation function
class LinearQNet(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=32):
        super(LinearQNet, self).__init__(observation_space, features_dim)
        self.flatten = nn.Flatten()
        
        with torch.no_grad():
            n_flatten = self.flatten(
                torch.as_tensor(observation_space.sample()[None]).float() # expand dims for batch size
            ).shape[1]
        
        self.linear = nn.Sequential(
                nn.Linear(n_flatten, features_dim),
                nn.ReLU(),
                nn.Linear(features_dim, features_dim),
                nn.ReLU(),
                nn.Linear(features_dim, features_dim),
                nn.ReLU(),
        )

    def forward(self, X):
        flat = self.flatten(X)
        out = self.linear(flat)
        return out

    
def evaluate_model(model, eval_env, num_episodes=10):
    all_rewards = []
    for _ in range(num_episodes):
        obs = eval_env.reset()
        # Si l'environnement retourne un tuple (obs, info), extraire obs
        if isinstance(obs, tuple):
            obs = obs[0]
        
        terminated = False
        total_rewards = 0
        for _ in range(1000):
            action, _states = model.predict(obs, deterministic=True)
            step_result = eval_env.step(action)
            
            # Gérer les différents formats de retour (gym vs gymnasium)
            if len(step_result) == 5:  # Format gymnasium: obs, reward, terminated, truncated, info
                obs, reward, terminated, truncated, info = step_result
                terminated = terminated or truncated
            else:  # Format gym: obs, reward, done, info
                obs, reward, terminated, info = step_result
                
            total_rewards += reward
            if all(terminated):
                break
            
        all_rewards.append(total_rewards)
    average_reward = sum(all_rewards) / num_episodes
    return average_reward


In [31]:
eval_env = SnakeEnv()
model2 = LinearQNet(eval_env.observation_space)
model2.forward(torch.as_tensor(eval_env.observation_space.sample()[None]).float())  # Add batch dimension for prediction

tensor([[0.0000, 0.9491, 0.0000, 0.0000, 0.0593, 0.0000, 0.3417, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.1041, 0.0000, 0.2752, 0.0000, 0.1235, 0.3380,
         0.0000, 0.4349, 0.0000, 0.0000, 0.2786, 0.0000, 0.0000, 0.4551, 0.0000,
         0.1392, 0.0339, 0.2765, 0.3157, 0.0000]], grad_fn=<ReluBackward0>)

# Training

The make_vec_env function from Stable Baselines 3 is used to create vectorized environments. <br>
Vectorized environments allow you to run multiple instances of an environment in parallel, <br>
providing a more efficient way to collect experiences (states, actions, rewards, etc.) during training.


In [10]:
from stable_baselines3 import PPO,DQN, A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.logger import configure
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
from stable_baselines3.common.utils import get_schedule_fn

In [None]:
def get_env(n_envs:int=5, use_frame_stack:bool=False, n_stack:int=4, game_size:int=N):
    # make_vec_env handle the multiprocessing details
    env = make_vec_env(
        lambda: SnakeEnv(game_size=game_size), 
        n_envs=n_envs,  
        seed=42    
    )
    if use_frame_stack:
        env = VecFrameStack(env, n_stack=n_stack, channels_order='first')
    return env

class ModelTrainer:
    def __init__(self, model_name:str, 
                policy_kwargs=None, 
                game_size:int=30, 
                n_envs:int=5, 
                n_stack:int=4, use_frame_stack:bool=False,
                verbose:int=2,
                ):
        self.model_name = model_name
        self.policy_kwargs = policy_kwargs
        self.game_size = game_size
        self.n_envs = n_envs
        self.n_stack = n_stack
        self.use_frame_stack = use_frame_stack
        self.verbose = verbose
        self.train_env = get_env(use_frame_stack=self.use_frame_stack, n_envs=n_envs, n_stack=n_stack, game_size=game_size)
        self.model = self.get_model(self.model_name, policy_kwargs=self.policy_kwargs)
    
    def get_model(self, model_name, policy_kwargs=None):
        if model_name == "PPO":
            # MlpPolicy for vectore base and CnnPolicy for image base
            model = PPO("MlpPolicy", self.train_env,
                        policy_kwargs=policy_kwargs, 
                        verbose=self.verbose,
                        learning_rate=get_schedule_fn(0.0003), 
                        n_steps=100)               
        elif model_name == "DQN":
            model = DQN("MlpPolicy", self.train_env, 
                        policy_kwargs=policy_kwargs,  # Enable custom features extractor
                        verbose=self.verbose,
                        learning_rate=1e-3,            # Fixed learning rate (no schedule needed for DQN)
                        buffer_size=10000,             # Size of replay buffer
                        learning_starts=1000,          # Start learning after this many steps
                        target_update_interval=500,    # Update target network every 500 steps
                        train_freq=4,                  # Train every 4 steps
                        gradient_steps=1,              # Number of gradient steps per training
                        exploration_fraction=0.3,      # Fraction of training for exploration
                        exploration_initial_eps=1.0,   # Initial exploration probability
                        exploration_final_eps=0.05)    # Final exploration probability
        else:
            raise ValueError(f"Model {model_name} is not supported.")
        return model
    
    def train(self, multiplicator:float=10):
        new_logger = configure("save_logs", ["stdout", "tensorboard"])
        self.model.set_logger(new_logger) # Run TensorBoard in a terminal: tensorboard --logdir=save_logs
        # IMPORTANT: Utiliser le même type d'environnement pour l'évaluation que pour l'entraînement
        eval_env = get_env(use_frame_stack=self.use_frame_stack, game_size=self.game_size, n_stack=self.n_stack, n_envs=self.n_envs)

        total_timesteps = int(100_000 * multiplicator)
        eval_interval = 10_000   # Increased interval since DQN learns differently
        n_session = total_timesteps//eval_interval
        num_eval_episodes = 5

        # Training loop with periodic evaluation
        for _ in range(0, total_timesteps, eval_interval):
            self.model.learn(total_timesteps=eval_interval)
            avg_reward = evaluate_model(self.model, eval_env, num_episodes=num_eval_episodes)
            print(f"Evaluation average reward: {avg_reward}")

    def save(self, name=""):
        self.model.save(f"{self.model_name}_{name}_snake")


In [35]:
policy_kwargs = dict(features_extractor_class=LinearQNet)

#CnnPolicy if obs is image-like, MlpPolicy if obs is vector-like
# model = get_model("DQN", use_frame_stack=False, policy_kwargs=policy_kwargs)              
model = ModelTrainer("PPO", game_size=10, verbose=0)
model.train(1/10)  # Train the model
model.save("0")

Logging to save_logs


We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=100 and n_envs=5)


---------------------------------
| rollout/           |          |
|    ep_len_mean     | 4.01     |
|    ep_rew_mean     | -9       |
| time/              |          |
|    fps             | 2034     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 500      |
---------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 3.92       |
|    ep_rew_mean          | -8.73      |
| time/                   |            |
|    fps                  | 1265       |
|    iterations           | 2          |
|    time_elapsed         | 0          |
|    total_timesteps      | 1000       |
| train/                  |            |
|    approx_kl            | 0.01685164 |
|    clip_fraction        | 0.143      |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.38      |
|    explained_variance   | -0.00183   |
|    learning_rate        | 0.0003     |
|   

In [None]:
model = ModelTrainer("PPO", use_frame_stack=True, policy_kwargs=policy_kwargs)      
model.train()  # Train the model
model.save("custom_policy")

In [None]:
# model = ModelTrainer("DQN", use_frame_stack=False)
# model.train()  # Train the model
# model.save()

# Tensorboard

In [None]:
! tensorboard --logdir=path_to_save_logs

# Testing

In [None]:
class ModelRender:
    def __init__(self, name:str, use_frame_stack:bool=True, game_size:int=30, n_stack:int=4):
        self.env = SnakeEnv(game_size=game_size)
        path = Path().cwd() / f"{name}.zip" 
        if not path.exists():
            raise FileNotFoundError(f"Model file {path} does not exist. Please train the model first.")
        self.n = n_stack
        self.use_frame_stack = use_frame_stack
        self.name = name

        if "PPO" in name:
            self.model = PPO.load(name, env=get_env(use_frame_stack=use_frame_stack, game_size=game_size, n_stack=n_stack))  # Ensure the model is compatible with the environment
        elif "DQN" in name:
            self.model = DQN.load(name, env=get_env(use_frame_stack=use_frame_stack, game_size=game_size, n_stack=n_stack))  # Ensure the model is compatible with the environment
        else:
            raise ValueError(f"Model {name} is not supported for rendering.")

    def render(self):
        obs, _ = self.env.reset()
        terminated = False
        self.env.render()
        step = 0
        while not terminated:
            if "PPO" in self.name and self.use_frame_stack:
                obs = np.concatenate([obs]*self.n).reshape((-1, 1)).flatten()
                
            action, _info = self.model.predict(obs, deterministic=True)
            print(f"Action taken: {action}")
            obs, reward, terminated, truncated, _ = self.env.step(action)
            step += 1
            print(f"Reward received: {reward}")
            print("distance to food:", obs.take(-2))
            print(f"step :{step}")
            self.env.render()
        self.env.close()

In [None]:
model = ModelRender("PPO__snake", game_size=10)
model.render()

In [None]:
model = ModelRender("PPO_N30_snake", game_size=30)
model.render()

In [None]:
model = ModelRender("PPO_N30_custom_policy_snake", game_size=30)
model.render()

In [None]:

model = ModelRender("DQN__snake", use_frame_stack=False, game_size=10)
model.render()