# Licence


In [None]:
#Licence
#Copyright 2021 Google LLC.
#SPDX-License-Identifier: Apache-2.0

#JAX Tutorial
This first practical aims at getting familliar with some of the tools we will be using in the next practicals.

##Installation

In [2]:
#@title Installations  { form-width: "30%" }

%pip install git+https://github.com/deepmind/acme.git#egg=dm-acme[jax,tf,envs]

from IPython.display import clear_output
clear_output()

In [3]:
#@title Imports  { form-width: "30%" }

from typing import *
import IPython

import base64
import chex
import collections
from collections import namedtuple
import dm_env
import enum
import functools
import gym
import haiku as hk
import io
import itertools
import jax
from jax import tree_util
import optax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import reverb
import rlax
import time
from bsuite import environments
import bsuite.environments.catch as dm_catch

import warnings

np.set_printoptions(precision=3, suppress=1)

%matplotlib inline

# Introduction to JAX

To implement our RL algorithms, we will resort to neural networks as our function approximators, and we will train them using gradient based optimizers. To make that easy, we will use <a href="https://github.com/google/jax">JAX</a>, to get access to easy gradient computations with a numpy-like flavor, and <a href="https://github.com/deepmind/dm-haiku">Haiku</a>, to easily define our neural network architectures. If you are familiar with other frameworks, JAX/Haiku respectively corresponds to tensorflow/keras pytorch/pytorch.nn.

If you need further tutorial, go to: https://jax.readthedocs.io/en/latest/jax-101/index.html

### 1. JAX Basics

JAX is a numerical computation library, very close to numpy for its basic use, that allows one to easily execute operations on GPU, and gives access to <a href="https://github.com/google/jax#transformations">numerical function transformations</a>, that allows for gradient computations, automatic vectorization, or jitting.

JAX has a _functional_ flavor; to be able to use numerical function transformations, you will have to define _pure_ functions, i.e. mathematical functions, whose result do not depend on the context in which they are used.

For instance the following function is pure:

In [None]:
def pure_function(x: chex.Array) -> chex.Array:
  return 3 * x + jnp.tanh(2 * x) / (x ** 2 + 1)

The following method is not pure:

In [None]:
class Counter:
  def __init__(self) -> None:
    self._i = 0.

  def unpure_function(self, x: chex.Array) -> chex.Array:
    self._i = self._i + 1.
    return self._i * x + jnp.tanh(x)

Given a pure function, you can easily obtain the associated gradient function:

In [None]:
grad_pure = jax.grad(pure_function)
x = 3.
print(f'Value at point x={x}, f(x)={pure_function(x)}, grad_f(x)={grad_pure(x)}')

In addition to `jax.grad`, JAX provides `jax.vmap` for automatic vectorization, `jax.jit` for jitting (to fully make use of specialized hardware) and `jax.pmap`, to automatically distribute functions accross devices.

For instance, if you want to have a batched version of matrix multiplication, you can use the usual matrix multiplication, and directly vmap it in the following way:

In [None]:
batch_matrix_multiply = jax.vmap(lambda a, b: a @ b)
rng = jax.random.PRNGKey(0)
rng_a, rng_b = jax.random.split(rng)
a = jax.random.normal(key=rng_a, shape=(12, 5, 7))
b = jax.random.normal(key=rng_b, shape=(12, 7, 9))
print(batch_matrix_multiply(a, b).shape)

In this example, we have been hitting one of the differences between JAX numpy and numpy. Numpy handles random seeds _implicitly_, when you want a random
number, you get one by simply calling one of numpy's functions, and the number
will depend on numpy global seed. With JAX, random seeds are handled _explicitly_, and each function that needs to generate random numbers takes a random key as additional input. This has to do with the functional paradigm of JAX: if we were not handling the random key explicitly, each call to a random function would lead to a different result, breaking the pure function hypothesis. By passing the random key explicitly, we make sure that the same random function, called with the same random key, will produce the same result. As a side effect, this also make results produced using JAX easily reproducible, as it is easy to trace which random seeds have been used where. To
know more about how JAX handles randomness, you can read <a href="https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#rngs-and-state">this page<a>.

### ***Exercises***

As a first exercise, we are giving you the following function

In [None]:
def func(x: chex.Array) -> chex.Array:
  return x ** 2

which simply computes the square of an array. By using simple jax transformations can you get a function that takes a batch of scalars, and outputs the value of the gradient of the squared function for each element of the batch?

**Hint:** jax.grad can only take as input a function that outputs a single scalar, so calling jax.grad directly on func and applying it to a vector won't work.

Can you make this function run faster?

In [None]:
#@title **[Implement]** Batched gradients { form-width: "30%" }
batched_grad = None
fast_batched_grad = None

In [None]:
#@title **[Solution]** Batched gradients { form-width: "30%" }
solution_batched_grad = jax.vmap(jax.grad(func))
jitted_solution = jax.jit(solution_batched_grad)

You can test your solution by running the cell below

In [None]:
#@title **[Test]** Batched gradients (Uncomment to run){ form-width: "30%" }
key = jax.random.PRNGKey(0)
normal = jax.random.normal(key=key, shape=(3,))
if (fast_batched_grad(normal) == 2 * normal).all():
  print('Probably correct.')
else:
  print('Provably incorrect.')

Can you do the same for a batch of batches, without flattening your input? (i.e. you have a matrix of numbers, and you want a matrix containing the gradient for each of the numbers in the matrix.)

In [None]:
#@title **[Implement]** Matrix gradients { form-width: "30%" }
fast_matrix_grad = None

In [None]:
#@title **[Solution]** Matrix gradients { form-width: "30%" }
jitted_solution_matrix = jax.jit(jax.vmap(jitted_solution))

In [None]:
#@title **[Test]** Matrix gradients (Uncomment to run){ form-width: "30%" }
key = jax.random.PRNGKey(0)
normal = jax.random.normal(key=key, shape=(3, 3,))
if (fast_matrix_grad(normal) == 2 * normal).all():
  print('Probably correct.')
else:
  print('Probably incorrect.')

Another very useful application of `vmap` is batched indexing. Assume you have a `[B1, B2, ..., BN]` tensor of indices `idx`, and a `[B1, B2, ..., BN, F]` tensor of features `features`, and for each element `i1, ..., iN`, you
would like to  retrieve element `features[i1, ..., iN, idx[i1, ..., iN]]` from
the feature tensor, can you do this easily using vmap? (maybe start with a fixed `N`, then generalize to all `N`'s.)

In [None]:
#@title **[Implement]** Batched indexing { form-width: "30%" }
def batched_indexing(idxs: chex.Array, features: chex.Array) -> chex.Array:
  ##### IMPLEMENT #####
  pass

In [None]:
#@title **[Solution]** Batched indexing { form-width: "30%" }
@jax.jit
def solution_batched_indexing(idxs: chex.Array, features: chex.Array) -> chex.Array:
  def simple_index(idx, feature):
    return feature[idx]
  batched_index = simple_index
  for _ in range(idxs.ndim):
    batched_index = jax.vmap(batched_index)
  return batched_index(idxs, features)

In [None]:
#@title **[Test]** Batched indexing (Uncomment to run){ form-width: "30%" }
inputs = jnp.array([[-0.196,  0.255,  0.573,  0.441, -0.847,  0.318,  0.646],
 [ 0.034, -0.889, -0.266, -1.561, -0.638, -0.442,  0.91 ],
 [-0.017,  0.758,  1.089,  0.299,  1.491,  0.079, -1.222],
 [ 0.952,  0.21,   1.386, -0.338,  2.952, -0.995, -0.516],
 [ 0.292, -0.143,  1.614,  1.643,  0.114,  0.254, -1.306],])
outputs = jnp.array([ 0.255, -1.561,  0.079,  1.386, -1.306])
idxs = jnp.array([1, 3, 5, 2, 6], dtype=jnp.int32)
if (batched_indexing(idxs, inputs) == outputs).all():
  print('Probably correct.')
else:
  print('Probably incorrect.')

### 2. Let's talk about gradients

Gradients manipulation are the core of deep learning and therefore in this class, you will have to use them quite a lot.

In the previous section, you used the function [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html) which allows you to evaluate the gradient of a function. The default behavior of `jax.grad` is to diffferentiate only with respect to the first argument. For example, have a look at the following:

In [None]:
def my_func(x: chex.Array, y: chex.Array) -> chex.Array:
  return (x + y*y).sum()

grad_my_func = jax.grad(my_func)

# The gradient of this function with respect to x is a vector with the same
# shape as x but filled with ones.

test_x = jnp.asarray([1., 1.])
test_y = jnp.asarray([2., 2.])
print(grad_my_func(test_x, test_y))
print(grad_my_func(test_x, 2*test_y))

However in deep learning, we rarely want to compute the gradient of a single vector. Fortunately, jax gives us several ways to do that.

For example, by using argument `argnums` we can tell jax to compute the gradient with respect to the argument at the given position. Let try it with our function:

In [None]:
def my_func(x: chex.Array, y: chex.Array) -> chex.Array:
  return (x + y*y).sum()

# Now we are computing the gradient with respect to the second argument y
grad_my_func = jax.grad(my_func, argnums=1)

# This gradient is equal to 2*y

test_x = jnp.asarray([1., 1.])
test_y = jnp.asarray([2., 2.])
print(grad_my_func(test_x, test_y))
print(grad_my_func(test_x, 2*test_y))

You can provide argnums with a list of integer instead of one value, in this case, the function will return a gradient for each index you gave:

In [None]:
def my_func(x: chex.Array, y: chex.Array) -> chex.Array:
  return (x + y*y).sum()

# Now we are computing the gradient with respect to both arguments
grad_my_func = jax.grad(my_func, argnums=(0, 1))

# This gradient is equal to (1, 2*y)

test_x = jnp.asarray([1., 1.])
test_y = jnp.asarray([2., 2.])

# Notice that the function now outputs two values
print(grad_my_func(test_x, test_y))
print(grad_my_func(test_x, 2*test_y))

***Exercise***

Let's consider a function $f$, which given a set of parameters $(a, b, \theta) \in \mathbb{R}^{NxNxM}$ computes the value:

$ f(a, b, \theta) = \sum_{i=1}^N a_ib_i + \sum_{j=1}^M \theta_j $

Code a function in jax which outputs the couple $(\nabla f_a, \nabla f_\theta)$.

In [None]:
# Your code here !

## Test it with the following values
test_a = jnp.asarray([1., 1.])
test_b = jnp.asarray([2., 2.])
test_theta = jnp.asarray([2., 2., 2.])

This is better, but in deep learning we usually don't handle just a pair of vectors: we deal with hundreds of them (with their gradient) at the same time. So how are we going to pass all theses gradients around ? Well, it happens that jax also works with dictionary of parameters, or any nested structure containing dicts, tuples or lists.

In [None]:
def my_fun_with_lots_of_params(params : Mapping[str, chex.Array]) -> float:
  x = params['x']
  y = params['y']
  z = params['z']
  return  (x**2 + 2*y + x*z).sum()

# Now let's compute the gradient
grad_my_func = jax.grad(my_fun_with_lots_of_params)

test_x = jnp.asarray([1., 1.])
test_y = jnp.asarray([2., 2.])
test_z = jnp.asarray([3., 3.])

params_dict = {'x' : test_x, 'y': test_y, 'z' : test_z}

# Now my gradient outputs a dictionary of values
print(grad_my_func(params_dict))

***Exercise***

Compute the gradients $(\nabla f_a, \nabla f_\theta)$ where $f$ is the function defined in the previous exercise but this time use a dictionary containing the parameters $(a, \theta)$.

In [None]:
### Your code here

## Test it with the following values
test_a = jnp.asarray([1., 1.])
test_b = jnp.asarray([2., 2.])
test_theta = jnp.asarray([2., 2., 2.])

### 3. Just in time compilation (jit)


Jax allows you to compile some portions of your python code as you run it. This process, called just in time compilation or [jitting](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) can substantially accelerate your code if used properly.

In [9]:
# Example of jitting
import time

def add_and_square(x: chex.Array, y: chex.Array) -> chex.Array:
  out = x**2
  out = out+ y * 12
  out = out**3
  return out

jitted_add = jax.jit(add_and_square)

n_runs = 1000
x = jnp.zeros((3,14), dtype=float)
y = 2 + jnp.zeros((3,14), dtype=float)
t_start = time.time()
for _ in range(n_runs):
  add_and_square(x, y)
duration = time.time() - t_start
print(f"Without jitting, ran {n_runs} in {duration} seconds.")

t_start = time.time()
for _ in range(n_runs):
  jitted_add(x, y)
duration = time.time() - t_start
print(f"With jitting, ran {n_runs} in {duration} seconds.")


Without jitting, ran 1000 in 1.1659579277038574 seconds.
With jitting, ran 1000 in 0.08092808723449707 seconds.


The process of jitting itself takes time, for this reason you should follow this good practicies when using `jax.jit`:
- Never jit inside a loop, jit your function ahead of time or use the `@jax.jit` decorator.
- Try to jit chunks of code that are as big and complex as possible

In [6]:
# Do not do
def add(x: chex.Array, y: chex.Array) -> chex.Array:
  out = x + y
  return out

def square(x : chex.Array) -> chex.Array:
  out = x**2
  return out

def log(x : chex.Array) -> chex.Array:
  out = jnp.log(x)
  return out

n_runs = 100
x = jnp.zeros((3,14), dtype=float)
y = 2 + jnp.zeros((3,14), dtype=float)
t_start = time.time()
for _ in range(n_runs):
  jit_add = jax.jit(add)
  jit_square = jax.jit(square)
  jit_log = jax.jit(log)
  jit_log(jit_square(jit_add(x, y)))
duration = time.time() - t_start
print(f"Poorly used jitting: {n_runs} runs in {duration} seconds.")

# But do
def add_square_log(x: chex.Array, y: chex.Array) -> chex.Array:
  out = x + y
  out = out**2
  return jnp.log(out)

jit_add_square_log = jax.jit(add_square_log)
t_start = time.time()
for _ in range(n_runs):
  jit_add_square_log(x,y)
duration = time.time() - t_start
print(f"Proper jitting: {n_runs} runs in {duration} seconds.")

Poorly used jitting: 100 runs in 0.3196744918823242 seconds.
Proper jitting: 100 runs in 0.05981945991516113 seconds.


Finally, you need to remember that `jax.jit` compile a function by **executing it**. This means that branches of code hidden behind conditions like a `if` for example won't be explored and therefore won't be compiled by `jax.jit`. For example have a look at the following code:

In [None]:
# It doesn't work !
def my_func(x: chex.Array):
  if jnp.sum(x) > 0:
    return x
  else:
    return 2 * x

pan = jax.jit(my_func)
pan(jnp.zeros((12,3)))

But what if you actually need conditions or branches ? Well in this case you should tell `jax.jit` that some of input arguments of your function are **statics**. In practice, this means that jax will jit your function for the first value of the static argument it encounter, but will recompile your function again each time this value changes.

In [13]:
def dummy_forward(x : chex.Array, is_training : bool) -> chex.Array:
  if is_training:
    return x**2
  else:
    return -x

f_jit_correct = jax.jit(dummy_forward, static_argnums=1)

# The function is compiled a first time here
x = jnp.zeros((12, 3))
f_jit_correct(x, True)

# No further compilation here
f_jit_correct(x + 1, True)

# The function is compiled again here
f_jit_correct(x +1, False)

DeviceArray([[-1., -1., -1.],
             [-1., -1., -1.],
             [-1., -1., -1.],
             [-1., -1., -1.],
             [-1., -1., -1.],
             [-1., -1., -1.],
             [-1., -1., -1.],
             [-1., -1., -1.],
             [-1., -1., -1.],
             [-1., -1., -1.],
             [-1., -1., -1.],
             [-1., -1., -1.]], dtype=float32)

For this reason, **you shouldn't use static arguments if you know they will change often in your code**. It would make your execution painfully slow as jax would spend its time compiling stuff rather than executing your code.

### 4. A common JAX caveat

There is one mistake that is recurrent in newcomer's JAX code, and that you will probably make at some point, which consists in forgetting that JAX function should remain pure, and will badly handle (impure) side effects. We present here a simple case where this mistake is made and the corresponding code does not behave as one could imagine.

In [None]:
class CounterAndJax:
  def __init__(self) -> None:
    self._counter = 0

  def increment(self) -> None:
    self._counter += 1

  def apply(self, x: chex.Array) -> chex.Array:
    return x + self._counter

The class `CounterAndJax` maintains a counter and has an apply method that can be applied to a `chex.Array`. We already see that `apply` is not a pure function. `apply` internally uses the `CounterAndJax` attribute `_counter`, which is not provided as an explicit argument. Can you guess what will be the result of the following computation:

In [None]:
caj = CounterAndJax()
caj.increment()
caj.apply(jnp.zeros((3,)))

It did the _correct_ thing, of actually incrementing the counter, then using the new value when applied. So why are we even bothering with only using pure functions? Let's try something else. Can you guess what the following code will print out?

In [None]:
caj = CounterAndJax()
caj.increment()
apply = jax.jit(caj.apply)
apply(jnp.zeros((3,)))
caj.increment()
caj.increment()
apply(jnp.zeros((3,)))

The apply function returned ones instead of threes, the two last increments were not taken into account. This is because the `caj.apply` function was jitted, and thus anything within this function except from its arguments was considered as static at compile time, and frozen when producing the `jaxpr`. After the jitting pass, `caj.increment` does not affect `apply` `jaxpr`, and thus the result of the computations. Now for some even trickier behavior, can you guess what the following cell prints out?

In [None]:
caj = CounterAndJax()
caj.increment()
apply = jax.jit(caj.apply)
caj.increment()
apply(jnp.zeros((3,)))
caj.increment()
apply(jnp.zeros((3,)))

That's a strange one. What you must remember to answer this question properly is that JAX does not jit your function when `jax.jit` is called, but when the resulting function is applied, **because it needs to know the shape of the arguments you are passing in**. In that case, `apply` is jitted when it is called on the first `jnp.zeros((3,))` tensor, after two increments have been done. Let's finish with the trickiest of all, can you guess what the following code will produce?

In [None]:
caj = CounterAndJax()
caj.increment()
apply = jax.jit(caj.apply)
caj.increment()
apply(jnp.zeros((3,)))
caj.increment()
apply(jnp.zeros((4,)))
caj.increment()
apply(jnp.zeros((4,)))

Again, you must predict where the apply function will be jitted for the specific argument shape you are using. In this case, the first call `apply(jnp.zeros((4,)))` is where the function is first jitted for 1D tensors of size 4. This comes after three counter increments, and the last counter increment has not effect.

As you may notice, predicting JAX's behavior when side effects are involved is extremely complicated (close to impossible in very complex cases). This is the reason why you should try to only use pure functions when you are using JAX.

### 5. A simple linear regression in JAX

To test our newly acquired skills, we are going to implement a simple linear regression algorithm on a fixed dataset. You are going to be provided with an input output dataset, with inputs $X \in \mathbb{R}^{d \times 6}$ and outputs $Y \in \mathbb{R}^{d \times 12}$, where $d$ is the dataset size.

We are going to optimize two sets of parameters, $W \in \mathbb{R}^{6 \times 12}$, some weights, and $b \in \mathbb{R}^{12}$, some biases, to minimize the mean squared error
$$\mathcal{L}_{W, b}(\mathbf{y}) = \frac{1}{d}\sum\limits_i \|x_i W + b - y_i\|^2$$
by [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent).

Gradient descent is the basic training method used in Deep Learning. The idea is very simple:

- Initialize $W_0$ and $b_0$ randomly
- For each step $t \in [0 ; T-1]$:
  - Sample a batch of examples $\mathbf{y} = (y_i)_{1 \leq i \leq N}$
  - Compute $\nabla\mathcal{L}_{W_t}(\mathbf{y})$ and $\nabla\mathcal{L}_{b_t}(\mathbf{y})$
  - Update $W_{t+1} := W_t - \alpha * \nabla\mathcal{L}_{W_t}(\mathbf{y})$ and $b_{t+1} := b_t - \alpha * \nabla\mathcal{L}_{b_t}(\mathbf{y})$

Here, $T$ is the number of iterations and $\alpha$ is the **learning rate**. This algorithm is called a **training loop**.

In [None]:
#@title Creating the dataset { form-width: "30%" }

# Create random dataset (Note that the seed make it deterministic)
X = jax.random.normal(key=jax.random.PRNGKey(0), shape=(128, 6))
Y = 12 * jnp.concatenate([X, X], axis=-1) + 6 + jax.random.normal(key=jax.random.PRNGKey(0), shape=(128, 12))

First implement the prediction function, which, given inputs $X$, weights $W$ and biases $b$ produces the output of the linear model $XW + b$.

In [None]:
#@title **[Implement]** Linear prediction { form-width: "30%" }


In [None]:
#@title **[Solution]** Linear prediction { form-width: "30%" }
def predict(W: chex.Array, b: chex.Array, X: chex.Array) -> chex.Array:
  return X @ W + b[None]

Next implement the loss function that takes in the weights, biases, inputs and outputs, and produces the mean squared error.

In [None]:
#@title **[Implement]** Linear prediction { form-width: "30%" }


In [None]:
#@title **[Solution]** Linear prediction { form-width: "30%" }
def loss_fn(W: chex.Array, b: chex.Array, X: chex.Array, Y: chex.Array) -> chex.Array:
  return jnp.mean(jnp.sum(jnp.square(predict(W, b, X) - Y), axis=-1))

Implement an update function, that takes in the current parameters, all inputs and outputs, and a learning rate, and produces the parameters, once updated by performing one step of gradient descent. This function should also return the loss incurred with the current parameters.

In [None]:
#@title **[Implement]** Update function { form-width: "30%" }
def update_fn(W: chex.Array, b: chex.Array, X: chex.Array, Y: chex.Array, learning_rate: float) -> Tuple[chex.Array, chex.Array, chex.Array]:
  #### IMPLEMENT ####
  return loss, W, b

In [None]:
#@title **[Solution]** Update function { form-width: "30%" }
def update_fn(W: chex.Array, b: chex.Array, X: chex.Array, Y: chex.Array, learning_rate: float) -> Tuple[chex.Array, chex.Array, chex.Array]:
  loss, (dW, db) = jax.value_and_grad(loss_fn, argnums=(0, 1))(W, b, X, Y)
  new_W = W - learning_rate * dW
  new_b = b - learning_rate * db
  return loss, new_W, new_b

JIT your function, so that it runs faster.

In [None]:
#@title **[Implement]** JIT { form-width: "30%" }
jitted_update_fn = ...

In [None]:
#@title **[Solution]** JIT { form-width: "30%" }
jitted_update_fn = jax.jit(update_fn, static_argnums=4)

Initialize $W$ and $b$ either randomly or to constants (in the linear regression case, even constant initialization will do, since our loss function is convex).

In [None]:
#@title Initialize { form-width: "30%" }
W = jnp.zeros((6, 12))
b = jnp.zeros((12,))

Finally check that your implementation is correct by running the following training loop. (The default hyperparameters will work, you can play with different hyperparameters to check the difference.) You should end up with $W$ being almost diagonal with only $12$ on the diagonal, and $b$ being constant with all entries equal to $6$.

In [None]:
#@title **[Run]** Training loop { form-width: "30%" }
num_iterations = 10_000 # @param
learning_rate = .001 # @param

for i in range(num_iterations):
  loss, W, b = jitted_update_fn(W, b, X, Y, learning_rate)
  if i % 100 == 0:
    print(f'At step {i},\t loss: {loss}')

As you might have noticed, handling parameters in plain JAX is painful, as you need to keep track of absolutely all parameters (for a big network, this is going to become unsustainable). For this reason, we will learn to use Haiku, that exactly tackled this issue, in one of the next practicals.

### 7. **[Bonus]** A simple Multi Layer Perceptron in JAX

To get more practice with JAX, you can now try to replicate the previous exercise, but this time using a 2 layer MLP with a ReLU activation in the middle instead of a linear function approximation. Remember that the output of a two layer MLP writes
$$y = \mathrm{relu}(x W_1 + b_1) W_2 + b_2\cdot$$

In [None]:
#@title Creating the dataset { form-width: "30%" }

# Create random dataset (Note that the seed make it deterministic)
rng = jax.random.PRNGKey(0)
rng, *rngs = jax.random.split(rng, 5)
# Train set
X = jax.random.normal(key=rngs[0], shape=(128, 6))
Y = 12 * jnp.concatenate([X, X], axis=-1) + 6 + jax.random.normal(key=rngs[1], shape=(128, 12))

# Eval set
X_eval = jax.random.normal(key=rngs[2], shape=(128, 6))
Y_eval = 12 * jnp.concatenate([X_eval, X_eval], axis=-1) + 6 + jax.random.normal(key=rngs[3], shape=(128, 12))

In [None]:
#@title **[Implement]** MLP regression{ form-width: "30%" }
# Implement all the steps that you implemented for the linear regression for the MLP model.
# Try to print the loss both on the train and eval sets, and see what happens for both losses
num_hiddens = 32 # @param

In [None]:
#@title **[Solution]** MLP regression { form-width: "30%" }
Params = Mapping[str, chex.Array]

def predict(params: Params, X: chex.Array) -> chex.Array:
  h = jax.nn.relu(X @ params['W1'] + params['b1'][None])
  return h @ params['W2'] + params['b2'][None]

def loss_fn(params: Params, X: chex.Array, Y: chex.Array) -> chex.Array:
  return jnp.mean(jnp.sum(jnp.square(predict(params, X) - Y), axis=-1))

def update_fn(params: Params, X: chex.Array, Y: chex.Array, learning_rate: float) -> Tuple[chex.Array, Params]:
  loss, dparams = jax.value_and_grad(loss_fn, argnums=0)(params, X, Y)
  new_params = jax.tree_util.tree_map(lambda x, dx: x - learning_rate * dx, params, dparams)
  return loss, new_params

def initialize(rng: chex.PRNGKey) -> Params:
  rng1, rng2 = jax.random.split(rng)
  W1 = jax.random.normal(rng1, (6, num_hiddens)) / jnp.sqrt(6)
  W2 = jax.random.normal(rng2, (num_hiddens, 12)) / jnp.sqrt(num_hiddens)
  b1 = jnp.zeros((num_hiddens,))
  b2 = jnp.zeros((12,))
  return dict(W1=W1, W2=W2, b1=b1, b2=b2)

#@title **[Solution]** JIT { form-width: "30%" }
jitted_update_fn = jax.jit(update_fn, static_argnums=3)
jitted_loss = jax.jit(loss_fn)

params = initialize(jax.random.PRNGKey(0))
num_iterations = 100_000 # @param
learning_rate = .001 # @param

for i in range(num_iterations):
  loss, params = jitted_update_fn(params, X, Y, learning_rate)
  eval_loss = jitted_loss(params, X_eval, Y_eval)
  if i % 100 == 0:
    print(f'At step {i},\t train loss: {loss}\t eval loss: {eval_loss}')