# Getting started

To begin our journey, we need to import `jax` and `pax`.

In [None]:
import jax
import jax.numpy as jnp
import pax

We start by defining a simple ``Linear`` module.

In [None]:
class Linear(pax.Module):
    """A Linear modules has two real parameters ``weight`` and ``bias``."""

    weight: jnp.ndarray
    bias: jnp.ndarray

    def __init__(self):
        super().__init__() # it is required to call the parent class.

        self.register_parameter('weight', jnp.array(1.0))
        self.register_parameter('bias', jnp.array(0.0))
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        return self.weight * x + self.bias

net = Linear()
print(net.summary())

**Note**: we are calling `self.register_parameter` method to register `weight` and `bias` as __trainable parameters__ of ``Linear``.

Next, we will create a simple fake dataset and use our `Linear` module to fit the data. Note that, ``a = -3.0`` and ``b=1.3`` are the ground-truth weight and bias.

In [None]:
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (3,2)


pax.seed_rng_key(42) # seeding pax random key

def create_data(a=-3.0, b=1.5):
    x = jax.random.uniform(pax.next_rng_key(), (128, 1))
    noise = jax.random.normal(pax.next_rng_key(), x.shape) * 0.2
    y = a * x + b + noise
    plt.scatter(x, y)
    plt.grid('on')
    plt.legend(["data"])
    plt.show()
    return x, y

x, y = create_data()

We now plot the inital _predictions_ of our linear module.

In [None]:
def plot_prediction(net, data):
    x, y= data
    y_hat = net(x) 

    plt.scatter(x, y)
    plt.scatter(x, y_hat)
    plt.legend(['data', 'prediction'])
    plt.xlabel('x')
    plt.ylabel('y')
    plt.grid('on')
    plt.show()

plot_prediction(net, (x, y))

To fit ``net`` with our data, we need to define a _mean squared error_ loss function which measures the current prediction errors. Our goal is to minimize the loss function using its gradient function and the _stochastic gradient descent_ algorithm.

In [None]:
from typing import Tuple
Batch = Tuple[jnp.ndarray, jnp.ndarray]

def mse_loss(params: Linear, model: Linear, inputs: Batch):
    model = model.update(params)
    x, y = inputs
    y_hat = model(x)

    # mse
    loss = jnp.mean(jnp.square(y - y_hat))
    return loss, (loss, model)

gradient_fn = jax.grad(mse_loss, has_aux=True)

params = net.parameters()
grad, (loss, net) = gradient_fn(params, net, (x, y))
print(grad.weight, grad.bias, loss)

There are few _interesting_ points in ``mse_loss`` function:

1. ``params`` is a pytree of trainable parameters. 
2. ``model.update(params)`` returns a new model that uses ``params`` as its trainable parameters.
3. ``mse_loss`` returns ``(loss, model)`` as an auxiliary output.

These points are all related to the gradient transformation:

```python
gradient_fn = jax.grad(mse_loss, has_aux=True)
```

1. By default, ``jax.grad`` returns a gradient function whose gradients are computed with respect to the first argument of ``mse_loss``.
Therefore, we define trainable parameters `params` as a separated argument.
2. ``model = model.update(params)`` is to let ``model`` uses ``params`` in its forward computation. That makes the returned output ``loss`` depends on ``params``.
3. ``has_aux=True`` informs ``jax.grad`` to return a function whose output includes ``(loss, model)``.

__Note__: we return ``model`` in the output of ``mse_loss`` to guarantee that any changes to ``model`` inside the ``mse_loss`` function will be passed to the outside world.

In [None]:
def sgd(params: Linear, gradient: Linear, lr: float = 1e-1):
    updated_params = jax.tree_map(lambda p, g: p - lr * g, params, gradient)
    return updated_params

The inputs to ``sgd`` function includes the trainable parameters ``params`` and the gradient vector of loss function with respect to ``params``.

Apply ``sgd`` iteratively to further improve our model.

In [None]:
losses = []
for step in range(500):
    grad, (loss, net) = gradient_fn(params, net, (x, y))
    params = sgd(params, grad)
    losses.append(loss)


plt.plot(losses)
plt.xlabel('step')
plt.ylabel('loss')
plt.show()

Finally, let plot our prediction

In [None]:
net = net.update(params)
plot_prediction(net, (x, y))