In [None]:
import jax
import numpy as np
import jax.numpy as jnp
import equinox as eqx
import optax

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.style as mplstyle

import seaborn as sns

In [None]:
plt.style.use('default')
sns.set_theme(context='paper', style='white', palette='icefire', font='serif',
            font_scale=2, color_codes=True, rc={'text.usetex' : True})
mplstyle.use('fast')

In [None]:
def identity(x):
    return x


In [None]:
class Neuron(eqx.Module):
    weight: jax.Array
    bias: jax.Array
    activation: callable

    def __init__(self, in_features, activation=identity, key=None):
        if key is None:
            key = jax.random.PRNGKey(0)
        w_key, b_key = jax.random.split(key)
        self.weight = jax.random.normal(w_key, (in_features,))
        self.bias = jax.random.normal(b_key, ())
        self.activation = activation

    def __call__(self, x):
        return self.activation(jnp.dot(self.weight, x) + self.bias)

In [None]:
class CustomMLP(eqx.Module):
    layers: list

    def __init__(self, input_size, hidden_sizes, output_size, activations, key=None):
        if key is None:
            key = jax.random.PRNGKey(0)
        keys = jax.random.split(key, len(hidden_sizes) + 1)

        layers = []
        in_features = input_size

        for i, out_features in enumerate(hidden_sizes):
            layer = [Neuron(in_features, activations[i], key=keys[i]) for _ in range(out_features)]
            layers.append(layer)
            in_features = out_features

        output_layer = [Neuron(in_features, activation=identity, key=keys[-1])]
        layers.append(output_layer)

        self.layers = layers

    def __call__(self, x):
        for layer in self.layers:
            # Apply each neuron to each element of x
            x = jax.vmap(lambda n, x: n(x), (0, 0))(layer, x)
        return x.mean(axis=0)

    def add_neuron(self, layer_index, activation=identity, key=None):
        if key is None:
            key = jax.random.PRNGKey(0)
        in_features = self.layers[layer_index][0].weight.shape[0]
        new_neuron = Neuron(in_features, activation, key)
        self.layers[layer_index].append(new_neuron)

        # Adjust the next layer's weight matrix to include the new neuron
        if layer_index + 1 < len(self.layers):
            for i, next_neuron in enumerate(self.layers[layer_index + 1]):
                new_weight = jax.random.normal(key, (1,))
                updated_weights = jnp.append(next_neuron.weight, new_weight)
                self.layers[layer_index + 1][i] = eqx.tree_at(lambda n: n.weight, next_neuron, updated_weights)

    def remove_neuron(self, layer_index, neuron_index):
        if len(self.layers[layer_index]) > 0:
            del self.layers[layer_index][neuron_index]

            # Adjust the next layer's weight matrix to remove the corresponding weight
            if layer_index + 1 < len(self.layers):
                for i, next_neuron in enumerate(self.layers[layer_index + 1]):
                    updated_weights = jnp.delete(next_neuron.weight, neuron_index)
                    self.layers[layer_index + 1][i] = eqx.tree_at(lambda n: n.weight, next_neuron, updated_weights)

    def get_shape(self):
        return [len(layer) for layer in self.layers]


In [None]:
# Example usage
input_size = 1
hidden_sizes = [4, 5]  # Two hidden layers with 4 and 5 neurons respectively
output_size = 1
activations = [jax.nn.relu, jax.nn.sigmoid]  # Different activation functions for each layer

key = jax.random.PRNGKey(42)
mlp = CustomMLP(input_size, hidden_sizes, output_size, activations, key)

# Initialize the optimizer
opt = optax.adam(learning_rate=0.01)
opt_state = opt.init(eqx.filter(mlp, eqx.is_inexact_array))

In [None]:
@eqx.filter_value_and_grad()
def compute_loss(mlp, x, y):
    pred = mlp(x)
    return jnp.mean((pred - y) ** 2)

In [None]:
@eqx.filter_jit()
def train_step(mlp, x, y, opt_state, opt_update):
    loss, grads = compute_loss(mlp, x, y)
    updates, opt_state = opt_update(grads, opt_state)
    mlp = eqx.apply_updates(mlp, updates)
    return loss, mlp, opt_state

In [None]:
# Example batched training data
x = jnp.linspace(-jnp.pi, jnp.pi, 128).reshape(-1, 1)
y = jnp.sin(x)


In [None]:
def initialize_optimizer_state(mlp, optimizer):
    return optimizer.init(eqx.filter(mlp, eqx.is_inexact_array))

In [None]:
activation_list = [jax.nn.relu, jax.nn.sigmoid, jax.nn.tanh]
jax.random.choice(key, jnp.arange(len(activation_list)))

Array(2, dtype=int32)

In [None]:
activation_list = [jax.nn.relu, jax.nn.sigmoid, jax.nn.tanh]
num_epochs = 100
add_node_every = 1
remove_node_every = 1
Loss_history = []
Node_history = []

In [None]:
for epoch in range(num_epochs):
    loss, mlp, opt_state = train_step(mlp, x, y, opt_state, opt.update)
    key, add_key, sub_key = jax.random.split(key,3)
    n_neurons = sum(mlp.get_shape())
    Loss_history.append(loss)
    Node_history.append(n_neurons)

    # Dynamically add or remove neurons
    if (epoch + 1) % add_node_every == 0 and jax.random.uniform(add_key) < 0.05:
        add_key, act_key = jax.random.split(add_key)
        activation = activation_list[jax.random.choice(key, jnp.arange(len(activation_list)))]
        layers = len(mlp.get_shape()) - 1
        layer = jax.random.randint(act_key, (1,), 0, layers)[0]
        mlp.add_neuron(layer_index=1, activation=activation, key=add_key)
        opt_state = initialize_optimizer_state(mlp, opt)
        print(f"Added neuron to hidden layer {layer+1} with activation {activation.__name__}")
        print(mlp.get_shape())
    
    elif (epoch + 1) % remove_node_every == 0 and jax.random.uniform(sub_key) < 0.05:
        layer_key, neuron_key, sub_key = jax.random.split(sub_key,3)
        layers = len(mlp.get_shape()) - 1
        layer = jax.random.randint(layer_key, (1,), 0, layers)[0]
        layer_neurons = len(mlp.layers[layer])
        neuron_idx = jax.random.randint(neuron_key, (1,), 0, layer_neurons)[0]
        mlp.remove_neuron(layer_index=layer, neuron_index=neuron_idx)
        opt_state = initialize_optimizer_state(mlp, opt)
        print(f"Removed neuron from hidden layer {layer+1} at index {neuron_idx}")
        print(mlp.get_shape())

    print(f"Epoch {epoch}, Prediction: {mlp(x)}, Loss: {loss}")

print("Final Prediction:", mlp(x))

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())