### toylib: A simple Neural Network library in jax!

Let's build yet another jax neural networks library. Our primary aim is to keep things (relatively) simple and transparent and progressively building up complexity. Why build a library at all? Reuse - as we build more complex networks, it is important to have building blocks we can rely upon. The goal is not to compete with Flax or Haiku — but to rebuild their core ideas from scratch.

I have been greatly inspired by the design of [equinox](https://github.com/patrick-kidger/equinox). Quoting from their Github page:
- neural networks (or more generally any model), with easy-to-use PyTorch-like syntax;
- filtered APIs for transformations;
- useful PyTree manipulation routines;
- advanced features like runtime errors;
- and best of all, Equinox isn't a framework: everything you write in Equinox is compatible with anything else in JAX or the ecosystem.

This is very close to my ideal library for sticking as close as possible to the core jax functionality. Having said that, equinox is still a feature complete library that holds its own against the likes of Flax and PyTorch.  We aim for a much simpler version inspired by Equinox with no bells and whistles.

Jax transforms work very well with [Pytrees](https://docs.jax.dev/en/latest/pytrees.html). It is a natural choice then to represent all our model parameters as a Pytrees. An MLP with a couple of linear layers could be represented something like:

```python
params = {
    'layer1': {
        'w': [0, 0, 1, 2],
        'b': [1]
    },
    'layer2': {
        'w': [-1, 2, 0, 3],
        'b': [-1]
    },
}
```

This is simple and works natively with jax transforms, but this will quickly get out of hand as we develop more nested components.

So, we try to come up with the `Module` abstraction. The Module is still a Pytree node, with:
- trainable fields for parameters (weights, biases, etc.)
- static fields for metadata (like hidden sizes, name scopes, etc.)
- nested modules to enable composition
- compatibility with jax.jit, jax.grad, etc.

## Setup

In [None]:
import jax
import jaxtyping
import math
import numpy as np
import typing

We set up a simple linear regression in N variables. Analytic solutions exist for this problem, but we use this as the first toy example as we build up toylib!

In [None]:
# Problem setup
n = 120  # examples
d = 10  # dimensions

# Generate some dummy data
np.random.seed(31)
xs = np.random.normal(size=(n, d))
weights_true = np.random.randint(0, 10, size=(d,))
ys = np.dot(xs, weights_true) + np.random.normal(size=(n,))

xs_train, xs_test = xs[:100], xs[100:]
ys_train, ys_test = ys[:100], ys[100:]
print(weights_true)

Let's define a linear layer that does a single matrix multiplication and optionally adds a bias.

In [None]:
class Module:
    pass

class Linear(Module):
    """Defines a simple feedforward layer: which is a linear transformation. """

    # Trainable parameters
    weights: jaxtyping.Array
    bias: typing.Optional[jaxtyping.Array]

    # Hyperparameters / metadata
    in_features: int
    out_features: int
    use_bias: bool

    def __init__(self, in_features: int, out_features: int, use_bias: bool = True, *, key: jaxtyping.PRNGKeyArray) -> None:
        # Split the random key for weights and bias
        w_key, b_key = jax.random.split(key, 2)
        
        # We initialize the weights with a uniform distribution
        lim = 1 / math.sqrt(in_features)
        self.weights = jax.random.uniform(w_key, (in_features, out_features), minval=-lim, maxval=lim)
        if use_bias:
            self.bias = jax.random.uniform(b_key, (out_features,), minval=-lim, maxval=lim)

        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = use_bias
        self.key = key

    def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array:
        x = jax.numpy.dot(x, self.weights)
        if self.use_bias:
            x = x + self.bias
        return x

Let's initialize a linear layer to match our data and do a simple forward pass.

In [None]:
model = Linear(d, 1, use_bias=True, key=jax.random.PRNGKey(0))

In [None]:
y_pred = model(xs_train)
print(y_pred.shape)

Looks good so far, let's define a loss function. We use the L2 loss (mean squared error) here.

In [None]:
def loss_function(model, xs, ys):
    preds = jax.numpy.squeeze(model(xs))
    return jax.numpy.mean((ys - preds) ** 2)  # L2 Loss

print(loss_function(model, xs_train, ys_train))

For training the model, we would utilize jax's `value_and_grad` function.

In [None]:
jax.value_and_grad(loss_function)(model, xs_train, ys_train)

We encounter another jax specific error here. Jax transformations like jax.grad, jax.value_and_grad, jax.jit, etc., operate on JAX-compatible types: mostly jax.Arrays, Python containers like tuples/lists/dicts of JAX arrays, and custom types that are registered with JAX's PyTree machinery.

We need to fix this.

We follow the startegy defined in https://jax.readthedocs.io/en/latest/faq.html#strategy-3-making-customclass-a-pytree to define a custom Pytree node.

Our `Module` class needs to distinguish between the `dynamic` elements (which need to interact with jax via `jit` & `grad`) vs the `static` elements:
- Hyperparameters (like layer sizes) will be static
- The actual weight arrays will be dynamic

We make the following changes:
- Use the `@register_pytree_node_class` decorator to tell Jax that Linear is also a Pytree
- Add a `tree_flatten` method that separates the dynamic and static elements
- Add a `tree_unflatten` method that constructs the object back given the dynamic and static elements


In [None]:
from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class Linear(Module):
    """Defines a simple feedforward layer: which is a linear transformation. """

    # Trainable parameters
    weights: jaxtyping.Array
    bias: typing.Optional[jaxtyping.Array]

    # Hyperparameters / metadata
    in_features: int
    out_features: int
    use_bias: bool

    def __init__(self, in_features: int, out_features: int, use_bias: bool = True, *, key: jaxtyping.PRNGKeyArray) -> None:
        # Split the random key for weights and bias
        w_key, b_key = jax.random.split(key, 2)
        
        # We initialize the weights with a uniform distribution
        lim = 1 / math.sqrt(in_features)
        self.weights = jax.random.uniform(w_key, (in_features, out_features), minval=-lim, maxval=lim)
        if use_bias:
            self.bias = jax.random.uniform(b_key, (out_features,), minval=-lim, maxval=lim)

        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = use_bias
        self.key = key

    def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array:
        x = jax.numpy.dot(x, self.weights)
        if self.use_bias:
            x = x + self.bias
        return x

    def tree_flatten(self) -> tuple:
        params = [self.weights, self.bias]
        static = {
            'in_features': self.in_features,
            'out_features': self.out_features,
            'use_bias': self.use_bias,
            'key': self.key,
        }
        return params, static

    @classmethod
    def tree_unflatten(cls, static, dynamic) -> 'Linear':
        weights, bias = dynamic
        in_features = static['in_features']
        out_features = static['out_features']
        use_bias = static['use_bias']
        obj = cls(in_features, out_features, use_bias, key=static['key'])
        obj.weights = weights
        obj.bias = bias
        return obj

In [None]:
model = Linear(d, 1, use_bias=True, key=jax.random.PRNGKey(0))
value, grad = jax.value_and_grad(loss_function)(model, xs_train, ys_train)
print(value, grad)

Great! We are able to get past the error and do a full forward and backward pass on the model.

The `grads` is now also an object of the same type `Linear`. This is because jax now treats `Linear` objects as pytree nodes. For each applicable child in the node, it will produce a grad.

How do we apply this gradient to get an update? Since `Linear` is now a PyTree, both model and grads share the same structure. We can use [jax.tree_utils.tree_map](https://docs.jax.dev/en/latest/_autosummary/jax.tree_util.tree_map.html) to walk both trees and apply the update element-wise.

Here, we simply use `theta_new` = `theta` - `learning_rate * grads`

In [None]:
def update(model, grad, learning_rate=0.01):
    """Update the model parameters using gradient descent."""
    return jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, model, grad)

In [None]:
updated = update(model, grad, learning_rate=0.01)

In [None]:
print('model')
jax.tree_util.tree_map(lambda x: print(x.shape), model)
print('grad')
jax.tree_util.tree_map(lambda x: print(x.shape), grad)
print('updated')
jax.tree_util.tree_map(lambda x: print(x.shape), updated)

This was relatively straightforward for the `Linear` class. However, it is not scalable to define these serialize/deserialize methods for every module we define.

To make this more generally useful, we move these methods to the base `Module` class.

We make some simplifying assumptions here:
- All jax or numpy arrays in the modules will be parameters
- Everything else is a hyperparameter to be treated as aux data

In [None]:
def _is_array(x: typing.Any) -> bool:
    return isinstance(
        x, (jax.Array, np.ndarray, np.generic)
    ) or hasattr(x, "__jax_array__")


def _is_random_key(x: str) -> bool:
    return x == 'key'

@register_pytree_node_class
class Module:
    def tree_flatten(self) -> tuple:
        params = []
        param_keys = []
        aux_data = dict()

        # Look through each attribute in the object
        for k, v in self.__dict__.items():
            if _is_array(v) and not _is_random_key(k):
                # trainable leaf param!
                params.append(v)
                param_keys.append(k)
            else:
                aux_data[k] = v

        aux_data['param_keys'] = param_keys
        return params, aux_data

    @classmethod
    def tree_unflatten(cls, static, dynamic) -> 'Linear':
        # Create a new empty object
        obj = object.__new__(cls)

        # overwrite all of the children using the values in the given pytree
        for k, v in zip(static['param_keys'], dynamic):
            obj.__setattr__(k, v)

        for k, v in static.items():
            obj.__setattr__(k, v)

        return obj

@register_pytree_node_class
class Linear(Module):
    """Defines a simple feedforward layer: which is a linear transformation. """

    # Trainable parameters
    weights: jaxtyping.Array
    bias: typing.Optional[jaxtyping.Array]

    # Hyperparameters / metadata
    in_features: int
    out_features: int
    use_bias: bool

    def __init__(self, in_features: int, out_features: int, use_bias: bool = True, *, key: jaxtyping.PRNGKeyArray) -> None:
        # Split the random key for weights and bias
        w_key, b_key = jax.random.split(key, 2)
        
        # We initialize the weights with a uniform distribution
        lim = 1 / math.sqrt(in_features)
        self.weights = jax.random.uniform(w_key, (in_features, out_features), minval=-lim, maxval=lim)
        if use_bias:
            self.bias = jax.random.uniform(b_key, (out_features,), minval=-lim, maxval=lim)

        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = use_bias
        self.key = key

    def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array:
        x = jax.numpy.dot(x, self.weights)
        if self.use_bias:
            x = x + self.bias
        return x

In [None]:
model = Linear(d, 1, use_bias=True, key=jax.random.PRNGKey(0))
value, grad = jax.value_and_grad(loss_function)(model, xs_train, ys_train)
print(value, grad)

That works!

The serializing function `tree_flatten` identifies all trainable parameters - which are assumed to be jax arrays, except for the random keys - and sets them as the "dynamic" elements in the pytree. Everything else, including the names of these parameters are added to the `aux_data` dict.
All jax transforms will serialize our class using this function, operate on the dynamic elements, and then reconstruct the class using the `tree_unflatten` method.

There is still one major feature that we haven't addressed yet: nested modules. 

Let's look at a simple MLP that uses two linear layers.

In [None]:
@register_pytree_node_class
class MLP(Module):
    output_layer: Linear
    layers: typing.List[Module]

    in_features: int
    hidden_dims: list[int]
    out_features: int

    def __init__(self, in_features: int, hidden_dims: list[int], out_features: int, *, key: jaxtyping.PRNGKeyArray) -> None:
        # Split the random key for weights and bias
        keys = jax.random.split(key, len(hidden_dims) + 1)

        # Create the layers
        layers = []
        input_dim = in_features
        for i, hidden_dim in enumerate(hidden_dims):
            layer = Linear(input_dim, hidden_dim, key=keys[i])
            layers.append(layer)
            input_dim = hidden_dim

        # Create the output layer
        output_layer = Linear(input_dim, out_features, key=keys[-1])

        self.layers = layers
        self.output_layer = output_layer

    def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array:
        for layer in self.layers:
            x = layer(x)
            x = jax.nn.relu(x)
        return self.output_layer(x)

In [None]:
model = MLP(d, [32, 32], 1, key=jax.random.PRNGKey(0))
print(jax.value_and_grad(loss_function)(model, xs_train, ys_train))
print(jax.tree_util.tree_flatten(model))

At first glance, this looks okay since the forward pass works fine. However, a closer look at the output of `tree_flatten` reveals that the nested modules are not included in the dynamic elements at all and hence would be ignored by the JAX transforms.

This fails presently because we never serialize any nested modules.

In [None]:
def _is_supported_container(x: typing.Any) -> bool:
    return isinstance(x, (list, tuple))

@register_pytree_node_class
class Module:
    def tree_flatten(self) -> tuple:
        params = []
        param_keys = []
        aux_data = dict()

        # Look through each attribute in the object
        for k, v in self.__dict__.items():
            if (
                (_is_array(v) and not _is_random_key(k))
                or isinstance(v, Module)
                or (_is_supported_container(v) and all(isinstance(elem, Module) for elem in v))
            ):
                # trainable leaf param!
                params.append(v)
                param_keys.append(k)
            else:
                aux_data[k] = v

        aux_data['param_keys'] = param_keys
        return params, aux_data

    @classmethod
    def tree_unflatten(cls, static, dynamic) -> 'Linear':
        # Create a new empty object
        obj = object.__new__(cls)

        # overwrite all of the children using the values in the given pytree
        for k, v in zip(static['param_keys'], dynamic):
            obj.__setattr__(k, v)

        for k, v in static.items():
            obj.__setattr__(k, v)

        return obj

@register_pytree_node_class
class Linear(Module):
    """Defines a simple feedforward layer: which is a linear transformation. """

    # Trainable parameters
    weights: jaxtyping.Array
    bias: typing.Optional[jaxtyping.Array]

    # Hyperparameters / metadata
    in_features: int
    out_features: int
    use_bias: bool

    def __init__(self, in_features: int, out_features: int, use_bias: bool = True, *, key: jaxtyping.PRNGKeyArray) -> None:
        # Split the random key for weights and bias
        w_key, b_key = jax.random.split(key, 2)
        
        # We initialize the weights with a uniform distribution
        lim = 1 / math.sqrt(in_features)
        self.weights = jax.random.uniform(w_key, (in_features, out_features), minval=-lim, maxval=lim)
        if use_bias:
            self.bias = jax.random.uniform(b_key, (out_features,), minval=-lim, maxval=lim)

        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = use_bias
        self.key = key

    def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array:
        x = jax.numpy.dot(x, self.weights)
        if self.use_bias:
            x = x + self.bias
        return x

@register_pytree_node_class
class MLP(Module):
    output_layer: Linear
    layers: typing.List[Module]

    in_features: int
    hidden_dims: list[int]
    out_features: int

    def __init__(self, in_features: int, hidden_dims: list[int], out_features: int, *, key: jaxtyping.PRNGKeyArray) -> None:
        # Split the random key for weights and bias
        keys = jax.random.split(key, len(hidden_dims) + 1)

        # Create the layers
        layers = []
        input_dim = in_features
        for i, hidden_dim in enumerate(hidden_dims):
            layer = Linear(input_dim, hidden_dim, key=keys[i])
            layers.append(layer)
            input_dim = hidden_dim

        # Create the output layer
        output_layer = Linear(input_dim, out_features, key=keys[-1])

        self.layers = layers
        self.output_layer = output_layer

    def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array:
        for layer in self.layers:
            x = layer(x)
            x = jax.nn.relu(x)
        return self.output_layer(x)

In [None]:
model = MLP(d, [32, 64], 1, key=jax.random.PRNGKey(0))
print(jax.value_and_grad(loss_function)(model, xs_train, ys_train))
jax.tree_util.tree_map(lambda x: print(x.shape), model)

In [None]:
print(jax.tree_util.tree_flatten(model))

## Train a model!

In [None]:
## training loop
# initial estimates

# hyperparmeters
max_steps = 100
learning_rate = 0.1

step = 0

# until we reach max_steps or the loss doesn't change by <eps>
while step < max_steps:
    # compute model value and grad
    loss, grads = jax.value_and_grad(loss_function)(model, xs, ys)
    print(loss)
    print(model.weights)
    print(grads.weights)
    model = apply_update(model, grads, learning_rate)
    print(model.weights)

    step += 1
    # break

In [None]:
weights_true

In [None]:
model.weights