This notebook builds on the work in the functional approach notebook and tries to develop the algorithm to match the calculation done for SGD with neuro-evolution. This notebook also uses the IRIS flowers dataset.

In [None]:
import jax, jax.numpy as jnp, jax.random as jr
import haiku as hk
import optax
from copy import deepcopy
import random
import networkx as nx
from tqdm import tqdm
import pandas as pd
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

print(f"Devices: {jax.devices()}")

In [None]:
from sklearn.preprocessing import MinMaxScaler

# @title The Iris Flowers Dataset
class IrisFlowers:
    def __init__(self, filename):
        self.file = filename
        df = pd.read_csv(self.file)
        dataset = df.to_dict(orient="list")
        self.variety_to_index = {}
        self.index_to_variety = {}

        for index, word in enumerate(set(dataset["variety"])):
            self.variety_to_index[word] = index
            self.index_to_variety[index] = word

        df["variety"] = df["variety"].map(self.variety_to_index)

        df = df.sample(frac=1)

        self.df = df
        self.df_np = self.df.to_numpy()
        scaler = MinMaxScaler(feature_range=(0, 1))
        self.df_np[:, :4] = scaler.fit_transform(self.df_np[:, :4])
        self.length = len(df)

    def single_sample(self) -> np.ndarray:
        random_selection = random.randint(0, self.length - 1)
        return self.df_np[random_selection]

    def batched_sample(self, batch_size: int) -> np.ndarray:
        random_shuffled_df = np.copy(self.df_np)
        np.random.shuffle(random_shuffled_df)
        sample = random_shuffled_df[:batch_size]
        x, y = jnp.asarray(sample[:, :4]), jnp.asarray(sample[:, 4])
        y = y.astype(jnp.int32)
        return x, y

dataset = IrisFlowers("./iris.csv")

# Neural Network Creation and Mutation

In [None]:
def create_graph(num_inputs, num_outputs):
    graph = nx.DiGraph()
    for i in range(1, num_inputs + 1):
        for o in range(num_inputs + 1, num_inputs + num_outputs + 1):
            graph.add_node(i, node="input", weight=np.random.normal())
            if i == 1:
                graph.add_node(o, node="output", weight=np.random.normal())
            graph.add_edge(i, o, enabled=True, weight=np.random.normal())
    return graph

def vertex(graph):
    graph = nx.DiGraph(graph)
    i, o = random.choice(list(graph.edges))
    # graph.edges[(i, o)]["enabled"] = False
    new_node = max(graph.nodes) + 1
    graph.add_node(new_node, node="hidden", weight=np.random.normal())
    graph.add_edge(i, new_node, enabled=True, weight=np.random.normal())
    graph.add_edge(new_node, o, enabled=True, weight=np.random.normal())
    return graph

def connection(graph):
    graph = nx.DiGraph(graph)
    nodes = set(graph.nodes)
    inputs = {node for node in graph.nodes if graph.nodes[node]["node"] == "input"}
    outputs = {node for node in graph.nodes if graph.nodes[node]["node"] == "output"}

    not_choosable = []
    while True:
        in_nodes = list(nodes - outputs - set(not_choosable))
        if len(in_nodes) == 0:
            return graph
        in_node = random.choice(in_nodes)
        out_nodes = nodes - inputs - {in_node} - set(nx.ancestors(graph, in_node))
        if len(out_nodes) == 0:
            not_choosable.append(in_node)
            continue
        out_node = random.choice(list(out_nodes))
        break
    if (in_node, out_node) in graph.edges:
        graph.edges[(in_node, out_node)]["enabled"] = True
    else:
        graph.add_edge(in_node, out_node, enabled=True, weight=np.random.normal())
    return graph

def random_graph(inputs, outputs, mutations):
    graph = create_graph(inputs, outputs)
    mutate_probs = np.sin(np.arange(mutations)) / mutations
    for i in range(mutations):
        mutate_prob = np.sin(i) + 1
        if random.random() < mutate_prob:
            graph = vertex(graph)
        else:
            graph = connection(graph)
    return graph

def plot_graph(graph, ax, with_weights=False):
    if not nx.is_directed_acyclic_graph(graph):
        return
    topological_order = list(nx.topological_sort(graph))
    pos = {}
    y = 0
    for generation in nx.topological_generations(graph):
        for x, node in enumerate(generation):
            if graph.nodes[node]["node"] == "output":
                continue
            pos[node] = (x, y)
        y += 1
    for x, node in enumerate([n for n in graph.nodes if graph.nodes[n]["node"] == "output"]):
        pos[node] = (x, y)
    edge_colors = ["green" if graph.edges[edge]["enabled"] else "red" for edge in graph.edges]
    nx.draw(graph, pos=pos, with_labels=True, ax=ax, edge_color=edge_colors)
    if with_weights:
        edge_labels = {edge: round(graph.edges[edge]["weight"], 2) for edge in graph.edges}
        nx.draw_networkx_edge_labels(graph, pos=pos, edge_labels=edge_labels, font_color="green")

graph = random_graph(4, 3, 5)

# SGD Algorithm

In [None]:
def get_params(graph):
    params = {
        **{f"{node}": jnp.array(graph.nodes[node]["weight"], dtype=jnp.float32) for node in graph.nodes},
        **{f"{i}_{o}": jnp.array(graph.edges[(i, o)]["weight"], dtype=jnp.float32) for i, o in graph.edges}
    }
    return hk.data_structures.to_immutable_dict(params)

def set_params(graph, params):
    graph = nx.DiGraph(graph)
    for node in graph.nodes:
        graph.nodes[node]["weight"] = params[f"{node}"]
    for i, o in graph.edges:
        graph.edges[(i, o)]["weight"] = params[f"{i}_{o}"]
    return graph

In [None]:
def evaluate(graph, params, x):
    inputs = [node for node in graph.nodes if graph.nodes[node]["node"] == "input"]
    outputs = [node for node in graph.nodes if graph.nodes[node]["node"] == "output"]
    values = {i: x[:, idx] for idx, i in enumerate(inputs)}
    eval_stack, evaled_stack = [], []

    for i in inputs:
        for s in graph.successors(i):
            if graph.edges[(i, s)]["enabled"] and s not in eval_stack:
                eval_stack.append(s)

    while len(eval_stack) > 0:
        node = eval_stack.pop(0)

        # Don't evaluate the node if it has already been evaluated.
        if node in values.keys():
            continue

        # Get the enabled parents and the successors of the node.
        parents = {parent for parent in set(nx.all_neighbors(graph, node)) - set(graph.successors(node)) if graph.edges[(parent, node)]["enabled"]}
        successors = set(graph.successors(node)) - set(eval_stack)

        if len(parents - set(values.keys())) == 0:
            value = 0
            for parent in parents:
                weight = params[f"{parent}_{node}"]
                value += weight * values[parent]
            value += params[f"{node}"]
            values[node] = jax.nn.relu(value)
            # Add successors that are not in the evaluation stack.
            eval_stack.extend(successors)
        else:
            eval_stack.extend(parents.union({node}) - set(eval_stack) - set(values.keys()))

    output = jnp.stack([val for i, val in values.items() if i in outputs], axis=-1)
    output = jax.nn.softmax(output, axis=-1)
    return output

In [None]:
def cross_entropy(y_pred, y, labels):
    eps = 1e-9
    y_pred = jnp.clip(y_pred, eps, 1 - eps)
    loss = -jnp.sum(y * jnp.log(y_pred), axis=-1)
    return jnp.mean(loss)

@jax.value_and_grad
def loss_fn(params, graph, x, y):
    y_hat = evaluate(graph, params, x)
    return cross_entropy(y_hat, y, 3)

def accuracy_fn(graph, params, x, y):
    y_hat = evaluate(graph, params, x)
    y_hat = jnp.argmax(y_hat, axis=-1)
    return jnp.mean(jnp.abs(y - y_hat) < 0.001)

accuracy_fn = jax.jit(accuracy_fn, static_argnums=(0,))

def step(graph, params, opt_state, x, y):
    loss, grads = loss_fn(params, graph, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return loss, params, opt_state

step = jax.jit(step, static_argnums=(0,))

Sanity check for the algorithm.

In [None]:
graph = random_graph(4, 3, 50)
params = get_params(graph)
optimizer = optax.adam(1e-4)
opt_state = optimizer.init(params)

losses = []
accuracies = []
runtimes = []

for _ in tqdm(range(5000)):
    x, y = dataset.batched_sample(64)
    loss, params, opt_state = step(graph, params, opt_state, x, jax.nn.one_hot(y, 3))
    accuracy = accuracy_fn(graph, params, x, y)
    losses.append(loss)
    accuracies.append(accuracy)


fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1)

ax1.plot(losses[::10])
ax2.plot(accuracies[::10])

# Evolutionary Aspect

While there may be other algorithms for obtaining an optimal structure for neural networks. Evolution seems to make sense since the primary advantage of using evolution is its ability to explore unstructured spaces. In other words, it is difficult to predict the how adding a new node might affect the neural network as a whole.

In [None]:
config = {
    # Evolution Parameters
    "population_size": 10,
    "num_offsprings": 10,
    "vertex_mutation_prob": 0.5,
    # SGD Parameters
    "training_length": 1000,
    "learning_rate": 1e-4,
    "batch_size": 64,
    "optimizer": optax.adam(1e-4)
}

In [None]:
def learn(graph, config):
    params = get_params(graph)
    opt_state = config["optimizer"].init(params)
    losses = []

    for index in range(config["training_length"]):
        x, y = dataset.batched_sample(config["batch_size"])
        y = jax.nn.one_hot(y, 3)
        loss, params, opt_state = step(graph, params, opt_state, x, y)
        losses.append(loss)

    graph = set_params(graph, params)
    return np.array(losses), graph

graph = create_graph(4, 3)
losses, graph = learn(graph, config)

Create the initial population. Since the population has the same nodes. It makes sense to only train one model and duplicate it for the entire population. Optimally, the nodes should converge at the same weights either way.

In [None]:
graph = create_graph(4, 3)
losses, graph = learn(graph, config)
population = [deepcopy(graph) for _ in range(config['population_size'])]

For each generation, calculate a new population that is the best of the mutated individuals and the previous generation. Then train the combined generation.

In [None]:
offsprings = []
for _ in range(config['num_offsprings']):
    random_individual = random.choice(population)
    offspring = deepcopy(random_individual)
    offspring = vertex(offspring) if random.random() < config['vertex_mutation_prob'] else connection(offspring)
    offsprings.append(offspring)

combined_population = population + offsprings

In [None]:
combined_losses = []
for i in tqdm(range(len(combined_population))):
    loss, offspring = learn(combined_population[i], config)
    combined_population[i] = offspring
    combined_losses.append(loss)

Now remove the worst individuals using the fitness score, in this case accuracy.

In [None]:
fitness_scores = []
for individual in combined_population:
    params = get_params(individual)
    x, y = dataset.batched_sample(128)
    accuracy = accuracy_fn(individual, params, x, y)
    fitness_scores.append(accuracy)

indices = np.flip(np.argsort(fitness_scores))

In [None]:
new_population = [combined_population[i] for i in indices[:config['population_size']]]
new_fitness_scores = [fitness_scores[i] for i in indices[:config['population_size']]]

# Complete Algorithm

The below code runs over the given number of generations.

In [None]:
config = {
    # Evolution Parameters
    "population_size": 10,
    "num_offsprings": 10,
    "vertex_mutation_prob": 0.5,
    "num_generations": 30,
    # SGD Parameters
    "training_length": 200,
    "learning_rate": 1e-4,
    "batch_size": 64,
    "optimizer": optax.adam(1e-4)
}

def learn(graph, config):
    params = get_params(graph)
    opt_state = config["optimizer"].init(params)
    losses = []

    for index in range(config["training_length"]):
        x, y = dataset.batched_sample(config["batch_size"])
        y = jax.nn.one_hot(y, 3)
        loss, params, opt_state = step(graph, params, opt_state, x, y)
        losses.append(loss)

    graph = set_params(graph, params)
    return np.array(losses), graph

# Initial Population
graph = create_graph(4, 3)
losses, graph = learn(graph, config)
population = [deepcopy(graph) for _ in range(config['population_size'])]

for generation in range(config["num_generations"]):
    print(f"Generation: {generation + 1} / {config['num_generations']}")
    offsprings = []
    for _ in range(config['num_offsprings']):
        random_individual = random.choice(population)
        offspring = deepcopy(random_individual)
        offspring = vertex(offspring) if random.random() < config['vertex_mutation_prob'] else connection(offspring)
        offsprings.append(offspring)

    combined_population = population + offsprings

    # Train the new population.
    for i in tqdm(range(len(combined_population))):
        loss, offspring = learn(combined_population[i], config)
        combined_population[i] = offspring

    fitness_scores = []
    for individual in combined_population:
        params = get_params(individual)
        x, y = dataset.batched_sample(128)
        accuracy = accuracy_fn(individual, params, x, y)
        fitness_scores.append(accuracy)

    indices = np.flip(np.argsort(fitness_scores))

    population = [combined_population[i] for i in indices[:config['population_size']]]
    fitness_scores = [fitness_scores[i] for i in indices[:config['population_size']]]

    print(f"Min: {np.min(fitness_scores)}\tMean: {np.mean(fitness_scores)}\tMax: {np.max(fitness_scores)}")