In [1]:
import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs.registration import register
from gymnasium.utils.env_checker import check_env

# Maak de custom gymnasium omgeving

In [12]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import random
import pygame
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import random

class SnakeEnv(gym.Env):
    metadata = {"render_modes": ["human"]}

    def __init__(self, render_mode=None):
        super(SnakeEnv, self).__init__()
        self.grid_size = 20
        self.cell_size = 30
        self.action_space = spaces.Discrete(4)  # 0=UP, 1=RIGHT, 2=DOWN, 3=LEFT
        self.observation_space = spaces.Box(low=0, high=1, shape=(3, 7, 7), dtype=np.float32)
        
        self.view_radius = 3  # de afstand rond de head die de agent ziet
        self.render_mode = render_mode
        self.window = None
        self.clock = None

    def reset(self, seed=None, options=None):
        self.snake = [(5, 5), (5, 5), (5, 5)]
        self.direction = 1
        self.food = self._place_food()
        self.done = False
        if self.render_mode == "human":
            self._init_render()
        return self._get_obs(), {}

    def _place_food(self):
        while True:
            food = (random.randint(0, 19), random.randint(0, 19))
            if food not in self.snake:
                return food


    def _get_obs(self):
        view_size = 3  # 3 op elke kant van de kop → 7x7
        obs = np.zeros((3, 7, 7), dtype=np.float32)

        head_x, head_y = self.snake[0]

        for dy in range(-view_size, view_size + 1):
            for dx in range(-view_size, view_size + 1):
                x = head_x + dx
                y = head_y + dy
                view_x = dx + view_size
                view_y = dy + view_size

                if 0 <= x < self.grid_size and 0 <= y < self.grid_size:
                    if (x, y) == self.snake[0]:
                        obs[2, view_y, view_x] = 1  # Head
                    elif (x, y) in self.snake[1:]:
                        obs[0, view_y, view_x] = 1  # Body
                    elif (x, y) == self.food:
                        obs[1, view_y, view_x] = 1  # Food

        return obs

    def step(self, action):
        if self.done:
            return self._get_obs(), 0, True, False, {}

        if abs(action - self.direction) == 2:
            action = self.direction

        self.direction = action

        dx = [0, 1, 0, -1]
        dy = [-1, 0, 1, 0]
        head_x, head_y = self.snake[0]
        new_head = (head_x + dx[action], head_y + dy[action])

        if (new_head in self.snake or
            not 0 <= new_head[0] < self.grid_size or
            not 0 <= new_head[1] < self.grid_size):
            self.done = True
            return self._get_obs(), -10, True, False, {}

        self.snake.insert(0, new_head)

        fx, fy = self.food
        old_dist = abs(head_x - fx) + abs(head_y - fy)
        new_dist = abs(new_head[0] - fx) + abs(new_head[1] - fy)

        self.steps = 0
        self.steps += 1

        if new_head == self.food:
            reward = 50
            self.food = self._place_food()
            self.steps = 0  
        else:
            reward = (old_dist - new_dist) * 0.5 - 0.2
            self.snake.pop()

        if self.steps >= 100:
            reward= -10
            self.done = True
            return self._get_obs(), reward, True, False, {}


        if self.render_mode == "human":
            self.render()
            

        return self._get_obs(), reward, False, False, {}

    def _init_render(self):
        pygame.init()
        self.window = pygame.display.set_mode(
            (self.grid_size * self.cell_size, self.grid_size * self.cell_size))
        pygame.display.set_caption("Snake AI")
        self.clock = pygame.time.Clock()

    def render(self):
        if self.window is None:
            self._init_render()

        self.window.fill((0, 0, 0))
        for x, y in self.snake:
            pygame.draw.rect(
                self.window,
                (0, 255, 0),
                pygame.Rect(x * self.cell_size, y * self.cell_size, self.cell_size, self.cell_size)
            )

        fx, fy = self.food
        pygame.draw.rect(
            self.window,
            (255, 0, 0),
            pygame.Rect(fx * self.cell_size, fy * self.cell_size, self.cell_size, self.cell_size)
        )
                        # === Highlight agent's view radius ===
        if self.snake:
            head_x, head_y = self.snake[0]
            r = self.view_radius
            for dy in range(-r, r+1):
                for dx in range(-r, r+1):
                    nx, ny = head_x + dx, head_y + dy
                    if 0 <= nx < self.grid_size and 0 <= ny < self.grid_size:
                        pygame.draw.rect(
                            self.window,
                            (50, 50, 50),  # blauw = zichtgebied
                            pygame.Rect(nx * self.cell_size, ny * self.cell_size, self.cell_size, self.cell_size),
                            width=1  # alleen rand tekenen
                        )

        pygame.display.flip()
        self.clock.tick(10)

    def close(self):
        if self.window:
            pygame.quit()


In [27]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env import DummyVecEnv
import gymnasium as gym

from stable_baselines3.common.callbacks import BaseCallback

class EvalAndRenderCallback(BaseCallback):
    def __init__(self, eval_env, render_freq=10_000, verbose=0):
        super().__init__(verbose)
        self.eval_env = eval_env
        self.render_freq = render_freq
        self.episodes_run = 0

    def _on_step(self) -> bool:
        if self.num_timesteps % self.render_freq == 0:
            obs, _ = self.eval_env.reset()
            done = False
            while not done:
                action, _ = self.model.predict(obs, deterministic=True)
                obs, reward, done, truncated, info = self.eval_env.step(action)
                self.eval_env.render()
        return True


# Maak een instance van je custom omgeving
env = SnakeEnv(render_mode=None)

# Check of alles compatibel is
#check_env(env)

# Wrap de omgeving (nodig voor Stable-Baselines3)
vec_env = DummyVecEnv([lambda: SnakeEnv(render_mode=None)])

# Initialiseer PPO
#model = PPO("MlpPolicy", vec_env, verbose=1, learning_rate=1e-4, n_steps=2048, batch_size=64, n_epochs=10,
#    tensorboard_log="./ppo_snake_tensorboard/",device='cpu')

model = PPO(
    "MlpPolicy",             # of "CnnPolicy" bij 3D observaties
    vec_env,
    verbose=1,
    learning_rate=2.5e-4,
    n_steps=2048,
    batch_size=64,
    n_epochs=10,
    tensorboard_log="./ppo_snake_tensorboard/",
    device="cpu"
)


eval_env = SnakeEnv(render_mode="human")
callback = EvalAndRenderCallback(eval_env=eval_env, render_freq=1000)

#model.learn(total_timesteps=100_000, callback=None)
model.learn(total_timesteps=100_000, callback=None)


model.save("ppo_snake_smallerview_5M_2.5learningrate")
# Om later opnieuw te laden:
# model = PPO.load("ppo_snake")

test_env = SnakeEnv(render_mode="human")
obs, _ = test_env.reset()
done = False

while not done:
    action, _states = model.predict(obs)
    obs, reward, done, truncated, info = test_env.step(action)
    test_env.render()

test_env.close()




Using cpu device
Logging to ./ppo_snake_tensorboard/PPO_23


TypeError: cannot unpack non-iterable NoneType object

In [13]:
import time
import numpy as np
from stable_baselines3 import PPO

# Recreate the environment (no training wrapper needed)
eval_env = SnakeEnv(render_mode="human")  # enable rendering

# Load your trained model
model = PPO.load("ppo_snake_smallerview_2M_2.5learningrate", env=eval_env,device='cpu')

num_episodes = 10
total_rewards = []
food_counts = []

highscore = 0

for episode in range(num_episodes):
    obs, _ = eval_env.reset()
    done = False
    episode_reward = 0
    step_count = 0
    food_eaten = 0

    while not done:
        action, _ = model.predict(obs, deterministic=True)
        prev_snake_len = len(eval_env.snake)

        obs, reward, done, truncated, info = eval_env.step(action)
        episode_reward += reward
        step_count += 1

        # Check of de slang is gegroeid (dus voedsel gegeten)
        if len(eval_env.snake) > prev_snake_len:
            food_eaten += 1

        time.sleep(0.1)

    total_rewards.append(episode_reward)
    food_counts.append(food_eaten)
    highscore = max(highscore, food_eaten)

    print(f"\n✅ Episode {episode+1} finished.")
    print(f"🔸 Reward: {episode_reward:.2f}")
    print(f"🍎 Food eaten: {food_eaten}")
    print(f"🏆 Highscore so far: {highscore}")

# Na alle episodes
avg_reward = sum(total_rewards) / len(total_rewards)
avg_food = sum(food_counts) / len(food_counts)

print("\n==== Test Summary ====")
print(f"Average reward: {avg_reward:.2f}")
print(f"Average food per episode: {avg_food:.2f}")
print(f"Highscore (most food): {highscore}")


Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.

✅ Episode 1 finished.
🔸 Reward: 83.10
🍎 Food eaten: 2
🏆 Highscore so far: 2

✅ Episode 2 finished.
🔸 Reward: -19.60
🍎 Food eaten: 0
🏆 Highscore so far: 2

✅ Episode 3 finished.
🔸 Reward: 41.00
🍎 Food eaten: 1
🏆 Highscore so far: 2

✅ Episode 4 finished.
🔸 Reward: 38.70
🍎 Food eaten: 1
🏆 Highscore so far: 2

✅ Episode 5 finished.
🔸 Reward: -21.60
🍎 Food eaten: 0
🏆 Highscore so far: 2

✅ Episode 6 finished.
🔸 Reward: 139.50
🍎 Food eaten: 3
🏆 Highscore so far: 3

✅ Episode 7 finished.
🔸 Reward: -14.60
🍎 Food eaten: 0
🏆 Highscore so far: 3

✅ Episode 8 finished.
🔸 Reward: -13.60
🍎 Food eaten: 0
🏆 Highscore so far: 3

✅ Episode 9 finished.
🔸 Reward: -10.60
🍎 Food eaten: 0
🏆 Highscore so far: 3

✅ Episode 10 finished.
🔸 Reward: -16.60
🍎 Food eaten: 0
🏆 Highscore so far: 3

==== Test Summary ====
Average reward: 20.57
Average food per episode: 0.70
Highscore (most food): 3


# Registreer de omgeving

In [36]:
%tensorboard --logdir='ppo_snake_tensorboard'
%reload_ext tensorboard

Reusing TensorBoard on port 6007 (pid 16380), started 0:00:50 ago. (Use '!kill 16380' to kill it.)