In [None]:
from enum import Enum
import math
import tensorflow as tf
from tensorflow.keras import layers, models
import pygame
import random
import random
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.optimizers import Adam
from collections import deque
from pygame.locals import QUIT, MOUSEBUTTONDOWN
import numpy as np
from sklearn.preprocessing import StandardScaler

#Adjustable hyperparemeters
height = 400
width = 400
pixel_size = 20
state_size = 11
action_size = 5
memory_length = 100000
episodes = 150
steps = 5000
batch_size = 1000

class Position:
    def __init__(self, x, y):
        self.x = x
        self.y = y

class Turn(Enum):
    RIGHT = 0
    LEFT = 1
    UP = 2
    DOWN = 3

#defines game state methods
class State:
    def __init__(self):
        self.state = {
            "apple": Position(0, 0),
            "apple_rendered": False,
            "snake_positions": [
                Position((width - pixel_size * 2) // pixel_size / 2, (height - pixel_size * 2) // pixel_size / 2)
            ],
            "die": False,
            "turn": Turn.RIGHT,
            "score": 0
        }
        self.set_apple_position()

    def get_state(self):
        return self.state

    def set_apple_position(self):
        if (self.state['apple_rendered'] == False):
            while True:
                random_x = random.randrange(1, (width - pixel_size * 2) // pixel_size + 1)
                random_y = random.randrange(1, (height - pixel_size * 2) // pixel_size + 1)
                found = any(pos.x == random_x and pos.y == random_y for pos in self.state['snake_positions'])
                if not found:
                    break
            self.state['apple'] = Position(random_x, random_y)
            self.state['apple_rendered'] = True
    
    def set_turn(self, turn):
        current_turn = self.state['turn']
        if (turn == Turn.RIGHT or turn == Turn.LEFT) and (current_turn == Turn.UP or current_turn == Turn.DOWN):
            self.state['turn'] = turn
        elif (turn == Turn.UP or turn == Turn.DOWN) and (current_turn == Turn.RIGHT or current_turn == Turn.LEFT):
            self.state['turn'] = turn
            
    def is_collision(self, point):
        #Check if head hits border
        head = self.state['snake_positions'][-1]
        if head.x >= ((width-pixel_size)/pixel_size) or head.x < 1 or head.y >= ((height-pixel_size)/pixel_size) or head.y < 1:
            return True
        #Check if head hits snake
        for pos in self.state['snake_positions'][:-1]:
            if pos.x == head.x and pos.y == head.y:
                return True
        return False
    
    def get_encoded_state(self):
        #danger straight, danger right, danger left, head direction, food directions
        encoded_state=[]
        state = self.state
        head = state['snake_positions'][-1]
        dir_l = state["turn"] == Turn.LEFT
        dir_r = state["turn"] == Turn.RIGHT
        dir_u = state["turn"] == Turn.UP
        dir_d = state["turn"] == Turn.DOWN
        food_l = state['apple'].x < head.x
        food_r = state['apple'].x > head.x
        food_u = state['apple'].y < head.y
        food_d = state['apple'].y > head.y
        danger_s = (dir_r and self.is_collision(head.x + 1)) or (dir_l and self.is_collision(head.x - 1)) or (dir_u and self.is_collision(head.y + 1)) or (dir_d and self.is_collision(head.y - 1))
        danger_r = (dir_u and self.is_collision(head.x + 1)) or (dir_d and self.is_collision(head.x - 1)) or (dir_l and self.is_collision(head.y + 1)) or (dir_r and self.is_collision(head.y - 1))
        danger_l = (dir_d and self.is_collision(head.x + 1)) or (dir_u and self.is_collision(head.x - 1)) or (dir_r and self.is_collision(head.y + 1)) or (dir_l and self.is_collision(head.y - 1))
        encoded_state = [danger_s, danger_r, danger_l, dir_l, dir_r, dir_u, dir_d, food_l, food_r, food_u, food_d]
        return encoded_state
            
    def update(self):
        head = self.state['snake_positions'][-1]
        if (self.state["turn"] == Turn.RIGHT):
            self.state['snake_positions'].append(Position(head.x+1, head.y))
        elif (self.state["turn"] == Turn.LEFT):
            self.state['snake_positions'].append(Position(head.x-1, head.y))
        if (self.state["turn"] == Turn.UP):
            self.state['snake_positions'].append(Position(head.x, head.y-1))
        elif (self.state["turn"] == Turn.DOWN):
            self.state['snake_positions'].append(Position(head.x, head.y+1))
        to_remove = self.state['snake_positions'][0]
        self.state['snake_positions'].pop(0)
        updated_head = self.state['snake_positions'][-1]
        reward = 0
        # Check if head hits tail or wall reward -10
        if self.is_collision(updated_head):
            self.state['die'] = True
            reward = -10
        #Check if head eats apple
        apple_position = self.state["apple"]
        if updated_head.x == apple_position.x and updated_head.y == apple_position.y:
            reward = 10 # If Snake eats apple reward +10
            self.state['apple_rendered'] = False
            self.state['snake_positions'].insert(0, to_remove)
            #Increase game score
            self.state['score'] = self.state['score'] + 1
        self.set_apple_position()
        return reward

    # Actions: 0 = No action, 1 = Right, 2 = Left, 3 = Up, 4 = Down
    def step(self, action=None):
        if action == 1:
            self.set_turn(Turn.RIGHT)
        elif action == 2:
            self.set_turn(Turn.LEFT)
        elif action == 3:
            self.set_turn(Turn.UP)
        elif action == 4:
            self.set_turn(Turn.DOWN)
        reward = self.update()
        new_state=self.get_encoded_state()
        done = 1 if self.state["die"] else 0
        score = self.state["score"]
        return new_state, reward, done, score

#main snake game class
class Snake_Environment():
    def __init__(self):
        self.state = State()
        pygame.init()
        self.screen = pygame.display.set_mode((width, height))
        self.font = pygame.font.Font(None, 24)
        self.text_color = (0,0,0)
        self.clock = pygame.time.Clock()
        
    def reset(self):
        self.state = State()
        
    def render_background(self):
        self.screen.fill("black")
        
    def render_snake(self):
        for pos in self.state.get_state()['snake_positions']:
            snake_chunk = pygame.Rect(pos.x * pixel_size, pos.y * pixel_size, pixel_size, pixel_size)
            pygame.draw.rect(self.screen, "lightblue", snake_chunk)
            
    def render_border(self): 
        rect1 = pygame.Rect(0, 0, width, pixel_size)
        rect2 = pygame.Rect(0, 0, pixel_size, height)
        rect3 = pygame.Rect(height - pixel_size, 0, pixel_size, height)
        rect4 = pygame.Rect(0, width - pixel_size, width, pixel_size)
        pygame.draw.rect(self.screen, "green", rect1)
        pygame.draw.rect(self.screen, "green", rect2)
        pygame.draw.rect(self.screen, "green", rect3)
        pygame.draw.rect(self.screen, "green", rect4)
        
    def render_apple(self):
        apple = pygame.Rect(self.state.get_state()['apple'].x * pixel_size, self.state.get_state()['apple'].y * pixel_size, pixel_size, pixel_size)
        pygame.draw.rect(self.screen, "red", apple)
    
    def render_summary(self, episode, step):
        text_surface = self.font.render(f"Episode: {episode}, Score: {self.state.get_state()['score']}", True, self.text_color)
        text_rect = text_surface.get_rect()
        text_rect.center = (100, 10)
        self.screen.blit(text_surface, text_rect)
    
    def stop(self):
        pygame.quit()
    
    def render(self, episode, step):
        self.render_background()
        self.render_border()
        self.render_snake()
        if self.state.get_state()['apple_rendered'] == True:
            self.render_apple()
        self.render_summary(episode, step)
        pygame.display.flip()
        self.clock.tick(6)
    
#Agent class
class DeepQNetAgent():
    def __init__(self, epsilon, epsilon_min, epsilon_decay, learning_rate, gamma):
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.memory = deque(maxlen=memory_length)
        self.model = self.build_model()
        
    def build_model(self):
        model = Sequential()
        model.add(Input(shape=[state_size,]))
        model.add(Dense(512, activation='relu'))
        model.add(Dense(action_size, activation='linear'))
        model.compile(loss='mse', optimizer='adam')
        return model

    def act(self, state):
        if random.uniform(0, 1) > self.epsilon:
            q_values = self.model.predict(state, verbose=0)
            return np.argmax(q_values[0])
        else:
            return random.randrange(action_size)

    def remember(self, state, action, reward, new_state, done):
        self.memory.append((state, action, reward, new_state, done))
        
    def replay(self, batch_size):
        if len(self.memory) < batch_size:
            return
        minibatch = random.sample(self.memory, batch_size)
        states = np.array([sample[0] for sample in minibatch])
        actions = np.array([sample[1] for sample in minibatch])
        rewards = np.array([sample[2] for sample in minibatch])
        new_states = np.array([sample[3] for sample in minibatch])
        done = np.array([sample[4] for sample in minibatch])
        states = tf.squeeze(states, axis=1)
        new_states = tf.squeeze(new_states, axis=1)
        target = rewards + (1 - done) * self.gamma * np.amax(self.model.predict(new_states, verbose=0), axis=1)
        target_full = self.model.predict(states, verbose=0)
        target_full[np.arange(batch_size), actions] = target
        self.model.fit(states, target_full, verbose=0)
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

class Break(Exception):
    pass

def get_scaler(env):
    states = []
    for _ in range(steps):
        action = random.randrange(action_size)
        new_state, reward, done, score = env.state.step(action)
        states.append(new_state)
        if done:
            break
    scaler = StandardScaler()
    scaler.fit(states)
    return scaler
    
agent = DeepQNetAgent(epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.995, learning_rate=0.005, gamma = 0.95)
env = Snake_Environment()
scaler = get_scaler(env)
#Training loop
try:
    for episode in range(episodes):
        env.reset()
        step = 0
        for step in range(steps):
            state = env.state.get_encoded_state()
            state = scaler.transform([state])
            action = agent.act(state)
            new_state, reward, done, score = env.state.step(action)
            new_state = scaler.transform([new_state])
            agent.remember(state, action, reward, new_state, done)
            agent.replay(batch_size=batch_size) 
            if done:
                print("Episode {}/{}, Score {}".format(episode, episodes, score))
                break
            for event in pygame.event.get():
                if event.type == QUIT:
                    env.stop()
                    raise Break
                elif event.type == MOUSEBUTTONDOWN:
                    pass
            #uncomment below to render game if necessary
            #env.render(episode, step)
except Break:
    print('Close game')
env.stop()

pygame 2.5.2 (SDL 2.28.3, Python 3.10.9)
Hello from the pygame community. https://www.pygame.org/contribute.html
Episode 0/150, Score 0
Episode 1/150, Score 0
Episode 2/150, Score 0
Episode 3/150, Score 0
Episode 4/150, Score 0
Episode 5/150, Score 0
Episode 6/150, Score 0
Episode 7/150, Score 0
Episode 8/150, Score 0
Episode 9/150, Score 0
Episode 10/150, Score 0
Episode 11/150, Score 0
Episode 12/150, Score 0
Episode 13/150, Score 0
Episode 14/150, Score 0
Episode 15/150, Score 0
Episode 16/150, Score 1
Episode 17/150, Score 0
Episode 18/150, Score 0
Episode 19/150, Score 0
Episode 20/150, Score 0
Episode 21/150, Score 0
Episode 22/150, Score 0
Episode 23/150, Score 0
Episode 24/150, Score 0
Episode 25/150, Score 0
Episode 26/150, Score 0
Episode 27/150, Score 0
Episode 28/150, Score 0
Episode 29/150, Score 0
Episode 30/150, Score 0
Episode 31/150, Score 0
Episode 32/150, Score 0
Episode 33/150, Score 6
Episode 34/150, Score 4
Episode 35/150, Score 5
Episode 36/150, Score 4
Episode 3