# EVAC Assessment 1 - Evolve a Player for the Video Game Snake


In [None]:
import random
import time
import turtle
import numpy as np
from deap import base
from deap import creator
from deap import tools
import logging
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
XSIZE = YSIZE = 16 # Number of grid cells in each direction (do not change this)
HEADLESS = True # True to run without graphical interface or False to run with the game showing

logging.basicConfig(level=logging.INFO) # Initializes the logging level used to output to console

In [None]:
class DisplayGame:
    """Class for displaying the game when HEADLESS is set to False"""

    def __init__(self, XSIZE, YSIZE):
        """Initializes all aspects of the game including the board, snake and food pellets."""
        # SCREEN
        self.win = turtle.Screen()
        self.win.title("EVAC Snake game")
        self.win.bgcolor("grey")
        self.win.setup(width=(XSIZE*20)+40,height=(YSIZE*20)+40)
        self.win.tracer(0)

        #Snake Head
        self.head = turtle.Turtle()
        self.head.shape("square")
        self.head.color("black")

        # Snake food
        self.food = turtle.Turtle()
        self.food.shape("circle")
        self.food.color("red")
        self.food.penup()
        self.food.shapesize(0.55, 0.55)
        self.segments = []

    def reset(self, snake):
        """Resets the display when the game is first ran"""
        self.segments = []
        self.head.penup()
        self.food.goto(-500, -500)
        self.head.goto(-500, -500)
        for i in range(len(snake)-1):
            self.add_snake_segment()
        self.update_segment_positions(snake)
       
    def update_food(self,new_food):
        """Updates/draws food to the display"""
        self.food.goto(((new_food[1]-9)*20)+20, (((9-new_food[0])*20)-10)-20)
        
    def update_segment_positions(self, snake):
        """Updates/draws each segment of the snake to the display"""
        self.head.goto(((snake[0][1]-9)*20)+20, (((9-snake[0][0])*20)-10)-20)
        for i in range(len(self.segments)):
            self.segments[i].goto(((snake[i+1][1]-9)*20)+20, (((9-snake[i+1][0])*20)-10)-20)

    def add_snake_segment(self):
        """Draws and adds a new snake segment to the display"""
        self.new_segment = turtle.Turtle()
        self.new_segment.speed(0)
        self.new_segment.shape("square")
        self.new_segment.color("green") # TODO: Change back to random colour generation before submission
        self.new_segment.penup()
        self.segments.append(self.new_segment)

In [None]:
class Snake:
    """Class which contains the game logic for the game Snake"""

    def __init__(self, _XSIZE, _YSIZE):
        """Draws and adds a new snake segment to the display"""
        self.XSIZE = _XSIZE
        self.YSIZE = _YSIZE
        self.reset()

    def reset(self):
        """Resets the game after a run has finished"""
        self.snake = [[8,10], [8,9], [8,8], [8,7], [8,6], [8,5], [8,4], [8,3], [8,2], [8,1],[8,0]] # Initial snake co-ordinates [ypos,xpos]    
        self.food = self.place_food()
        self.snake_direction = "right"
        self.time_alive = 0 # TODO: 1. Check if this modification is allowed

    def place_food(self):
        """Randomly generates a location for the food, and regenerates it if spawned inside the snake"""
        self.food = [random.randint(1, (YSIZE-2)), random.randint(1, (XSIZE-2))]
        while (self.food in self.snake):
            self.food = [random.randint(1, (YSIZE-2)), random.randint(1, (XSIZE-2))]
        return( self.food )
    
    def update_snake_position(self):
        """Adds the new coordinate of the snakes head to the front of the snake coordinate list."""
        self.snake.insert(0, [self.snake[0][0] + (self.snake_direction == "down" and 1) + (self.snake_direction == "up" and -1), self.snake[0][1] + (self.snake_direction == "left" and -1) + (self.snake_direction == "right" and 1)])

    def food_eaten(self):
        """Returns True if snakes head coordinate is the same as the food location, otherwise removes the oldest coordinate 
            in the snake coordinate list (as a new one will be added for the movement of the head) and returns False."""
        if self.snake[0] == self.food:
            return True
        else:    
            self.snake.pop()  # snake moves forward and so last tail item is removed
            return False
            
    def snake_turns_into_self(self):
        """Returns True if new snakes head coordinate is already in the body, otherwise False"""
        if self.snake[0] in self.snake[1:]:
            return True
        else:
            return False

    def snake_hit_wall(self):
        """Returns True if new snakes head coordinate goes out of bounds, otherwise False"""
        if self.snake[0][0] == 0 or self.snake[0][0] == (YSIZE-1) or self.snake[0][1] == 0 or self.snake[0][1] == (XSIZE-1):
            return True
        else:
            return False

    # Sensor Functions
    def get_adj_coords(self):
        """Returns dictionary of adjacent coordinates to the snakes head"""
        adj_coords = {}
        dir_offsets = {"up": [-1, 0], ("down"): [+1, 0],
                   ("left"): [0, -1], ("right"): [0, +1]}
        for key, value in dir_offsets.items():
            adj_coords[key] = list(map(sum, zip(self.snake[0], value)))

        return adj_coords
        
    def sense_wall(self, coord):
        """Returns True if provided coordinate out of bounds, otherwise False"""
        return(coord[0] == 0 or coord[0] == (YSIZE-1) or coord[1] == 0 or coord[1] == (XSIZE-1))

    def sense_food(self, coord):
        """True if food is at provided coordinate, otherwise False"""
        return self.food == coord

    def sense_tail(self, coord):
        """Returns True if coordinate is a part of the snake, otherwise False"""
        return coord in self.snake
    
    def dist_to_food(self):
        """Calculates the manhattan distance to food""" # TODO: Check this works
        return abs(self.food[0]-self.snake[0][0]) + abs(self.food[1]-self.snake[0][1])
    
    def obstacle_check(self, coord):
        """Returns 0 if a tail or wall is found in a given direction, otherwise 1"""
        if self.sense_wall(coord):
            return 0
        elif self.sense_tail(coord):
            return 0
        else:
            return 1    

In [None]:
class NeuralNetwork(object):
    '''Creates a fully connected/dense neural network with 2 hidden layers'''

    def __init__(self, numInput, numHidden1, numHidden2, numOutput):
        '''Initializes the neural network'''
        self.numInput = numInput + 1 # Add bias node for first hidden layer
        self.numHidden1 = numHidden1 + 1 # Adds bias node for second hidden layer
        self.numHidden2 = numHidden2
        self.numOutput = numOutput

        self.w_i_h1 = np.random.randn(self.numHidden1, self.numInput) 
        self.w_h1_h2 = np.random.randn(self.numHidden2, self.numHidden1) 
        self.w_h2_o = np.random.randn(self.numOutput, self.numHidden2)

        self.ReLU = lambda x : max(0,x)

    def softmax(self, x):
        '''Returns elements from last layer of network as a probability distribution which adds up to 1'''
        e_x = np.exp(x - np.max(x))
        return e_x / e_x.sum()
    
    def feedForward(self, inputs):
        '''Takes the inputs & weights and processes the softmax output of the neural network'''
        inputsBias = inputs[:]                  # copies input array
        inputsBias.append(1)                    # adds bias value for hidden layer 1

        h1 = np.dot(self.w_i_h1, inputsBias)    # feed input to hidden layer 1
        h1 = [self.ReLU(x) for x in h1]         # activates hidden layer 1
        
        h1.append(1)                            # add bias value for hidden layer 2
               
        h2 = np.dot(self.w_h1_h2, h1)           # feed hidden layer 1 to hidden layer 2
        h2 = [self.ReLU(x) for x in h2]         # activate hidden layer 2

        output = np.dot(self.w_h2_o, h2)        # feed to output layer
        output = self.softmax(output)
        return output

    def getWeightsLinear(self):
        '''Returns the current weights set in the network'''
        flat_w_i_h1 = list(self.w_i_h1.flatten())
        flat_w_h1_h2 = list(self.w_h1_h2.flatten())
        flat_w_h2_o = list(self.w_h2_o.flatten())
        return( flat_w_i_h1 + flat_w_h1_h2 + flat_w_h2_o)

    def setWeightsLinear(self, Wgenome):
        '''Sets the weights for the network'''

        numWeights_I_H1 = (self.numHidden1-1) * self.numInput
        numWeights_H1_H2 = (self.numHidden2) * self.numHidden1

        self.w_i_h1 = np.array(Wgenome[:numWeights_I_H1])
        self.w_i_h1 = self.w_i_h1.reshape((self.numHidden1-1, self.numInput))
        
        self.w_h1_h2 = np.array(Wgenome[numWeights_I_H1:(numWeights_H1_H2+numWeights_I_H1)])
        self.w_h1_h2 = self.w_h1_h2.reshape((self.numHidden2, self.numHidden1))

        self.w_h2_o = np.array(Wgenome[(numWeights_H1_H2+numWeights_I_H1):])
        self.w_h2_o = self.w_h2_o.reshape((self.numOutput, self.numHidden2))

In [None]:
def run_game(display, snake_game, headless, network):
    '''Runs through a game simulation, using the neural network to make decisions on the snakes movement. 
        Returns the final score the snake achieved before a loss condition was met.'''

    # Resets the score, game & display
    score = 0
    snake_game.reset()
    if not headless:
        display.reset(snake_game.snake)
        display.win.update()
    snake_game.place_food()
    game_over = False

    while not game_over:      
        snake_game.time_alive += 1     # Increments time alive on each game tick

        adj_coords= snake_game.get_adj_coords()

        # Gets softmax output of the neural network decision
        decision = network.feedForward([snake_game.obstacle_check(adj_coords["up"]),
                                        snake_game.obstacle_check(adj_coords["down"]),
                                        snake_game.obstacle_check(adj_coords["left"]),
                                        snake_game.obstacle_check(adj_coords["right"]),
                                        snake_game.sense_food(adj_coords["up"]),
                                        snake_game.sense_food(adj_coords["down"]),
                                        snake_game.sense_food(adj_coords["left"]),
                                        snake_game.sense_food(adj_coords["right"]),
                                        ])
        
        # Converts softmax output to output direction and sets it
        directions = ["up", "down", "left", "right"]
        direction = np.argmax(decision)
        snake_game.snake_direction = directions[direction]
            
        snake_game.update_snake_position()

        # Checks if food is eaten and replaces food + increments score
        if snake_game.food_eaten():
            snake_game.place_food()
            score += 1
            if not headless: display.add_snake_segment()

        # Ends game if the snake runs into itself
        if snake_game.snake_turns_into_self():
            game_over = True

        # Ends game if the snake hits a wall
        if snake_game.snake_hit_wall():
            game_over = True
        
        # # Ends game if snake lives longer than 200 game ticks - TEMPORARY
        # if snake_game.time_alive > 200: # TODO: Remove as only used to test
        #     game_over = True

        # Updates display when not running in headless mode
        if not headless:       
            display.update_food(snake_game.food)
            display.update_segment_positions(snake_game.snake)
            display.win.update()

            time.sleep(0.01)     # Change to change update rate of the game


    if not headless: turtle.done()
    
    return score


In [None]:
# Initializes game (and display if not running in headless)
if not HEADLESS:
    display = DisplayGame(XSIZE,YSIZE)
snake_game = Snake(XSIZE,YSIZE)

# Initializes neural network
numInputNodes = 8
numHiddenNodes1 = 16
numHiddenNodes2 = 8
numOutputNodes = 4
network = NeuralNetwork(numInputNodes, numHiddenNodes1, numHiddenNodes2, numOutputNodes)

# Calculates the size of the individual using input, output and hidden layer neuron counts (accounting for bias nodes for hidden layers)
IND_SIZE = ((numInputNodes+1) * numHiddenNodes1) + ((numHiddenNodes1+1) * numHiddenNodes2) + (numHiddenNodes2 * numOutputNodes)

# Creates single objective maximizing fitness named FitnessMax
creator.create("FitnessMax", base.Fitness, weights=(1.0,))

# Creates an individual with a list of attributes using previously created FitnessMax
creator.create("Individual", list, fitness=creator.FitnessMax)

def evaluate(individual, myNet, snake_game):
    '''Returns the fitness of the individual after evaluating performance from game simulation'''
    myNet.setWeightsLinear(individual)   # Load the individual's weights into the neural network
    fitness = run_game(display, snake_game, HEADLESS, myNet) # Evaluate the individual by running the game (discuss)
    return fitness,

# Defines functions to create individuals whos genes are random float values (uniformly distributed between -1 and 1)
toolbox = base.Toolbox()
toolbox.register("attr_float", random.uniform, -1.0, 1.0)
toolbox.register("individual", tools.initRepeat, creator.Individual, toolbox.attr_float, n=IND_SIZE)

# Defines functions to evaluate and mutate individuals
toolbox.register("evaluate", evaluate)
toolbox.register("mutate", tools.mutGaussian, mu=0.0, sigma=0.2, indpb=0.1)

# Defines function to select best individuals
toolbox.register("select", tools.selTournament, tournsize=5) # TODO: Implement crossover

# Defines function to generate initial population
toolbox.register("population", tools.initRepeat, list, toolbox.individual)

# Defines the statistics & logbook that will be logged during the GA
stats = tools.Statistics(key=lambda ind: ind.fitness.values)
stats.register("avg", np.mean)
stats.register("std", np.std)
stats.register("min", np.min)
stats.register("max", np.max)
logbook = tools.Logbook()

In [None]:
# Initializes population
population = toolbox.population(n=5)

# Calculates the initial fitness values for each individual and sets them
fitnesses = [toolbox.evaluate(individual, network, snake_game) for individual in population]
for ind, fit in zip(population, fitnesses):
    ind.fitness.values = fit

# Number of generations the GA will compute
NGEN = 5

# Genetic Algorithm
for g in range(NGEN):
    logging.info("> Running generation " + str(g))
    
    # Selects number of individuals equal to population length 
    offspring = toolbox.select(population, len(population))
    # Includes duplicates so clones all individuals
    offspring = list(map(toolbox.clone, offspring))

    # Mutates offspring based on previously defined probabilities #TODO: Modify probability/algorithm type?
    for mutant in offspring:
        toolbox.mutate(mutant)
        del mutant.fitness.values   # Deletes old fitness values
          
    # Recalculates fitness values for mutated offspring
    invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
    fitnesses = [toolbox.evaluate(individual, network, snake_game) for individual in invalid_ind]
    for ind, fit in zip(invalid_ind, fitnesses):
         ind.fitness.values = fit
   
    population[:] = offspring   # Replaces old population with new mutated offspring

    # Compiles & records the statistics for the new generation
    record = stats.compile(population)      
    logbook.record(gen=g, **record)         

In [None]:
# Sets the header & statistics to select from logbook
logbook.header = "gen", "avg", "evals", "std", "min", "max"
gen, _min, _max, average, standard_dev = logbook.select("gen", "min", "max", "avg", "std")
print(logbook)

# Draws graphs 
plt.rc('axes', labelsize=14)
plt.rc('xtick', labelsize=14)
plt.rc('ytick', labelsize=14) 
plt.rc('legend', fontsize=14)

fig, ax = plt.subplots()

ax.set_xlabel("Generation")
ax.set_ylabel("Fitness")

line1 = ax.plot(gen, average)
line2 = ax.plot(gen, _min)
line3 = ax.plot(gen, _max)

# Adds key to graph
ax.legend(["Average Fitness", "Minimum Fitness", "Maximum Fitness"], loc='center left', bbox_to_anchor=(1, 0.5))

In [None]:
# Takes weights of the best individual and runs in non-headless mode
bestInd = tools.selBest(population, 1)[0]
network.setWeightsLinear(bestInd)
# weights of decent snake from big run
# weights = [-3.3786318224249774, 1.264706228369047, -0.40942301323576036, 0.8228627733512951, 1.2952849514161144, -2.161447552842574, 2.459304604445068, 0.751959693871249, 0.2646987606130422, 0.44054769982886344, 2.963415085453179, -1.4734280754553195, 1.9762563570191047, -1.3980643317127417, 1.565652804914855, -1.3436918453278417, 3.2884625073498817, -0.07067514264084776, 0.16989418906325993, -1.3506191153128582, -0.9577401209982983, -0.6868746770233645, 0.41317660463818046, -1.17013356033095, -1.6907316878337038, -0.7788865033646469, 1.5457942459648732, 0.7252479850234728, -0.5003985396381967, 1.3279832030358585, -0.46304647770929447, -0.8531771278308292, 0.8411415781462619, -0.5037411765896797, -2.4331300129641327, 0.7215281054941072, 0.3149669933639485, -1.6099786395606295, -1.122140275792487, 1.1926320696927895, -1.2613199315404577, 0.7667372744315759, 0.40390578170828906, -1.5931007716499335, 0.5809589765334454, -0.7390772483839543, 1.1664280651907746, -1.7328894002113566, 0.24755009309092743, -2.4901934640909222, -1.4150369875600253, -0.03748343761100241, 2.765686616714359, 0.8762267072228441, -1.4213291568860333, 3.5029434678489677, 0.09738587437680091, -2.3196384005626065, 0.8214933026594639, 1.409549625326027, -2.8920796535874858, -1.8250013504147193, 0.175446620348488, -0.9249621509086291, -0.5790439502407716, -1.6974727829158998, -0.2750526824653321, 1.977987168747408, -0.8937438988830115, -0.4315991416671394, 0.025314561048932427, 1.162685851111196, 0.48690643714042237, -0.5540900055111977, -1.8285726130963185, -0.11830786709438537, -2.073279629019206, -0.24078189503546168, 0.18988356617971125, -2.7061506128721833, 1.54096148072729, -1.6797703537519966, 1.7952608810601667, 1.0405412899357096, -0.8299709727753755, 1.1732086636728498, -2.0563440331152365, 1.8928593971714958, 0.3016272459407501, -0.21737236147597444, 2.780865459721678, -1.698104067812116, 0.07212796237875763, -0.6270463436257026, 2.134186523898742, 2.3755549136388714, -0.9885587957329299, 2.626775872225696, 1.2980598589378585, -0.3922580636680093, 0.19963460387160675, -0.9757090918322109, -1.345411808022169, -0.6679567219504612, 0.6327451931594831, -0.5717976253665514, -1.9469242806155815, -2.385350170221946, -2.3017829074581617, 2.28285019896275, 1.693795645655068, 0.03289688815866365, 0.013906957581289445, 0.25736693214505857, 1.1178970011086455, 1.913850154888052, 0.3899407778733879, 1.1987144944005508, 0.6068250996781694, 0.07595431461180863, -0.30755572429590405, 2.016148507616494, -0.19879195748695275, 0.01666275978799664, -0.7957781939212187, 1.3115189557370153, 0.286758035272611, -1.1243784045476288, -0.7506689784844444, 0.18619825439443896, -2.495485320926148, -0.8315403136953066, -1.6403088104311934, -1.6262008408799984, 2.2436087651709054, -0.6572400275541929, -2.2765733522893448, 0.7782739357757492, -0.30973410897232906, 1.1788342043967561, 2.9279597081311413, 0.6008268849732972, -0.35101848094385146, 1.4656611216952873, 0.10420528691858252, 2.3384091806706815, 1.7074779302191798, 0.3416018337245642, 1.3827825827583993, -2.875759178814153, 1.438384957885212, 1.1027367924491582, 0.9993624576536544, 0.26083251593376755, 0.054701320621820454, 1.4655376304104868, 0.35023670109319144, -1.5659338117116186, 3.1079242952812236, -0.44526535649690535, -0.39533167403988445, 0.6611150644093164, -0.5389959954640773, -0.4182004391286605, 2.955947131021676, 0.3064480522887835, -1.7117124375560564, -0.5603912146026477, 0.00785738974827982, 1.5392347912772273, 1.633768473137616, -0.9695656601737563, -1.7235285270958494, 0.26860457026710083, -1.60069371993584, -1.0563595859268577, 2.6338738144421763, -0.47869633315876126, -0.8625043965138284, 1.003024320293217, -0.7329651782946017, 0.40166933527133575, -0.9305524074729066, 0.3617315187585298, 0.9935343713239062, -1.2712403993634183, -1.5332277218629815, 0.3318209295407418, -0.4072495674771841, 0.20748709783664726, 0.2826262877094916, -0.103355920901315, -0.3408923624099704, 0.2541048963551895, 1.26182027075251, 0.3354822262508063, 2.4380963314587785, 1.0637524808109353, -1.796672777709277, -0.015629020707459043, -0.07424017006499205, -0.9038316919718722, -0.1827430646643367, -1.2913391261454543, -1.2379288263663581, -0.7487917045399035, 0.9769312631345152, 0.07558491189238548, -1.273553471360424, 0.7302827255358089, 0.5811170784730877, -0.6289949068938911, -3.8870496423644054, -0.368553989918404, -1.095522846316345, -1.0582601756751056, -3.0776000943075146, 1.2410361991161125, 0.7614820660148116, -2.0115772990078544, 1.2517335430894785, -0.5470880540146368, -0.27635297357270516, -1.4357963522325188, 1.1491892797514458, -0.7583889978868158, 2.0933905780342696, -2.4225188846307653, 1.9034975242608294, 0.24965543046521244, -1.0172814243967296, -1.237400942655199, 1.7736358537498247, 0.3454665583689715, -3.082502515149337, 0.012885844550694383, 0.3294242844258688, -2.406155957010417, 0.7386854729906833, -0.17109848129854108, 0.4106663915363994, -2.898160947438728, 1.2978371795543842, -1.5318492894265647, -1.0638426962667027, -0.046965017241237965, -0.09582964593453899, 0.08285151947081876, 0.9859427107618864, 0.4606465736620381, 1.421534861815946, 1.7931002891968812, -3.4588769277993894, 0.18651211328753767, 1.8866480654929214, -1.4501116185459169, 1.7622605912779543, -1.1709581137304979, -0.3229764295885535, -0.9572365343703699, -0.8075790254877449, -1.2359866287092056, -0.10360954244547194, -0.03058811270930377, -3.099540754444425, 2.1660109127929177, 1.3512202043320023, -2.8114098868745256, -0.4447292428425811, 1.188511748895626, 0.6783979832714906, -0.317051585775403, -0.10924801409813643, 1.8539060720780678, 0.7588061628484052, -1.1268542654085967, -0.2888938610335625, 0.9309236210923635, 0.7821511403377567, 2.7907710204784952, 1.1696090640145176, -1.3366780152148885, 1.0506836982345986, -0.23926185015379092, -1.5530935514625401, 2.050760216281348, -0.06849773655372443, 1.550647026113014, 2.3071199299430063, -1.1748096897901315, 0.9651293368915005, -0.23503703495422576, 3.2832712811259546, -0.47699291331564486, -2.764534616190909, -0.659049453400496, 1.8657276293350016, 1.8615638475458378, -1.1825990521860548, -1.5762237948816398, -1.1377258235995373, -0.264087089171003, -3.370727452998959, 1.0030994638539028, 2.2389755052495177, -3.319809284042486, -1.0104396750187883, -0.37777789881967405, 3.160653489380219, 0.7482208631679532, 0.9105473005189015, -1.2071592931617545]
display = DisplayGame(XSIZE,YSIZE)
run_game(display, snake_game, False, network)