# Check GPU availability

In [1]:
import tensorflow as tf

In [2]:
tf.test.is_built_with_cuda()
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

# Imports

In [3]:
from snake_model import *

from tensorflow.keras import layers
import random as rd

from icecream import ic
import logging

# Constants & Helper functions

In [4]:
LEARNING_RATE = 0.001
GAMMA = 0.9
EPSILON_DECAY_FACTOR = 0.999
REPLAY_BUFFER_SIZE = 1000
INIT_REPLAY_COUNT = REPLAY_BUFFER_SIZE // 2
BATCH_SIZE = 64

def log(text):
    logging.basicConfig(filename="log.txt", level=logging.DEBUG)
    logging.debug(text)

def head_to_one_hot(head, gridSize):
    one_hot = np.zeros((gridSize, gridSize))
    one_hot[head.x][head.y] = 1
    return one_hot

def body_to_one_hot(bodyBlocks, gridSize):
    one_hot = np.zeros((gridSize, gridSize))
    for bodyBlock in bodyBlocks:
        one_hot[bodyBlock.x][bodyBlock.y] = 1
    return one_hot

def food_to_one_hot(foods, gridSize):
    one_hot = np.zeros((gridSize, gridSize))
    for food in foods:
        one_hot[food.x][food.y] = 1
    return one_hot

def direction_to_one_hot(direction):
    one_hot = np.zeros(4)
    one_hot[direction] = 1
    return one_hot

def action_to_direction(currentDirection, chosenAction): # 0up 1down 2left 3right : 0left 1stay 2right
    if currentDirection == 0:
        if chosenAction == 0:
            return 2
        if chosenAction == 2:
            return 3
    if currentDirection == 1:
        if chosenAction == 0:
            return 3
        if chosenAction == 2:
            return 2
    if currentDirection == 2:
        if chosenAction == 0:
            return 1
        if chosenAction == 2:
            return 0
    if currentDirection == 3:
        if chosenAction == 0:
            return 0
        if chosenAction == 2:
            return 1
    return currentDirection

def get_mini_batch(replay):
    mini_batch = rd.sample(replay, BATCH_SIZE) 
    col_indices = [0,1,2,3]
    result = [list(column) for column in zip(*mini_batch)][col_indices[0]:col_indices[-1]+1]
    return np.array(result[0]), np.array(result[1]), np.array(result[2]), np.array(result[3])

# Q Network

In [5]:
class DQN():
    def __init__(self, state_size, action_size):
        self.q_net = self.build_dqn_model(state_size, action_size)
        self.target_q_net = self.build_dqn_model(state_size, action_size)

    def build_dqn_model(self, state_size, action_size):

        l1 = state_size
        l2 = 128
        l3 = 64
        l4 = action_size
        
        q_net = tf.keras.Sequential()
        q_net.add(layers.Dense(l2, input_dim=l1, activation='relu', kernel_initializer='he_uniform'))
        q_net.add(layers.Dense(l3, activation='relu', kernel_initializer='he_uniform'))
        q_net.add(layers.Dense(l4, activation='linear', kernel_initializer='he_uniform'))
        q_net.compile(optimizer=tf.optimizers.Adam(learning_rate=0.001), loss='mse')
        return q_net

    def update_target(self):
        self.target_q_net.set_weights(self.q_net.get_weights())

    def get_qvals(self, state):
        state = tf.convert_to_tensor(state, dtype=tf.float32)
        q_values = self.q_net(state)
        return q_values.numpy()

    def train(self, batch):
        state_batch, action_batch, reward_batch, next_state_batch = batch
        current_q = self.q_net(state_batch).numpy()
        target_q = np.copy(current_q)
        next_q = self.target_q_net(next_state_batch).numpy()
        max_next_q = np.amax(next_q, axis=1)
        for i in range(state_batch.shape[0]):
            target_q_val = reward_batch[i]
            target_q_val += 0.95 * max_next_q[i]
            target_q[i][action_batch[i]] = target_q_val
        training_history = self.q_net.fit(x=state_batch, y=target_q, verbose=0)
        loss = training_history.history['loss']
        return loss

# Deep Q Learning Code

In [6]:
class Environment:
    def __init__(self):
        pass

class Agent:
    def __init__(self):
        pass

class DeepQLearning:
    def __init__(self):
        pass

# Simulation

In [7]:
import pygame
from sys import exit
from snake_view import *

TICK_RATE = 30

pygame.init()
pygame.display.set_caption("Snake")

class Game:
    def __init__(self, driver):
        self.dqnDriver = driver
        self.UI = UI(self.driver.agent.environment.world)
        self.clock = pygame.time.Clock()

    def game_loop(self):
        while True:
            self.handle_player_input()
            self.dqnDriver.step()
                
            self.UI.draw_hud()
            self.UI.draw_blocks()
            pygame.display.update()
            self.clock.tick(TICK_RATE)

    def handle_player_input(self):
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                exit()

            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_RETURN or event.key == pygame.K_SPACE:
                    self.gameWorld.__init__()

                elif event.key == pygame.K_ESCAPE:
                    pygame.quit()
                    exit()


pygame 2.5.2 (SDL 2.28.3, Python 3.9.18)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [8]:
# myGame = Game(test)
# myGame.game_loop()