In [None]:
import jax
import numpy as np
import jax.numpy as jnp
import equinox as eqx
import optax

In [None]:
class Neuron(eqx.Module):
    weight: jax.Array
    bias: jax.Array
    activation: callable

    def __init__(self, in_features, activation=jax.nn.relu, key=None):
        if key is None:
            key = jax.random.PRNGKey(0)
            key, _ = jax.random.split(key)
        w_key, b_key = jax.random.split(key)
        self.weight = jax.random.normal(w_key, (in_features,))
        self.bias = jax.random.normal(b_key, ())

        self.activation = activation

    def __call__(self, x):
        return self.activation(jnp.dot(self.weight, x) + self.bias)


In [None]:
def identity(x):
    return x


In [None]:
class CustomMLP(eqx.Module):
    layers: list

    def __init__(self, input_size, hidden_sizes, output_size, key=None):
        if key is None:
            key = jax.random.PRNGKey(0)
        keys = jax.random.split(key, len(hidden_sizes) + 1)

        layers = []
        in_features = input_size

        # Create hidden layers
        for i, out_features in enumerate(hidden_sizes):
            layer = [Neuron(in_features, jax.nn.relu, key=keys[i]) for _ in range(out_features)]
            layers.append(layer)
            in_features = out_features

        # Create output layer
        output_layer = [Neuron(in_features, activation=identity, key=keys[-1]) for _ in range(output_size)]
        layers.append(output_layer)

        self.layers = layers

    def __call__(self, x):
        for layer in self.layers:
            x = jnp.array([neuron(x) for neuron in layer])
        return x[0]  # Since output layer is a single neuron

    def add_neuron(self, layer_index, activation=jax.nn.relu, key=None):
        if key is None:
            key = jax.random.PRNGKey(0)
        in_features = self.layers[layer_index][0].weight.shape[0]
        new_neuron = Neuron(in_features, activation, key)
        self.layers[layer_index].append(new_neuron)

        # Adjust the next layer's weight matrix to include the new neuron
        if layer_index + 1 < len(self.layers):
            for i, next_neuron in enumerate(self.layers[layer_index + 1]):
                new_weight = jax.random.normal(key, (1,))
                updated_weights = jnp.append(next_neuron.weight, new_weight)
                self.layers[layer_index + 1][i] = eqx.tree_at(lambda n: n.weight, next_neuron, updated_weights)

    def remove_neuron(self, layer_index, neuron_index):
        if len(self.layers[layer_index]) > 0:
            del self.layers[layer_index][neuron_index]
        
        # Adjust the next layer's weight matrix to remove the corresponding weight
        if layer_index + 1 < len(self.layers):
            for i, next_neuron in enumerate(self.layers[layer_index + 1]):
                updated_weights = jnp.delete(next_neuron.weight, neuron_index)
                self.layers[layer_index + 1][i] = eqx.tree_at(lambda n: n.weight, next_neuron, updated_weights)
    
    def get_shape(self):
        return [len(layer) for layer in self.layers]


In [None]:
def initialize_optimizer_state(mlp, optimizer):
    return optimizer.init(eqx.filter(mlp, eqx.is_inexact_array))


In [None]:
# Example usage
input_size = 3
hidden_sizes = [4, 5]  # Two hidden layers with 4 and 5 neurons respectively
output_size = 1

key = jax.random.PRNGKey(42)
mlp = CustomMLP(input_size, hidden_sizes, output_size, key)
opt = optax.adam(learning_rate=1e-2)
opt_state = initialize_optimizer_state(mlp, opt)


In [None]:
mlp.get_shape()

[4, 5, 1]

In [None]:
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([1.0])

In [None]:
@eqx.filter_value_and_grad()
def compute_loss(mlp, x, y):
    pred = mlp(x)
    return jnp.mean((pred - y) ** 2)

In [None]:
@eqx.filter_jit()
def train_step(mlp, x, y, opt_state, opt_update):
    loss, grads = compute_loss(mlp, x, y)
    updates, opt_state = opt_update(grads, opt_state)
    mlp = eqx.apply_updates(mlp, updates)
    return loss, mlp, opt_state

In [None]:
mlp = CustomMLP(input_size, hidden_sizes, output_size, key)

In [None]:
for epoch in range(100):
    loss, mlp, opt_state = train_step(mlp, x, y, opt_state, opt.update)
    
    # Dynamically add or remove neurons
    if epoch == 5:
        layer = 1
        mlp.add_neuron(layer_index=1, activation=jax.nn.tanh)
        opt_state = initialize_optimizer_state(mlp, opt)
        print(f"Added neuron to hidden layer {layer+1}")
        print(mlp.get_shape())
    elif epoch == 10:
        layer = 1
        neuron_idx = 0
        mlp.remove_neuron(layer_index=layer, neuron_index=neuron_idx)
        opt_state = initialize_optimizer_state(mlp, opt)
        print(f"Removed neuron from hidden layer {layer+1} at index {neuron_idx}")
        print(mlp.get_shape())

    print(f"Epoch {epoch}, Prediction: {mlp(x)}, Loss: {loss}")

print("Final Prediction:", mlp(x))

Epoch 0, Prediction: 1.0107297897338867, Loss: 0.0001336823625024408
Epoch 1, Prediction: 1.0099060535430908, Loss: 0.00011512838682392612
Epoch 2, Prediction: 1.0090949535369873, Loss: 9.812989446800202e-05
Epoch 3, Prediction: 1.0083004236221313, Loss: 8.271817932836711e-05
Epoch 4, Prediction: 1.00752592086792, Loss: 6.889703217893839e-05
Added neuron to hidden layer 2
[4, 6, 1]
Epoch 5, Prediction: 1.1814756393432617, Loss: 5.6639484682818875e-05
Epoch 6, Prediction: 1.152264952659607, Loss: 0.03293340653181076
Epoch 7, Prediction: 1.1233584880828857, Loss: 0.02318461611866951
Epoch 8, Prediction: 1.09499990940094, Loss: 0.015217316336929798
Epoch 9, Prediction: 1.0675103664398193, Loss: 0.00902498234063387
Removed neuron from hidden layer 2 at index 0
[4, 5, 1]
Epoch 10, Prediction: 1.0412989854812622, Loss: 0.004557649604976177
Epoch 11, Prediction: 1.0123785734176636, Loss: 0.001705606235191226
Epoch 12, Prediction: 0.9876964092254639, Loss: 0.0001532290771137923
Epoch 13, Predi