# Snake game logic

In [None]:
import random
import numpy as np

N = 10

class SnakeGame:
    def __init__(self, size=N):
        self.size = size
        self.reset()

    def reset(self):
        self.snake = [(self.size // 2, self.size // 2)]
        self.score = 0
        self.food = None
        self._place_food()
        self.game_over = False

    def _place_food(self):
        while self.food is None or self.food in self.snake:
            self.food = (random.randint(0, self.size - 1), random.randint(0, self.size - 1))

    def step(self, action):
        if self.game_over:
            return self.obs, self.score, self.game_over, {}

        # Directions: 0-Up, 1-Down, 2-Right,  3-Left
        direction = [(0, -1), (0, 1), (1, 0),  (-1, 0)][action]
        new_head = (self.snake[0][0] + direction[0], self.snake[0][1] + direction[1])

        # Check for game over conditions
        if (new_head in self.snake) or new_head[0] < 0 or new_head[0] >= self.size or new_head[1] < 0 or new_head[1] >= self.size:
            self.game_over = True
            return self.obs, self.score, self.game_over, {}

        self.snake.insert(0, new_head)

        # Check if snake eats food
        if new_head == self.food:
            self.score += 1
            self._place_food()
        else:
            self.snake.pop()

        self.obs = self.get_observation()
        return self.obs, self.score, self.game_over, {}

    def get_observation(self):
        obs = np.zeros((self.size, self.size))
        for x, y in self.snake:
            obs[y,x] = 1
        x, y = self.food
        obs[y,x] = 2
        return obs
    
    def render(self):
        obs = self.get_observation()
        for line in obs :
            print(line, end="\n")
        print()
        
    def quit(self):
        pass

# Gym env 

In [None]:
import gym
from gym import spaces
import numpy as np

class SnakeEnv(gym.Env):
    """
    step(action): This method takes an action as input, updates the game state based on that action, returns the new state, the reward gained (or lost), whether the game is over (done), and additional info if necessary.
    reset(): This method resets the environment to an initial state and returns this initial state. It's used at the beginning of a new episode.
    render(): This method is for visualizing the state of the environment. Depending on how you want to view the game, this could simply update the game window.
    close(): This method performs any necessary cleanup, like closing the game window.
    """

    def __init__(self):
        super(SnakeEnv, self).__init__()
        self.action_space = spaces.Discrete(4) # Output
        self.observation_space = spaces.Box(low=0, high=2,
                                            shape=(1, N, N), dtype=np.int32)
        self.snake_game = None
        self.previous_score = 0
        self.last_distance = np.inf
        # StableBaselines throws error if these are not defined
        self.spec = None
        self.metadata = None
        
    def seed(self, seed=42): # needed with make_vec_env
        np.random.seed(seed)
        
    def euclidean_distance_centroid(self, obs):
        snake_positions = np.argwhere(obs == 1)
        food_positions = np.argwhere(obs == 2)
        snake_centroid = np.mean(snake_positions, axis=0)
        food_position = food_positions[0]  

        new_distance = np.linalg.norm(snake_centroid - food_position)
        return new_distance
    

    def feature_gen_euclidean_distance_to_food(self, raw_obs):
        n = raw_obs.shape[0]
        snake_positions = np.argwhere(raw_obs == 2)
        point_coords = snake_positions[0].astype(float)
        
        x_coords, y_coords = np.meshgrid(np.arange(raw_obs.shape[0]), np.arange(raw_obs.shape[1]))
        obs = np.sqrt((x_coords - point_coords[0])**2 + (y_coords - point_coords[1])**2)
        return obs.reshape((1, N, N))


    def step(self, action):
        raw_obs, score, done, _ = self.snake_game.step(action)

        # Calculate the Euclidean distance between the snake and the food
        new_distance = self.euclidean_distance_centroid(raw_obs)

        # Check if the snake has eaten food and update the reward
        if self.previous_score != score:
            reward = 100
            self.previous_score = score
        elif done:
            reward = -10
        else:
            reward =  1/10 if new_distance < self.last_distance else -1/100
        self.last_distance = new_distance
        
        obs = self.feature_gen_euclidean_distance_to_food(raw_obs)
        return obs, reward, done, _


    def reset(self):
        self.snake_game = SnakeGame()
        return self.snake_game.get_observation()

    def render(self, mode='human'):
        if mode == 'human':
            self.snake_game.render()
            
    def close(self):
        self.snake_game.quit()


# Reinforcement learning

In [None]:

import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


# The make_vec_env function from Stable Baselines 3 is used to create vectorized environments. 
# Vectorized environments allow you to run multiple instances of an environment in parallel, 
# providing a more efficient way to collect experiences (states, actions, rewards, etc.) during training.


class CustomCNN(BaseFeaturesExtractor):
    
    def __init__(self, observation_space, features_dim: int=128):
        super(CustomCNN, self).__init__(observation_space, features_dim)
        n_input_channels = observation_space.sample().shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 16, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=2, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=2, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )
        
        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(
                torch.as_tensor(observation_space.sample()[None]).float()
            ).shape[1]
        
        # n_flatten = (N-Conv2d*kernel_size)**2
        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    # Default policy of PPO (stable_baseline3?)=> take the `action_space` of SnakeEnv for final layer and a softmax activation
    def forward(self, observations):
        return self.linear(self.cnn(observations))

    
def evaluate_model(model, eval_env, num_episodes=10):
    all_rewards = []
    for episode in range(num_episodes):
        print(f"evaluation {episode=}")
        obs = eval_env.reset()
        done = False
        total_rewards = 0
        for _ in range(1000):
            action, _states = model.predict(np.reshape(obs, (1, N, N)), deterministic=True)
            obs, reward, done, _ = eval_env.step(action)
            total_rewards += reward
            if done :
                break
            
        all_rewards.append(total_rewards)
    average_reward = sum(all_rewards) / num_episodes
    return average_reward


# Training

In [None]:
%%time   
from stable_baselines3 import PPO,DQN, A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.logger import configure
from stable_baselines3.common.utils import get_schedule_fn

train_env = make_vec_env(lambda: SnakeEnv(), n_envs=2)
eval_env = SnakeEnv()


policy_kwargs = dict(
    features_extractor_class=LinearQNet,
)

learning_rate_schedule = get_schedule_fn(0.0003)
model = PPO("CnnPolicy", train_env, policy_kwargs=policy_kwargs, verbose=2,learning_rate=learning_rate_schedule)               

# new_logger = configure("path_to_save_logs", ["stdout", "tensorboard"])
# model.set_logger(new_logger) # Run TensorBoard in a terminal: tensorboard --logdir=path_to_save_logs


total_timesteps = 200_000
eval_interval = 50_000  
num_eval_episodes = 100  

# Training loop with periodic evaluation
for _ in range(0, total_timesteps, eval_interval):
    model.learn(total_timesteps=eval_interval)
    avg_reward = evaluate_model(model, eval_env, num_episodes=10)
    print(f"Evaluation average reward: {avg_reward}")

model.save("ppo_snake")

# Testing

In [None]:
from stable_baselines3 import PPO 


env = SnakeEnv()
# model = PPO.load("ppo_snake", env=env)

obs = env.reset()
done = False
env.render()
while not done:
    input("press enter to continue")
    action, _info = model.predict(np.reshape(obs, (1, N, N)), deterministic=True)
    obs, reward, done, _ = env.step(action)
    env.render()
    #input("press key for next step")
    