### Defining the backbone of our toylib NN library!
As we embark on the journey of creating more complex models, it would be good to have some reusable pieces.

We create the foundation of our NN library, an abstract base class: `Module`, which all our other modules will inherit.

In [18]:
from abc import ABC


class Module(ABC):
    pass


In [60]:
# Our first Module is the `Linear` class: which simply creates a Feedforward layer.
import math
import jax

from typing import Any, Optional
from jaxtyping import Array, PRNGKeyArray


class Linear(Module):
    """Defines a simple feedforward layer: which is a linear transformation."""

    weights: Array
    bias: Optional[Array]

    in_features: int
    out_features: int
    use_bias: bool

    def __init__(self, in_features: int, out_features: int, use_bias: bool = True, *, key: PRNGKeyArray) -> None:
        w_key, b_key = jax.random.split(key, 2)
        lim = 1 / math.sqrt(in_features)
        self.weights = jax.random.uniform(w_key, (out_features, in_features), minval=-lim, maxval=lim)
        if use_bias:
            self.bias = jax.random.uniform(b_key, (out_features,), minval=-lim, maxval=lim)

        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = use_bias

    def __call__(self, x: Array) -> Array:
        x = jax.numpy.dot(self.weights, x)
        if self.use_bias:
            x = x + self.bias
        return x

The `jaxtyping` library is used to define types in addition to the standard py library `typing`.

In [61]:
key = jax.random.PRNGKey(seed=10)

In [62]:
layer1 = Linear(10, 2, key=key)

In [74]:
x = jax.random.uniform(key, (10,))
print(x)
print(x.shape)

[0.6799785  0.3947947  0.87870073 0.31470668 0.07135046 0.65858626
 0.44972312 0.4361874  0.809513   0.40791905]
(10,)


In [75]:
layer1(x)

Array([-0.74184865, -0.12828125], dtype=float32)

Et voila! We have a simple linear layer defined! In a real network, we would use the layer on a batch of examples of at a time.

In [79]:
x_batch = jax.random.normal(key, (8, 10))  # batch_size = 8
layer1(x_batch)

TypeError: dot_general requires contracting dimensions to have the same shape, got (10,) and (8,).

We hit our first issue here, the forward pass in the function is defined to work with a single input example.

That's where we use the handy jax auto-vectorization function `vmap`.

In [82]:
jax.vmap(layer1)(x_batch).shape  # works as expected

(8, 2)

At this point, we're ready to create a skeleton of our toy NN library: `toylib`!