In [None]:
import pygame
import random
from enum import Enum
from collections import namedtuple
import time
import numpy as np 
from collections import deque
import keras
from keras.models import Sequential
from keras.layers import Input, Dense

In [None]:
pygame.init()
font = pygame.font.SysFont('timesnewroman', 25)
font = pygame.font.SysFont('arial', 25)

class Direction(Enum):
    RIGHT = 1
    LEFT = 2
    UP = 3
    DOWN = 4
    
Point = namedtuple('Point', 'x, y')

# rgb colors
WHITE = (255, 255, 255)
RED1 = (255, 0, 0)
RED2 = (255, 100, 0)
BLUE1 = (0, 0, 255)
BLUE2 = (0, 100, 255)
BLACK = (0, 0, 0)

BLOCK_SIZE = 20
SPEED = 20

#delay in ms
DELAY = 100

MILISECOND = 1000

NUM_ROW = 30
NUM_COL = 30

WIDTH = NUM_COL * BLOCK_SIZE
HEIGHT = NUM_ROW * BLOCK_SIZE

class SnakeGame:
    def __init__(self):
        self.width = NUM_COL
        self.height = NUM_ROW
        self.display = pygame.display.set_mode((WIDTH, HEIGHT))
        pygame.display.set_caption('Snake')        
        self.reset()

    def reset(self):
        self.direction = Direction.RIGHT
        self.gameOver = False
        self.lastTime = time.time() * MILISECOND 
        
        self.head = Point(15, 15)
        self.snake = [self.head]
        for i in range(1, 4):
            self.snake.append(Point(15-i, 15))

        self.score = 0
        self.food = None
        self.setNewFood()
        
    def setNewFood(self):
        x = random.randint(0, NUM_COL-1) 
        y = random.randint(0, NUM_ROW-1)
        self.food = Point(x, y)
        if self.food in self.snake:
            self.setNewFood()
    
    def checkInput(self):
        # 1. collect user input
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()

    def isCollision(self, head):
        # hits boundary
        if head.x < 0 or head.x >= self.width:
            return True
        if head.y < 0 or head.y >= self.height:
            return True
        
        # hits itself
        if head in self.snake[1:]:
            return True
        return False
        
    def draw(self):
        self.display.fill(BLACK)
        for point in self.snake:
            pygame.draw.rect(self.display, BLUE1, pygame.Rect(
                point.x*BLOCK_SIZE, 
                point.y*BLOCK_SIZE, 
                BLOCK_SIZE, BLOCK_SIZE
            ))
            pygame.draw.rect(self.display, BLUE2, pygame.Rect(
                point.x*BLOCK_SIZE+4, 
                point.y*BLOCK_SIZE+4, 12, 12
            ))
            
        pygame.draw.rect(self.display, RED1, pygame.Rect(
            self.food.x*BLOCK_SIZE, 
            self.food.y*BLOCK_SIZE, 
            BLOCK_SIZE, BLOCK_SIZE
        ))
        pygame.draw.rect(self.display, RED2, pygame.Rect(
            self.food.x*BLOCK_SIZE+4, 
            self.food.y*BLOCK_SIZE+4, 12, 12
        ))
        
        pygame.display.update()
    
    def getDirection(self, action):
        right = [Direction.RIGHT, Direction.DOWN, Direction.UP]
        left = [Direction.LEFT, Direction.UP, Direction.DOWN]
        up = [Direction.UP, Direction.RIGHT, Direction.LEFT]
        down = [Direction.DOWN, Direction.LEFT, Direction. RIGHT]
        
        direction = self.direction
        
        if self.direction == Direction.RIGHT:
            direction = right[np.argmax(action)]
        elif self.direction == Direction.LEFT:
            direction = left[np.argmax(action)]
        elif self.direction == Direction.UP:
            direction = up[np.argmax(action)]
        elif self.direction == Direction.DOWN:
            direction = down[np.argmax(action)]
        return direction
        
    def move(self, action):
        self.direction = self.getDirection(action)
        x = self.head.x
        y = self.head.y
        if self.direction == Direction.RIGHT:
            x += 1
        elif self.direction == Direction.LEFT:
            x -= 1
        elif self.direction == Direction.DOWN:
            y += 1
        elif self.direction == Direction.UP:
            y -= 1
        newHead = Point(x, y)
        
        # Checking collision
        if self.isCollision(newHead):
            self.gameOver = True
        else:
            self.head = newHead
            self.snake.insert(0, newHead)
        
        # Checking score
        if self.head == self.food:
            self.score += 1
            self.setNewFood()
        elif not(self.isCollision(newHead)):
            self.snake.pop()
        
    def hasTimeElapsed(self):
        timeNow = time.time() * MILISECOND
        if timeNow - self.lastTime > DELAY:
            self.lastTime = time.time() * MILISECOND
            return True
        return False
        
    def run(self):
        while self.gameOver == False:
            self.checkInput()
            if self.hasTimeElapsed():
                self.move(self.direction)
            self.draw()
    
    def getState(self):
        head = self.head

        pointL = Point(head.x - 1, head.y)
        pointR = Point(head.x + 1, head.y)
        pointU = Point(head.x, head.y - 1)
        pointD = Point(head.x, head.y + 1)

        dirL = self.direction == Direction.LEFT
        dirR = self.direction == Direction.RIGHT
        dirU = self.direction == Direction.UP
        dirD = self.direction == Direction.DOWN

        state = [
            # danger straight
            (dirR and self.isCollision(pointR)) or
            (dirL and self.isCollision(pointL)) or
            (dirU and self.isCollision(pointU)) or
            (dirD and self.isCollision(pointD)),

            # danger right
            (dirU and self.isCollision(pointR)) or
            (dirD and self.isCollision(pointL)) or
            (dirL and self.isCollision(pointU)) or
            (dirR and self.isCollision(pointD)),

            # danger left
            (dirU and self.isCollision(pointL)) or
            (dirD and self.isCollision(pointR)) or
            (dirL and self.isCollision(pointD)) or
            (dirR and self.isCollision(pointU)),

            # move direction
            dirL, dirR, dirU, dirD,

            # food location
            self.food.x < self.head.x,
            self.food.x > self.head.x,
            self.food.y < self.head.y,
            self.food.y > self.head.y
        ]
        return np.array(state, dtype=int)

In [None]:
# Creatign the model
class Network:
    def __init__(self):
        self.model = Sequential()
        self.model.add(Input(shape=(11,)))
        self.model.add(Dense(50, activation='relu'))
        self.model.add(Dense(50, activation='relu'))
        self.model.add(Dense(50, activation='relu'))
        self.model.add(Dense(3, activation='relu'))
        self.model.compile(
            loss='sparse_categorical_crossentropy', 
            optimizer='adam'
        )
    
    def predictSingle(self, input):
        predict = self.model.predict(np.array([input]))
        return predict[0]
    
    def predictMany(self, input):
        predict = self.model.predict(input)
        return predict
    
    def trainSingle(self, state, action, reward, newState, done):
        predict = self.predictSingle(state)
        target = predict
        target[np.argmax(action)] = reward
        self.model.fit(np.array([predict]), np.array([target]))
    
    def trainMany(self, inputs, predicts):
        self.model.fit(inputs, predicts)

In [None]:
snake = SnakeGame()
model = Network()
predict = model.predictSingle(snake.getState())
print(predict)

In [None]:
# Implementing the agent
MAX_MEMORY = 100000
MIN_EPSILON = 0.1
MAX_EPSILON = 1.0
class Agent:
    def __init__(self) -> None:
        self.nGames = 0
        self.epsilon = 0.9
        self.gamma = 0.9 
        self.memory = deque(maxlen=MAX_MEMORY)
        self.model = Network()
        self.snake = SnakeGame()
    
    def getAction(self, state):
        # random moves: tradeoff exploration / exploitation
        finalMove = [0, 0, 0]
        if random.uniform(0, 1) < self.epsilon:
            move = random.randint(0, 2)
            finalMove[move] = 1
        else:
            prediction = self.model.predictSingle(state)
            finalMove[np.argmax(prediction)] = 1
        return finalMove
    
    def run(self):
        while True:
            state = self.snake.getState()
            action = self.getAction(state)
            reward, done = self.snake.move(action)
            newState = self.snake.getState()
            self.model.trainSingle(state, action, reward, newState, done)

            if done == True:
                self.snake.reset()


In [None]:
x = [1, 3, 6]
y = np.argmax(x)
print(x[y])