**1. Import Dependencies**

In [146]:
# Import Gym Stuff
import gym
from gym import Env
from gym.spaces import Discrete, Box, Dict, Tuple, MultiBinary, MultiDiscrete

# Import helpers
import numpy as np
import random
import os

# Import stable baselines stuff
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.vec_env import DummyVecEnv 
from stable_baselines3.common.evaluation import evaluate_policy

In [147]:
#Import Game Stuff
from snakeGame import SnakeGame

**2. Building the ENV**

In [148]:
class SnakeENV(Env):
        
    def __init__(self) -> None:
        #super().__init__()
        self.game = SnakeGame()

        self.action_space = Discrete(4)
        self.observation_space = Dict([('head_pos', Box(low=np.array([0, 0]), high=np.array([self.game.window.width, self.game.window.height]))),
                                        ('fruit_pos', Box(low=np.array([0, 0]), high=np.array([self.game.window.width, self.game.window.height]))),
                                        ('length', Box(low=np.array([0]), high=np.array([np.inf]))),
                                        #('direction', Discrete(4))
                                ])
        
        self.num_envs = 1

        self.time_run = 0

    def isEpisodeLenght(self):
        self.time_run += 1
        return self.time_run > 6600

    def step(self, action):
        done = self.isEpisodeLenght()

        direction = self.action_to_direction(action)
        reward, done = self.game.main(direction, render=False)

        info = {}

        return self.state, reward, done, info

    def render(self, mode=None):
        self.game.render()

    def reset(self):
        self.game = SnakeGame()
        self.time_run = 0
        self.state = {'head_pos': np.array(self.game.snake.head_pos).astype(int),
                        'fruit_pos': np.array(self.game.fruit.pos).astype(int),
                        'length': np.array(len(self.game.snake.body)).astype(int),
                        #'direction': self.direction_to_action(self.game.snake.direction)
                    }
        return self.state


    def action_to_direction(self, action):
        if action == 0:
            return 'UP'
        elif action == 1:
            return 'DOWN'
        elif action == 2:
            return 'LEFT'
        elif action == 3:
            return 'RIGHT'
    def direction_to_action(self, direction):
        if direction == 'UP':
            return 0
        elif direction == 'DOWN':
            return 1
        elif direction == 'LEFT':
            return 2
        elif direction == 'RIGHT':
            return 3


In [149]:
env = SnakeENV()
episodes = 5


for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0

    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score += reward
    print(f'Episode: {episode} Score: {score}')
env.close()



Episode: 1 Score: -1
Episode: 2 Score: -1
Episode: 3 Score: -1
Episode: 4 Score: -1
Episode: 5 Score: -1


**3. Train a PPO model**

In [150]:
del env
#env = SnakeENV()
#env = DummyVecEnv([lambda: SnakeENV()])
#env = VecFrameStack(env, n_stack=4)

from stable_baselines3.common.env_util import make_vec_env
env = make_vec_env(SnakeENV, n_envs=4)

In [151]:
log_path = "./Training/Logs/"
model = PPO("MultiInputPolicy", env, verbose=1, tensorboard_log=log_path)

Using cpu device


In [152]:
model.learn(total_timesteps=100000)

Logging to ./Training/Logs/PPO_9


In [None]:
ppo_path = "./Training/Models/PPO_Snake_Model_ExtraPoints"
model.save(ppo_path)

In [None]:
del model

**4. Eval and Test**

In [None]:
ppo_path = "./Training/Models/PPO_Snake_Model"

env = SnakeENV()
model = PPO.load(ppo_path, env)

In [None]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)