<a href="https://colab.research.google.com/github/GomathyDhanya/SnakeAgent/blob/main/Snake.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install stable-baselines3[extra] gymnasium shimmy

In [None]:

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import random
from stable_baselines3 import PPO
from IPython.display import clear_output

class SnakeLidarEnv(gym.Env):
    def __init__(self):
        super(SnakeLidarEnv, self).__init__()

        self.action_space = spaces.Discrete(3)

        self.observation_space = spaces.Box(low=0, high=1, shape=(20,), dtype=np.float32)

        self.w = 10
        self.h = 10
        self.reset()

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.direction = 0
        self.head = np.array([self.w // 2, self.h // 2])
        self.snake = [self.head.copy(), self.head - [1, 0], self.head - [2, 0]]
        self.score = 0
        self.food = None
        self._place_food()
        self.frame_iteration = 0
        return self._get_state(), {}

    def _place_food(self):
        while True:
            x, y = random.randint(0, self.w - 1), random.randint(0, self.h - 1)
            self.food = np.array([x, y])
            if not any(np.array_equal(self.food, s) for s in self.snake): break

    def step(self, action):
        self.frame_iteration += 1

        clock_wise = [0, 1, 2, 3]
        idx = clock_wise.index(self.direction)
        if action == 1: self.direction = clock_wise[(idx + 1) % 4]
        elif action == 2: self.direction = clock_wise[(idx - 1) % 4]

        move_map = {0: [1, 0], 1: [0, 1], 2: [-1, 0], 3: [0, -1]}
        self.head = self.head + np.array(move_map[self.direction])

        terminated = False
        reward = 0

        if (self.head[0] < 0 or self.head[0] >= self.w or
            self.head[1] < 0 or self.head[1] >= self.h or
            any(np.array_equal(self.head, s) for s in self.snake[1:])):
            terminated = True
            reward = -10

        elif self.frame_iteration > 100 * len(self.snake):
            terminated = True
            reward = -10

        if terminated:
            return self._get_state(), reward, terminated, False, {}

        self.snake.insert(0, self.head.copy())
        if np.array_equal(self.head, self.food):
            self.score += 1
            reward = 10
            self._place_food()
        else:
            self.snake.pop()

        return self._get_state(), reward, terminated, False, {}

    def _get_state(self):

        directions = [
            [1, 0], [1, 1], [0, 1], [-1, 1],
            [-1, 0], [-1, -1], [0, -1], [1, -1]
        ]

        vision_dist = []
        vision_food = []

        for d in directions:
            dist = 0
            found_food = 0
            curr = self.head.copy()

            while True:
                curr = curr + d
                dist += 1

                if (curr[0] < 0 or curr[0] >= self.w or
                    curr[1] < 0 or curr[1] >= self.h or
                    any(np.array_equal(curr, s) for s in self.snake)):
                    break

                if np.array_equal(curr, self.food):
                    found_food = 1

            vision_dist.append(1.0 / dist)
            vision_food.append(found_food)

        dir_one_hot = [0, 0, 0, 0]
        dir_one_hot[self.direction] = 1.0

        state = np.concatenate([vision_dist, vision_food, dir_one_hot])
        return np.array(state, dtype=np.float32)

    def render_ascii(self):
        board = [['.' for _ in range(self.w)] for _ in range(self.h)]
        board[self.food[1]][self.food[0]] = 'F'
        for i, pt in enumerate(self.snake):
            char = 'H' if i == 0 else 'o'
            if 0 <= pt[1] < self.h and 0 <= pt[0] < self.w:
                board[pt[1]][pt[0]] = char
        print(f"Score: {self.score}")
        print("+" + "-" * self.w + "+")
        for row in board:
            print("|" + "".join(row) + "|")
        print("+" + "-" * self.w + "+")

env = SnakeLidarEnv()
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=200000)


from stable_baselines3.common.evaluation import evaluate_policy
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
print(f"New Mean Reward: {mean_reward:.2f} +/- {std_reward:.2f}")


In [None]:

obs, _ = env.reset()
done = False
while not done:
    action, _ = model.predict(obs)
    obs, reward, done, _, _ = env.step(action)
    clear_output(wait=True)
    env.render_ascii()