In [1]:

import gymnasium as gym
import numpy as np
from stable_baselines3 import DQN, PPO
from stable_baselines3.common.evaluation import evaluate_policy
import pygame
from gymnasium import spaces
import random

In [2]:
class SnakeEnv(gym.Env):
    metadata = {'render_modes': ['human', 'rgb_array'], 'render_fps': 4}

    def __init__(self, grid_size=10, render_mode=None):
        super(SnakeEnv, self).__init__()

        self.grid_size = grid_size
        self.window_size = grid_size * 20
        self.render_mode = render_mode

        # Initialize Pygame if human rendering is required
        if self.render_mode == 'human':
            pygame.init()
            self.screen = pygame.display.set_mode((self.window_size, self.window_size))
            self.clock = pygame.time.Clock()

        # Action space: 0=UP, 1=RIGHT, 2=DOWN, 3=LEFT
        self.action_space = spaces.Discrete(4)

        # Observation space: [head_x, head_y, food_x, food_y]
        self.observation_space = spaces.Box(
            low=0, high=1,
            shape=(4,),  # Changed from 8 to 4
            dtype=np.float32
        )

    def reset(self, *, seed=None, options=None):
        if seed is not None:
            np.random.seed(seed)
            random.seed(seed)

        # Initialize snake in the middle
        self.snake = [(self.grid_size//2, self.grid_size//2)]
        self.direction = 1  # Start moving right
        self.food = self._place_food()
        self.steps = 0
        self.score = 0

        return self._get_obs(), {}

    def _place_food(self):
        while True:
            food = (random.randint(0, self.grid_size-1),
                   random.randint(0, self.grid_size-1))
            if food not in self.snake:
                return food

    def _get_obs(self):
        head = self.snake[0]
        return np.array([
            head[0] / self.grid_size,  # head x position
            head[1] / self.grid_size,  # head y position
            self.food[0] / self.grid_size,  # food x position
            self.food[1] / self.grid_size,  # food y position
        ], dtype=np.float32)

    def step(self, action):
        self.steps += 1
        reward = 0
        done = False

        # Update direction based on action
        self.direction = action

        # Get new head position
        head = self.snake[0]
        if action == 0:  # UP
            new_head = (head[0], head[1] - 1)
        elif action == 1:  # RIGHT
            new_head = (head[0] + 1, head[1])
        elif action == 2:  # DOWN
            new_head = (head[0], head[1] + 1)
        else:  # LEFT
            new_head = (head[0] - 1, head[1])

        # Check for collision with walls or self
        if (new_head[0] < 0 or new_head[0] >= self.grid_size or
            new_head[1] < 0 or new_head[1] >= self.grid_size or
            new_head in self.snake):
            return self._get_obs(), -10, True, False, {}

        # Move snake
        self.snake.insert(0, new_head)

        # Check if food is eaten
        if new_head == self.food:
            self.score += 1
            reward = 10
            self.food = self._place_food()
        else:
            self.snake.pop()
            # Small reward/penalty based on distance to food
            prev_dist = abs(head[0] - self.food[0]) + abs(head[1] - self.food[1])
            new_dist = abs(new_head[0] - self.food[0]) + abs(new_head[1] - self.food[1])
            reward = 0.1 if new_dist < prev_dist else -0.1

        # End episode if it's taking too long
        if self.steps >= 100:
            done = True

        return self._get_obs(), reward, done, False, {}

In [3]:



# Training function
def train_and_evaluate(model_class, env, total_timesteps, model_name):
    model = model_class("MlpPolicy", env, verbose=1)
    model.learn(total_timesteps=total_timesteps)

    # Evaluate the model
    mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
    print(f"{model_name} - Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")

    return model, mean_reward



In [4]:
def main():
    # Create environment
    env = SnakeEnv(render_mode=None)

    # Training parameters
    total_timesteps = 100000

    # Train DQN
    print("Training DQN...")
    dqn_model = DQN(
        "MlpPolicy",
        env,
        learning_rate=1e-3,
        buffer_size=50000,
        learning_starts=1000,
        batch_size=64,
        gamma=0.99,
        verbose=1
    )

    dqn_model.learn(total_timesteps=total_timesteps)
    mean_reward_dqn, std_reward_dqn = evaluate_policy(dqn_model, env, n_eval_episodes=10)
    print(f"DQN Mean reward: {mean_reward_dqn:.2f} +/- {std_reward_dqn:.2f}")

    # Train PPO
    print("\nTraining PPO...")
    ppo_model = PPO(
        "MlpPolicy",
        env,
        learning_rate=3e-4,
        n_steps=2048,
        batch_size=64,
        n_epochs=10,
        gamma=0.99,
        verbose=1
    )

    ppo_model.learn(total_timesteps=total_timesteps)
    mean_reward_ppo, std_reward_ppo = evaluate_policy(ppo_model, env, n_eval_episodes=10)
    print(f"PPO Mean reward: {mean_reward_ppo:.2f} +/- {std_reward_ppo:.2f}")

    # Compare results
    print("\nFinal Comparison:")
    print(f"DQN average reward: {mean_reward_dqn:.2f}")
    print(f"PPO average reward: {mean_reward_ppo:.2f}")

if __name__ == "__main__":
    main()

Training DQN...
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 53.8     |
|    ep_rew_mean      | -10.1    |
|    exploration_rate | 0.98     |
| time/               |          |
|    episodes         | 4        |
|    fps              | 10252    |
|    time_elapsed     | 0        |
|    total_timesteps  | 215      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 41.6     |
|    ep_rew_mean      | -8.77    |
|    exploration_rate | 0.968    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 9628     |
|    time_elapsed     | 0        |
|    total_timesteps  | 333      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 34.2     |
|    ep_rew_mea



---------------------------------
| rollout/           |          |
|    ep_len_mean     | 31.8     |
|    ep_rew_mean     | -8.04    |
| time/              |          |
|    fps             | 1323     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 36.8        |
|    ep_rew_mean          | -7.43       |
| time/                   |             |
|    fps                  | 933         |
|    iterations           | 2           |
|    time_elapsed         | 4           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.010150967 |
|    clip_fraction        | 0.0621      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.38       |
|    explained_variance   | 0.00988     |
|    learning_rate        | 0.

In [7]:
# Save DQN model
dqn_model.save("dqn_snake_model")

# Save PPO model
ppo_model.save("ppo_snake_model")


NameError: name 'dqn_model' is not defined

In [8]:
dqn_model = DQN.load("snake_dqn_model")
ppo_model = PPO.load("snake_ppo_model")


FileNotFoundError: [Errno 2] No such file or directory: 'snake_dqn_model.zip'