In [1]:
# Imports
import random

## Helper functions

In [2]:
def ReLU(x):
    if x < 0:
        return 0
    return x


## NEAT Algorithm

Original paper: https://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf


#### Network Encoding

In the NEAT algorithm each neuron in the neural network is represented as:

![image](assets/genotype.png)

*Source: [Evolving Neural Networks through Augmenting Topologies](https://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf)*


#### Mutation


There are 2 types of mutation:
- Add connection
- Add node

![image](assets/mutation.png)

*Source: [Evolving Neural Networks through Augmenting Topologies](https://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf)*


In [None]:
class InnovationNumber:
    def __init__(self):
        self.value = 0
    
    def pop(self):
        """ Returns the innovation number and increments it automatically """
        tmp = self.value
        self.value += 1
        return tmp

class NodeGene:
    def __init__(self, layer, activation, bias):
        self.layer = layer  # The layer to which the node belongs
        self.activation = activation    # Activation function
        self.bias = bias

class ConncectionGene:
    def __init__(self, in_node: NodeGene, out_node: NodeGene, weight: float,  innov: int, enabled: bool = True):
        self.in_node = in_node
        self.out_node = out_node
        self.weight = weight
        self.enabled = enabled  # Whether the connection is enabled or not
        self.innov = innov  # Innovation number described in the paper
    
    def copy(self):
        return ConncectionGene(self.in_node, self.out_node, self.weight, self.enabled, self.innov)

class Genome:
    def __init__(self, nodes, connections):
        self.nodes = nodes
        self.connections = connections
    
    def mutate_add_connection(self, innov: InnovationNumber):
        node1, node2 = random.sample(self.nodes, 2)
        if node1 == node2:
            return
        for c in self.connections:
            if c.in_node == node1 and c.out_node == node2:
                return
        
        new_conn = ConncectionGene(node1, node2, random.uniform(-1, 1), innov.pop(), True)
        self.connections.append(new_conn)
    
    def mutate_add_node(self, innov: InnovationNumber):
        connection = random.sample(self.connections)
        connection.enabled = False  # Disable the connection

        new_node = NodeGene("hid", ReLU, random.uniform(-1,1))

        conn1 = ConncectionGene(connection.in_node, new_node, 1, innov.pop(), True)
        conn2 = ConncectionGene(new_node, connection.out_node, connection.weight, innov.pop(), True)

        self.nodes.append(new_node)
        self.connections.extend([conn1, conn2])


def find_disjoint(set1, set2):
    set1_dict = {item.innov: item for item in set1}
    set2_dict = {item.innov: item for item in set2}

    disjoint = []

    # Innovations in set1 but not in set2
    for innov, item in set1_dict.items():
        if innov not in set2_dict:
            disjoint.append(item)

    # Innovations in set2 but not in set1
    for innov, item in set2_dict.items():
        if innov not in set1_dict:
            disjoint.append(item)

    return disjoint


def crossover(parent1: Genome, parent2: Genome) -> Genome:
    """ Crossover assuming parent1 is the fittest parent """
    # Build maps of genes keyed by innovation number
    genes1 = {g.innov: g for g in parent1.connections}
    genes2 = {g.innov: g for g in parent2.connections}

    offspring_connections = []
    offspring_nodes = set()

    # Combine all innovation numbers
    all_innovs = set(genes1.keys()) | set(genes2.keys())

    for innov in sorted(all_innovs):
        gene1 = genes1.get(innov)
        gene2 = genes2.get(innov)
        
        if gene1 and gene2:  # Matching genes
            selected = random.choice([gene1, gene2])
            gene_copy = selected.copy()

        elif gene1 and not gene2:   # Disjoint gene (from the fittest parent)
            gene_copy = gene1.copy()
        
        else:   # Not taking disjoint genes from less fit parent
            continue

        offspring_connections.append(gene_copy)
        # offspring_nodes.extend([gene_copy.in_node, gene_copy.out_node])
        offspring_nodes.add(gene_copy.in_node)
        offspring_nodes.add(gene_copy.out_node)

    offspring_nodes = list(offspring_nodes) # Remove the duplicates
    return Genome(offspring_nodes, offspring_connections)
    

