While torch seems easier to use, JAX also optimizes python code, so it might be faster than using torch? This notebook attempts to check that. NOTE: this notebook only has the SGD part of the algorithm.

In theory, this should be just fine. Since JAX recompiles for static argnames, passing it as a static variable should have no issues with the graph. Again using params as a seperate variable so that the network graph doesn't interact with the graph should be helpful.

The exact implementation, however, is the purpose of this notebook.

In [None]:
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import haiku as hk
import optax
import networkx as nx
import random
import pandas as pd
from datetime import datetime
from tqdm import tqdm

%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

# Functional Approach

Maybe using graphs and pytrees is the way to go? Use the methods from before to manage the graphs and for evaluation, convert the weights to pytree weights and keep the graph as a static argument. The weights can be optimized for the pytree.

The graph construction and mutation functions are as follows:

1. `create_graph(num_inputs: int, num_outputs: int) -> nx.DiGraph`
2. `vertex(graph: nx.DiGraph) -> nx.DiGraph`
3. `connection(graph: nx.Digraph) -> nx.DiGraph`

A utility plot function is also below which plots the DAG.

In [None]:
def create_graph(num_inputs, num_outputs):
    graph = nx.DiGraph()
    for i in range(1, num_inputs + 1):
        for o in range(num_inputs + 1, num_inputs + num_outputs + 1):
            graph.add_node(i, node="input", weight=np.random.normal())
            if i == 1:
                graph.add_node(o, node="output", weight=np.random.normal())
            graph.add_edge(i, o, enabled=True, weight=np.random.normal())
    return graph

def vertex(graph):
    graph = nx.DiGraph(graph)
    i, o = random.choice(list(graph.edges))
    graph.edges[(i, o)]["enabled"] = False
    new_node = max(graph.nodes) + 1
    graph.add_node(new_node, node="hidden", weight=np.random.normal())
    graph.add_edge(i, new_node, enabled=True, weight=np.random.normal())
    graph.add_edge(new_node, o, enabled=True, weight=np.random.normal())
    return graph

def connection(graph):
    graph = nx.DiGraph(graph)
    nodes = set(graph.nodes)
    inputs = {node for node in graph.nodes if graph.nodes[node]["node"] == "input"}
    outputs = {node for node in graph.nodes if graph.nodes[node]["node"] == "output"}

    not_choosable = []
    while True:
        in_nodes = list(nodes - outputs - set(not_choosable))
        if len(in_nodes) == 0:
            return graph
        in_node = random.choice(in_nodes)
        out_nodes = nodes - inputs - {in_node} - set(nx.ancestors(graph, in_node))
        if len(out_nodes) == 0:
            not_choosable.append(in_node)
            continue
        out_node = random.choice(list(out_nodes))
        break
    if (in_node, out_node) in graph.edges:
        graph.edges[(in_node, out_node)]["enabled"] = True
    else:
        graph.add_edge(in_node, out_node, enabled=True, weight=np.random.normal())
    return graph

def random_graph(inputs, outputs, mutations):
    graph = create_graph(inputs, outputs)
    for _ in range(mutations):
        graph = connection(vertex(graph))
    return graph

graph = random_graph(4, 3, 5)

In [None]:
def plot_graph(graph, ax):
    if not nx.is_directed_acyclic_graph(graph):
        return
    topological_order = list(nx.topological_sort(graph))
    pos = {}
    y = 0
    for generation in nx.topological_generations(graph):
        for x, node in enumerate(generation):
            if graph.nodes[node]["node"] == "output":
                continue
            pos[node] = (x, y)
        y += 1
    for x, node in enumerate([n for n in graph.nodes if graph.nodes[n]["node"] == "output"]):
        pos[node] = (x, y)
    nx.draw(graph, pos=pos, with_labels=True, ax=ax)

# Evaluation and SGD

This part of the notebook focuses on the evaluation and SGD algorithm used to train each graph. The graph first needs to be broken down into a pytree of the weights and the graph itself.

Params have to be converted to an immutable dict since a regular dictionary isn't hashable. Similar to how a regular list isn't hashable.

In [None]:
def get_params(graph):
    params = {
        **{f"{node}": jnp.array(graph.nodes[node]["weight"], dtype=jnp.float32) for node in graph.nodes},
        **{f"{i}_{o}": jnp.array(graph.edges[(i, o)]["weight"], dtype=jnp.float32) for i, o in graph.edges}
    }
    return hk.data_structures.to_immutable_dict(params)

def set_params(graph, params):
    graph = nx.DiGraph(graph)
    for node in graph.nodes:
        graph.nodes[node]["weight"] = params[f"{node}"]
    for i, o in graph.edges:
        graph.edges[(i, o)]["weight"] = params[f"{i}_{o}"]
    return graph

params = get_params(graph)

In [None]:
def evaluate(graph, params, x):
    inputs = [node for node in graph.nodes if graph.nodes[node]["node"] == "input"]
    outputs = [node for node in graph.nodes if graph.nodes[node]["node"] == "output"]
    values = {i: x[:, idx] for idx, i in enumerate(inputs)}
    eval_stack, evaled_stack = [], []

    for i in inputs:
        for s in graph.successors(i):
            if graph.edges[(i, s)]["enabled"] and s not in eval_stack:
                eval_stack.append(s)

    while len(eval_stack) > 0:
        node = eval_stack.pop(0)

        # Don't evaluate the node if it has already been evaluated.
        if node in values.keys():
            continue

        # Get the enabled parents and the successors of the node.
        parents = {parent for parent in set(nx.all_neighbors(graph, node)) - set(graph.successors(node)) if graph.edges[(parent, node)]["enabled"]}
        successors = set(graph.successors(node)) - set(eval_stack)

        if len(parents - set(values.keys())) == 0:
            value = 0
            for parent in parents:
                weight = params[f"{parent}_{node}"]
                value += weight * values[parent]
            value += params[f"{node}"]
            values[node] = jax.nn.relu(value)
            # Add successors that are not in the evaluation stack.
            eval_stack.extend(successors)
        else:
            eval_stack.extend(parents.union({node}) - set(eval_stack) - set(values.keys()))

    output = jnp.stack([val for i, val in values.items() if i in outputs], axis=-1)
    output = jax.nn.softmax(output, axis=-1)
    return output

evaluate_jit = jax.jit(evaluate, static_argnums=(0,))
x = np.random.normal(size=(16, 4))
evaluate(graph, params, x).shape

Test stuff about the efficiency of using `JIT` in this context.

In [None]:
graph = random_graph(4, 3, 200)
params = get_params(graph)

n = 10
regular_evaluate_times = []
jit_evaluate_times = []

for _ in range(10):
    x = jnp.asarray(np.random.normal(size=(16, 4)))
    start = datetime.now()
    y = evaluate(graph, params, x)
    end = datetime.now()

    regular_evaluate_times.append((end - start).total_seconds())

    start = datetime.now()
    y = evaluate_jit(graph, params, x).block_until_ready()
    end = datetime.now()

    jit_evaluate_times.append((end - start).total_seconds())

plt.plot(regular_evaluate_times)
plt.plot(jit_evaluate_times)

Check the values are consistent between a JITted function and the non JITed version.

# Only SGD Optimization

Test the optimization capabilities of the evaluate function. It should be fairly straightforward.

In [None]:
# @title The Iris Flowers Dataset
class IrisFlowers:
    def __init__(self, filename):
        self.file = filename
        df = pd.read_csv(self.file)
        dataset = df.to_dict(orient="list")
        self.variety_to_index = {}
        self.index_to_variety = {}

        for index, word in enumerate(set(dataset["variety"])):
            self.variety_to_index[word] = index
            self.index_to_variety[index] = word

        df["variety"] = df["variety"].map(self.variety_to_index)
        df = df.sample(frac=1)

        self.df = df
        self.df_np = self.df.to_numpy()
        self.length = len(df)

    def single_sample(self) -> np.ndarray:
        random_selection = random.randint(0, self.length - 1)
        return self.df_np[random_selection]

    def batched_sample(self, batch_size: int) -> np.ndarray:
        random_shuffled_df = np.copy(self.df_np)
        np.random.shuffle(random_shuffled_df)
        sample = random_shuffled_df[:batch_size]
        x, y = jnp.asarray(sample[:, :4]), jnp.asarray(sample[:, 4])
        y = y.astype(jnp.int32)
        return x, y

dataset = IrisFlowers("./iris.csv")

In [None]:
def cross_entropy(y_hat, y, labels):
    y = jax.nn.one_hot(y, labels)
    entropy = -y * jnp.log(y_hat) - (1 - y) * jnp.log(1 - y_hat)
    loss = np.mean(np.sum(entropy, axis=-1))
    return loss

def loss_fn(params, graph, x, y):
    y_hat = evaluate(graph, params, x)
    return cross_entropy(y_hat, y, 3)


loss_fn = jax.value_and_grad(loss_fn)

x, y = dataset.batched_sample(16)
loss, grads = loss_fn(params, graph, x, y)

print(loss)

In [None]:
def accuracy_fn(graph, params, x, y):
    y_hat = evaluate(graph, params, x)
    y_hat = jnp.argmax(y_hat, axis=-1)
    return jnp.sum((y == y_hat) / y.shape[0])

def step(graph, params, opt_state, x, y):
    loss, grads = loss_fn(params, graph, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    return loss, params, opt_state

step_jit = jax.jit(step, static_argnums=(0,))
accuracy_fn_jit = jax.jit(accuracy_fn, static_argnums=(0,))

graph = random_graph(4, 3, 10)
params = get_params(graph)
optimizer = optax.adam(1e-4)
opt_state = optimizer.init(params)
x, y = dataset.batched_sample(16)

loss, params, opt_state = step(graph, params, opt_state, x, y)
graph = set_params(graph, params)

Test everything out in a loop.

In [None]:
graph = random_graph(4, 3, 30)
params = get_params(graph)
optimizer = optax.adam(1e-4)
opt_state = optimizer.init(params)

losses = []
accuracies = []
runtimes = []

for _ in tqdm(range(5000)):
    x, y = dataset.batched_sample(64)
    s = datetime.now()
    loss, params, opt_state = step_jit(graph, params, opt_state, x, y)
    e = datetime.now()

    x, y = dataset.batched_sample(128)
    accuracy = accuracy_fn_jit(graph, params, x, y)

    losses.append(loss)
    accuracies.append(accuracy)
    runtimes.append((e - s).total_seconds())

fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1)
ax1.plot(losses[::10])
ax1.set_title("Loss Plot")
ax2.plot(accuracies[::10])
ax2.set_title("Accuracy Plot")
plt.tight_layout()

In [None]:
losses