### A simple Neural Network, using toylib!
We repeat the same exercise that we did previously for learning jax, but building `toylib` along this time around!

In [None]:
import jax

from toylib.nn import layers

In [None]:
key = jax.random.PRNGKey(seed=10)

We set up a simple regression problem. The problem is not very meaningful to solve by itself using such a model, but it allows us to get all the pieces in place.

In [None]:
import numpy as np
# problem setup
n = 120  # examlpes
d = 10  # dimension

# dummy data
xs = np.random.normal(size=(n, d))
weights_true = np.random.randint(0, 10, size=(d,))
ys = np.dot(xs, weights_true) + np.random.normal(size=(n,))

xs_train, xs_test = xs[:100], xs[100:]
ys_train, ys_test = ys[:100], ys[100:]
print(weights_true)

Our dummy `ys` are a linear transformation of the inputs. Let's try to fit a model on the train set to predict the values in the test set.

In [None]:
# @jax.jit
def loss_function(model, xs, ys):
    preds = jax.numpy.squeeze(jax.vmap(model)(xs))
    return jax.numpy.mean((ys - preds) ** 2)  # L2 Loss

In [None]:
model = layers.Linear(d, 1, use_bias=False, key=key)

We encounter another jax specific error here. The jax.jit function works specficially on pytrees, but the class that we defined is not interpretable as one.

We need to fix this.

We follow the startegy defined in https://jax.readthedocs.io/en/latest/faq.html#strategy-3-making-customclass-a-pytree. There are other ways to achieve this and altogether different design options possible, we just choose this as a simple and flexible way to make progress.


We rely on the abstract class `Module` and register it as a pytree node class.

The class needs to distinguish between the `dynamic` elements (which need to interact with jax via `jit` & `grad`) vs the `static` elements:
- Hyperparameters (like layer sizes) will be static
- The actual weight arrays will be dynamic

This is very doable for a single class by manually definiing each of these.

For making this more generally useful, we define some helper functions to make a general `pytree` class that is understood by jax.

We make some simplifying assumptions here:
- All jax or numpy arrays in the modules will be parameters
- Everything else is a hyperparameter to be treated as aux data.

Some things are still unclear: what happens with nested modules?
We shall deal with these at a later point.


Now that we have a basic `Linear` module, we can define our first forward-backward pass using the `jax.value_and_grad` function.


In [None]:
loss, grads = jax.value_and_grad(loss_function)(model, xs, ys)

The `grads` is now also an object of the same type `Linear`. This is because jax now treats `Linear` objects as
pytree nodes. For each applicable child in the node, it will produce a grad.

Once we have the grads, we still need to update the original model parameters.

Here, we simply use `theta_new` = `theta` - `alpha * grads`

In [None]:
def apply_update(model, grads, learning_rate):
    return jax.tree_map(lambda x, y: x - learning_rate * y, model, grads)

In [None]:
weights_true

In [None]:
model.tree_flatten()

In [None]:
grads.weights

In [None]:
## training loop
# initial estimates

# hyperparmeters
max_steps = 100
learning_rate = 0.1

step = 0

# until we reach max_steps or the loss doesn't change by <eps>
while step < max_steps:
    # compute model value and grad
    loss, grads = jax.value_and_grad(loss_function)(model, xs, ys)
    print(loss)
    print(model.weights)
    print(grads.weights)
    model = apply_update(model, grads, learning_rate)
    print(model.weights)

    step += 1
    # break

In [None]:
weights_true

In [None]:
model.weights

In [None]:
import matplotlib.pyplot as plt

plt.scatter(weights_true, model.weights.ravel())