In [1]:
import pygame
import random

pygame.init()
height = 600
width = 600
pixel_size = 20
screen = pygame.display.set_mode((width, height))
clock = pygame.time.Clock()
running = True

pygame 2.5.2 (SDL 2.28.3, Python 3.10.9)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
def render_background(screen):
    screen.fill("black")

In [3]:
def render_border(screen): 
    rect1 = pygame.Rect(0, 0, width, pixel_size)
    rect2 = pygame.Rect(0, 0, pixel_size, height)
    rect3 = pygame.Rect(height - pixel_size, 0, pixel_size, height)
    rect4 = pygame.Rect(0, width - pixel_size, width, pixel_size)
    pygame.draw.rect(screen, "green", rect1)
    pygame.draw.rect(screen, "green", rect2)
    pygame.draw.rect(screen, "green", rect3)
    pygame.draw.rect(screen, "green", rect4)

In [4]:
def render_apple(screen, position):
    apple = pygame.Rect(position.x * pixel_size, position.y * pixel_size, pixel_size, pixel_size)
    pygame.draw.rect(screen, "red", apple)

In [5]:
def render_snake(screen, positions):
    for pos in positions:
        snake_chunk = pygame.Rect(pos.x * pixel_size, pos.y * pixel_size, pixel_size, pixel_size)
        pygame.draw.rect(screen, "lightblue", snake_chunk)

In [6]:
class Position:
    def __init__(self, x, y):
        self.x = x
        self.y = y

In [7]:
from enum import Enum

class Turn(Enum):
    RIGHT = 0
    LEFT = 1
    UP = 2
    DOWN = 3

In [8]:
class State:
    def __init__(self):
        self.state = {
            "apple": Position(0, 0),
            "apple_rendered": False,
            "snake_positions": [
                Position((width - pixel_size * 2) // pixel_size / 2, (height - pixel_size * 2) // pixel_size / 2)
            ],
            "die": False,
            "turn": Turn.RIGHT,
            "score": 0
        }
        self.set_apple_position()

    def get_state(self):
        return self.state

    def set_apple_position(self):
        if (self.state['apple_rendered'] == False):
            while True:
                random_x = random.randrange(1, (width - pixel_size * 2) // pixel_size + 1)
                random_y = random.randrange(1, (height - pixel_size * 2) // pixel_size + 1)
                found = any(pos.x == random_x and pos.y == random_y for pos in self.state['snake_positions'])
                if not found:
                    break
            self.state['apple'] = Position(random_x, random_y)
            self.state['apple_rendered'] = True
    def set_turn(self, turn):
        current_turn = self.state['turn']
        if (turn == Turn.RIGHT or turn == Turn.LEFT) and (current_turn == Turn.UP or current_turn == Turn.DOWN):
            self.state['turn'] = turn
        elif (turn == Turn.UP or turn == Turn.DOWN) and (current_turn == Turn.RIGHT or current_turn == Turn.LEFT):
            self.state['turn'] = turn
            
    def update(self):
        game_state.set_apple_position()
        head = self.state['snake_positions'][-1]
        
        if (self.state["turn"] == Turn.RIGHT):
            self.state['snake_positions'].append(Position(head.x+1, head.y))
        elif (self.state["turn"] == Turn.LEFT):
            self.state['snake_positions'].append(Position(head.x-1, head.y))
        if (self.state["turn"] == Turn.UP):
            self.state['snake_positions'].append(Position(head.x, head.y-1))
        elif (self.state["turn"] == Turn.DOWN):
            self.state['snake_positions'].append(Position(head.x, head.y+1))
        
        to_remove = self.state['snake_positions'][0]
        self.state['snake_positions'].pop(0)
        #Check if head hits border
        updated_head = self.state['snake_positions'][-1]
        if updated_head.x >= ((width-pixel_size)/pixel_size) or updated_head.x < 1 or updated_head.y >= ((height-pixel_size)/pixel_size) or updated_head.y < 1:
            self.state['die'] = True
        #TODO: Check if head hits snake
        hit = False
        for pos in self.state['snake_positions'][:-1]:
            if pos.x == updated_head.x and pos.y == updated_head.y:
                self.state['die'] = True
        #Check if head eats apple
        apple_position = self.state["apple"]
        if updated_head.x == apple_position.x and updated_head.y == apple_position.y:
            self.state['score'] = self.state['score'] + 1
            self.state['apple_rendered'] = False
            self.state['snake_positions'].insert(0, to_remove)


In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models

class Reinforcement_Learning():
    __init__(self, num_episodes):
        self.num_episodes = num_episodes
        self.create_model()
    
    def create_model():
        model = models.Sequential([
            layers.Flatten(input_shape=(your_state_shape,)),
            layers.Dense(128, activation='relu'),
            layers.Dense(your_number_of_actions, activation='softmax') 
        ])
        model.compile(optimizer='adam', loss='categorical_crossentropy')
        self.model = model

    def get_encoded_state(self, state):
        encoded_state = []
        if (state.turn == Turn.RIGHT):
            encoded_state.append(0) # Snake's head is turned right
        elif (state.turn == Turn.LEFT):
            encoded_state.append(1) # Snake's head is turned left
        elif (state.turn == Turn.UP):
            encoded_state.append(2) # Snake's head is turned up
        elif (state.turn == Turn.DOWN):
            encoded_state.append(3) # Snake's head is turned down
        vector = (state['snake_positions'][-1].x - state.apple.x, state['snake_positions'][-1].y - state.apple.y)
        if vector[0] > vector[1]:
            if vector[0] + vector[1] >= 0:
                encoded_state.append(1) # Apple is on the left 
            else:
                encoded_state.append(3) # Apple is on the down
        else:
            if vector[0] + vector[1] >= 0:
                encoded_state.append(2) #Apple is on the up
            else:
                encoded_state.append(0) #Apple is on the right
        encoded_state.append(state['snake_positions'][-1].x - 0) # Distance X px to the left
        encoded_state.append(state['snake_positions'][-1].y - 0) # Distance X px to the top
        encoded_state.append((width / pixel_size) - state['snake_positions'][-1].x ) # Distance X px to the right
        encoded_state.append((height / pixel_size) - state['snake_positions'][-1].y ) # Distance X px to the down
        encoded_state.append(score) # Game score
        return encoded_state # Array with Length 7: [Snake_Direction, Apple_Direction, Xpx to left, Xpx to top, Xpx to right, Xpx to down, Score]
    
    def get_action(self, state):
        q_values = model.predict(state.reshape(1, -1))
        return np.argmax(q_values)
    
    def update_model(self, state, next_state, action, reward, discount_factor=0.95, learning_rate=0.001):
        return

    def train_loop(self):
        for episodes in range(self.num_episodes):
            game_state = State()
            # Game Loop
            while running:
                state = self.get_encoded_state(game_state.get_state())
                if game_state.get_state()['die'] == True:
                    print(f"Score: {game_state.get_state()['score']}")
                    break
                action = self.get_action(state)
                if action == 0: 
                    game_state.set_turn(Turn.RIGHT)
                elif action == 1: 
                    game_state.set_turn(Turn.LEFT)
                elif action == 2: 
                    game_state.set_turn(Turn.UP)
                elif action == 3: 
                    game_state.set_turn(Turn.DOWN)
            
                #render game here
                render_background(screen)
                render_border(screen)
                render_snake(screen, game_state.get_state()["snake_positions"])
                if game_state.get_state()['apple_rendered'] == True:
                    render_apple(screen, game_state.get_state()['apple'])
    
                #update game here
                game_state.update()
        
                next_state = self.get_encoded_state(game_state.get_state())
                self.update_model(state, next_state, action, self.state.score)
                
                pygame.display.flip()
                clock.tick(6)
            pygame.quit()

In [None]:
rl = Reinforcement_Learning(num_episodes=10)
rl.train_loop()