<a href="https://colab.research.google.com/github/AbhiJ2706/snake_game_reinforcement_learning/blob/main/SnakeRL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [23]:
import numpy as np
from random import randint
from random import choice
from random import uniform
from random import shuffle
from itertools import combinations
from copy import deepcopy
import math

In [24]:
n = 7

epsilon = 0.2

lr = 0.13

discount = 0.6

epochs = 250000

In [25]:
# all 4 directions for moving the snake
# CONVENTION: L U R D

dirs = ((0, -1), (-1, 0), (0, 1), (1, 0))

all_apples = [(x, y) for x in range(1, n - 1) for y in range(1, n - 1)]

In [26]:
grid = [[-1 for i in range(n)]] + [[-1] + [0 for i in range(n - 2)] + [-1] for i in range(n - 2)] + [[-1 for i in range(n)]]

shuffle(all_apples)
apple_pos = all_apples[0]
radius = ([*range(1, apple_pos[0] - 1), *range(apple_pos[0] + 1, n - 1)], 
          [*range(1, apple_pos[1] - 1), *range(apple_pos[1] + 1, n - 1)])
snake_pos = (choice(radius[0]), choice(radius[1]))

grid[apple_pos[0]][apple_pos[1]] = 1
grid[snake_pos[0]][snake_pos[1]] = -1

original_snake_pos = deepcopy(snake_pos)
original_apple_pos = deepcopy(apple_pos)

snake = [snake_pos]

In [27]:
states = list(set(map(lambda x: tuple(map(lambda y: 0 if y % 2 == 0 else -1, x)), combinations(range(8), 4))))

apple_states = []

for s in states:
    apple_states += [tuple(1 if j == i else s[i] for i in range(len(s))) for j in range(len(s))]

states += apple_states

for i in range(2):
    all_states = list(map(lambda x: [(*x, *a) for a in all_apples], states))

    states = []

    for s in all_states:
        states += s

states = {s: [0 for i in range(4)] for s in states}

In [28]:
def reward(s, a):
    global snake, apple_pos
    head = snake[0]
    if s[a] == 1:
        return 100
    elif s[a] == 0:
        new_head = (head[0] + dirs[a][0], head[1] + dirs[a][1])
        old_apple_dist = ((apple_pos[0] - head[0]) ** 2, (apple_pos[1] - head[1]) ** 2)
        new_apple_dist = ((apple_pos[0] - new_head[0]) ** 2, (apple_pos[1] - new_head[1]) ** 2)
        if old_apple_dist[0] < new_apple_dist[0] or old_apple_dist[1] < new_apple_dist[1]:
            return -10
        else:
            return 1
    else:
        return -100

def Q(s, a):
    global snake, grid, states, lr, discount, apple_pos
    if type(a) is range:
        return [Q(s, x) for x in a]
    else:
        head = (snake[0][0] + dirs[a][0], snake[0][1] + dirs[a][1])
        try:
            new_state = (grid[head[0]][head[1] - 1], grid[head[0] - 1][head[1]], grid[head[0]][head[1] + 1], grid[head[0] + 1][head[1]], *apple_pos, *head)
            return lr * (reward(s, a) + discount * max(states[new_state]) - states[s][a])
        except:
            return 0

def try_reference_spot(x, y):
    global grid
    try:
        z = grid[x][y]
        return True
    except:
        return False

def update_q_table(random=True):
    global snake, grid, states, apple_pos, epsilon
    x = uniform(0, 1) if random else 0
    head = snake[0]
    state = (grid[head[0]][head[1] - 1], grid[head[0] - 1][head[1]], grid[head[0]][head[1] + 1], grid[head[0] + 1][head[1]], *apple_pos, *head)
    if x < epsilon and random:
        move = randint(0, 3)
        while not try_reference_spot(snake[0][0] + dirs[move][0], snake[0][1] + dirs[move][1]):
            move = randint(0, 3)
        states[state][move] = (1 - lr) * states[state][move] + Q(state, move)
        return move
    else:
        move = max(enumerate(states[state], 0), key=lambda x: x[1])[0]
        states[state][move] = (1 - lr) * states[state][move] + Q(state, move)
        return move
    

In [29]:
def move_snake(snake, dir):
    return list(filter(lambda x: x is not None, [(snake[0][0] + dir[0], snake[0][1] + dir[1])] + \
            list(map(lambda x: snake[x[0]] if x[0] < len(snake) - 1 else None, \
                     enumerate(snake, 0)))))

def pretty_print_grid(grid):
    for y in grid:
        for xy in y:
            print(f"{xy}\t", end="")
        print()

def iterate(print_out=False, learning_phase=True):
    global snake, grid, apple_pos, original_apple_pos, original_snake_pos, dirs
    alive = True
    game_length = 0
    snake_length = 0

    while alive:
        move = update_q_table(random=learning_phase)
        snake = move_snake(snake, dirs[move])
        if grid[snake[0][0]][snake[0][1]] == -1:
            shuffle(all_apples)
            apple_pos = all_apples[0]
            snake_length = len(snake)
            radius = ([*range(1, apple_pos[0] - 1), *range(apple_pos[0] + 1, n - 1)], 
                    [*range(1, apple_pos[1] - 1), *range(apple_pos[1] + 1, n - 1)])
            snake_pos = (choice(radius[0]), choice(radius[1]))
            snake = [snake_pos]
            alive = False
        elif grid[snake[0][0]][snake[0][1]] == 1:
            radius = ([*range(1, snake[0][0] - 1), *range(snake[0][0] + 1, n - 1)], 
                        [*range(1, snake[0][1] - 1), *range(snake[0][1] + 1, n - 1)])
            new_x = choice(radius[0])
            new_y = choice(radius[1])
            while grid[new_x][new_y] != 0:
                new_x = choice(radius[0])
                new_y = choice(radius[1])
            apple_pos = (new_x, new_y)
            tail = snake[-1]
            state = (grid[tail[0]][tail[1] - 1], grid[tail[0] - 1][tail[1]], grid[tail[0]][tail[1] + 1], grid[tail[0] + 1][tail[1]])
            for i in range(4):
                if state[i] == 0:
                    grid[tail[0] + dirs[i][0]][tail[1] + dirs[i][1]] = -1
                    new_state = (grid[snake[0][0]][snake[0][1] - 1], grid[snake[0][0] - 1][snake[0][1]], grid[snake[0][0]][snake[0][1] + 1], grid[snake[0][0] + 1][snake[0][1]])
                    if new_state == (-1, -1, -1, -1): continue
                    snake.append((tail[0] + dirs[i][0], tail[1] + dirs[i][1]))
                    break
        grid = [[-1 for i in range(n)]] + [[-1] + [0 for i in range(n - 2)] + [-1] for i in range(n - 2)] + [[-1 for i in range(n)]]
        grid[apple_pos[0]][apple_pos[1]] = 1
        if not learning_phase: print(snake)
        if alive:
            game_length += 1
            for s in snake:
                grid[s[0]][s[1]] = -1
            if print_out: 
                pretty_print_grid(grid)
        else:
            if learning_phase: 
                print("Length of game:", game_length, "| Length of snake:", snake_length, "| Apples found:", snake_length - 1)


In [30]:
for i in range(epochs):
    print(i, end=" ")
    iterate()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
245000 Length of game: 0 | Length of snake: 1 | Apples found: 0
245001 Length of game: 11 | Length of snake: 2 | Apples found: 1
245002 Length of game: 33 | Length of snake: 5 | Apples found: 4
245003 Length of game: 12 | Length of snake: 3 | Apples found: 2
245004 Length of game: 12 | Length of snake: 2 | Apples found: 1
245005 Length of game: 20 | Length of snake: 4 | Apples found: 3
245006 Length of game: 7 | Length of snake: 2 | Apples found: 1
245007 Length of game: 9 | Length of snake: 1 | Apples found: 0
245008 Length of game: 19 | Length of snake: 3 | Apples found: 2
245009 Length of game: 14 | Length of snake: 2 | Apples found: 1
245010 Length of game: 4 | Length of snake: 2 | Apples found: 1
245011 Length of game: 10 | Length of snake: 3 | Apples found: 2
245012 Length of game: 6 | Length of snake: 1 | Apples found: 0
245013 Length of game: 11 | Length of snake: 1 | Apples found: 0
245014 Length of game: 14 | Le

In [31]:
for i in range(2):
    iterate(print_out=True, learning_phase=False)
    print("-" * 100)

[(1, 3)]
-1	-1	-1	-1	-1	-1	-1	
-1	0	0	-1	0	0	-1	
-1	0	0	0	0	0	-1	
-1	0	0	0	0	1	-1	
-1	0	0	0	0	0	-1	
-1	0	0	0	0	0	-1	
-1	-1	-1	-1	-1	-1	-1	
[(1, 4)]
-1	-1	-1	-1	-1	-1	-1	
-1	0	0	0	-1	0	-1	
-1	0	0	0	0	0	-1	
-1	0	0	0	0	1	-1	
-1	0	0	0	0	0	-1	
-1	0	0	0	0	0	-1	
-1	-1	-1	-1	-1	-1	-1	
[(1, 5)]
-1	-1	-1	-1	-1	-1	-1	
-1	0	0	0	0	-1	-1	
-1	0	0	0	0	0	-1	
-1	0	0	0	0	1	-1	
-1	0	0	0	0	0	-1	
-1	0	0	0	0	0	-1	
-1	-1	-1	-1	-1	-1	-1	
[(2, 5)]
-1	-1	-1	-1	-1	-1	-1	
-1	0	0	0	0	0	-1	
-1	0	0	0	0	-1	-1	
-1	0	0	0	0	1	-1	
-1	0	0	0	0	0	-1	
-1	0	0	0	0	0	-1	
-1	-1	-1	-1	-1	-1	-1	
[(3, 5), (3, 4)]
-1	-1	-1	-1	-1	-1	-1	
-1	0	0	0	0	0	-1	
-1	0	0	0	0	0	-1	
-1	0	0	0	-1	-1	-1	
-1	0	0	0	0	0	-1	
-1	0	0	1	0	0	-1	
-1	-1	-1	-1	-1	-1	-1	
[(4, 5), (3, 5)]
-1	-1	-1	-1	-1	-1	-1	
-1	0	0	0	0	0	-1	
-1	0	0	0	0	0	-1	
-1	0	0	0	0	-1	-1	
-1	0	0	0	0	-1	-1	
-1	0	0	1	0	0	-1	
-1	-1	-1	-1	-1	-1	-1	
[(5, 5), (4, 5)]
-1	-1	-1	-1	-1	-1	-1	
-1	0	0	0	0	0	-1	
-1	0	0	0	0	0	-1	
-1	0	0	0	0	0	-1	
-1	0	0	0	0	-1	-1	
-1	0	0	1	0	-1	-1	
-1	-1	-1	-1	-1	-1	-1	
