In [1]:
import random
import numpy as np


In [2]:


class QLearningSnake:
    def __init__(self, width, height):
        self.width = width
        self.height = height
        self.snake = [(width // 2, height // 2)]
        self.direction = 'up'
        self.apple = self.generate_apple()
        self.score = 0
        self.game_over = False

        # Q-Learning parameters
        self.alpha = 0.1  # Learning rate
        self.gamma = 0.9  # Discount factor
        self.epsilon = 1.0  # Exploration rate
        self.min_epsilon = 0.01
        self.epsilon_decay = 0.99

        # Initialize Q-table with zeros
        self.q_table = np.zeros((width * height, width * height, 4))

    def generate_apple(self):
        while True:
            apple = (random.randint(0, self.width - 1),
                     random.randint(0, self.height - 1))
            if apple not in self.snake:
                return apple

    def change_direction(self, direction):
        if direction == 'up' and self.direction != 'down':
            self.direction = 'up'
        elif direction == 'down' and self.direction != 'up':
            self.direction = 'down'
        elif direction == 'left' and self.direction != 'right':
            self.direction = 'left'
        elif direction == 'right' and self.direction != 'left':
            self.direction = 'right'

    def move(self):
        head = self.snake[0]
        if self.direction == 'up':
            new_head = (head[0], head[1] - 1)
        elif self.direction == 'down':
            new_head = (head[0], head[1] + 1)
        elif self.direction == 'left':
            new_head = (head[0] - 1, head[1])
        elif self.direction == 'right':
            new_head = (head[0] + 1, head[1])

        found_apple = False

        if (
            new_head[0] < 0 or new_head[0] >= self.width or
            new_head[1] < 0 or new_head[1] >= self.height or
            new_head in self.snake
        ):
            self.game_over = True
            return True, found_apple

        self.snake.insert(0, new_head)
        if new_head == self.apple:
            self.score += 1
            self.apple = self.generate_apple()
            found_apple = True
        else:
            self.snake.pop()

        return False, found_apple

    def get_state_index(self):
        head_pos = self.snake[0]
        apple_pos = self.apple
        return head_pos[0] * self.height + head_pos[1], apple_pos[0] * self.height + apple_pos[1]

    def select_action(self, state_index):
        if random.random() < self.epsilon:
            return random.randint(0, 3)
        else:
            return np.argmax(self.q_table[state_index])

    def update_q_table(self, state_index, action_index, reward, next_state_index):
        current_q_value = self.q_table[state_index][action_index]
        max_q_value = np.max(self.q_table[next_state_index])
        new_q_value = current_q_value + self.alpha * \
            (reward + self.gamma * max_q_value - current_q_value)
        self.q_table[state_index][action_index] = new_q_value

    def train(self, num_episodes):
        for episode in range(num_episodes):
            self.reset()

            while not self.game_over:
                state_index = self.get_state_index()
                action_index = self.select_action(state_index)

                self.change_direction(
                    ['up', 'down', 'left', 'right'][action_index])
                game_over, found_apple = self.move()

                if found_apple:
                    reward = 10
                elif game_over:
                    reward = -10
                else:
                    reward = -1

                next_state_index = self.get_state_index()
                self.update_q_table(state_index, action_index,
                                    reward, next_state_index)

            self.epsilon = max(
                self.min_epsilon, self.epsilon * self.epsilon_decay)

            print(f"Episode: {episode+1}/{num_episodes}, Score: {self.score}")

    def play(self):
        self.reset()

        while not self.game_over:
            state_index = self.get_state_index()
            action_index = np.argmax(self.q_table[state_index])

            self.change_direction(
                ['up', 'down', 'left', 'right'][action_index])
            self.move()

        print(f"Game Over! Score: {self.score}")

    def reset(self):
        self.snake = [(self.width // 2, self.height // 2)]
        self.direction = 'up'
        self.apple = self.generate_apple()
        self.score = 0
        self.game_over = False





In [3]:
# Test the QLearningSnake class
snake_game = QLearningSnake(width=20, height=20)
snake_game.train(num_episodes=100000)
snake_game.play()


Episode: 1/1000000, Score: 0
Episode: 2/1000000, Score: 0
Episode: 3/1000000, Score: 0
Episode: 4/1000000, Score: 1
Episode: 5/1000000, Score: 1
Episode: 6/1000000, Score: 1
Episode: 7/1000000, Score: 0
Episode: 8/1000000, Score: 0
Episode: 9/1000000, Score: 0
Episode: 10/1000000, Score: 0
Episode: 11/1000000, Score: 0
Episode: 12/1000000, Score: 0
Episode: 13/1000000, Score: 0
Episode: 14/1000000, Score: 0
Episode: 15/1000000, Score: 0
Episode: 16/1000000, Score: 0
Episode: 17/1000000, Score: 0
Episode: 18/1000000, Score: 0
Episode: 19/1000000, Score: 0
Episode: 20/1000000, Score: 0
Episode: 21/1000000, Score: 0
Episode: 22/1000000, Score: 0
Episode: 23/1000000, Score: 0
Episode: 24/1000000, Score: 0
Episode: 25/1000000, Score: 0
Episode: 26/1000000, Score: 0
Episode: 27/1000000, Score: 1
Episode: 28/1000000, Score: 0
Episode: 29/1000000, Score: 0
Episode: 30/1000000, Score: 0
Episode: 31/1000000, Score: 0
Episode: 32/1000000, Score: 0
Episode: 33/1000000, Score: 0
Episode: 34/1000000

KeyboardInterrupt: 