<a href="https://colab.research.google.com/github/Eminent01/AMMI-RL/blob/main/%5BAMMI_2022%5D_Introduction_to_Haiku.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Haiku

In [1]:
#@title Installations  { form-width: "30%" }

!pip install dm-acme
!pip install dm-acme[reverb]
!pip install dm-acme[jax]
!pip install dm-acme[tf]
!pip install dm-acme[envs]
!pip install dm-env
!pip install dm-haiku
!pip install chex
!pip install imageio
!pip install gym

from IPython.display import clear_output
clear_output()

In [2]:
#@title Imports  { form-width: "30%" }
from typing import *
import haiku as hk
import jax
import jax.numpy as jnp
import chex
import matplotlib.pyplot as plt
%matplotlib inline

  PyTreeDef = type(jax.tree_structure(None))


## A few reminders about jax

In the last practical session we learned the basis of jax. Here are a few reminder of the jax functions we will need here:
- `jax.grad`: computes the gradient of a function with respect to the first argument of this function.
- `jax.vmap`: apply a function to all the elements in the first axis of a given group of tensors
- `jax.jit`: compiles a given function to make it faster

In [None]:
# Example of jax.grad
def square_sum_x(x : chex.Array) -> chex.Array:
  return jnp.sum(x**2)

# The expected gradient of x = [x0, x1, .., x1] is [2*x0, 2*x1, .., 2*x1]

print(jax.grad(square_sum_x)(jnp.asarray([1., 2., 3., 4.])))

# It also works with dictionary !
def mul_sum(params : Mapping[str, chex.Array]) -> chex.Array:
  x = params['x']
  y = params['y']

  return jnp.sum(x*(y**2))

my_dict = {'x' : jnp.asarray([1., 2., 3., 4.]), 'y': jnp.asarray([1., 2., 3., 4.])}
print(jax.grad(mul_sum)(my_dict))

In [None]:
## Example of jax.vmap
def mini_conv(x : chex.Array, index : chex.Array) -> chex.Array:
  return x[index] + x[index+1]

x = jnp.asarray([[1., 2., 3., 4.], [5., 6., 7., 8.]])
index = jnp.asarray([0, 1])

# It should return [x[0, index[0]] + x[0, index[0] + 1],  x[1, index[1]] + x[1, index[1] + 1]]
print(jax.vmap(mini_conv)(x, index))

In [None]:
## Example of jax.jit
import time
def polynom(x : chex.Array,
            y : chex.Array) -> float:

    a = x + 1
    b = y[:, 1]**2 
    c = x[0]
    return x**2 + c[None]*a+ y + x + y**3 + a*b[:, None]

n_iters = 1000
start_time = time.time()

rng = jax.random.PRNGKey(0)
for _ in range(n_iters):
  cur_rng, rng = jax.random.split(rng)
  polynom(jax.random.normal(key=cur_rng, shape=(128, 12)), 
          jax.random.normal(key=cur_rng, shape=(128, 12)))
print(f"Running time without jitting: {time.time() - start_time}")

start_time = time.time()
jitted_polynom = jax.jit(polynom)
for _ in range(n_iters):
  cur_rng, rng = jax.random.split(rng)
  jitted_polynom(jax.random.normal(key=cur_rng, shape=(128, 12)), 
                 jax.random.normal(key=cur_rng, shape=(128, 12)))
print(f"Running time with jitting: {time.time() - start_time}")

## Introduction to haiku

Haiku is to JAX what Sonnet was to TensorFlow. It allows you to simply and cleanly define (deep) neural network architectures of all kind. If you know Sonnet, or Keras, or the standard PyTorch network building facilities, you won't be too surprised by how Haiku is defining networks. 

The main difference between Haiku and those other NN libraries is the transform mechanism. As mentionned earlier, JAX works with pure function. However, the way most NN libraries build NN architectures is impure (in a functional sense): for instance, for each `torch.Module`, PyTorch defines a `forward` function, that takes the inputs of the networks, and outputs its final activations. This `forward` function is impure: it relies on the parameters of the networks, which are attributes of the encompassing object, but not given as parameters to the function each time you call it. You need to know the state of the encompassing object to get the output of the function.

On the other hand, to compute the gradient of a function, JAX requires a pure function, i.e. a function that takes both the _inputs_ of the network, as well as its _parameters_, and outputs the resulting activations. If you have such a function `f(params, inputs)`, it is straightforward to compute its gradient w.r.t. `params`, using `grad(f)(params, inputs)`. However, passing parameters around when defining network architectures would require a lot of boilerplate. To relieve the user from those considerations, Haiku allows user to define impure functions to define the architecture of the network, by only passing inputs around, then transform those impure functions into pure ones, by using `hk.transform`. Let's look at an example:


In [3]:
def easy_linear_net(x: chex.Array) -> chex.Array:
  out = hk.Linear(12)(x)
  return out

This defines a simple linear layer with 12 outputs. For pytorch users, please note that with haiku neural network are not *objects* but *functions*. With jax and haiku, **we will always try to handle functions and limit the use of objects**. This will have a lot of consequences in the way you define your network, but we won't dwell on them in this session.

Similar to Keras and Sonnet, Haiku does not require you to specify the size of your input when defining linear or convolutional layers. This function is impure, as `hk.Linear` implicitly defines some parameters. Directly calling it won't work and haiku will raise an exception saying you did not transform the function first. To be able to apply the module you have created, you need first to wrap it inside a `hk.transform` in the following way:

In [None]:
ez_linear = hk.transform(easy_linear_net)

`hk.transform` takes as input a **function** and build an `hk.Transformed` object which will be our neural network. This object has two important methods:
- `init` which initializes the parameters of your network
- `apply` which, given input data and network parameters will compute the output of the network

In [None]:
# Initializing the network
# Please note that we need a random key for the initialization
ez_linear_params = ez_linear.init(rng=jax.random.PRNGKey(0), x=jnp.zeros((1, 6)))
print("Params", ez_linear_params)

# Applying the network to a given input
# We also need a random key here !
outputs = ez_linear.apply(params=ez_linear_params, rng=jax.random.PRNGKey(0), x=jnp.ones((1, 6)))
print("Outputs", outputs)

As you can see the parameters are a dictionnary of tensors. Besides, we also need to provide a random key both for initialization and inference.

**Question**: why do we need a random key to initialize the parameters of the network ?

---
 The reason why we need a random key for inference is trickier. Some neural networks run random operations, this is the case of network using random subsampling or dropout for example and therefore need a random key. In this session, *we won't use such networks*. Therefore, it is just painful to drag that rng key around. For this reason we use `hk.without_apply_rng` to get rid of it.

In [None]:
ez_linear = hk.without_apply_rng(hk.transform(easy_linear_net))

outputs = ez_linear.apply(params=ez_linear_params, x=jnp.ones((1, 6)))
print(outputs)

From now on, we will **always** call `hk.without_apply_rng` when defining a new neural network.

**Exercise**:

Define a network function that will take as input a tensor `x`, apply it a Linear layer of hidden dimension `5`, then a [ReLU](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.relu.html) activation function, and finally another Linear layer of hidden dimension `5`.

You will call your haiku network `my_haiku_net` and its parameters `my_haiku_params`.

Then initialize your function and test it on a ones tensor of shape (3, 12).

In [5]:
def my_haiku_net_func(x: chex.Array) -> chex.Array:
  ## Your code here !
  out = hk.Linear(5)(x)
  jax.nn.relu(out)
  out2 = hk.Linear(5)(out)
  return out2

my_haiku_net =  hk.without_apply_rng(hk.transform(my_haiku_net_func))
my_haiku_params =my_haiku_net.init(rng=jax.random.PRNGKey(0), x=jnp.zeros((3, 12)))
print(my_haiku_params)

{'linear': {'w': DeviceArray([[ 0.00448234, -0.11408426, -0.23024954,  0.17270288,
               0.15407835],
             [-0.41392472,  0.14234851,  0.5522218 ,  0.43113503,
               0.18783994],
             [-0.36085713,  0.29878154, -0.03435664, -0.3096484 ,
               0.2447363 ],
             [-0.32117477,  0.2818347 ,  0.13971953,  0.3778855 ,
              -0.27359253],
             [-0.33711928, -0.29828262, -0.11816742,  0.39114642,
               0.08640154],
             [ 0.3216892 ,  0.1056358 , -0.02083999,  0.24832848,
              -0.14799379],
             [ 0.4858457 ,  0.03296737,  0.35295808, -0.28210652,
              -0.24311773],
             [-0.14897682, -0.16198616, -0.42050296,  0.07294818,
              -0.26604193],
             [-0.314112  ,  0.18078145,  0.30027768,  0.5027087 ,
               0.10974491],
             [ 0.14390136, -0.01611666, -0.14650457, -0.20843306,
              -0.25107452],
             [ 0.14964017, -0.1024896 , -0.

  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)


We have our neural network, now we need a gradient. Indeed, in deep learning, we often try to update our network paramerters to minimize a given loss and we do so by **gradient descent**.

This loss has to be a function, and its first argument must be the network parameters. Here is an example:

In [6]:
def loss_fn(params : hk.Params, 
            net: hk.Transformed,
            input : chex.Array,
            ) ->chex.Array:
  return net.apply(params, input).sum()

**Question**: why is it important that the first argument of the loss is the network parameters ?

**Question**: what is the dimension of the output of the loss function ? Why is it important ?

Now let's try out this loss function on the network we defined before:

In [None]:
my_loss = loss_fn(ez_linear_params, ez_linear, jnp.ones((1, 6)))
print("Loss", my_loss)

my_grad = jax.grad(loss_fn)(ez_linear_params, ez_linear, jnp.ones((1, 6)))
print("Gradient", my_grad)

And to get both the loss and the gradient with a single call you can use [`jax.value_and_grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.value_and_grad.html)

In [None]:
my_loss, my_grad = jax.value_and_grad(loss_fn)(ez_linear_params, ez_linear, jnp.ones((1, 6)))
print("Loss", my_loss)
print("Gradient", my_grad)

Finally we need to define a way to update our parameters from the loss. To do so, we define an *update function*: this function will compute the gradient of the loss and update the parameters.

In [None]:
def update_fn(params : hk.Params, 
              net: hk.Transformed,
              input: chex.Array,
              learning_rate: float,
              ) -> Tuple[chex.Array, hk.Params]:

    loss, grad = jax.value_and_grad(loss_fn)(params, net, input)
    next_params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grad)

    return loss, next_params

my_loss, new_params = update_fn(ez_linear_params, ez_linear, jnp.ones((1, 6)), 0.1)
print("Loss", my_loss)
print("Old network parameters", ez_linear_params)
print("New network parameters", new_params)

Note that we have introduced a new function: [`jax.tree_util.tree_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map.html). This function works similarly as `jax.vmap` except that it applies a function to a list, a dictionnary or any other nested structure. See it by yourself in this example:

In [None]:
a = jnp.ones((2,2))
grad_a = jnp.asarray([[0.1, 0.], [0.2, 0.3]])

b = jnp.zeros((4))
grad_b = jnp.asarray([0.1, 0., 0. , 0.])

dict_vals = {'a': a, 'b': b}
dict_grad = {'a': grad_a, 'b': grad_b}
print(jax.tree_util.tree_map(lambda p, g: p - g, dict_vals, dict_grad))

**Exercise:**

Define a function `square_loss` that, given an input tensor `x`, and target tensor `target`, a network function `net` and its corresponding parameters params will compute the $L_2$ norm between net(x, params) and target.

Then, define an update function `my_update` which returns the loss and the new parameters.

Test this function with the network `my_haiku_net` you defined before and `x` a one tensor of shape (3, 12).

In [None]:
def square_loss(params : hk.Params, 
                net: hk.Transformed,
                input : chex.Array,
                target : chex.Array,
                )-> chex.Array:
  ## Your code here
  return ...

#Define the update function
def my_update(params : hk.Params, 
              net: hk.Transformed,
              input : chex.Array,
              target : chex.Array,
              learning_rate: float,
              ) -> Tuple[chex.Array, hk.Params]:
  ### Your code here
  return ...

test_value, test_grad = ...
print("Value", test_value)
print("Gradient", test_grad)

Now let's implement a full linear regression. Complete the following code using all the function you defined before: `my_haiku_net`, `square_loss` and so on.

In [None]:
from typing import *

# Create random dataset (Note that the seed make it deterministic)
inputs = jax.random.normal(key=jax.random.PRNGKey(0), shape=(128, 6))
outputs = 12 * jnp.concatenate([inputs, inputs], axis=-1) + 6 + jax.random.normal(key=jax.random.PRNGKey(0), shape=(128, 12))

# Learning params
learning_rate = 1e-2
num_iterations = 5000

# Init Network
# Use the network my_haiku_net you defined before
my_haiku_params = ...


for i in range(num_iterations):

  if i % 100 == 99:
    print(f'Loss at iteration {i}: {loss}')

  # Implement: Call the update function
  loss, my_haiku_params = ...

Finally, we can make things faster with jitting. No, trying to directly jit your `value_and_grad` function won't work. Indeed, your network function, which is of type `hk.Transformed`, is not an acceptable argument for jit. Let's give it a try with `ez_linear`.

In [None]:
def loss_fn(params : hk.Params, 
            net: hk.Transformed,
            input : chex.Array,
            ) ->chex.Array:
  return net.apply(params, input).sum()

def update_fn(params : hk.Params, 
              net: hk.Transformed,
              input: chex.Array,
              learning_rate: float,
              ) -> Tuple[chex.Array, hk.Params]:

    loss, grad = jax.value_and_grad(loss_fn)(params, net, input)
    next_params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grad)

    return loss, next_params

my_loss, my_params = jax.jit(update_fn)(ez_linear_params, ez_linear, jnp.ones((1, 6)), 0.1)

You see, it is not working ! To work around that you would need to provide `jax.jit` a function that doesn't require a network as argument. You could for example do it with a lambda function:

In [None]:
update_fn_ez_linear = lambda params, x, lr : update_fn(params,ez_linear,x, lr)
my_loss, my_params = jax.jit(update_fn_ez_linear)(ez_linear_params, jnp.ones((1, 6)), 0.1)
print("Value", my_loss)
print("Parameters", my_params)

It's working ! Now redefine the training loop you made before with a **jitted** version of the `my_update` function you used.

In [None]:
from typing import *

# Create random dataset (Note that the seed make it deterministic)
inputs = jax.random.normal(key=jax.random.PRNGKey(0), shape=(128, 6))
outputs = 12 * jnp.concatenate([inputs, inputs], axis=-1) + 6 + jax.random.normal(key=jax.random.PRNGKey(0), shape=(128, 12))

# Learning params
learning_rate = 1e-2
num_iterations = 5000

# Init Network
# Use the network my_haiku_net you defined before
my_haiku_params = ...

#Define the update function
def my_update(params : hk.Params, 
              net: hk.Transformed,
              input : chex.Array,
              target : chex.Array,
              learning_rate: float,) -> Tuple[chex.Array, hk.Params]:
  ### Your code here
  return ...

my_jitted_update = ...

for i in range(num_iterations):

  if i % 100 == 99:
    print(f'Loss at iteration {i}: {loss}')

  # Implement: Call the update function
  loss, my_haiku_params = ...

## Application

Now let's apply everything you learned to a full scale exercise. You are going to train a *classifier network*. We start by building a dataset which associates to each input tensor `x` a label `y` which will be either $0$ or $1$.

In [None]:
rng = jax.random.PRNGKey(0)
def build_input_output(rng):
  inputs = jax.random.normal(key=rng, shape=(8, 36))
  noisy_inputs = 12* inputs + jax.random.normal(key=rng, shape=inputs.shape)
  outputs = noisy_inputs.sum(axis=1) > 6
  return inputs, outputs.astype(int)[:, None]

sample_input, sample_output = build_input_output(rng)
print("Sample input", sample_input)
print("Sample output", sample_output)

Your network will take an input vector $x$ and network parameters $\theta$ and will output a real number $f(x, \theta)$. 

From this number, we will make the following prediction:
- if $f(x, \theta) \geq 0$ then the label associated to $x$ must be $1$
- else it should be $0$.

\\

To do so, we want define a loss $L$ so that, for a given batch $(x, y)$ of input data and labels:
- when $y_i =1 $ we want $f(x_i, \theta)$ to be as big as possible, and $L(f(x, \theta))$ as small as possible.
- when $y_i = 0$ we want $f(x_i, \theta)$ to be as small as possible, and $L(f(x, \theta))$ as small as possible.

The following function would do the job:

$ L(f(x, y,\theta)) = \sum_{x_i} f(x_i, \theta) * (1 - 2y_i) $

However, in the current state the network won't be stable: indeed $L$ doesn't have a minimum and therefore the gradient descenr would push $L$ down to $-∞$. To avoid that, we add what we add a **regularization term** to $L$:

$ L(f(x, y,\theta)) = \sum_{x_i} f(x_i, \theta) * (1 - 2y_i) + 0.01 * \sum_{x_i} f(x_i, \theta)^2$.

With this term, we are now sure that $L$ actually has a minimum.

\\

Your function f will be a neural network which applies to x in the following order:
- a linear layer with output dimension 12
- a RelU
- another linear layer with with output dimension 12
- a RelU
- a linear layer with output dimension 1

Now you can complete the following code !

In [None]:
# Initialize an rng key
rng = jax.random.PRNGKey(0)

def my_net_func(x : chex.Array):
  ## Your code here
  return ...


## Initialize your network and its parameters here
## Use sample from the dataset for the initialization
sample_input, sample_output = build_input_output(rng)
my_net = ...
my_params = ...

## Now define your loss
def my_loss(params, net, x, labels):
  ## Your code here
  return ...

## And your update function
def my_update_function(params, net, x, labels):
  ## Your code here
  return ...

## We will use an acuracy function to see how fast the network is learning
def my_accuracy(params, net, x, labels):
  predictions = net.apply(params, x)
  predict_0 = (predictions > 0).astype(int)
  return jnp.mean(predict_0 == labels)

############################
# Training loop
############################

# Learning params
learning_rate = 1e-2
num_iterations = 1000
accuracy = 0

# Jit your update function
my_jitted_update = ...

for i in range(num_iterations):

  # Reminder: jax is pure, so we need a way to update the rng seed to get
  # a truly random sequence
  cur_rng, rng = jax.random.split(rng)

  input, labels = build_input_output(cur_rng)

  # Implement: Call the update function
  loss, my_params = ...

  accuracy = my_accuracy(my_params, my_net, input, labels)

  if i % 100 == 0:
    print(f'Acuracy at iteration {i}: {accuracy}')
    print(f'Loss at iteration {i}: {loss}')
