# Prototyping

> Proof of concept exploration of backprop optimization through MLP with number of nodes varying during training.

In [None]:
# | default_exp prototyping

In [None]:
# | export

import jax
import jax.numpy as jnp
import optax
import equinox as eqx
from jax import random
from jax.random import PRNGKey, split

In [None]:
# | export


class Model(eqx.Module):
    layers: list

    def __init__(self, layers):
        self.layers = layers

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jax.nn.relu(layer(x))
        return self.layers[-1](x)

In [None]:
# | export


def MLP(layer_sizes, key):
    keys = split(key, len(layer_sizes) - 1)
    layers = [
        eqx.nn.Linear(in_size, out_size, key=k)
        for in_size, out_size, k in zip(layer_sizes[:-1], layer_sizes[1:], keys)
    ]
    return Model(layers)

In [None]:
# | test


def test_mlp_initialization():
    key = random.PRNGKey(0)
    layer_sizes = [2, 3, 1]
    mlp = MLP(layer_sizes, key)
    assert len(mlp.layers) == 2
    assert mlp.layers[0].weight.shape == (3, 2)
    assert mlp.layers[0].bias.shape == (3,)
    assert mlp.layers[1].weight.shape == (1, 3)
    assert mlp.layers[1].bias.shape == (1,)


def test_mlp_forward_pass():
    key = random.PRNGKey(0)
    layer_sizes = [2, 3, 1]
    mlp = MLP(layer_sizes, key)
    x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
    y = jax.vmap(mlp)(x)
    assert y.shape == (2, 1)

In [None]:
# | test
test_mlp_initialization()

In [None]:
# | test
test_mlp_forward_pass()

In [None]:
# | export
# getters for pytree manipulation
def where_weight(linear):
    return linear.weight


def where_bias(linear):
    return linear.bias

In [None]:
# | export


def add_node(mlp, key):
    new_layers = []

    # first layer
    initial_layer = mlp.layers[0]
    out_features, in_features = initial_layer.weight.shape
    new_inital_shape = (out_features + 1, in_features)
    new_initial_weight = jnp.resize(initial_layer.weight, new_inital_shape)
    new_initial_bias = jnp.resize(initial_layer.bias, (new_inital_shape[0],))
    new_initial_layer = eqx.nn.Linear(new_inital_shape[1], new_inital_shape[0], key=key)
    new_initial_layer = eqx.tree_at(where_weight, new_initial_layer, new_initial_weight)
    new_initial_layer = eqx.tree_at(where_bias, new_initial_layer, new_initial_bias)
    new_layers.append(new_initial_layer)

    # hidden layer(s)
    for i, layer in enumerate(mlp.layers[1:-1]):
        out_features, in_features = layer.weight.shape
        new_shape = (out_features + 1, in_features + 1)
        new_weight = jnp.resize(layer.weight, new_shape)
        new_bias = jnp.resize(layer.bias, (new_shape[0],))
        new_layer = eqx.nn.Linear(new_shape[1], new_shape[0], key=key)
        new_layer = eqx.tree_at(where_weight, new_layer, new_weight)
        new_layer = eqx.tree_at(where_bias, new_layer, new_bias)
        new_layers.append(new_layer)

    # final layer
    final_layer = mlp.layers[-1]
    out_features, in_features = final_layer.weight.shape
    new_final_shape = (out_features, in_features + 1)
    new_final_weight = jnp.resize(final_layer.weight, new_final_shape)
    new_final_bias = jnp.resize(final_layer.bias, (new_final_shape[0],))
    new_final_layer = eqx.nn.Linear(new_final_shape[1], new_final_shape[0], key=key)
    new_final_layer = eqx.tree_at(where_weight, new_final_layer, new_final_weight)
    new_final_layer = eqx.tree_at(where_bias, new_final_layer, new_final_bias)
    new_layers.append(new_final_layer)

    return Model(new_layers)

In [None]:
# | test
def test_add_node():
    key = random.PRNGKey(0)
    layer_sizes = [2, 3, 1]
    mlp = MLP(layer_sizes, key)
    key, subkey = random.split(key)
    new_mlp = add_node(mlp, subkey)
    assert len(new_mlp.layers) == 2
    assert new_mlp.layers[0].weight.shape == (4, 2)
    assert new_mlp.layers[0].bias.shape == (4,)
    assert new_mlp.layers[1].weight.shape == (1, 4)
    assert new_mlp.layers[1].bias.shape == (1,)

In [None]:
# | test
test_add_node()

In [None]:
# | export
def remove_node(mlp, key):
    new_layers = []

    # first layer
    initial_layer = mlp.layers[0]
    out_features, in_features = initial_layer.weight.shape
    new_inital_shape = (out_features - 1, in_features)
    new_initial_weight = jnp.resize(initial_layer.weight, new_inital_shape)
    new_initial_bias = jnp.resize(initial_layer.bias, (new_inital_shape[0],))
    new_initial_layer = eqx.nn.Linear(new_inital_shape[1], new_inital_shape[0], key=key)
    new_initial_layer = eqx.tree_at(where_weight, new_initial_layer, new_initial_weight)
    new_initial_layer = eqx.tree_at(where_bias, new_initial_layer, new_initial_bias)
    new_layers.append(new_initial_layer)

    # hidden layer(s)
    for i, layer in enumerate(mlp.layers[1:-1]):
        out_features, in_features = layer.weight.shape
        new_shape = (out_features - 1, in_features - 1)
        new_weight = jnp.resize(layer.weight, new_shape)
        new_bias = jnp.resize(layer.bias, (new_shape[0],))
        new_layer = eqx.nn.Linear(new_shape[1], new_shape[0], key=key)
        new_layer = eqx.tree_at(where_weight, new_layer, new_weight)
        new_layer = eqx.tree_at(where_bias, new_layer, new_bias)
        new_layers.append(new_layer)

    # final layer
    final_layer = mlp.layers[-1]
    out_features, in_features = final_layer.weight.shape
    new_final_shape = (out_features, in_features - 1)
    new_final_weight = jnp.resize(final_layer.weight, new_final_shape)
    new_final_bias = jnp.resize(final_layer.bias, (new_final_shape[0],))
    new_final_layer = eqx.nn.Linear(new_final_shape[1], new_final_shape[0], key=key)
    new_final_layer = eqx.tree_at(where_weight, new_final_layer, new_final_weight)
    new_final_layer = eqx.tree_at(where_bias, new_final_layer, new_final_bias)
    new_layers.append(new_final_layer)

    return Model(new_layers)

In [None]:
# | test
def test_remove_node():
    key = random.PRNGKey(0)
    layer_sizes = [2, 3, 1]
    mlp = MLP(layer_sizes, key)
    key, subkey = random.split(key)
    new_mlp = remove_node(mlp, subkey)
    assert len(new_mlp.layers) == 2
    assert new_mlp.layers[0].weight.shape == (2, 2)
    assert new_mlp.layers[0].bias.shape == (2,)
    assert new_mlp.layers[1].weight.shape == (1, 2)
    assert new_mlp.layers[1].bias.shape == (1,)

In [None]:
# | test
test_remove_node()

In [None]:
# | export


@eqx.filter_value_and_grad()
def compute_loss(params, x, y):
    preds = jax.vmap(params)(x)
    return jnp.mean((preds - y) ** 2)

In [None]:
# | export


@eqx.filter_jit
def train_step(params, x, y, opt_state, opt_update):
    loss, grads = compute_loss(params, x, y)
    updates, opt_state = opt_update(grads, opt_state)
    params = eqx.apply_updates(params, updates)
    return loss, params, opt_state

In [None]:
# Initialize training data
x = jnp.array([[0.0], [1.0], [2.0], [3.0]])
y = jnp.array([0.0, 1.0, 4.0, 9.0])

In [None]:
# Initialize model and optimizer
key = PRNGKey(0)
layer_sizes = [1, 1]
model = MLP(layer_sizes, key)
opt = optax.sgd(learning_rate=0.01)
opt_state = opt.init(model)

In [None]:
# Training loop
num_epochs = 500
add_node_every = 100
remove_node_every = 150

In [None]:
for epoch in range(num_epochs):
    loss, model, opt_state = train_step(model, x, y, opt_state, opt.update)
    if (epoch + 1) % add_node_every == 0:
        key, subkey = split(key)
        model = add_node(model, subkey)
        print(f"Added node at epoch {epoch + 1}")
        opt_state = opt.init(model)
        print("Model reinitialized")

    if (epoch + 1) % remove_node_every == 0:
        key, subkey = split(key)
        model = remove_node(model, subkey)
        print(f"Removed node at epoch {epoch + 1}")
        opt_state = opt.init(model)
        print("Model reinitialized")

    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch + 1}, Loss: {loss}")

Epoch 50, Loss: 14.493556022644043
Added node at epoch 100
Model reinitialized
Epoch 100, Loss: 13.484792709350586
Removed node at epoch 150
Model reinitialized
Epoch 150, Loss: 12.44898796081543
Added node at epoch 200
Model reinitialized
Epoch 200, Loss: 12.36666202545166
Epoch 250, Loss: 12.303114891052246
Added node at epoch 300
Model reinitialized
Removed node at epoch 300
Model reinitialized
Epoch 300, Loss: 12.278339385986328
Epoch 350, Loss: 12.263911247253418
Added node at epoch 400
Model reinitialized
Epoch 400, Loss: 12.256312370300293
Removed node at epoch 450
Model reinitialized
Epoch 450, Loss: 12.250673294067383
Added node at epoch 500
Model reinitialized
Epoch 500, Loss: 12.25202465057373


In [None]:
# | hide
import nbdev

nbdev.nbdev_export()