### A simple Neural Network, using toylib!
We repeat the same exercise that we did previously for learning jax, but building `toylib` along this time around!

In [1]:
import jax

from toylib.nn import layers

In [2]:
key = jax.random.PRNGKey(seed=10)

I0000 00:00:1702260288.097675   28089 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
layer1 = layers.Linear(10, 2, key=key)

We set up a simple regression problem. The problem is not very meaningful to solve by itself using such a model, but it allows us to get all the pieces in place.

In [4]:
import numpy as np
# problem setup
n = 120  # examlpes
d = 10  # dimension

# dummy data
xs = np.random.normal(size=(n, d))
ys = np.dot(xs, np.random.randint(0, 10, size=(d,))) + np.random.normal(size=(n,))

xs_train, xs_test = xs[:100], xs[100:]
ys_train, ys_test = ys[:100], ys[100:]

Our dummy `ys` are a linear transformation of the inputs. Let's try to fit a model on the train set to predict the values in the test set.

In [5]:
@jax.jit
def loss(model, xs, ys):
    preds = jax.vmap(model)(xs)
    return jax.numpy.mean((ys - preds) ** 2)  # L2 Loss

In [6]:
model = layers.Linear(d, 1, use_bias=False, key=key)

In [10]:
loss(model, xs, ys)

Array(237.21623, dtype=float32)

We encounter another jax specific error here. The jax.jit function works specficially on pytrees, but the class that we defined is not interpretable as one.

We need to fix this.

We follow the startegy defined in https://jax.readthedocs.io/en/latest/faq.html#strategy-3-making-customclass-a-pytree. There are other ways to achieve this and altogether different design options possible, we just choose this as a simple and flexible way to make progress.


We rely on the abstract class `Module` and register it as a pytree node class.

The class needs to distinguish between the `dynamic` elements (which need to interact with jax via `jit` & `grad`) vs the `static` elements:
- Hyperparameters (like layer sizes) will be static
- The actual weight arrays will be dynamic

This is very doable for a single class by manually definiing each of these.

For a



In [None]:
## training loop
# initial estimates
params = {
    'theta': jax.numpy.array([0, 0], dtype=jax.numpy.float32)
}

# hyperparmeters
max_steps = 100
eps = 1e-4

step = 0

# until we reach max_steps or the loss doesn't change by <eps>
while step < max_steps:
    # compute model value and grad
    (loss, preds), grads = jax.value_and_grad(sum_squared_error, has_aux=True)(params, xs, ys)


    print(loss, params)
    params_ = update(params, grads)
    print(params_)

    if jax.numpy.mean(jax.numpy.abs(params['theta'] - params_['theta'])) < eps:
        break
    params = params_
    step += 1
    # # check diff 
    # prev_loss, diff = loss, jax.numpy.abs(prev_loss - loss)