In [10]:
import pygame, random, sys
from pygame import Vector2
# initialize
pygame.init()

(5, 0)

In [11]:
class Fruit:
    def __init__(self, cell_size, cell_num):
        self.cell_size = cell_size
        self.cell_num = cell_num
        # x, y pos
        self.x = random.randint(0, self.cell_num-1)
        self.y = random.randint(0, self.cell_num-1)
        # vector
        self.pos = Vector2(self.x, self.y)
        
    def draw_fruit(self, screen):
        # creat rect
        fruit_rect = pygame.Rect(self.pos.x * self.cell_size, self.pos.y * self.cell_size, self.cell_size, self.cell_size)
        # draw rect
        pygame.draw.ellipse(screen, pygame.Color('red'), fruit_rect)

    def move_position(self):
        # x, y pos
        self.x = random.randint(0, self.cell_num-1)
        self.y = random.randint(0, self.cell_num-1)
        # vector
        self.pos = Vector2(self.x, self.y)

In [12]:
class Snake:
    def __init__(self, cell_size, cell_num):
        self.cell_size = cell_size
        self.cell_num = cell_num
        # inital position of 3 snake parts
        x_init = self.cell_num/4
        y_init = self.cell_num/2
        # head tail tail
        self.body=[Vector2(x_init, y_init), Vector2(x_init-1, y_init), Vector2(x_init-2, y_init)]
        # movement direction
        self.direction = Vector2(1, 0)
        
    def draw_snake(self, screen):
        f_block = self.body[0]
        # create rect
        f_block_rect = pygame.Rect(f_block.x * self.cell_size, f_block.y * self.cell_size, self.cell_size, self.cell_size)
        # draw rect
        pygame.draw.rect(screen, pygame.Color(50, 168, 82, 255), f_block_rect)
        for block in self.body[1:]:
            # create rect
            block_rect = pygame.Rect(block.x * self.cell_size, block.y * self.cell_size, self.cell_size, self.cell_size)
            # draw rect
            pygame.draw.rect(screen, pygame.Color('green'), block_rect)
            
    def move_snake(self):
        # create body without last block, add new block to beginning
        body_copy = self.body[:-1]
        body_copy.insert(0, body_copy[0] + self.direction)
        self.body = body_copy

    def eat_apple(self):
        # create body with last block, add new block to beginning
        body_copy = self.body
        body_copy.insert(0, body_copy[0] + self.direction)
        self.body = body_copy

In [13]:
import numpy as np

In [14]:
import gymnasium
from gymnasium import Env

In [20]:
class SnakeEnv(Env):
    def __init__(self, cell_size=40, cell_num=20, render_mode = 'human'):
        self.render_mode = render_mode
        pygame.init()
        # game dimensions
        self.cell_size = cell_size
        self.cell_num = cell_num
        
        # initialize fruit and snake
        self.fruit = Fruit(self.cell_size, self.cell_num)
        self.snake = Snake(self.cell_size, self.cell_num)
        
        # env dimensions
        self.action_space = gymnasium.spaces.MultiBinary(2)
        # fruit x y, snake head x y
        self.observation_space = gymnasium.spaces.Box(low=0, high=self.cell_num, shape=(4,), dtype=np.float64)
        self.state = np.array([self.fruit.x, self.fruit.y, self.snake.body[0].x, self.snake.body[0].y])

        # score
        self.score = 0
        self.info = {}
        self.screen = None
        self.clock = None
        self.framerate = None
        # game over
        self.done = False
        
    
    # set new dimensions for env game
    def set_dim(self, cell_size, cell_num):
        self.cell_size = cell_size
        self.cell_num = cell_num
    
    # take action
    def step(self, action):
        if self.done:
            self.done = False
            self.close()
        # move snake in direction
        # action = [1,0] turn left, [0,1] turn right, [0,0] go straight
        self.snake.direction = self.set_movement(self.snake.direction, action)
        
        # game rules
        next_pos = self.snake.body[0] + self.snake.direction
        # if snake eat apple
        if next_pos == self.fruit.pos:
            self.score += 1
            self.snake.eat_apple()
            self.fruit.move_position()
            # make sure fruit is not in snake
            while self.fruit.pos in self.snake.body:
                self.fruit.move_position()
        # if snake eat self
        elif next_pos in self.snake.body:
            self.score -= 1
            self.done = True
        # if snake hit wall
        elif not 0 <= next_pos.x < self.cell_num or not 0 <= next_pos.y < self.cell_num:
            self.score -= 1
            self.done = True
        # regular movement
        else:
            self.snake.move_snake()
        # reset direction
        self.snake.direction = self.set_movement(self.snake.direction,[0,0])
        
        # state: fruit x y, snake head x y
        self.state = np.array([self.fruit.x, self.fruit.y, self.snake.body[0].x, self.snake.body[0].y])
        
        if self.render_mode == 'human':
            self.render()
        
        return self.state, self.score, self.done, self.done, self.info
    
    # turn action into direction vector
    def set_movement(self, current_dir, cd):
        if (np.array(cd) == np.array([1, 1])).all():
            cd = [0, 0]
        # current_dir = 1,0 || -1,0
        nx = current_dir.x
        ny = current_dir.y
        # if moving right or left
        if nx:
            #turn
            if 1 in cd:
                # get new directions
                ny = (cd[1] - cd[0]) * nx
                nx = 0
            return Vector2(nx, ny)
        # up or down
        else:
            if 1 in cd:
                # get new directions
                nx = (cd[0] - cd[1]) * ny
                ny = 0
            # return vector direction
            return Vector2(nx, ny)

    def draw_pattern(self, screen):
        # draw checkerboard pattern
        p_color = (101, 158, 219)
        for row in range(self.cell_num):
            if row % 2 == 0:
                for col in range(self.cell_num):
                    if col % 2 == 0:
                        p_rect = pygame.Rect(col* self.cell_size,row* self.cell_size,self.cell_size,self.cell_size)
                        pygame.draw.rect(screen,p_color, p_rect)
            if row % 2 == 1:
                for col in range(self.cell_num):
                    if col % 2 == 1:
                        p_rect = pygame.Rect(col* self.cell_size,row* self.cell_size,self.cell_size,self.cell_size)
                        pygame.draw.rect(screen,p_color, p_rect)

    def draw_score(self, screen):
        pygame.init()
        # draw score text to screen
        font = pygame.font.Font(None, 32)
        score_text = "Score: {}".format(self.score)
        score_surface = font.render(score_text, True, pygame.Color('red'))
        # score position: bottom left (with 60 pixels padding)
        score_x = 60
        score_y = int(self.cell_size * self.cell_num - 60)
        score_rect = score_surface.get_rect(center = (score_x,score_y))
        screen.blit(score_surface, score_rect)
    
    # show game visual
    def render(self, render_mode='human'):
        if self.screen is None:
            self.reset()
        
        for event in pygame.event.get():
            # if x out, exit game
            if event.type == pygame.QUIT:
                self.close()
        # draw screen
        self.screen.fill((104, 166, 232))
        # draw pattern
        self.draw_pattern(self.screen)
        # draw score
        self.draw_score(self.screen)
        # draw fruit and snake
        self.fruit.draw_fruit(self.screen)
        self.snake.draw_snake(self.screen)
        # display
        pygame.display.update()
        self.clock.tick(self.framerate)

    # reset fruit and snake positions
    def reset(self, seed=0):
        self.done = False
        self.fruit = Fruit(self.cell_size, self.cell_num)
        self.snake = Snake(self.cell_size, self.cell_num)
        self.score = 0
        self.info = {}
        # state: fruit x y, snake head x y
        self.state = np.array([self.fruit.x, self.fruit.y, self.snake.body[0].x, self.snake.body[0].y])


        if self.render_mode == 'human':
            pygame.init()
            # while True:
            pygame.display.set_caption('Snake Game')

            # display and speed
            self.screen = pygame.display.set_mode((self.cell_size * self.cell_num, self.cell_size * self.cell_num))
            self.clock = pygame.time.Clock()
            self.framerate = 60
        
            self.render()
        
        return self.state, self.info
        
    
    # close game (if rendered)
    def close(self):
        # pygame.quit()
        # sys.exit()
        pygame.quit()
        pygame.display.quit()
        self.screen = None
        self.clock = None

In [16]:
env = SnakeEnv(render_mode='human')

In [17]:
from stable_baselines3 import PPO

In [18]:
model = PPO('MlpPolicy', env, verbose=1)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [19]:
model.learn(total_timesteps=2500)
env.close()

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 43.1     |
|    ep_rew_mean     | 1.85     |
| time/              |          |
|    fps             | 58       |
|    iterations      | 1        |
|    time_elapsed    | 34       |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 45.5        |
|    ep_rew_mean          | 1.13        |
| time/                   |             |
|    fps                  | 56          |
|    iterations           | 2           |
|    time_elapsed         | 71          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.003493443 |
|    clip_fraction        | 0.00244     |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.39       |
|    explained_variance   | -0.06       |
|    learning_rate        | 0.

In [None]:
# model play for i games (1)
obs, info = env.reset()
term, trunc = False, False
for i in range(1):
    while not term or not trunc:
        action, _states = model.predict(obs)
        obs, rewards, term, trunc, info = env.step(action)
        env.render("human")
    print(rewards)
    env.close()
    term, trunc = False, False