In [2]:
from collections import deque
import random
import time
import cv2
import numpy as np
import gymnasium as gym
from gymnasium import spaces

from snake import collision_with_apple, collision_with_boundaries, collision_with_self

SNAKE_LEN_GOAL = 30
N_DISCRETE_ACTIONS = 4

class SnakeEnv(gym.Env):
    """Custom Environment that follows gym interface"""

    def __init__(self):
        super(SnakeEnv, self).__init__()
        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions:
        self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)
        # Example for using image as input (channel-first; channel-last also works):
        self.observation_space = spaces.Box(low=-500, high=500,
                                            shape=(5+SNAKE_LEN_GOAL,), dtype=np.float32)

    def step(self, action):
        self.prev_actions.append(action)
        cv2.imshow('a',img)
        cv2.waitKey(1)
        img = np.zeros((500,500,3),dtype='uint8')
        # Display Apple
        cv2.rectangle(img,(self.apple_position[0],self.apple_position[1]),(self.apple_position[0]+10,self.apple_position[1]+10),(0,0,255),3)
        # Display Snake
        for position in self.snake_position:
            cv2.rectangle(self.img,(position[0],position[1]),(position[0]+10,position[1]+10),(0,255,0),3)
        
        # Takes step after fixed time
        t_end = time.time() + 0.2
        k = -1
        while time.time() < t_end:
            if k == -1:
                k = cv2.waitKey(125)
            else:
                continue
            
        # 0-Left, 1-Right, 3-Up, 2-Down, q-Break
        # a-Left, d-Right, w-Up, s-Down

        if k == ord('a') and prev_button_direction != 1:
            button_direction = 0
        elif k == ord('d') and prev_button_direction != 0:
            button_direction = 1
        elif k == ord('w') and prev_button_direction != 2:
            button_direction = 3
        elif k == ord('s') and prev_button_direction != 3:
            button_direction = 2
        elif k == ord('q'):
            break
        else:
            button_direction = button_direction
        prev_button_direction = button_direction

        # Change the head position based on the button direction
        if action == 1:
            self.snake_head[0] += 10
        elif action == 0:
            self.snake_head[0] -= 10
        elif action == 2:
            self.snake_head[1] += 10
        elif action == 3:
            self.snake_head[1] -= 10

        # Increase Snake length on eating apple
        if self.snake_head == apple_position:
            apple_position, score = collision_with_apple(apple_position, score)
            self.snake_position.insert(0,list(self.snake_head))

        else:
            self.snake_position.insert(0,list(self.snake_head))
            self.snake_position.pop()
            
        # On collision kill the snake and print the score
        if collision_with_boundaries(self.snake_head) == 1 or collision_with_self(self.snake_position) == 1:
            font = cv2.FONT_HERSHEY_SIMPLEX
            img = np.zeros((500,500,3),dtype='uint8')
            cv2.putText(self.img,'Your Score is {}'.format(self.score),(140,250), font, 1,(255,255,255),2,cv2.LINE_AA)
            cv2.imshow('a',self.img)
            self.done = True

        if self.done:
            self.reward = -10
        else: 
            self.reward = self.score


        info = {}
        return self.observation, self.reward, self.done, info

    def reset(self):
        self.done = False
        self.img = np.zeros((500,500,3),dtype='uint8')
        # Initial Snake and Apple position
        self.snake_position = [[250,250],[240,250],[230,250]]
        self.apple_position = [random.randrange(1,50)*10,random.randrange(1,50)*10]
        self.score = 0
        self.prev_button_direction = 1
        self.button_direction = 1
        self.snake_head = [250,250]

        head_x = self.snake_head[0]
        head_y = self.snake_head[1]
        apple_delta_x = head_x - self.apple_position[0] 
        apple_delta_y = head_y - self.apple_position[1]
        snake_length = len(self.snake_position)
        self.prev_actions = deque(maxLen=SNAKE_LEN_GOAL)
        for _ in range(SNAKE_LEN_GOAL):
            self.prev_actions.append(-1)
        
        self.observation = [head_x,head_y,apple_delta_x,apple_delta_y,snake_length] + list(self.prev_actions)
        self.observation = np.array(self.observation)

        return self.observation  # reward, done, info can't be included  

In [None]:
from stable_baselines3.common.env_checker import check_env

env = SnakeEnv()

check_env(env)

NameError: name 'SnakeEnv' is not defined