<a href="https://colab.research.google.com/github/AbhiJ2706/snake_game_reinforcement_learning/blob/main/SnakeRL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SNAKE GAME (ft. reinforcement learning)

This is an ML model designed to play the popular [snake game](https://g.co/kgs/HGzzUC). It utilizes a Q-learning algorithm to learn the game. Essentially, in Q-learning, the game is composed of 2 pieces: an agent, which is the main "character" or "game piece" which plays the game, and the environment, where the agent exists. The algorithm updates a set of weights which dictate how the agent reacts in each possible state in the game. The algorithm assigns a reward for each action that can be taken at each state. This reward is used to determine the best action at each state, and update the weights accordingly.

How the algorithm works:

1. at each state we either choose the best action by looking at the weights, or choose a random action. We then update the weights at that state depending on the outcome, using the Q-learning formula.
2. if the game ends, it is reset and the next iteration of learning begins, with the updated weights. The position of the snake and the apple is randomized. The weights keep getting tuned over time this way.
3. if the game does not end, the next turn begins and step 1 s repeated until step 2 occurs.

Step 1: import libraries

In [None]:
import numpy as np
from random import randint
from random import choice
from random import uniform
from random import shuffle
from itertools import combinations
from copy import deepcopy
import math

Step 2: configure hyperparameters

There are several hyperparameters needed to configure the model.

* n: the size of the grid. This affects the number of states in the game. Thus the other parameters need to be configured around it accordingly.
* epsilon: the chance that a random action is taken over the best action. This helps the agent learn, especially at the beginning of training, as all weights are initially 0.
* lr: the learning rate, affects the impact of the reward on the tuning of weights
* discount: how much to consider the outcome of the best action from the next state
* epochs: number of training iterations to perform

In [None]:
n = 7

epsilon = 0.2

lr = 0.13

discount = 0.6

epochs = 250000

In [None]:
# all 4 directions for moving the snake
# CONVENTION: L U R D

dirs = ((0, -1), (-1, 0), (0, 1), (1, 0))

all_apples = [(x, y) for x in range(1, n - 1) for y in range(1, n - 1)]

Step 3: grid setup

In [None]:
grid = [[-1 for i in range(n)]] + [[-1] + [0 for i in range(n - 2)] + [-1] for i in range(n - 2)] + [[-1 for i in range(n)]]

shuffle(all_apples)
apple_pos = all_apples[0]
radius = ([*range(1, apple_pos[0] - 1), *range(apple_pos[0] + 1, n - 1)], 
          [*range(1, apple_pos[1] - 1), *range(apple_pos[1] + 1, n - 1)])
snake_pos = (choice(radius[0]), choice(radius[1]))

grid[apple_pos[0]][apple_pos[1]] = 1
grid[snake_pos[0]][snake_pos[1]] = -1

original_snake_pos = deepcopy(snake_pos)
original_apple_pos = deepcopy(apple_pos)

snake = [snake_pos]

Step 4: building states

In the case of the snake game, we have the following convention for states:

* 1 represents the apple
* -1 represents a square which would end the game, e.g. the snake's tail or a wall
* 0 is an empty space

Since the snake can only move up, left, down, or right, at each state we only care about the square above, to the left, to the right, and below the snake's head. We also need to keep in mind where the head is and where the apple is in order to force the agent towards the apple. Thus the states are constructed as such:

(square left, square above, square right, square below, apple x, apple y, head x, head y)

The numebr of states is the number of possible (left, up, right, down) combinations * the number of apple positions * the number of head positions. This can grow extremely quickly with n.

In [None]:
states = list(set(map(lambda x: tuple(map(lambda y: 0 if y % 2 == 0 else -1, x)), combinations(range(8), 4))))

apple_states = []

for s in states:
    apple_states += [tuple(1 if j == i else s[i] for i in range(len(s))) for j in range(len(s))]

states += apple_states

for i in range(2):
    all_states = list(map(lambda x: [(*x, *a) for a in all_apples], states))

    states = []

    for s in all_states:
        states += s

states = {s: [0 for i in range(4)] for s in states}

Step 5: reinforcement learning methods

The following methods are for determining the rewards and updating the q table containing the states.

Essentially, a positive reward is given for getting to the apple or getting closer to it, a negative reward is given for getting farther from the apple or hitting a wall or the snake's tail. 

In [None]:
def reward(s, a):
    global snake, apple_pos
    head = snake[0]
    if s[a] == 1:
        return 100
    elif s[a] == 0:
        new_head = (head[0] + dirs[a][0], head[1] + dirs[a][1])
        old_apple_dist = ((apple_pos[0] - head[0]), (apple_pos[1] - head[1]))
        new_apple_dist = ((apple_pos[0] - new_head[0]), (apple_pos[1] - new_head[1]))
        if old_apple_dist[0] < new_apple_dist[0] or old_apple_dist[1] < new_apple_dist[1]:
            return -10
        else:
            return 1
    else:
        return -100

def Q(s, a):
    global snake, grid, states, lr, discount, apple_pos
    if type(a) is range:
        return [Q(s, x) for x in a]
    else:
        head = (snake[0][0] + dirs[a][0], snake[0][1] + dirs[a][1])
        try:
            new_state = (grid[head[0]][head[1] - 1], 
                         grid[head[0] - 1][head[1]], 
                         grid[head[0]][head[1] + 1], 
                         grid[head[0] + 1][head[1]], *apple_pos, *head)
            return lr * (reward(s, a) + discount * max(states[new_state]) - states[s][a])
        except:
            return 0

def try_reference_spot(x, y):
    global grid
    try:
        z = grid[x][y]
        return True
    except:
        return False

def update_q_table(random=True):
    global snake, grid, states, apple_pos, epsilon
    x = uniform(0, 1) if random else 0
    head = snake[0]
    state = (grid[head[0]][head[1] - 1], 
             grid[head[0] - 1][head[1]], 
             grid[head[0]][head[1] + 1], 
             grid[head[0] + 1][head[1]], *apple_pos, *head)
    if x < epsilon and random:
        move = randint(0, 3)
        while not try_reference_spot(snake[0][0] + dirs[move][0], snake[0][1] + dirs[move][1]):
            move = randint(0, 3)
        states[state][move] = (1 - lr) * states[state][move] + Q(state, move)
        return move
    else:
        move = max(enumerate(states[state], 0), key=lambda x: x[1])[0]
        states[state][move] = (1 - lr) * states[state][move] + Q(state, move)
        return move
    

Step 6: game function

This is the main loop for the game which handles all the game mechanics.

In [None]:
def move_snake(snake, dir):
    return list(filter(lambda x: x is not None, [(snake[0][0] + dir[0], snake[0][1] + dir[1])] + \
            list(map(lambda x: snake[x[0]] if x[0] < len(snake) - 1 else None, \
                     enumerate(snake, 0)))))

def pretty_print_grid(grid):
    global n, snake
    for i, y in enumerate(grid, 0):
        for j, xy in enumerate(y, 0):
            if i == 0 or i == n - 1:
                print("w\t", end="")
            elif j == 0 or j == n - 1:
                print("w\t", end="")
            elif xy == -1 and i == snake[0][0] and j == snake[0][1]:
                print("h\t", end="")
            elif xy == 1:
                print("A\t", end="")
            elif xy == -1:
                print("s\t", end="")
            else:
                print(f" \t", end="")
        print()

def play_game(print_out=False, learning_phase=True):
    global snake, grid, apple_pos, original_apple_pos, original_snake_pos, dirs
    alive = True
    game_length = 0
    snake_length = 0

    while alive:
        move = update_q_table(random=learning_phase)
        snake = move_snake(snake, dirs[move])
        if grid[snake[0][0]][snake[0][1]] == -1:
            shuffle(all_apples)
            apple_pos = all_apples[0]
            snake_length = len(snake)
            radius = ([*range(1, apple_pos[0] - 1), *range(apple_pos[0] + 1, n - 1)], 
                    [*range(1, apple_pos[1] - 1), *range(apple_pos[1] + 1, n - 1)])
            snake_pos = (choice(radius[0]), choice(radius[1]))
            snake = [snake_pos]
            alive = False
        elif grid[snake[0][0]][snake[0][1]] == 1:
            radius = ([*range(1, snake[0][0] - 1), *range(snake[0][0] + 1, n - 1)], 
                        [*range(1, snake[0][1] - 1), *range(snake[0][1] + 1, n - 1)])
            new_x = choice(radius[0])
            new_y = choice(radius[1])
            while grid[new_x][new_y] != 0:
                new_x = choice(radius[0])
                new_y = choice(radius[1])
            apple_pos = (new_x, new_y)
            tail = snake[-1]
            state = (grid[tail[0]][tail[1] - 1], grid[tail[0] - 1][tail[1]], grid[tail[0]][tail[1] + 1], grid[tail[0] + 1][tail[1]])
            for i in range(4):
                if state[i] == 0:
                    grid[tail[0] + dirs[i][0]][tail[1] + dirs[i][1]] = -1
                    new_state = (grid[snake[0][0]][snake[0][1] - 1], grid[snake[0][0] - 1][snake[0][1]], grid[snake[0][0]][snake[0][1] + 1], grid[snake[0][0] + 1][snake[0][1]])
                    if new_state == (-1, -1, -1, -1): continue
                    snake.append((tail[0] + dirs[i][0], tail[1] + dirs[i][1]))
                    break
        grid = [[-1 for i in range(n)]] + [[-1] + [0 for i in range(n - 2)] + [-1] for i in range(n - 2)] + [[-1 for i in range(n)]]
        grid[apple_pos[0]][apple_pos[1]] = 1
        if not learning_phase: print(snake)
        if alive:
            game_length += 1
            for s in snake:
                grid[s[0]][s[1]] = -1
            if print_out: 
                pretty_print_grid(grid)
        else:
            if learning_phase: 
                print("Length of game:", game_length, "| Length of snake:", snake_length, "| Apples found:", snake_length - 1)


Step 7: training

In [None]:
for i in range(epochs):
    print(i, end=" ")
    play_game(learning_phase=True)

Step 8: play the game!

The board is pretty printed for convenience.

* A: apple
* h: snake head
* s: snake tail
* w: wall

In [None]:
play_game(print_out=True, learning_phase=False)
print("-" * 100)