In [1]:
### Imports
###########

import time
import random
import torch
import numpy as np
from snake import Snake 

In [2]:
### Random Seeds
################

torch.manual_seed(123)
random.seed(123)
np.random.seed(123)

In [3]:
### Bot/AI Classes
##################

class RandomBot:
    def __init__(self):
        self.fitness = 0
        
    def get_next_move(self, game):
        return int(random.random() * 4) + 1

class SnakeAI:
    
    def __init__(self, model):
        self.model = model
        self.fitness = 0
        
    def get_next_move(self, game):
        # do something about fitness
        # TODO: move returned by get_next_move() is not always executed!!!!!!!!!!!!!!!!!
        x = self._create_model_input(game)
        next_move = self._process_model_output(model(x))
        return next_move
    
    def _create_model_input(self, game):
        return torch.randn(1, 4)
    
    def _process_model_output(self, output):
        output = output.detach().numpy()
        max_value_index = np.argmax(output, axis=1)[0]
        return max_value_index + 1

In [4]:
### Helper
##########

def init_model(state_dict=None):
    model = torch.nn.Sequential(
        torch.nn.Linear(4, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 4),
    )
    
    if state_dict is not None:
        model.load_state_dict(state_dict)
        
    return model

def get_random_state_dicts(population):
    len_pop = len(population)
    idx1 = int(random.random() * len_pop)
    idx2 = int(random.random() * len_pop)
    while idx1 == idx2:
        idx2 = int(random.random() * len_pop)
        
    return population[idx1][1], population[idx2][1]

def merge_state_dicts(state_dict1, state_dict2):
    return state_dict1 # return merged set

def mutate_state_dict(state_dict):
    None

In [5]:
### Evolutionary Algorithm
##########################

# initialize population
population = []
population_max = 100 # max amount of sets that can be in population
num_initial_pop = 10 # number of sets at the beginning of the evolution
for i in range(num_initial_pop):
    population.append((0, init_model().state_dict()))

# start evolution
generations = 10 # evolution ends after this many generations
mean_over = 25 # the number of games a set gets to play
for t in range(generations):
    
    while len(population) < population_max:
    
        state_dict1, state_dict2 = get_random_state_dicts(population)
        
        state_dict = merge_state_dicts(state_dict1, state_dict2)

        mutate_state_dict(state_dict)
        
        model = init_model(state_dict)

        # calculate fitness-mean over x games
        fitness = 0
        for i in range(mean_over):
            snakeAI = SnakeAI(model)
            snake = Snake(snakeAI, human=False, verbose=False, move_at_ticks=1)
            snake.start()
            fitness += snakeAI.fitness
            
        population.append((fitness / mean_over, model.state_dict()))
        
    population.sort(key=lambda x: x[0], reverse=True)
    population = population[:num_initial_pop] # remove all but best state_dicts from population
    
# TODO save best model to a file