This is the practical session of the presentation
"JAX/Flax for AI Residents"
from the AI Resident onboarding.

Please see go/flax-air for **slides** and **solutions**.

You probably first want to **make a copy** so you changes are not lost:

![save a copy](https://screenshot.googleplex.com/34YXmEa2Lb2ipNL.png)

### JAX Fundamentals

Now is a good moment to open the JAX documentation in a separate tab:

https://jax.readthedocs.io

#### Initialization

In [None]:
import jax

import jax.numpy as jnp
import numpy as np

from matplotlib import pyplot as plt

In [None]:
# Check connected accelerators. Depending on what runtime you're connected to,
# this will show a single CPU/GPU, or 8 TPU cores (jf_2x2 aka JellyDonut).

# You can start a TPU runtime via : "Connect to a runtime" -> "Start" ->
# "Borg Runtime" -> "Brain Frameworks JellyDonut (go/ml-colab)"
# https://screenshot.googleplex.com/87HTCpQNhBKUZUp
# See also http://go/research-workflow-intro-deck#colab
jax.devices()

In [None]:
# Local devices: In this case it's the same as all devices, but if you run JAX
# in a multi host setup, then local_devices will only show the devices connected
# to the host running the program.
jax.local_devices()

In [None]:
# Alternatively, you can also connect to GPU runtime.
!nvidia-smi

#### Randomness

https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers


<center>
<img src="https://live.staticflickr.com/3127/2875827736_2224e426c6_w.jpg" width="400" height="300" alt="Green Tree Python"><br>
<i>CC BY Image by <a href="https://www.flickr.com/photos/tedmurphy/">Ted Murphy</a></i>
</center>


In [None]:
# YOUR ACTION REQUIRED:
# Your task is to use JAX to generate 5 uniform random numbers and 5 normally
# distributed random numbers.

# Check out the following JAX API calls:
# - jax.random.PRNGKey()
# - jax.random.split()
# - jax.random.uniform()
# - jax.random.normal()

#### `jnp` vs. `np`

<center>
<img src="https://live.staticflickr.com/2828/9578749884_d93a4a1315_w.jpg" width="400" height="255" alt="steam forever"><br>
<i>CC BY Image by <a href="https://www.flickr.com/photos/h-studio/">targut</a></i>
</center>


In [None]:
# Let's do some semi-serious matrix multiplication:
k = 3_000
x = np.random.normal(size=[k, k])
# ~3.4s
%time x @ x

In [None]:
# YOUR ACTION REQUIRED: Do the same computation using JAX!
# You should use result.block_until_ready() for a fair comparison.

In [None]:
# Note the different class of the JAX array. There is additional API e.g. to
# determine on which device the data is stored, check out x.device_buffer

In [None]:
# Combining jnp & np : Below array initialization is rather slow because we
# create a lot of jnp array. Replace jnp with np and observe the speedup!
%%time
# GPU : 1.79s
# CPU : 1.04s
x = jnp.array([jnp.arange(100) for _ in range(10000)])
print(repr(x))
# YOUR ACTION REQUIRED:
# In this situation we would want to create the array in np and then convert it
# to a jnp array using jnp.array() or jax.device_put().
# (Note that we could use np.tile() here, but that's not the point)

#### `grad()`

<center>
<img src="https://live.staticflickr.com/8573/15246394073_0cfdcc458b_w.jpg" width="400" height="221" alt="Gradient"><br>
<i>CC BY Image by <a href="https://www.flickr.com/photos/60506610@N08/">Manel Torralba</a></i>
</center>

https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation

In [None]:
def sigmoid(x):
    return 0.5 * (1 + jnp.tanh(x))


# YOUR ACTION REQUIRED:
# Use grad() to create a new function that computes the gradient of `sigmoid`.
# Verify the output of the new function at some points.

In [None]:
def f(x, y):
    return 2 * x * y ** 2


# YOUR ACTION REQUIRED:
# Compute df/dx and df/dy with grad()

#### `vmap()`

<center>
<img src="https://live.staticflickr.com/65535/49164406707_a954dc465f_w.jpg" width="400" height="225" alt="Les Tanji, éléments majeurs du paysage urbain coréen (Daejeon, Corée du sud)"><br>
<i>CC BY Image by <a href="https://www.flickr.com/photos/dalbera/">Jean-Pierre Dalbéra</a></i>
</center>

https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap

In [None]:
# Now let's plot the gradient of the sigmoid function in the range [05, 5]
xs = jnp.linspace(-5, 5, 100)
# We can of course evaluate the gradient at every position separately:
grads = [jax.grad(sigmoid)(x) for x in xs]
plt.plot(xs, grads)
# But JAX can "vectorize" our gradient function for us automatically.
# YOUR ACTION REQUIRED:
# Read the documentation about `vmap` and reimplement the plot without a Python
# loop.

In [None]:
# Another vmap() example : Let's re-implement matmul using vector dot product:
vdp = lambda v1, v2: v1.dot(v2)
# Vector dot product:
vdp(jnp.arange(1, 4), jnp.arange(1, 4))
# Matrix vector product:
mvp = jax.vmap(vdp, in_axes=(0, None), out_axes=0)
# Matrix matrix product:
mmp = jax.vmap(mvp, in_axes=(None, 1), out_axes=1)

# Verify result.
m1 = jnp.arange(12).reshape((3, 4))
m2 = m1.reshape((4, 3))
# In case you were wondering : Since Python 3.5 we have `.__matmul__()` operator
# that happens to use the same character as for decorators (cf. `@jit` below).
mmp(m1, m2) - m1 @ m2

# YOUR ACTION REQUIRED:
# It's curry time!
# Try re-implementing mvp() but this time without using the in_axes=, and
# out_axes=. Instead use lambda expressions to (un)curry the arguments in such
# a way that vmap()'s default in_axes=0 and out_axes=0 does the job.
# (You can also re-implement mmp() this way, but it involves transposing).

#### `jit()`

<center>
<img src="https://live.staticflickr.com/3803/9540184355_0dee2f496a_w.jpg" width="400" height="267" alt="Silicon Village"><br>
<i>CC BY Image by <a href="https://www.flickr.com/photos/jackofspades/">Jack Spades</a></i>
</center>

https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-*jit*

In [None]:
# JAX would not have the final X in it's name if it were not for XLA, the
# magic sauce that somehow takes computation defined in a function as input
# and produces a much faster version of it.

# @jax.jit
def f(x):
    y = x
    for _ in range(10):
        y = y - 0.1 * y + 3.0
    return y[:100, :100]


x = jax.random.normal(jax.random.PRNGKey(0), (5000, 5000))
%timeit f(x).block_until_ready()

# YOUR ACTION REQUIRED:
# Move your magic JAX wand and cast a spell by removing a single character from
# above example, drastically speeding up the computation!
# Note: JIT unrolls the for loop and converts all computations to XLA
# primitives. XLA is then smart enough to fuse kernels for multiplication and
# addition, and optimize the program to only compute those parts that are
# actually needed for the function result...

In [None]:
# Just to be clear : `@jit` is Python's decorator syntax [1], you can also use
# jit() like the other function transformations.
# [1] https://www.python.org/dev/peps/pep-0318


@jax.jit
def f1_jit(x):
    return x ** 0.5


def f2(x):
    return x ** 0.5


# It's really the same.
f2_jit = jax.jit(f2)

f1_jit(2) - f2_jit(2)

In [None]:
# What you need to understand about JIT (1/3): When a function is traced.


@jax.jit
def noop(x):
    # This statement only gets executed when the function is traced, i.e. every
    # time you execute the JIT-ted version with a new ShapedArray (different dtype
    # and/or different shape).
    print("Tracing noop:", x)
    return x


noop(jnp.arange(3))  # Tracing.
noop(jnp.arange(3) + 1)  # Using trace from cache.
noop(jnp.arange(4))  # Tracing.
noop(jnp.arange(4.0))  # Tracing.
noop(jnp.arange(1.0, 5.0))  # Using trace from cache.

In [None]:
# What you need to understand about JIT (2/3): Baking in environment.
magic_number = 13


@jax.jit
def add_magic(x):
    return x + magic_number


print(add_magic(np.array([0])))
magic_number = 42
print(add_magic(np.array([0])))
print(add_magic(np.array([0.0])))

In [None]:
# What you need to understand about JIT (2/3): Value-dependent flow.
def mult(x, n):
    print("Tracing mult:", x, n)
    tot = 0
    while n > 0:
        tot += x
        n -= 1
    return tot

In [None]:
# The problem:

# The following statement fails, because : JIT will generate the function's XLA
# code by tracing it with `ShapedArray`'s. These arrays have only their shape
# and datatype defined. Hence, if there are any statements involving the actual
# *values* of the parameters, JIT does not know what to do and raises an
# exception.
# (Note that if mult were traced with `ConcreteArray`s then the trace would work
#  just fine; you can see that when executing `grad(mult)(3., 2.)`)
try:
    jax.jit(mult)(3, 2)
except Exception as e:
    print(f"\n### FAILED WITH : {e}")

# How can we fix this ??

In [None]:
# Solution 1 : static_argnums
jax.jit(mult, static_argnums=1)(3, 4)
jax.jit(mult, static_argnums=1)(3, 5)
jax.jit(mult, static_argnums=1)(3, 6)

# By the way : did you notice how the function is traced exactly three times the
# first time this cell is executed, but not when you re-execute the same cell?
# That's because JIT-ted functions are cached. If You want to observe the
# tracing a second time, you first need to execute above cell so that `mult`
# gets redefined and the cache needs to be updated with the new definition.

In [None]:
# Solution 2 : (un)currying

# YOUR ACTION REQUIRED:
# Use jit() without `static_argnums=`, but (un)curry the function mult instead.

In [None]:
# Solution 3 : Use XLA primitives for control flow.

# Remember: You can inspect `jax.lax.while_loop()` docs by either:
# - Go to https://jax.readthedocs.io
# - Execute a cell containing `?jax.lax.while_loop`
# - Hover your mouse over `while_loop` and wait two seconds


def mult_(x, n):
    print("Tracing mult_:", x, n)

    def cond_fun(n_tot):
        n, tot = n_tot
        return n > 0

    def body_fun(n_tot):
        n, tot = n_tot
        return (n - 1, tot + x)

    return jax.lax.while_loop(cond_fun, body_fun, (n, 0))


jax.jit(mult_)(3, 4)
jax.jit(mult_)(3, 5)
jax.jit(mult_)(3, 6)

In [None]:
# Woah! Wasn't JAX supposed to be fast !? What is going on here ??
# Also note that increasing the second number significantly will crash
# your runtime...
%%time
jax.jit(mult, static_argnums=1)(3, 5000)

In [None]:
# Does this function have the same problems? Why not?
%%time
jax.jit(mult_)(3, 5000)

#### `pmap()`

<center>
<a href="https://storage.googleapis.com/gweb-cloudblog-publish/original_images/TPU_V3_POD_FULLFRONT_FORWEBONLY_FINAL.jpg" target="_blank"><img src="https://storage.googleapis.com/gweb-cloudblog-publish/original_images/TPU_V3_POD_FULLFRONT_FORWEBONLY_FINAL.jpg" width="800"  alt="Full TPUv3 (DragonFish) pod"></a><br>
<i>Full TPUv3 (DragonFish) pod (from <a href="https://cloud.google.com/blog/products/ai-machine-learning/cloud-tpu-pods-break-ai-training-records?hl=fr_ca&skip_cache=true">GoogleAI Blog</a>) - click to enlarge.</i>
</center>

https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap

In [None]:
# Parallel computing is more fun with multiple devices :-)
# Go back to "Initialization" and connect to a different runtime if you're
# running on a single device.
assert jax.device_count() == 8, "Please connect to a JellyDonut runtime!"

In [None]:
# By default in_axes=0, so pmap() will split every incoming tensor across it's
# first axis - which should be sized jax.local_device_count().
# The computations are then performed in parallel and the results are returned
# as a sharded device array. The dat remains on the individual accelerators.
# Note that pmap() also XLA-compiles the function, so no need to call jit().

# Generate 8 different random seeds.
keys = jax.random.split(jax.random.PRNGKey(0), 8)
# Generate 8 different random matrices. Data remains on devices.
mats = jax.pmap(lambda key: jax.random.normal(key, (8_000, 8_000)))(keys)
# Perform 8 matmuls in parallel.
results = jax.pmap(lambda m1, m2: m1 @ m2)(mats, mats)

# YOUR ACTION REQUIRED:
# Fetch the mean of thes matrices from every device and print it out here.

In [None]:
import functools

# Here we use jax.lax.psum() to do computations across devices. Note that these
# operations can cause a lot of communication costs. Below we split our 8
# devices along two axis (4x2).

# Note in particular that parallel operators work across hosts! We can't
# demonstrate this in a Colab, but you will encounter it later in the Flax
# examples and brain templates.

# You can read more about parallel operators here:
# https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators

# axis 0 : rows
@functools.partial(jax.pmap, axis_name="rows")
# axis 1 : columns
@functools.partial(jax.pmap, axis_name="cols")
def f(x):
    # across the rows (= column sum)
    row_sum = jax.lax.psum(x, "rows")
    # across the cols (= row sum)
    col_sum = jax.lax.psum(x, "cols")
    total_sum = jax.lax.psum(x, ("rows", "cols"))
    return row_sum, col_sum, total_sum


# YOUR ACTION REQUIRED:
# Create an array, feed it to f() and verify the correctness of the results

#### pytrees

<center>
<img src="https://live.staticflickr.com/4695/38641518410_53da16c2a9_w.jpg" width="400" height="300" alt="Green Tree Python"><br>
<i>CC BY Image by <a href="https://www.flickr.com/photos/markgillow/">Mark Gillow</a></i>
</center>

https://jax.readthedocs.io/en/latest/notebooks/JAX_pytrees.html

In [None]:
# Whenever we encounter a function argument, e.g. the `params` for a model, or
# the first argument to `grad()` to whose respect we perform automatic
# differentiation, it can really be a "pytree" of `jnp.ndarray`. A pytree
# consists of an arbitrary combination of Python dict/list/tuple and allows us
# to structure our data hierarchically.

# This is a pytree:
data = dict(
    array_3x2=jnp.arange(6.0).reshape((3, 2)),
    mixed_tuple=(0.1, 0.2, 0.3, [1.0, 2.0, 3.0]),
    subdict=dict(
        array_3x4=jnp.arange(12.0).reshape((3, 4)),
        array_4x3=jnp.arange(12.0).reshape((4, 3)),
    ),
)

In [None]:
# Call a function over all values, output resulting tree:
jax.tree_map(jnp.shape, data)

In [None]:
# Define a function that does some computation with the values:
def sumsquares(x):
    value_flat, value_tree = jax.tree_flatten(x)
    del value_tree  # not needed.
    tot = 0
    for value in value_flat:
        if isinstance(value, jnp.ndarray):
            value = value.sum()
        tot += value ** 2
    return tot


sumsquares(data)

In [None]:
# Compute gradients. Remember that grad() computes gradients wrt the first
# argument, but that first argument can be an arbitrarily complex pytree (like
# all the weights in your hierarchical model).
grads = jax.grad(sumsquares)(data)
grads

In [None]:
# YOUR ACTION REQUIRED:
# Take a step against the gradients using `jax.tree_multimap()`

### JAX Linear Classifier

#### Fetch data

In [None]:
# Our one stop shop for datasets. If you use dataset preprocessing, then those
# computations will be performed with a Tensorflow graph. But we don't really
# need to understand the details, but rather use the API to stream through the
# dataset and then use JAX for computations.
import tensorflow_datasets as tfds

In [None]:
# Don't like fashion? Go checkout the other image classification datasets:
# https://www.tensorflow.org/datasets/catalog/overview#image_classification
# (actually, go and check them out, even if you like fashion...)
ds, ds_info = tfds.load("fashion_mnist", with_info=True)

In [None]:
tfds.show_examples(ds["train"], ds_info, rows=4, cols=6);

In [None]:
# We're not really interested in tf.data preprocessing here, so let's just fetch
# all the data as a jax.ndarray...


def ds_get_all(ds, *keys):
    """Returns jnp.array() for specified `keys` from entire dataset `ds`."""
    d = next(iter(ds.batch(ds.cardinality())))
    return tuple(jnp.array(d[key]._numpy()) for key in keys)


train_images, train_labels = ds_get_all(ds["train"], "image", "label")
train_images /= 255.0
test_images, test_labels = ds_get_all(ds["test"], "image", "label")
test_images /= 255.0

train_images.shape, train_labels.shape  # labels as indices, not one-hot

#### Step 1 : Define a model

In [None]:
# YOUR ACTION REQUIRED:
# Implement the body of this function.
def linear_init(key, input_shape, n_classes):
    """Initializes parameters for a linear classifier.

    Args:
      key: a PRNGKey used as the random key.
      input_shape: Shape of a single input example.
      n_classes: Number of output classes.

    Returns:
      A pytree to be used as a first argument with `linear_apply()`.
    """
    pass


# YOUR ACTION REQUIRED:
# Implement the body of this function.
def linear_apply(params, inp):
    """Computes logits for a SINGLE EXAMPLE.

    Args:
      params: A pytree as returned by `linear_init()`.
      inp: A single input example.

    Returns:
      Logits (i.e. values that should be normalized by `jax.nn.softmax()` to get a
      valid probability distribution over the output classes).
    """
    pass

In [None]:
# Initialize classifier & run on a single example.

params = linear_init(
    key=jax.random.PRNGKey(0),
    input_shape=train_images[0].shape,
    n_classes=ds_info.features["label"].num_classes,
)
print(jax.tree_map(jnp.shape, params))
linear_apply(params, train_images[0])

#### Step 2 : Define a loss

In [None]:
def loss_fun(params, inputs, targets):
    """Compute x-entropy loss for a batch of images.

    Args:
      params: a pytree as returned by `linear_init()`.
      inputs: batch of images
      targets: batch of target labels (indices)

    Returns:
      The loss value.
    """
    # Note that we defined linear_apply() for a single example and how we use
    # `vmap()` here to vectorize the function.
    logits = jax.vmap(linear_apply, in_axes=(None, 0))(params, inputs)
    # We go from logits directly to log(probs):
    logprobs = logits - jax.scipy.special.logsumexp(
        logits, axis=-1, keepdims=True
    )
    # Note: targets are indices.
    return -logprobs[jnp.arange(len(targets)), targets].mean()


loss_fun(params, train_images[:2], train_labels[:2])

#### Step 3 : `update_step()`

In [None]:
# This is a good moment to compile our computations using `jit()` !
# REMEMBER: Since we "bake in" all globals when `jit()` is called, you will need
# to re-execute this cell every time you change some code `update_step()`
# depends on (like e.g. `loss_fun()`, or `linear_apply()`).
@jax.jit
def update_step(params, inputs, targets):
    """Take a single optimization step.

    Args:
      params: A pytree as returned by `linear_init()`.
      inputs: batch of images
      targets: batch of target labels (indices)

    Returns:
      A tuple (updated_params, loss).
    """
    loss, grads = jax.value_and_grad(loss_fun)(params, inputs, targets)
    # Opimize using SGD
    updated_params = jax.tree_multimap(
        lambda param, grad: param - 0.05 * grad, params, grads
    )
    return updated_params, loss


update_step(params, train_images[:2], train_labels[:2])

#### Step 4: Train the model

In [None]:
# Step 4 : Do the training by calling `update_step()` repeatedly.


def train(params, steps, batch_size=128):
    losses = []
    steps_per_epoch = len(train_images) // batch_size
    for step in range(steps):
        i0 = (step % steps_per_epoch) * batch_size
        # Training is simply done by calling `update_step()` repeatedly and
        # replacing `params` with `updated_params` returned by `update_step()`.
        params, loss = update_step(
            params,
            train_images[i0 : i0 + batch_size],
            train_labels[i0 : i0 + batch_size],
        )
        losses.append(float(loss))
    return params, jnp.array(losses)


learnt_params, losses = train(params, steps=1_000)
plt.plot(losses)
print("final loss:", np.mean(losses[-100]))

In [None]:
# Compute accuracy of linear model.


def accuracy(params, inputs, targets):
    logits = jax.vmap(linear_apply, in_axes=(None, 0))(params, inputs)
    return (targets == logits.argmax(axis=-1)).mean()


accuracy(learnt_params, test_images, test_labels)

### Flax

You probably want to keep the Flax documentation ready in another tab:

https://flax.readthedocs.io/

In [None]:
# from typing import Callable, Sequence  # used ?

import flax
from flax import linen as nn

#### Functional core

In [None]:
# Simple module with matmul layer. Note that we could build this in many
# different ways using the `scope` for parameter handling.


class Matmul:
    def __init__(self, features):
        self.features = features

    def kernel_init(self, key, shape):
        return jax.random.normal(key, shape)

    def __call__(self, scope, x):
        kernel = scope.param(
            "kernel", self.kernel_init, (x.shape[1], self.features)
        )
        return x @ kernel


class Model:
    def __init__(self, features):
        self.matmuls = [Matmul(f) for f in features]

    def __call__(self, scope, x):
        x = x.reshape([len(x), -1])
        for i, matmul in enumerate(self.matmuls):
            x = scope.child(matmul, f"matmul_{i + 1}")(x)
            if i < len(self.matmuls) - 1:
                x = jax.nn.relu(x)
        x = jax.nn.log_softmax(x)
        return x


model = Model([ds_info.features["label"].num_classes])
y, variables = flax.core.init(model)(key, train_images[:1])
assert (y == flax.core.apply(model)(variables, train_images[:1])).all()

# YOUR ACTION REQUIRED:
# Check out the parameter structure, try adding/removing "layers" and see how it
# changes

In [None]:
# YOUR ACTION REQUIRED:
# Redefine loss_fun(), update_step(), and train() from above to train the new
# model.

#### Stateless Linen module

In [None]:
# Reimplementation of above model using the Linen API.


class Model(nn.Module):
    num_classes: int

    def setup(self):
        self.dense = nn.Dense(self.num_classes)

    def __call__(self, x):
        x = x.reshape([len(x), -1])
        x = self.dense(x)
        x = nn.log_softmax(x)
        return x


model = Model(num_classes=ds_info.features["label"].num_classes)
variables = model.init(jax.random.PRNGKey(0), train_images[:1])
jax.tree_map(jnp.shape, variables)

In [None]:
# YOUR ACTION REQUIRED:
# 1. Rewrite above model using the @nn.compact notation.
# 2. Extend the model to use additional layers, see e.g.
#    convolutions in
#    http://google3/third_party/py/flax/linen/linear.py

In [None]:
model = Model(ds_info.features["label"].num_classes)
variables = model.init(key, train_images[:1])
jax.tree_map(jnp.shape, variables)

In [None]:
# Reimplementation of training loop using a Flax optimizer.


@jax.jit
def update_step_optim(optim, inputs, targets):
    def loss_fun(params):
        logits = model.apply(dict(params=params), inputs)
        logprobs = logits - jax.scipy.special.logsumexp(
            logits, axis=-1, keepdims=True
        )
        return -logprobs[jnp.arange(len(targets)), targets].mean()

    loss, grads = jax.value_and_grad(loss_fun)(optim.target)
    return optim.apply_gradient(grads), loss


def train_optim(optim, steps, batch_size=128):
    losses = []
    steps_per_epoch = len(train_images) // batch_size
    for step in range(steps):
        i0 = (step % steps_per_epoch) * batch_size
        optim, loss = update_step_optim(
            optim,
            train_images[i0 : i0 + batch_size],
            train_labels[i0 : i0 + batch_size],
        )
        losses.append(float(loss))
    return optim, jnp.array(losses)


optim = flax.optim.adam.Adam(learning_rate=0.01).create(variables["params"])
learnt_optim, losses = train_optim(optim, steps=1_000)
plt.plot(losses)
print("final loss:", np.mean(losses[-100]))

In [None]:
# Re-evaluate accuracy.
(
    model.apply(dict(params=learnt_optim.target), test_images).argmax(axis=-1)
    == test_labels
).mean()

#### Linen module with state

In [None]:
# Let's add batch norm!
# I'm not saying it's a good idea here, but it will allow us study the changes
# we need to make for models that have state.


class Model(nn.Module):
    num_classes: int

    @nn.compact
    def __call__(self, x, *, train):
        x = x.reshape([len(x), -1])
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.Dense(self.num_classes)(x)
        x = nn.log_softmax(x)
        return x


model = Model(num_classes=ds_info.features["label"].num_classes)
variables = model.init(jax.random.PRNGKey(0), train_images[:1], train=True)
jax.tree_map(jnp.shape, variables)

# Note the new "batch_stats" collection !

In [None]:
# YOUR ACTION REQUIRED:
# Check below code and add comments for every change compared to the model above
# without state.


@jax.jit
def update_step_optim(optim, batch_stats, inputs, targets):
    def loss_fun(params):
        logits, mutated_state = model.apply(
            dict(params=params, batch_stats=batch_stats),
            inputs,
            mutable="batch_stats",
            train=True,
        )
        logprobs = logits - jax.scipy.special.logsumexp(
            logits, axis=-1, keepdims=True
        )
        return (
            -logprobs[jnp.arange(len(targets)), targets].mean(),
            variables["batch_stats"],
        )

    (loss, state), grads = jax.value_and_grad(loss_fun, has_aux=True)(
        optim.target
    )
    return optim.apply_gradient(grads), batch_stats, loss


def train_optim(optim, batch_stats, steps, batch_size=128):
    losses = []
    steps_per_epoch = len(train_images) // batch_size
    for step in range(steps):
        i0 = (step % steps_per_epoch) * batch_size
        optim, batch_stats, loss = update_step_optim(
            optim,
            batch_stats,
            train_images[i0 : i0 + batch_size],
            train_labels[i0 : i0 + batch_size],
        )
        losses.append(float(loss))
    return optim, batch_stats, jnp.array(losses)


optim = flax.optim.adam.Adam(learning_rate=0.01).create(variables["params"])
learnt_optim, batch_stats, losses = train_optim(
    optim, variables["batch_stats"], steps=1_000
)
plt.plot(losses)
print("final loss:", np.mean(losses[-100]))

In [None]:
# YOUR ACTION REQUIRED:
# Make predictions with above model with state

#### Modify MNIST example

Check out the Flax MNIST example Colab - you can find a link on Github

https://github.com/google/flax/tree/master/linen_examples/mnist

In [None]:
# YOUR ACTION REQURIED:
# Store the Colab in your personal drive and modify it to use the dataset from
# above.
# While this might sound boring, you will learn the following things:
# - how to load files in public Colab from Github, modify them in the UI and
#   optionally store them on your personal Google Drive.
# - how to use inline TensorBoard on public Colab and export it to tensorboard.dev

### Brain templates

code : go/brain-templates

documentation : go/brain-templates-doc

These are more open ended exercises, but they could well pay off most in terms
of time saved in your own projects

In [None]:
# YOUR ACTION REQUIRED:
# 1. Fork the MNIST example.
# 2. Launch on Xmanager.
# 3. Check out the Colab.
# 4. Replace the dataset with the fashion mnist dataset from above.
# 5. Re-run all tests to and fix if necessary.
# 6. Launch modified version on Xmanager.
# 6. Run Colab again with your updated code.

In [None]:
# YOUR ACTION REQUIRED:
# Check out the code of the MNIST and imagenet examples.
# What differences do you see?

### Mini project ?

You might want to use brain templates and Flax examples for your AIR mini
project.

Suggestions

- The go/bt-imagenet incorporates many best practices and might be a good
  candidate to start your project from.
- You might also want to have a look at the current Flax examples at
  https://github.com/google/flax/tree/master/linen_examples
  as a starting point for your project. You could try extracting the model code
  from an example there and merging it into the brain template.

If you encounter any problems on the way, you can reach us via go/flaxers-chat.

### end