### A simple Neural Network, using toylib!
We repeat the same exercise that we did previously for learning jax, but building `toylib` along this time around!

In [1]:
import jax

from toylib.nn import layers

In [2]:
key = jax.random.PRNGKey(seed=10)

I0000 00:00:1702359971.867601   28222 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


We set up a simple regression problem. The problem is not very meaningful to solve by itself using such a model, but it allows us to get all the pieces in place.

In [114]:
import numpy as np
# problem setup
n = 120  # examlpes
d = 10  # dimension

# dummy data
xs = np.random.normal(size=(n, d))
weights_true = np.random.randint(0, 10, size=(d,))
ys = np.dot(xs, weights_true) + np.random.normal(size=(n,))

xs_train, xs_test = xs[:100], xs[100:]
ys_train, ys_test = ys[:100], ys[100:]
print(weights_true)

[2 8 5 5 2 2 5 2 0 1]


Our dummy `ys` are a linear transformation of the inputs. Let's try to fit a model on the train set to predict the values in the test set.

In [115]:
# @jax.jit
def loss_function(model, xs, ys):
    preds = jax.numpy.squeeze(jax.vmap(model)(xs))
    return jax.numpy.mean((ys - preds) ** 2)  # L2 Loss

In [116]:
model = layers.Linear(d, 1, use_bias=False, key=key)

We encounter another jax specific error here. The jax.jit function works specficially on pytrees, but the class that we defined is not interpretable as one.

We need to fix this.

We follow the startegy defined in https://jax.readthedocs.io/en/latest/faq.html#strategy-3-making-customclass-a-pytree. There are other ways to achieve this and altogether different design options possible, we just choose this as a simple and flexible way to make progress.


We rely on the abstract class `Module` and register it as a pytree node class.

The class needs to distinguish between the `dynamic` elements (which need to interact with jax via `jit` & `grad`) vs the `static` elements:
- Hyperparameters (like layer sizes) will be static
- The actual weight arrays will be dynamic

This is very doable for a single class by manually definiing each of these.

For making this more generally useful, we define some helper functions to make a general `pytree` class that is understood by jax.

We make some simplifying assumptions here:
- All jax or numpy arrays in the modules will be parameters
- Everything else is a hyperparameter to be treated as aux data.

Some things are still unclear: what happens with nested modules?
We shall deal with these at a later point.


Now that we have a basic `Linear` module, we can define our first forward-backward pass using the `jax.value_and_grad` function.


In [117]:
loss, grads = jax.value_and_grad(loss_function)(model, xs, ys)

The `grads` is now also an object of the same type `Linear`. This is because jax now treats `Linear` objects as
pytree nodes. For each applicable child in the node, it will produce a grad.

Once we have the grads, we still need to update the original model parameters.

Here, we simply use `theta_new` = `theta` - `alpha * grads`

In [118]:
def apply_update(model, grads, learning_rate):
    return jax.tree_map(lambda x, y: x - learning_rate * y, model, grads)

In [119]:
weights_true

array([2, 8, 5, 5, 2, 2, 5, 2, 0, 1])

In [120]:
model.tree_flatten()

([('weights',
   Array([[ 0.12969103, -0.09365288, -0.2750842 , -0.30919993,  0.22844149,
            0.15211436, -0.11285052,  0.04825569, -0.07129356, -0.1104468 ]],      dtype=float32))],
 {'aux': {'in_features': 10,
   'out_features': 1,
   'use_bias': False,
   'key': Array([ 0, 10], dtype=uint32)},
  'dynamic_keys': ['weights']})

In [121]:
grads.weights

Array([[ -2.6850753, -21.38487  , -13.916252 ,  -8.798282 ,  -4.8938985,
         -4.650757 , -10.49443  ,  -5.7708864,  -1.0235255,  -3.9894276]],      dtype=float32)

In [122]:
## training loop
# initial estimates

# hyperparmeters
max_steps = 100
learning_rate = 0.1

step = 0

# until we reach max_steps or the loss doesn't change by <eps>
while step < max_steps:
    # compute model value and grad
    loss, grads = jax.value_and_grad(loss_function)(model, xs, ys)
    print(loss)
    print(model.weights)
    print(grads.weights)
    model = apply_update(model, grads, learning_rate)
    print(model.weights)

    step += 1
    # break

193.83684
[[ 0.12969103 -0.09365288 -0.2750842  -0.30919993  0.22844149  0.15211436
  -0.11285052  0.04825569 -0.07129356 -0.1104468 ]]
[[ -2.6850753 -21.38487   -13.916252   -8.798282   -4.8938985  -4.650757
  -10.49443    -5.7708864  -1.0235255  -3.9894276]]
[[0.39819857 2.0448341  1.1165411  0.5706282  0.7178314  0.61719006
  0.93659246 0.62534434 0.03105899 0.28849596]]
111.44427
[[0.39819857 2.0448341  1.1165411  0.5706282  0.7178314  0.61719006
  0.93659246 0.62534434 0.03105899 0.28849596]]
[[ -2.329467  -15.6821995 -10.336321   -7.4133377  -3.746536   -3.4573255
   -8.113352   -4.161183   -0.6989823  -2.833923 ]]
[[0.6311453  3.6130543  2.1501732  1.311962   1.092485   0.9629226
  1.7479277  1.0414627  0.10095722 0.57188827]]
64.9169
[[0.6311453  3.6130543  2.1501732  1.311962   1.092485   0.9629226
  1.7479277  1.0414627  0.10095722 0.57188827]]
[[ -2.011976  -11.516059   -7.666044   -6.204724   -2.8485315  -2.5722497
   -6.3181243  -2.9922438  -0.4556164  -1.9848006]]
[[0.832

In [123]:
weights_true

array([2, 8, 5, 5, 2, 2, 5, 2, 0, 1])

In [124]:
model.weights

Array([[ 1.9710617 ,  8.022537  ,  4.9767237 ,  4.961315  ,  1.9318951 ,
         1.9906292 ,  5.10355   ,  2.0776217 , -0.01261421,  0.98264873]],      dtype=float32)