**1. Import Dependencies**

In [1]:
# Import Gym Stuff
import gym
from gym import Env
from gym.spaces import Discrete, Box, Dict, Tuple, MultiBinary, MultiDiscrete

# Import helpers
import numpy as np
import random
import os

# Import stable baselines stuff
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

In [2]:
#Import Game Stuff
from snakeGame import SnakeGame

pygame 2.1.2 (SDL 2.0.18, Python 3.8.15)
Hello from the pygame community. https://www.pygame.org/contribute.html


**2. Building the ENV**

In [3]:
class SnakeENV(Env):
        
    def __init__(self, num_envs=1) -> None:
        super(SnakeENV, self).__init__()
        metadata = {'render.modes': ['human',]}

        self.num_envs = num_envs
       #self.game = [SnakeGame() for _ in range(self.num_envs)]
        self.game = SnakeGame()

        self.action_space = Discrete(4)
        #self.observation_space = Dict([('head_pos', Box(low=np.array([0, 0]), high=np.array([self.game.window.width, self.game.window.height]), dtype=int)),
        #                                ('fruit_pos', Box(low=np.array([0, 0]), high=np.array([self.game.window.width, self.game.window.height]), dtype=int)),
        #                                ('length', Box(low=np.array([len(self.game.snake.body)]), high=np.array([self.game.window.width * self.game.window.height]), dtype=int)),
        #                                ('direction', Discrete(4))
        #                        ])

        #self.observation_space = Dict([('head_pos', Box(low=np.array([0, 0]), high=np.array([self.game.window.width, self.game.window.height]), dtype=int)),
        #                                ('fruit_pos', Box(low=np.array([0, 0]), high=np.array([self.game.window.width, self.game.window.height]), dtype=int)),
        #                                ('free_space', MultiBinary(4))
        #                        ])
        
        self.observation_space = Dict([('relative_to_fruit', Box(low=np.array([-1, -1]), high=np.array([1, 1]), dtype=np.float32)),
                                        ('free_space', MultiBinary(4))
                                ])
        



        self.time_run = 0

    def isEpisodeLenght(self):
        self.time_run += 1
        return self.time_run > 6600

    def step(self, action, render=False):
        #done = self.isEpisodeLenght()

        direction = self.action_to_direction(action)
        reward, done = self.game.main(direction, render=render, isAI=True)

        #self.state = {'head_pos': np.array(self.game.snake.head_pos).astype(int),
        #                'fruit_pos': np.array(self.game.fruit.pos).astype(int),
        #                'length': np.array(len(self.game.snake.body)).astype(int),
        #                'direction': self.direction_to_action(self.game.snake.direction)
        #            }

        #self.state = {'head_pos': np.array(self.game.snake.head_pos).astype(int),
        #                'fruit_pos': np.array(self.game.fruit.pos).astype(int),
        #                'free_space': self.game.snake.get_free_space(self.game.window)
        #            }

        self.state = {
            'relative_to_fruit': np.array(self.game.relative_to_fruit()).astype(np.float32),
            'free_space': self.game.snake.get_free_space(self.game.window)
        }


        info = {}

        return self.state, reward, done, info

    def render(self, mode='human'):
        self.game.render(mode=mode)

    def reset(self):
        self.game = SnakeGame()
        self.game.snake.randomPos(self.game.window)
        self.time_run = 0
        #self.state = {'head_pos': np.array(self.game.snake.head_pos).astype(int),
        #                'fruit_pos': np.array(self.game.fruit.pos).astype(int),
        #                'length': np.array(len(self.game.snake.body)).astype(int),
        #                'free_space': self.game.snake.get_free_space(self.game.window)
        #            }
        
        self.state = {
            'relative_to_fruit': np.array(self.game.relative_to_fruit()).astype(np.float32),
            'free_space': self.game.snake.get_free_space(self.game.window)
        }

                    
        return self.state


    def action_to_direction(self, action):
        if action == 0:
            return 'UP'
        elif action == 1:
            return 'DOWN'
        elif action == 2:
            return 'LEFT'
        elif action == 3:
            return 'RIGHT'
    def direction_to_action(self, direction):

        if direction == 'UP':
            return 0
        elif direction == 'DOWN':
            return 1
        elif direction == 'LEFT':
            return 2
        elif direction == 'RIGHT':
            return 3


In [4]:
env = SnakeENV()
episodes = 5


for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0

    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score += reward
    print(f'Episode: {episode} Score: {score}')
env.close()



Episode: 1 Score: -500
Episode: 2 Score: -500
Episode: 3 Score: -500
Episode: 4 Score: -500
Episode: 5 Score: -500


**3. Train a PPO model**

In [5]:
del env
#env = SnakeENV()
cpu_cores = 4
#env = DummyVecEnv([lambda: SnakeENV])
#env = SubprocVecEnv([lambda: SnakeENV for i in range(cpu_cores)])
#env = VecFrameStack(env, n_stack=4)

from stable_baselines3.common.env_util import make_vec_env
env = make_vec_env(SnakeENV, n_envs=4)

In [6]:
log_path = "./Training/Logs/"
model = PPO("MultiInputPolicy", env, verbose=1, tensorboard_log=log_path)

Using cpu device


In [7]:
#model.learn(total_timesteps=100000)

In [8]:
timesteps = 500_000

for i in range(4):
    model.learn(total_timesteps=timesteps)
    model.save(f"./Training/Models/PPO_SnakeModelJustFruit_{(i+1)*timesteps}")

Logging to ./Training/Logs/PPO_17
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 3.84     |
|    ep_rew_mean     | -500     |
| time/              |          |
|    fps             | 56       |
|    iterations      | 1        |
|    time_elapsed    | 145      |
|    total_timesteps | 8192     |
---------------------------------


In [None]:
#ppo_path = "./Training/Models/PPO_Snake_Model_RightObs"
#model.save(ppo_path)

In [None]:
del model

**4. Eval and Test**

In [None]:
ppo_path = "./Training/Models/PPO_SnakeModelBin_500000"

env = DummyVecEnv([lambda: SnakeENV()])
model = PPO.load(ppo_path, env)

In [None]:
evaluate_policy(model, env, n_eval_episodes=10, render=True, deterministic=True)