In [120]:
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import random
from collections import OrderedDict
from copy import deepcopy

TODO:
    - Add Bias term
    - Custom weights
    - Drop out connections (set weight to 0)
    - Custom activation per node

# Activation functions

In [None]:
def leaky_relu(x):
    return F.leaky_relu(x)

def tanh(x):
    return F.Tanh(x)

def relu(x):
    return F.relu(x)

def sigmoid(x):
    return F.Sigmoid(x)

string_to_activation = {
    'leaky_relu' : leaky_relu,
    'relu' : relu,
    'sigmoid' : sigmoid,
    'tanh' : tanh
}

# Model

In [88]:
class Model(nn.Module):
    def __init__(self,layer_sizes):
        super(Model, self).__init__()
        layers = OrderedDict()
        
        previous_layer_size = layer_sizes[0]
        for idx, current_layer_size in enumerate(layer_sizes[1:]):
            layers[str(idx)] = nn.Linear(previous_layer_size, current_layer_size)
            previous_layer_size = current_layer_size
            
        self.layers = nn.Sequential(layers)
        
    def forward(self, x):
        return self.model(x)

# Genotype

In [124]:
class Genotype(object):
    def __init__(self, inputs, outputs, nonlinearities,
                 p_add_neuron, p_add_connection, p_mutate_weight, p_reenable_connection,
                 p_disable_connection, p_mutate_bias,
                distance_excess_weight, distance_disjoint_weight, distance_weight):
        
        self.inputs = inputs
        self.outputs = outputs
        self.nonlinearities = nonlinearities
        
        # Mutation Probabilities
        self.p_add_neuron = p_add_neuron
        self.p_add_connection = p_add_connection
        self.p_mutate_weight = p_mutate_weight
        self.p_reenable_connection = p_reenable_connection
        self.p_disable_connection = p_disable_connection
        self.p_mutate_bias = p_mutate_bias
        
        # Distance weights
        self.distance_excess_weight = distance_excess_weight
        self.distance_disjoint_weight = distance_disjoint_weight
        self.distance_weight = distance_weight
        
        # Tuples of: id, non_linearity, layer
        self.neuron_genes = []
        # Tuples of: innovation number, input, output, weight, enabled
        self.connection_genes = {}
        # Hyperparameter genes
        self.hyperparameter_genes = []
        
        self.initialise_topology()
        
    def initialise_topology(self):
        # Initialise inputs
        for i in range(self.inputs):
            self.neuron_genes.append([i * 2048, random.choice(self.nonlinearities),0])
        
        # Initialise outputs
        for i in range(self.outputs):
            self.neuron_genes.append([(self.inputs + i) * 2048, random.choice(self.nonlinearities)])
        
        # Initialiase connections
        innovation_number = 0
        for i in range(self.inputs):
            for j in range(self.outputs,self.inputs + self.outputs):
                weight = self.initialise_weight(self.inputs,self.outputs)
                self.connection_genes[(i,j)] = [innovation_number, i, j, weight ,True]
                innovation_number += 1
                
    def initialise_weight(self, input_neurons, output_neurons):
        weight = np.random.rand()*np.sqrt(1/(input_neurons + output_neurons))
        
#     def translate_to_pytorch(self):
#         self.model = Model([self.inputs, 5, self.outputs])
        
#         print(self.model.layers[0].weight)
#         raise NotImplementedError
        
    def get_weight_matrix(self,layer_1, layer_2):
        raise NotImplementedError
        
    def recombinate(self, other):
        child = deepcopy(self)
        child.neuron_genes = []
        child.connection_genes = {}
        
    def mutate(self):
        # TODO: move to separate functions
        if np.random.rand() < self.p_add_neuron:
            # Choose connection to split
            split_neuron = self.connection_genes[random.choice(self.connection_genes.keys())]
            # Disable old connection
            split_neuron[4] = False
            
            input_neuron, output_neuron, weight = split_neuron[1:4]
            neuron_id = (self.neuron_genes[input_neuron][0] + self.neuron_genes[input_neuron][0]) * 0.5
            nonlinearity = random.choice(self.nonlinearities)
            layer = self.neuron_genes[input_neuron][2] + 1
            
            neuron = [neuron_id, nonlinearity, layer]
            
            neuron_id = len(self.node_genes) - 1
            
            self.neuron_genes.append(neuron)
            # 1.0 to initialise_weight?
            # TODO: get innovation number
            self.connection_genes[(input_neuron, neuron_id)] = [innovation_number, input_neuron, neuron_id, 1.0, True]
            
            self.connection_genes[(neuron_id, output_neuron)] = [innovation_number, neuron_id, output_neuron, weight, True]
        
        return self
        
    def distance(self, other):
        raise NotImplementedError

# Species

In [116]:
class Species(object):
    def __init__(self):
        raise NotImplementedError

# Population

In [117]:
class Population(object):
    def __init__(self):
        raise NotImplementedError

# Test

In [118]:
inputs = 3
outputs = 4
nonlinearities = ['relu','sigmoid']

p_add_node = 0.03
p_add_connection = 0.3
p_mutate_weight = 0.8
p_reenable_connection = 0.01
p_disable_connection = 0.01
p_mutate_bias = 0.2

distance_excess_weight = 1.0
distance_disjoint_weight = 1.0
distance_weight = 0.4

genotype = Genotype(inputs, outputs, nonlinearities, 
                    p_add_node, p_add_connection, p_mutate_weight, p_reenable_connection,
                   p_disable_connection, p_mutate_bias,
                   distance_excess_weight, distance_disjoint_weight, distance_weight)

In [14]:
w = torch.empty(3, 5)
nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))

tensor([[-0.2442, -0.6035,  0.8889,  0.5022, -0.4859],
        [-0.2921,  1.2088,  0.4895, -0.8608,  0.1795],
        [ 1.1595, -0.7377, -0.1698,  0.0373,  1.0230]])