# Check GPU availability

In [1]:
import tensorflow as tf

In [2]:
tf.test.is_built_with_cuda()
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

# Imports

In [3]:
from snake_model import *

from tensorflow.keras import layers
import random as rd

from icecream import ic
import logging

# Constants & Helper functions

In [4]:
LEARNING_RATE = 0.001
GAMMA = 0.9
EPSILON_DECAY_FACTOR = 0.999
REPLAY_BUFFER_SIZE = 1000
INIT_REPLAY_COUNT = REPLAY_BUFFER_SIZE // 2
BATCH_SIZE = 64

def log(text):
    logging.basicConfig(filename="log.txt", level=logging.DEBUG)
    logging.debug(text)

def head_to_one_hot(head, gridSize):
    one_hot = np.zeros((gridSize, gridSize))
    one_hot[head.x][head.y] = 1
    return one_hot

def body_to_one_hot(bodyBlocks, gridSize):
    one_hot = np.zeros((gridSize, gridSize))
    for bodyBlock in bodyBlocks:
        one_hot[bodyBlock.x][bodyBlock.y] = 1
    return one_hot

def food_to_one_hot(foods, gridSize):
    one_hot = np.zeros((gridSize, gridSize))
    for food in foods:
        one_hot[food.x][food.y] = 1
    return one_hot

def action_to_direction(currentDirection, chosenAction): # 0up 1down 2left 3right : 0left 1stay 2right
    if currentDirection == 0:
        if chosenAction == 0:
            return 2
        if chosenAction == 2:
            return 3
    if currentDirection == 1:
        if chosenAction == 0:
            return 3
        if chosenAction == 2:
            return 2
    if currentDirection == 2:
        if chosenAction == 0:
            return 1
        if chosenAction == 2:
            return 0
    if currentDirection == 3:
        if chosenAction == 0:
            return 0
        if chosenAction == 2:
            return 1
    return currentDirection

def get_projected_coodinates(x, y, direction):
    if currentDirection == 0:
        return (x-1, y)
    if currentDirection == 1:
        return (x+1, y)
    if currentDirection == 2:
        return (x, y-1)
    if currentDirection == 3:
        return (x, y+1)

def get_mini_batch(replay):
    mini_batch = rd.sample(replay, BATCH_SIZE) 
    col_indices = [0,1,2,3]
    result = [list(column) for column in zip(*mini_batch)][col_indices[0]:col_indices[-1]+1]
    return result[0], result[1], result[2], result[3]

# Q Network

In [5]:
class DQN(tf.keras.Model):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        
        l1 = state_size
        l2 = 128
        l3 = 64
        l4 = action_size
        
        self.model = tf.keras.Sequential([
            layers.Dense(l2, activation='relu', input_shape=(l1,)),
            layers.Dense(l3, activation='relu'),
            layers.Dense(l4)
        ])

        self.model2 = tf.keras.models.clone_model(self.model)
        self.model2.set_weights(self.model.get_weights())
        
        self.loss_fn = tf.keras.losses.MeanSquaredError()
        self.learning_rate = 0.001
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)

    def update_target(self):
        self.model2.set_weights(self.model.get_weights())

    def get_qvals(self, state):
        state = tf.convert_to_tensor(state, dtype=tf.float32)
        q_values = self.model(state)
        return q_values.numpy()

    def get_maxQ(self, state):
        state = tf.convert_to_tensor(state, dtype=tf.float32)
        q_values = self.model2(state)
        return tf.reduce_max(q_values, axis=1).numpy()

    def train_one_step(self, states, actions, targets):
        targets_reply = []
        state1_batch = tf.convert_to_tensor(states, dtype=tf.float32)
        action_batch = tf.convert_to_tensor(actions, dtype=tf.int32)
        with tf.GradientTape() as tape:
            Q1 = self.model(state1_batch)
            X = tf.gather(Q1, action_batch, axis=1, batch_dims=1)
            Y = tf.convert_to_tensor(targets, dtype=tf.float32)
            loss = self.loss_fn(X, Y)
        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
        return loss.numpy()

# Deep Q Learning Code

In [8]:
class Environment:
    def __init__(self):
        self.world = World()
        self.state = self.get_state()
        self.prevState = None
        self.score = self.world.score
        self.prevScore = self.score

    def get_state(self):
        headArray = head_to_one_hot(self.world.snake.head, self.world.size).flatten()
        bodyArray = body_to_one_hot(self.world.snake.body, self.world.size).flatten()
        foodArray = food_to_one_hot(self.world.foods, self.world.size).flatten()
        state = np.array([headArray, bodyArray, foodArray]).flatten()
        state = np.append(state, self.world.snake.direction)
        return state

    def step(self):
        print(self.world.snake.head.x, self.world.snake.head.y)
        self.prevState = self.get_state()
        self.prevScore = self.world.score
        self.world.step()
        if not self.world.isCollide:
            self.state = self.get_state()
            self.score = self.world.score

class Agent:
    def __init__(self):
        self.environment = Environment()
        self.network = DQN((self.environment.world.size**2) * 3 + 1, 3)
        self.replay = []
        self.epsilon = 1

        self.cumulativeScore = 0

    def append_replay_table(self, action):
        prevState = self.environment.prevState
        reward = self.get_reward()
        newState = self.environment.state
        self.replay.append([prevState, action, reward, newState])
        if len(self.replay) > REPLAY_BUFFER_SIZE:
            self.replay.pop()

    def get_reward(self):
        reward = -1
        if self.environment.score > self.environment.prevScore:
            reward = 5
            self.cumulativeScore += 1
        if self.environment.world.isCollide:
            reward = -5
        return reward

    def e_greedy(self, epsilon):
        if np.random.rand() <= epsilon:
            chosenAction = np.random.choice([0,1,2])
            return chosenAction
        else:
            currentState = self.environment.get_state()
            currentQVals = self.network.get_qvals([currentState])
            return np.argmax(currentQVals)

    def step(self):
        chosenAction = self.e_greedy(self.epsilon)
        newDirection = action_to_direction(self.environment.world.snake.direction, chosenAction)
        self.environment.world.snake.change_direction(newDirection)
        self.environment.step()
        self.append_replay_table(chosenAction)
        

class DeepQLearning:
    def __init__(self):
        self.time = 0
        self.agent = Agent()

        self.init_replay_table()
    
    def init_replay_table(self):
        for i in range (INIT_REPLAY_COUNT):
            self.agent.step()
            if self.agent.environment.world.isCollide:
                self.agent.environment.world.__init__()

        self.agent.epsilon = 1        

    def train_network(self):
        prevStates, actions, rewards, newStates = get_mini_batch(self.agent.replay)
        maxQValues = self.agent.network.get_maxQ(newStates)
        targets = []
        for i in range (len(rewards)):
            targets.append(rewards[i] + (GAMMA * maxQValues[i]))

        self.agent.network.train_one_step(prevStates, actions, targets)
        
    def step(self):
        self.agent.step()

        if self.time % 10 == 0:
            self.train_network()
            self.agent.epsilon *= EPSILON_DECAY_FACTOR

        if self.time % 1000 == 0:
            self.agent.network.update_target()


In [9]:
# Test
test = DeepQLearning()

epsilon = test.agent.epsilon
while epsilon >= 0.1:
    test.step()
    print(f"Time :{test.time} | Cumulative Score : {test.agent.cumulativeScore} | Current Score : {test.agent.environment.world.score}", end='\r')
# for i in range(1000):
#     test.step()

5 5
4 5
4 6
5 6
6 6
7 6
8 6
8 7
7 7
7 6
6 6
6 5
7 5
8 5
8 4
7 4
6 4
5 4
5 5
4 5
3 5
2 5
1 5
0 5
0 4
1 4
1 3
2 3
2 2
3 2
4 2
4 3
4 4
3 4
2 4
2 3
2 2
3 2
3 3
2 3
2 2
3 2
3 1
3 0
2 0
1 0
0 0
5 5
4 5
4 6
5 6
5 7
6 7
7 7
7 8
6 8
6 7
5 7
5 8
4 8
4 7
3 7
3 8
2 8
1 8
1 9
0 9
5 5
4 5
4 6
3 6
3 5
2 5
1 5
1 6
1 7
2 7
3 7
3 6
4 6
4 7
3 7
3 8
2 8
2 9
3 9
3 8
2 8
2 7
1 7
1 6
0 6
0 7
1 7
1 6
1 5
1 4
2 4
2 3
2 2
3 2
3 3
2 3
1 3
0 3
0 4
1 4
1 5
2 5
2 4
2 3
2 2
1 2
1 3
2 3
3 3
3 4
2 4
2 3
2 2
3 2
3 3
4 3
5 3
6 3
6 2
7 2
7 1
8 1
8 2
9 2
9 3
9 4
8 4
8 3
9 3
9 2
8 2
8 3
9 3
9 4
5 5
4 5
3 5
2 5
2 4
1 4
1 3
0 3
5 5
5 4
4 4
4 3
3 3
3 2
3 1
3 0
4 0
5 0
5 5
4 5
3 5
3 6
3 7
3 8
4 8
5 8
5 7
5 6
5 5
6 5
6 6
6 7
6 8
7 8
8 8
9 8
9 9
5 5
5 6
5 7
5 8
4 8
4 9
5 9
5 5
5 4
6 4
6 5
6 6
5 6
5 5
5 4
4 4
4 3
4 2
5 2
5 1
6 1
7 1
8 1
9 1
5 5
5 4
4 4
4 5
4 6
4 7
4 8
3 8
3 7
4 7
4 8
5 8
5 7
6 7
6 8
6 9
5 5
5 4
6 4
6 3
6 2
7 2
8 2
8 3
8 4
7 4
7 5
8 5
8 6
7 6
7 7
7 8
6 8
6 7
6 6
6 5
6 4
5 4
5 5
6 5
6 4
5 4
4 4
3 4
3 3
3 2
2 2
2 1


IndexError: index 10 is out of bounds for axis 0 with size 10

# Simulation

In [None]:
import pygame
from sys import exit
from snake_view import *

TICK_RATE = 500

pygame.init()
pygame.display.set_caption("Snake")

class Game:
    def __init__(self):
        self.clock = pygame.time.Clock()
        self.dqnDriver = DeepQLearning()
        self.agent = self.dqnDriver.agent
        self.gameWorld = self.agent.environment.world
        self.UI = UI(self.gameWorld)

    def game_loop(self):
        while True:
            self.handle_player_input()
            self.agent.step()

            if self.gameWorld.isCollide:
                self.gameWorld.__init__()
                
            self.UI.draw_hud()
            self.UI.draw_blocks()
            pygame.display.update()
            self.clock.tick(TICK_RATE)

    def handle_player_input(self):
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                exit()

            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_RETURN or event.key == pygame.K_SPACE:
                    self.gameWorld.__init__()

                elif event.key == pygame.K_ESCAPE:
                    pygame.quit()
                    exit()


In [None]:
myGame = Game()
myGame.game_loop()