# AI plays snake game
In this Python Reinforcement Learning course you will learn how to teach an AI to play Snake! We build everything from scratch using Flask websocket and PyTorch. Sources: 
- [Video](https://ripper.linq-it.com/#/player;type=video;uid=c3be4300-94c6-4303-b8f3-2383bd4794ae) and [https://github.com/python-engineer/snake-ai-pytorch](https://github.com/python-engineer/snake-ai-pytorch)
- [https://marketsplash.com/tutorials/flask/how-to-use-flask-with-websockets/](https://marketsplash.com/tutorials/flask/how-to-use-flask-with-websockets/)

## Setup conda environment
Actions:
- conda create -n snake python=3.12
- conda activate snake
- pip install -U Flask flask-socketio eventlet
- pip install opencv-python
- pip install torch torchvision

In [12]:
from flask import Flask, render_template, request
from flask_socketio import SocketIO
import json

import random
from enum import Enum
from collections import namedtuple, deque

import torch
import random
import numpy as np
from skimage.morphology import flood_fill

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os

## Agent - Game - Model
<img title="Agent Game Model" src="images/agent-game-model.png" width="600px">

## Game
<img title="Reward" src="images/reward.png" width="400px">

In [13]:
class Direction(Enum):
    RIGHT = 1
    LEFT = 2
    UP = 3
    DOWN = 4

class Action(Enum):
    STRAIGHT = 1
    LEFT = 2
    RIGHT = 3

Point = namedtuple('Point', 'x, y')

In [14]:
SKULL = 9
JOINT = 8
FOOD = 5
FLOOD = 1

class SnakeGameAI:

    def __init__(self, w, h):
        self.w = w
        self.h = h
        self.game_area = np.zeros((self.w, self.h), dtype=int) 
        self.reset()


    def reset(self):
        # init game state
        self.direction = Direction.RIGHT

        self.head = Point(round(self.w/2), round(self.h/2))
        self.snake = [self.head,
                      Point(self.head.x-1, self.head.y),
                      Point(self.head.x-2, self.head.y)]

        self.score = 0
        self.food = None
        self._place_food()
        self.frame_iteration = 0
        self.interrupted = False


    def _place_food(self):
        x = random.randint(0, self.w - 1)
        y = random.randint(0, self.h - 1)
        self.food = Point(x, y)
        if self.food in self.snake:
            self._place_food()


    def play_step(self, action):
        self.frame_iteration += 1

        # 1. move
        self._move(action) # update the head
        self.snake.insert(0, self.head)
        
        # 2. check if game over
        reward = 0
        game_over = False
        if self.is_collision() or self.frame_iteration > 100*len(self.snake):
            game_over = True
            reward = -10
            return reward, game_over, self.score

        # 3. place new food or just move
        if self.head == self.food:
            self.score += 1
            reward = 10
            self._place_food()
        else:
            self.snake.pop()
        
        # 4. update ui 
        self._update_game_area()

        # 5. return game over and score
        return reward, game_over, self.score

    def is_collision(self, pt=None):
        if pt is None:
            pt = self.head
        # hits boundary
        if pt.x > self.w - 1 or pt.x < 0 or pt.y > self.h - 1 or pt.y < 0:
            return True
        # hits itself
        if pt in self.snake[1:]:
            return True

        return False


    def _update_game_area(self):
        self.game_area = np.zeros((self.w, self.h), dtype=np.uint8) 

        for i, pt in enumerate(self.snake):
            if (i == 0): # The head is on position 0
                self.game_area[pt.x, pt.y] = SKULL
            else:
                self.game_area[pt.x, pt.y] = JOINT
            
        self.game_area[self.food.x, self.food.y] = FOOD 

    def _move(self, action):
        # [straight, right, left]

        clock_wise = [Direction.RIGHT, Direction.DOWN, Direction.LEFT, Direction.UP]
        idx = clock_wise.index(self.direction)

        if np.array_equal(action, [1, 0, 0]):
            new_dir = clock_wise[idx] # no change because direction is straight
        elif np.array_equal(action, [0, 1, 0]):
            next_idx = (idx + 1) % 4
            new_dir = clock_wise[next_idx] # right turn r -> d -> l -> u
        else: # [0, 0, 1]
            next_idx = (idx - 1) % 4
            new_dir = clock_wise[next_idx] # left turn r -> u -> l -> d

        self.direction = new_dir

        x = self.head.x
        y = self.head.y
        if self.direction == Direction.RIGHT:
            x += 1
        elif self.direction == Direction.LEFT:
            x -= 1
        elif self.direction == Direction.DOWN:
            y += 1
        elif self.direction == Direction.UP:
            y -= 1

        self.head = Point(x, y)

## Model
<img title="Model" src="images/model.png" width="400px">

In [15]:
class Linear_QNet(nn.Module):
    def __init__(self, inputSize, hiddenSize, outputSize):
        super().__init__()
        self.linear1 = nn.Linear(inputSize, hiddenSize)
        self.linear2 = nn.Linear(hiddenSize, outputSize)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

<img title="Learning procedure" src="images/procedure.png" width="400px">

<img title="Bellman equation" src="images/bellman-equation.png" width="800px">

<img title="Q rule simplified" src="images/q-rule-simplified.png" width="400px">

<img title="Loss function" src="images/loss-function.png" width="400px">

## Training
<img title="Action" src="images/action.png" width="400px">

<img title="State" src="images/state.png" width="800px">

In [16]:
class QTrainer:
    def __init__(self, model, lr, lrDecayRate, gamma):
        self.lr = lr
        self.lrDecayRate = lrDecayRate
        self.gamma = gamma
        self.model = model
        self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()

    def train_step(self, state, action, reward, next_state, game_over):
        state = torch.tensor(state, dtype=torch.float)
        next_state = torch.tensor(next_state, dtype=torch.float)
        action = torch.tensor(action, dtype=torch.long)
        reward = torch.tensor(reward, dtype=torch.float)
        # (n, x)

        if len(state.shape) == 1:
            # (1, x)
            state = torch.unsqueeze(state, 0)
            next_state = torch.unsqueeze(next_state, 0)
            action = torch.unsqueeze(action, 0)
            reward = torch.unsqueeze(reward, 0)
            game_over = (game_over, )

        # 1: predicted Q values with current state
        pred = self.model(state)

        target = pred.clone()
        for idx in range(len(game_over)):
            Q_new = reward[idx]
            if not game_over[idx]:
                Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))

            target[idx][torch.argmax(action[idx]).item()] = Q_new
    
        # 2: Q_new = r + y * max(next_predicted Q value) -> only do this if not game_over
        # pred.clone()
        # preds[argmax(action)] = Q_new
        self.optimizer.zero_grad()
        loss = self.criterion(target, pred)
        loss.backward()

        self.optimizer.step()

        self.lr = self.lr * self.lrDecayRate
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr


## Agent

The snake can get trapped in itself, here 3 examples (D=danger):

<img title="Trapped" src="images/trapped.png" width="600px">

To solve this, I am going to use the flood function to calculate the area which is run after the danger_... calculations. 
The [skimage.segmentation.flood_fill](https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.flood_fill) is used on a copy of game.area to calculate the cells available for the adjacent cells not indicated a dangerous yet. If the area is smaller than the length of the snake then this direction is marked as dangerous too.

In [17]:
class Agent:

    def __init__(self, hiddenLayerSize, modelWeights, gamma, epsilon, epsilonDecayRate, lr, lrDecayRate, maxMemorySize, batchSize):
        self.n_games = 0
        self.epsilon = epsilon # randomness
        self.epsilonDecayRate = epsilonDecayRate
        self.gamma = gamma # discount rate
        self.memory = deque(maxlen=maxMemorySize) # popleft()
        self.model = Linear_QNet(inputSize=11, hiddenSize=hiddenLayerSize, outputSize=3)
        if modelWeights is not None:
            self.model.load_state_dict(modelWeights)
        self.batch_size = batchSize
        self.trainer = QTrainer(self.model, lr=lr, lrDecayRate=lrDecayRate, gamma=self.gamma)

    def _trapped_danger(self, game, direction, adjecent_position):
        head = game.snake[0]
        area = np.copy(game.game_area)
        if direction == [1, 0 ,0 ,0]: #up
            if adjecent_position == Action.LEFT:
                if (game.game_area[head.x - 1, head.y] == FOOD):
                    return False
                area = flood_fill(area, (head.x - 1, head.y), FLOOD, connectivity=1)
            elif adjecent_position == Action.STRAIGHT:
                if (game.game_area[head.x, head.y - 1] == FOOD):
                    return False
                area = flood_fill(area, (head.x, head.y - 1), FLOOD, connectivity=1)
            elif adjecent_position == Action.RIGHT:
                if (game.game_area[head.x + 1, head.y] == FOOD):
                    return False
                area = flood_fill(area, (head.x + 1, head.y), FLOOD, connectivity=1)
            pass
        elif direction == [0, 1, 0, 0]: #right
            if adjecent_position == Action.LEFT:
                if (game.game_area[head.x, head.y - 1] == FOOD):
                    return False
                area = flood_fill(area, (head.x, head.y - 1), FLOOD, connectivity=1)
            elif adjecent_position == Action.STRAIGHT:
                if (game.game_area[head.x + 1, head.y] == FOOD):
                    return False
                area = flood_fill(area, (head.x + 1, head.y), FLOOD, connectivity=1)          
            elif adjecent_position == Action.RIGHT:
                if (game.game_area[head.x, head.y + 1] == FOOD):
                    return False
                area = flood_fill(area, (head.x, head.y + 1), FLOOD, connectivity=1)
            pass
        elif direction == [0, 0, 1, 0]: #down
            if adjecent_position == Action.LEFT:
                if (game.game_area[head.x + 1, head.y] == FOOD):
                    return False
                area = flood_fill(area, (head.x + 1, head.y), FLOOD, connectivity=1)
            elif adjecent_position == Action.STRAIGHT:
                if (game.game_area[head.x, head.y + 1] == FOOD):
                    return False
                area = flood_fill(area, (head.x, head.y + 1), FLOOD, connectivity=1)
            elif adjecent_position == Action.RIGHT:
                if (game.game_area[head.x - 1, head.y] == FOOD):
                    return False
                area = flood_fill(area, (head.x - 1, head.y), FLOOD, connectivity=1)
            pass
        elif direction == [0, 0, 0, 1]: #left
            if adjecent_position == Action.LEFT:
                if (game.game_area[head.x, head.y + 1] == FOOD):
                    return False
                area = flood_fill(area, (head.x, head.y + 1), FLOOD, connectivity=1)
            elif adjecent_position == Action.STRAIGHT:
                if (game.game_area[head.x - 1, head.y] == FOOD):
                    return False
                area = flood_fill(area, (head.x - 1, head.y), FLOOD, connectivity=1)
            elif adjecent_position == Action.RIGHT:
                if (game.game_area[head.x, head.y - 1] == FOOD):
                    return False
                area = flood_fill(area, (head.x, head.y - 1), FLOOD, connectivity=1)
            pass
        enclosed_cells = np.count_nonzero(area == FLOOD)
        snake_length = np.count_nonzero(area == FLOOD) - 1
        if (enclosed_cells < snake_length):
            return True
        else:
            return False
    
    def get_state(self, game):
        head = game.snake[0]
        point_l = Point(head.x - 1, head.y)
        point_r = Point(head.x + 1, head.y)
        point_u = Point(head.x, head.y - 1)
        point_d = Point(head.x, head.y + 1)
        
        dir_l = game.direction == Direction.LEFT
        dir_r = game.direction == Direction.RIGHT
        dir_u = game.direction == Direction.UP
        dir_d = game.direction == Direction.DOWN

        danger_straight = \
            (dir_r and game.is_collision(point_r)) or \
            (dir_l and game.is_collision(point_l)) or \
            (dir_u and game.is_collision(point_u)) or \
            (dir_d and game.is_collision(point_d))
        
        danger_right = \
            (dir_u and game.is_collision(point_r)) or \
            (dir_d and game.is_collision(point_l)) or \
            (dir_l and game.is_collision(point_u)) or \
            (dir_r and game.is_collision(point_d))
        
        danger_left = \
            (dir_d and game.is_collision(point_r)) or \
            (dir_u and game.is_collision(point_l)) or \
            (dir_r and game.is_collision(point_u)) or \
            (dir_l and game.is_collision(point_d))
        
        if (not danger_straight):
            danger_straight = self._trapped_danger(game, [dir_u, dir_r, dir_d, dir_l], Action.STRAIGHT)

        if (not danger_right):
            danger_right = self._trapped_danger(game, [dir_u, dir_r, dir_d, dir_l], Action.RIGHT)

        if (not danger_left):
            danger_left = self._trapped_danger(game, [dir_u, dir_r, dir_d, dir_l], Action.LEFT)

        state = [
            # Danger around head
            danger_straight,
            danger_right,
            danger_left,
            
            # Move direction
            dir_l,
            dir_r,
            dir_u,
            dir_d,
            
            # Food location 
            game.food.x < game.head.x,  # food left
            game.food.x > game.head.x,  # food right
            game.food.y < game.head.y,  # food up
            game.food.y > game.head.y  # food down
            ]

        return np.array(state, dtype=int)

    def remember(self, state, action, reward, next_state, game_over):
        self.memory.append((state, action, reward, next_state, game_over)) # popleft if MAX_MEMORY is reached

    def train_long_memory(self):
        if len(self.memory) > self.batch_size:
            mini_sample = random.sample(self.memory, self.batch_size) # list of tuples
        else:
            mini_sample = self.memory

        states, actions, rewards, next_states, game_overs = zip(*mini_sample)
        self.trainer.train_step(states, actions, rewards, next_states, game_overs)
        #for state, action, reward, nexrt_state, game_over in mini_sample:
        #    self.trainer.train_step(state, action, reward, next_state, game_over)

    def train_short_memory(self, state, action, reward, next_state, game_over):
        self.trainer.train_step(state, action, reward, next_state, game_over)

    def get_action(self, state):
        # random moves: tradeoff exploration / exploitation
        final_move = [0,0,0]
        if random.uniform(0, 1) < self.epsilon:
            move = random.randint(0, 2)
            final_move[move] = 1
        else:
            state0 = torch.tensor(np.array(state), dtype=torch.float)
            prediction = self.model(state0)
            move = torch.argmax(prediction).item()
            final_move[move] = 1
        return final_move
    
    def adjust_epsilon(self):
        #self.epsilon = self.epsilon * self.epsilonDecayRate
        self.epsilon = 0.4 - self.n_games / 200

# Run the trainer

In [18]:
app = Flask(__name__)
socketio = SocketIO(app, cors_allowed_origins="*", manage_session=False)
session_id = ''

def reset(modelName, modelWeights, width, height, hiddenLayerSize, gamma, epsilon, epsilonDecayRate, lr, lrDecayRate, maxMemorySize, batchSize):
    global model_name, scores, mean_scores, total_score, record, game, agent, model_name, epsilonDecay, lrDecay
    print('################################')
    if modelWeights is not None:
        print('Model Weights are provided')
    print('Model Name:', modelName)
    print('Width:', width)
    print('Height:', height)
    print('Hidden Layer Size:', hiddenLayerSize)
    print('Gamma:', gamma)
    print('Epsilon:', epsilon)
    print('Epsilon Decay rate:', epsilonDecayRate)
    print('Max Memory Size:', maxMemorySize)
    print('Learning Rate:', lr)
    print('Learning Rate Decay rate:', lrDecayRate)
    print('Batch Size:', batchSize)
    model_name = modelName
    scores = []
    mean_scores = []
    epsilonDecay = []
    lrDecay = []
    total_score = 0
    record = 0
    game = SnakeGameAI(width, height)
    game.reset()
    agent = Agent(hiddenLayerSize=hiddenLayerSize, modelWeights=modelWeights, gamma=gamma, epsilon=epsilon, epsilonDecayRate=epsilonDecayRate, lr=lr, lrDecayRate=lrDecayRate, maxMemorySize=maxMemorySize, batchSize=batchSize)
    send_model_loaded()
    send_game_state()
    send_progress()

def send_models():
    data = {
            'models': [os.path.splitext(model)[0] for model in os.listdir('./model')],
        }
    json_str = json.dumps(data)
    socketio.emit('models', json_str)

def send_model_loaded():
    data = {
            'model': model_name,
            'hiddenLayerSize': agent.model.linear1.out_features,
            'gamma': agent.gamma,
            'epsilon': agent.epsilon,
            'maxMemorySize': agent.memory.maxlen,
            'learningRate': agent.trainer.lr,
            'batchSize': agent.batch_size,
        }
    json_str = json.dumps(data)
    socketio.emit('model_loaded', json_str)

def send_game_state():
    data = {
            'score': game.score,
            'snake': [list(pt) for pt in game.snake],  # Assuming 'snake' is a list of tuples
            'food': list(game.food)  # Assuming 'food' is a tuple
        }
    json_str = json.dumps(data)
    socketio.emit('state', json_str)

def send_progress():
    data = {
            'record': record,
            'scores': scores,
            'meanScores': mean_scores,
            'epsilonDecay': epsilonDecay,
            'learningRateDecay': lrDecay
        }
    json_str = json.dumps(data)
    socketio.emit('progress', json_str)


@app.route('/')
def index():
    return render_template('index-websocket.html')


@socketio.on('connect')
def handle_connect():
    global session_id
    if session_id == '':
        session_id = request.sid
    send_models()


@socketio.on('disconnect')
def handle_disconnect():
    global session_id
    session_id = ''


@socketio.on('load')
def handle_model_load(json_str):
    global scores, mean_scores, total_score, record, game, agent
    print('handle_model_load', json_str)
    json_data = json.loads(json_str)
    
    model_name = json_data['modelName']
    width = json_data['width'] if 'width' in json_data else 36
    height = json_data['height'] if 'height' in json_data else 24
    model_path = os.path.join('./model', model_name + '.pth')
    if os.path.exists(model_path):
        modelWeights, epsilonDecay, lrDecay, epsilonDecayRate, lrDecayRate, gamma, batchSize, maxMemorySize, hiddenLayerSize = torch.load(model_path)
        lr = lrDecay[-1]
        epsilon = epsilonDecay[-1]
        reset(modelName=model_name, modelWeights=modelWeights, width=width, height=height, hiddenLayerSize=hiddenLayerSize, gamma=gamma,
              epsilon=epsilon, epsilonDecayRate=epsilonDecayRate, lr=lr, lrDecayRate=lrDecayRate, maxMemorySize=maxMemorySize, batchSize=batchSize)


@socketio.on('init')
def handle_init(json_str):
    json_data = json.loads(json_str)
    modelName = json_data['modelName']
    width = json_data['width']
    height = json_data['height']
    hiddenLayerSize = json_data['hiddenLayerSize']
    gamma = json_data['gamma']
    epsilon = json_data['epsilon']
    epsilonDecayRate = json_data['epsilonDecayRate']
    maxMemorySize = json_data['maxMemorySize']
    lr = json_data['lr']
    lrDecayRate = json_data['lrDecayRate']
    batchSize = json_data['batchSize']

    reset(modelName=modelName, modelWeights=None, width=width, height=height, hiddenLayerSize=hiddenLayerSize, gamma=gamma,
            epsilon=epsilon, epsilonDecayRate=epsilonDecayRate, lr=lr, lrDecayRate=lrDecayRate, maxMemorySize=maxMemorySize, batchSize=batchSize)


@socketio.on('step')
def handle_step():
            global scores, mean_scores, total_score, record, game, agent

            state_old = agent.get_state(game)

            # get move
            final_move = agent.get_action(state_old)

            # perform move and get new state
            reward, game_over, score = game.play_step(final_move)

            state_new = agent.get_state(game)

            # train short memory
            agent.train_short_memory(state_old, final_move, reward, state_new, game_over)

            # remember
            agent.remember(state_old, final_move, reward, state_new, game_over)

            send_game_state()

            if game_over:
                # train long memory, plot result
                game.reset()
                agent.n_games += 1
                agent.train_long_memory()
                agent.adjust_epsilon()
                epsilonDecay.append(agent.epsilon)
                lrDecay.append(agent.trainer.lr)

                if score > record:
                    record = score
                    model_folder_path = './model'
                    if not os.path.exists(model_folder_path):
                        os.makedirs(model_folder_path)

                    object_to_save = (agent.model.state_dict(), epsilonDecay, lrDecay, agent.epsilonDecayRate, agent.trainer.lrDecayRate, agent.gamma, agent.batch_size, agent.memory.maxlen, agent.model.linear1.out_features)
                    file_name = os.path.join(model_folder_path, model_name + '.pth')
                    torch.save(object_to_save, file_name)
                    agent.adjust_epsilon()

                scores.append(score)
                total_score += score
                mean_score = total_score / agent.n_games
                mean_scores.append(mean_score)
                send_progress()

In [19]:
if __name__ == '__main__':
    socketio.run(app, debug=False, port=5001)

Traceback (most recent call last):
  File "/usr/local/Caskroom/miniforge/base/envs/snake/lib/python3.12/site-packages/eventlet/hubs/hub.py", line 471, in fire_timers
    timer()
  File "/usr/local/Caskroom/miniforge/base/envs/snake/lib/python3.12/site-packages/eventlet/hubs/timer.py", line 59, in __call__
    cb(*args, **kw)
  File "/usr/local/Caskroom/miniforge/base/envs/snake/lib/python3.12/site-packages/eventlet/greenthread.py", line 265, in main
    result = function(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/Caskroom/miniforge/base/envs/snake/lib/python3.12/site-packages/socketio/server.py", line 586, in _handle_event_internal
    r = server._trigger_event(data[0], namespace, sid, *data[1:])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/Caskroom/miniforge/base/envs/snake/lib/python3.12/site-packages/socketio/server.py", line 611, in _trigger_event
    return handler(*args)
           ^^^^^^^^^^^^^^
  File "/usr

################################
Model Name: model_2
Width: 36
Height: 24
Hidden Layer Size: 256
Gamma: 0.9
Epsilon: 0.5
Epsilon Decay rate: 0.97
Max Memory Size: 100000
Learning Rate: 0.001
Learning Rate Decay rate: 1
Batch Size: 1000
handle_model_load {"modelName":"model_2","width":36,"height":24}
################################
Model Weights are provided
Model Name: model_2
Width: 36
Height: 24
Hidden Layer Size: 256
Gamma: 0.9
Epsilon: 0.26
Epsilon Decay rate: 0.97
Max Memory Size: 100000
Learning Rate: 0.001
Learning Rate Decay rate: 1
Batch Size: 1000
handle_model_load {"modelName":"model_2","width":36,"height":24}
################################
Model Weights are provided
Model Name: model_2
Width: 36
Height: 24
Hidden Layer Size: 256
Gamma: 0.9
Epsilon: -0.255
Epsilon Decay rate: 0.97
Max Memory Size: 100000
Learning Rate: 0.001
Learning Rate Decay rate: 1
Batch Size: 1000
