This notebook tests the ability of the proposed algorithm in finding its efficiency on the MNIST dataset.

In [None]:
import torch, torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import networkx as nx
import numpy as np
import pandas as pd
import math, random, statistics, itertools
from datetime import datetime
from sklearn.datasets import fetch_openml
from tqdm import tqdm

%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

Create the dataset loader and other utilities for evaluating the model. If the model can learn on this data, push toward other tests.

In [None]:
# @title The MNIST Digits Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST specific normalization
])
dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)

def sample(sample_size):
    return next(iter(DataLoader(dataset, batch_size = sample_size)))

In [None]:
# @title The Innovation Database Class
class InnovationDatabase:
    def __init__(self):
        self.i2n = {}
        self.n2i = {}
        self.innovation_number = 1

    def __setitem__(self, key, value):
        if key in self.i2n.keys() or key in self.n2i.keys():
            raise Exception(f"Not Updating Database. Key {key} already exists.")
        if isinstance(key, int):
            self.i2n[key] = value
            self.n2i[value] = key
        else:
            self.n2i[key] = value
            self.i2n[value] = key

    def __getitem__(self, key):
        if isinstance(key, int):
            return self.i2n[key]
        else:
            return self.n2i[key]

    def __str__(self):
        return str(self.i2n)

    def innovation(self, edge):
        if edge not in self.n2i.keys():
            self[edge] = self.innovation_number
            self.innovation_number += 1
        return self[edge]

    def reverse_innovation(self, innovation_number):
        if innovation_number not in self.i2n.keys():
            return None
        return self.i2n[innovation_number]

database = InnovationDatabase()

Create the class that defines each individual in a species. The individual class is centered around a `torch.nn.Module` and also a `nx.DiGraph`. For evaluation, the network uses the DiGraph and during backpropogation, the constructed graph can be used for updating the weights of the model. Modifying the graph allows for mutations in the genome of the individual.

The below implementation has the following features:
    - Given a graph, the class adapts itself to that network.
    - Returning a copy of the graph as a numpy array rather than as a torch array (to make sure the parameters are immutable).
    - Evaluation
    - Mutations (new edge, new vertex).

In [None]:
# @title The network class
class Network(nn.Module):
    def __init__(self, inputs, outputs, database, graph=None):
        super().__init__()

        self.inputs = inputs
        self.outputs = outputs
        self.database = database
        self.params = nn.ParameterDict()
        self.flatten = nn.Flatten()

        if graph is not None:
            self.set_graph(graph)
        else:
            self.graph = nx.DiGraph()
            for i in range(1, inputs + 1):
                for o in range(inputs + 1, inputs + outputs + 1):
                    self.graph.add_node(i, node="input")
                    self.graph.add_node(o, node="output")
                    self.graph.add_edge(
                        i, o, enabled=True, i_n=self.database.innovation((i, o))
                    )
                    self.params.update(
                        {
                            f"{i}": nn.Parameter(
                                data=nn.Parameter(data=self.new_weight())
                            ),
                            f"{o}": nn.Parameter(
                                data=nn.Parameter(data=self.new_weight())
                            ),
                            f"{i}_{o}": nn.Parameter(
                                data=nn.Parameter(data=self.new_weight())
                            ),
                        }
                    )

    def get_graph(self):
        graph = nx.DiGraph()
        for node, data in self.graph.nodes(True):
            graph.add_node(
                node,
                weight=np.asarray(self.params.get(f"{node}").data),
                node=data["node"],
            )

        for (i, o) in self.graph.edges:
            data = self.graph.edges[(i, o)]
            graph.add_edge(
                i,
                o,
                weight=np.asarray(self.params.get(f"{i}_{o}").data),
                enabled=data["enabled"],
                i_n=data["i_n"],
            )

        return graph

    def set_graph(self, graph):
        self.params.clear()
        self.graph = nx.DiGraph()

        for node, data in graph.nodes(True):
            param = nn.Parameter(data=torch.as_tensor(data["weight"]))
            self.params.update({f"{node}": param})
            self.graph.add_node(node, node=data["node"])

        for (i, o) in graph.edges:
            data = graph.edges[(i, o)]
            param = nn.Parameter(data=torch.as_tensor(data["weight"]))
            self.params.update({f"{i}_{o}": param})
            self.graph.add_edge(i, o, enabled=data["enabled"], i_n=data["i_n"])

    def forward(self, x):
        x = self.flatten(x)
        inputs, outputs = self.get_inputs_and_outputs()
        values = {i: x[:, idx] for idx, i in enumerate(inputs)}
        eval_stack, evaled_stack = [], []

        for i in inputs:
            for s in self.graph.successors(i):
                if self.graph.edges[(i, s)]["enabled"] and s not in eval_stack:
                    eval_stack.append(s)

        while len(eval_stack) > 0:
            node = eval_stack.pop(0)

            if node in values.keys():
                continue

            parents = {
                parent
                for parent in set(nx.all_neighbors(self.graph, node))
                - set(self.graph.successors(node))
            }
            successors = {
                i
                for i in set(self.graph.successors(node))
                if self.graph.edges[(node, i)]["enabled"]
            } - set(values.keys())

            if len(parents - set(values.keys())) == 0:
                value = 0
                for parent in parents:
                    weight = self.params[f"{parent}_{node}"]
                    value += weight * values[parent]
                value += self.params[f"{node}"]
                values[node] = torch.relu(value)

                eval_stack.extend(successors)
            else:
                eval_stack.extend(
                    parents.union({node}) - set(eval_stack) - set(values.keys())
                )

        output = torch.stack(
            [val for i, val in values.items() if i in outputs],
            dim=-1,
        )
        output = torch.softmax(output, dim=-1)
        return output

    def mutate_vertex(self):
        edge = random.choice(list(self.enabled_edges()))
        self.graph.edges[edge]["enabled"] = False
        i, o = edge
        new_node = max(self.graph.nodes) + 1
        self.graph.add_node(new_node, node="hidden")
        self.graph.add_edge(
            i, new_node, enabled=True, i_n=self.database.innovation((i, new_node))
        )
        self.graph.add_edge(
            new_node, i, enabled=True, i_n=self.database.innovation((new_node, i))
        )
        self.params.update(
            {
                f"{new_node}": nn.Parameter(data=self.new_weight()),
                f"{i}_{new_node}": nn.Parameter(data=self.new_weight()),
                f"{new_node}_{i}": nn.Parameter(data=self.new_weight()),
            }
        )

    def mutate_connection(self):
        inputs, outputs = self.get_inputs_and_outputs()
        nodes = set(self.graph.nodes)

        possible = list(
            set(
                edge
                for edge in itertools.product(nodes - outputs, nodes - inputs)
                if edge[0] != edge[1]
                and edge[1] not in set(nx.ancestors(self.graph, edge[0]))
            )
            - self.enabled_edges()
        )

        if len(possible) == 0:
            return

        random_edge = random.choice(possible)
        if random_edge in self.graph.edges:
            self.graph.edges[random_edge]["enabled"] = True
        else:
            i, o = random_edge
            self.graph.add_edge(
                i, o, enabled=True, i_n=self.database.innovation(random_edge)
            )
            self.params.update({f"{i}_{o}": nn.Parameter(data=self.new_weight())})

    def plot(self, filename=None, with_weights=False):
        node_colors = {"hidden": "blue", "input": "green", "output": "red"}
        node_colors = [
            node_colors[self.graph.nodes[i]["node"]] for i in self.graph.nodes
        ]
        # Define the edge colors.
        edge_colors = [
            "green" if self.graph.edges[edge]["enabled"] else "red"
            for edge in self.graph.edges
        ]
        # Define the layout of the drawn genome.
        pos = nx.circular_layout(self.graph)

        nx.draw(
            self.graph,
            pos=pos,
            with_labels=True,
            node_color=node_colors,
            edge_color=edge_colors,
        )
        if with_weights:
            edge_labels = {
                edge: round(
                    float(self.params[f"{edge[0]}_{edge[1]}"].data.detach().numpy()), 2
                )
                for edge in self.graph.edges
            }
            nx.draw_networkx_edge_labels(
                self.graph, pos=pos, edge_labels=edge_labels, font_color="green"
            )
        if filename is None:
            plt.show()
            plt.clf()
        else:
            plt.savefig(filename)
            plt.clf()

    def clone(self):
        return Network(self.inputs, self.outputs, self.database, self.get_graph())

    def get_inputs_and_outputs(self):
        return {
            node for node, data in self.graph.nodes(True) if data["node"] == "input"
        }, {node for node, data in self.graph.nodes(True) if data["node"] == "output"}

    def enabled_edges(self):
        return {(i, o) for (i, o, enabled) in self.graph.edges.data("enabled")}

    def new_weight(self):
        return torch.normal(0, 1, size=())

Perform tests to see if the graphs work.

In [None]:
# Graph Creation
net = Network(28 * 28, 10, database)

# Graph Mutations
for _ in range(10):
    net.mutate_connection()
    net.mutate_vertex()

net.set_graph(net.get_graph())

# Graph Evaluation
x, y = sample(16)

print(f"Input Shape: {x.shape}")

start_time = datetime.now()
print(f"Output Shape: {net(x).shape}")
end_time = datetime.now()

print(f"Time For Evaluation: {(end_time - start_time).total_seconds()}")

# Plotting
# net.plot()

## Generation

As pointed out before, the process in each generation goes as follows, with hyperparameters:

$$
\mu = \text{Population Size}\\
\Delta = \text{Maximum Age of an individual}\\
\lambda = \text{Number of offsprings}
$$

During each generation, given $\mu$ individuals, the individuals first generate offspring. The offspring are then mutated. The offspring go through a learning phase where they are trained using SGD for a maximum of $\Delta$ iterations. Finally, the fitness score is calculated for each individual in the population and the offsprings and the best are kept for the next generation.

First things first, test the algorithm without crossover.

In [None]:
mu = 5
delta = 50
lmbda = 5
batch_size = 64
learning_rate = 1e-4
generations = 10

In [None]:
def learn(net):
    optimizer = torch.optim.Adam(params=net.parameters(), lr=learning_rate)
    losses = []
    for iter in range(delta):
        optimizer.zero_grad()
        x, y = sample(batch_size)
        y = y.to(torch.long)

        y_hat = net(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)

        loss.backward()
        optimizer.step()

        losses.append(float(loss.item()))

    return losses

def fitness(net, sample_size):
    x, y = sample(sample_size)
    y_hat = net(x)
    y_hat = torch.argmax(y_hat, axis=-1)
    return float(torch.sum(y == y_hat) / sample_size)

def fitness_population(population, sample_size):
    return np.array([fitness(i, sample_size) for i in population])

def generate_offsprings(population, mutate_vertex_prob=0.3):
    offsprings = []

    for _ in range(lmbda):
        random_individual = random.choice(population).clone()

        # Only one type of mutation at a time. Could possibly alter this later.
        random_individual.mutate_vertex() if random.random() < mutate_vertex_prob else random_individual.mutate_connection()

        offsprings.append(random_individual)

    return offsprings

In [None]:
population = [Network(28 * 28, 10, database) for _ in range(mu)]
losses = [learn(i) for i in tqdm(population)]
fitness_scores = fitness_population(population, 64)
np.max(fitness_scores), np.mean(fitness_scores), np.min(fitness_scores)

In [None]:
# @title The process for one iteration.
offsprings = generate_offsprings(population)
offspring_losses = [learn(individual) for individual in tqdm(offsprings)]

combined = population + offsprings
combined_fitness = fitness_population(combined, 64)
sorted_fitness_indices = np.flip(np.argsort(combined_fitness))

new_population = [combined[i] for i in sorted_fitness_indices[:mu]]
fitness_scores = np.flip(np.sort(combined_fitness))[:mu]

np.max(fitness_scores), np.mean(fitness_scores), np.min(fitness_scores)

## Complete Algorithm

The algorithm below implements all the ideas with a mutation only generated next generation and SGD for learning during each generation. Since the population trains in the previous generation, it doesn't have to train, therefore, in each generation only train the offsprings since they were mutated. If the mutation works, then the offspring should have the optimal weight values assigned to it.

In [None]:
# Hyper parameters
mu = 1 # Population Size
delta = 20 # The Gestation Period (learning phase) of each individual.
lmbda = 1 # The number of offsprings generated
batch_size = 64 # The batch size used for SGD and fitness score calculation.
learning_rate = 1e-4 # The learning rate for SGD.
generations = 50 # The number of generations to perform the evolution for.
scaling_factor = 10 # The scaling factor for the sine curve.

mutate_vertex_probs = np.sin(np.arange(generations) / scaling_factor) + 1

# Keep track of some stuff.
generational_losses = []
generational_fitness = []

population = [Network(28 * 28, 10, database) for _ in range(mu)]
losses = np.array([learn(i) for i in tqdm(population)]) # The initial population has to learn too.
fitness_scores = fitness_population(population, 64)

generational_losses.append(np.mean(losses))
generational_fitness.append((np.max(fitness_scores), np.mean(fitness_scores), np.min(fitness_scores)))

print(generational_fitness)

In [None]:
for g in range(generations):

    offsprings = generate_offsprings(population, mutate_vertex_prob=mutate_vertex_probs[g])
    offspring_losses = [learn(individual) for individual in tqdm(offsprings)]

    # Find the mu best amongst the offspring and the population.
    combined = population + offsprings
    combined_fitness = fitness_population(combined, 64)
    sorted_fitness_indices = np.flip(np.argsort(combined_fitness))
    new_population = [combined[i] for i in sorted_fitness_indices[:mu]]

    fitness_scores = fitness_population(new_population, 64)
    generational_losses.append(np.mean(offspring_losses))
    generational_fitness.append((np.max(fitness_scores), np.mean(fitness_scores), np.min(fitness_scores)))

    population = new_population

    print(f"Generation: {g + 1} / {generations}: {generational_fitness[-1]}\tReplacements: {np.sum(sorted_fitness_indices[:mu] > mu)}, Max Replace: {sorted_fitness_indices[0] == 0}")

In [None]:
fitness_scores = pd.DataFrame(generational_fitness, columns=["Max", "Mean", "Min"])
ax = fitness_scores.plot()
plt.xlabel("Generation #")
plt.ylabel("Fitness Score (out of 1)")
plt.title("Fitness Scores Over Generations")

best_individual = population[0].get_graph()
# population[0].plot()

In [None]:
print(len(population[0].graph.edges))

In [None]:
opt_state = optax.adam(1e-4)(graph)