<a href="https://colab.research.google.com/github/SimonKoop/jax_tutorial_trial/blob/main/practical_session_1_answers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
# JAX Tutorial
## 1. What's JAX?
* Accellerated Array computations (like Numpy but on GPU)
* JIT (just-in-time) compiled to XLA (Accellerated Linear Algebra)
* Autograd and other transformations


jax.numpy: stand-in replacement for numpy

NB some caveats apply! More on that later, but also, see [this page](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) for a more comprehensive guide.

In [1]:
# this should install all requirements if you're running this in colab
#%pip install git+https://github.com/SimonKoop/common_jax_utils
# You might get an error for some typeguard version not being correct, you can safely ignore this error.

In [2]:
import jax
from jax import numpy as jnp

def sigmoid(x:jax.Array)->jax.Array:  # many commonly used activation functions can be found in jax.nn (e.g. jax.nn.sigmoid)
    return 1 / (1 + jnp.exp(-x))

some_vector = jnp.array([-1., 0., 1., 2.])
print(sigmoid(some_vector))

# we can JIT compile the functions
sigmoid_jitted = jax.jit(sigmoid)
print(sigmoid_jitted(some_vector))  # first time calling this will take long

2024-09-09 19:29:27.620572: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


[0.26894143 0.5        0.73105854 0.8807971 ]
[0.26894143 0.5        0.73105854 0.8807971 ]


In [3]:
%%timeit
sigmoid(some_vector).block_until_ready()  # if you want to know why we use block_until_ready(), read https://jax.readthedocs.io/en/latest/async_dispatch.html

163 μs ± 9.83 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [4]:
%%timeit
sigmoid_jitted(some_vector).block_until_ready()

36.5 μs ± 1.26 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### Exercise 1
Use the tools from jax.numpy to compute the following:
$$
\mathrm{sigmoid}\left(\left(\begin{matrix}1. & 0.5 & 0.25\\ 0. & 2. & 0.25 \\ 0. & 0. & 1.\end{matrix}\right)\left(\begin{matrix}1. \\-1. \\ .3\end{matrix}\right)\right)
$$

In [5]:
# put your answer here
matrix = jnp.array([[1., 0.5, 0.25],[0., 2., 0.25], [0., 0., 1.]], dtype=jnp.float32)
vector = jnp.array([1., -1., .3], dtype=jnp.float32)
sigmoid(matrix@vector)

Array([0.6399161 , 0.12730505, 0.5744425 ], dtype=float32)

## 2. First quirk: pure functions
In order to make a.o. JIT compilation and automatic differentiation easier, the creators of JAX opted for a more functional style of programming. For many of the useful transforms of JAX to work, you need to use **pure functions** i.e. functions without side-effects.

Because of this, **jax arrays are immutable**

In [6]:
# jax arrays vs numpy arrays example
import numpy as np

some_numpy_array = np.zeros((3, 3))
some_numpy_array[1] = np.linspace(0., 1., 3)
print(f"{some_numpy_array=}")

some_jax_array = jnp.zeros((3, 3))
# some_jax_array[1] = jnp.linspace(0., 1., 3)  # this will give an error, uncomment if you want to try
# print(f"{some_jax_array=}")

some_numpy_array=array([[0. , 0. , 0. ],
       [0. , 0.5, 1. ],
       [0. , 0. , 0. ]])


In [7]:
# how to do this instead
some_new_jax_array = some_jax_array.at[1].set(jnp.linspace(0., 1., 3))
print(f"{some_jax_array=}")  # also, note that the default dtype is jnp.float32. 
print(f"{some_new_jax_array=}")  # To be able to use float64, you'd have to enable it at startup, but maybe just don't unless you really need to.

# Note that the original array remains unmodified. 

some_jax_array=Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
some_new_jax_array=Array([[0. , 0. , 0. ],
       [0. , 0.5, 1. ],
       [0. , 0. , 0. ]], dtype=float32)


As mentioned before, if we want to use function transformations like `jax.jit` or `jax.grad` (more on this later), we will have to use pure functions. To gain some understanding of what this means, let's look at something that is *not* a pure function

In [8]:
# the following works absolutely fine untill we try to JIT things (or use autograd, or any of the other useful function transformations)
class SelfChangingLayer:

    def __init__(self, initial_weights:jax.Array):
        """ 
        Initialize the SelfChangingLayer
        :parameter initial_weights: weights matrix, is.e. jax.Array with shape (N, N)
        """
        self.weights = initial_weights

    def __call__(self, x:jax.Array):
        """ 
        forward pass
        :parameter x: vector i.e. jax.Array with shape (N,)
        """
        output = self.weights@x
        # and now for the side-effect:
        self.weights = .9*self.weights + .1*(output[:, None]@output[None, :])  # output[:, None] has shape (N, 1) and output[None, :] has shape (1, N)
        return output
    
my_layer = SelfChangingLayer(jnp.eye(3))
some_x1 = jnp.array([1., 0., 0.])
some_x2 = jnp.array([2., -1., 2.])

print(f"{my_layer(some_x1)=}")
print(f"{my_layer(some_x2)=}")
print(f"{my_layer(some_x1)=}")

# As you can see, my_layer has side-effects. Due to this, the same input results in two different outputs.
# That means we can't jit compile it, or use automatic differentiation with it, without getting wrong answers

my_layer_jitted = jax.jit(SelfChangingLayer(jnp.eye(3)))
print(f"{my_layer_jitted(some_x1)=}")
print(f"{my_layer_jitted(some_x2)=}")
print(f"{my_layer_jitted(some_x1)=}")  # as you can see, these outcomes are not correct!

my_layer(some_x1)=Array([1., 0., 0.], dtype=float32)
my_layer(some_x2)=Array([ 2. , -0.9,  1.8], dtype=float32)
my_layer(some_x1)=Array([ 1.3       , -0.17999999,  0.35999998], dtype=float32)
my_layer_jitted(some_x1)=Array([1., 0., 0.], dtype=float32)
my_layer_jitted(some_x2)=Array([ 2., -1.,  2.], dtype=float32)
my_layer_jitted(some_x1)=Array([1., 0., 0.], dtype=float32)


### Exercise 2
The way we typically deal with this, is by making the state an argument to the function and returning an updated state as an output of the function. Write a pure function analogue of the `SelfChangingLayer` above.


In [9]:
def self_changing_layer(x:jax.Array, weights:jax.Array) -> tuple[jax.Array, jax.Array]:
    output = weights @ x 
    weights = .9*weights +.1*(output[:, None]@output[None, :])
    return output, weights

# uncomment the following to test your code

weights = jnp.eye(3)
output_1, weights = self_changing_layer(some_x1, weights)
output_2, weights = self_changing_layer(some_x2, weights)
output_3, weights = self_changing_layer(some_x1, weights)
print(f"{output_1=}\n{output_2=}\n{output_3=}")

self_changing_layer_jitted = jax.jit(self_changing_layer)
weights = jnp.eye(3)
output_1, weights = self_changing_layer_jitted(some_x1, weights)
output_2, weights = self_changing_layer_jitted(some_x2, weights)
output_3, weights = self_changing_layer_jitted(some_x1, weights)
print(f"{output_1=}\n{output_2=}\n{output_3=}")

output_1=Array([1., 0., 0.], dtype=float32)
output_2=Array([ 2. , -0.9,  1.8], dtype=float32)
output_3=Array([ 1.3       , -0.17999999,  0.35999998], dtype=float32)
output_1=Array([1., 0., 0.], dtype=float32)
output_2=Array([ 2. , -0.9,  1.8], dtype=float32)
output_3=Array([ 1.3       , -0.17999999,  0.35999998], dtype=float32)


Does this mean we can't use object-oriented programming with JAX? No, later in this tutorial we'll look at a package called Equinox that will help us with this.

## 3. Higher order functions
Another aspect of functional programming that JAX makes heavy use of, is higher order functions. That means: functions that either take another function as an argument, or return another function as their output (or both). One example we have already seen of a higher order function is `jax.jit`: it takes a (pure) function as its input, and returns a jit-compiled function as its output. We also alluded to the the existence of `jax.grad`, let's now look at what it does:

In [10]:
sigmoid_grad = jax.grad(sigmoid)  # this gives the gradient of sigmoid with respect to its first (and only) argument. 
def sigmoid_grad_manual(x):
    sigmoid_x = sigmoid(x)
    return sigmoid_x*(1-sigmoid_x)

print(f"0.: {sigmoid_grad(0.)}, 1.:{sigmoid_grad(1.)}, -1.:{sigmoid_grad(-1.)}")
print(f"0.: {sigmoid_grad_manual(0.)}, 1.:{sigmoid_grad_manual(1.)}, -1:{sigmoid_grad_manual(-1.)}")

0.: 0.25, 1.:0.1966119408607483, -1.:0.1966119408607483
0.: 0.25, 1.:0.19661195576190948, -1:0.1966119408607483


One salient detail of `jax.grad` is that the function of which we compute the gradient must have scalar output. 
When we feed `sigmoid` a vector, it returns a vector, so the following gives `TypeError: Gradient only defined for scalar-output functions. Output had shape: (4,).`

In [11]:
# sigmoid_grad(some_vector) # uncomment if you want to try 

If we want to use `sigmoid_grad` to compute the gradient of the *scalar* function sigmoid on a batch (vector) of scalars, we can use another function transformation: `jax.vmap`.

What `jax.vmap` does, is it takes a function f, and it returns a version of f that works on a batch of inputs (this is a vectorized function, so the computations happen in parallel). Let's see how this works:

In [12]:
print(f"{jax.vmap(sigmoid_grad)(some_vector)=}")
print(f"{sigmoid_grad_manual(some_vector)=}")

jax.vmap(sigmoid_grad)(some_vector)=Array([0.19661194, 0.25      , 0.19661194, 0.10499357], dtype=float32)
sigmoid_grad_manual(some_vector)=Array([0.19661194, 0.25      , 0.19661196, 0.10499357], dtype=float32)


Now for many functions (such as sigmoid and the derivative of a sigmoid), it is easy to write a function that automatically gets vectorized when applied to higher-dimensional input. But when writing more complicated functions, `jax.vmap` allows us to write the function for inputs for which we easily understand what it should look like, and then automatically apply it correctly to higher dimensional inputs without worrying about this. 

### Exercise 3
Use the documentation of [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html) to create the following function:
$$
\mathrm{target\_function\_1}(\mathrm{value}, \mathrm{weights}) = \left(\left(\nabla_x \mathrm{self\_changing\_layer\_square}\right)(\mathrm{value}, \mathrm{weights})[0],\  \mathrm{self\_changing\_layer}(\mathrm{value},\, \mathrm{weights}) [1]\right)
$$
where $f()[i]$ is the element at index $i$ of the tuple resulting from $f()$.  

In [13]:
def self_changing_layer_square(x, weights):
    out, new_weights = self_changing_layer(x, weights)
    return jnp.square(out).sum(), new_weights

target_function = jax.grad(self_changing_layer_square, has_aux=True)
target_function(some_x1, weights)

(Array([ 3.9295735, -0.9079778,  1.8159556], dtype=float32),
 Array([[ 1.384392  , -0.19168504,  0.38337007],
        [-0.19168504,  0.7280632 , -0.1439266 ],
        [ 0.38337007, -0.1439266 ,  0.94395316]], dtype=float32))

In [14]:
type(target_function)

function

In [15]:
# check that the resulting weights are correct:
self_changing_layer(some_x1, weights)[1]

Array([[ 1.384392  , -0.19168504,  0.38337007],
       [-0.19168504,  0.7280632 , -0.1439266 ],
       [ 0.38337007, -0.1439266 ,  0.94395316]], dtype=float32)

### Exercise 4
In the following code block, you are given one array containing 20 different $3$-vectors, and an array containing 30 different $3$-vectors. Use `vmap` (twice) with the given outer product function to create a $20\times 30\times 3\times 3$ array with in location $(i, j)$ a $3\times 3$-matrix containing the result of taking the outer product of the $i^{\text{th}}$ vector in the first array with the $j^{\text{th}}$ vector in the second array. 

Hint: if you want vmap to not vmap over any dimension in some array, you can provide `None` for that array in `in_axes`

In [16]:
vectors_1 = jnp.linspace(0., 1., 60).reshape((20, 3))
vectors_2 = jnp.linspace(0., 1., 90).reshape((30, 3))

def outer_product(u: jax.Array, v: jax.Array)->jax.Array:
    """ 
    Vector outer product
    :parameter u: jax.Array of shape (n,)
    :parameter v: jax.Array of shape (n,)
    :return: jax.Array of shape (n, n) representing the outer product u@v^T
    """
    return u[:, None]@v[None, :]


# your code goes here
jax.vmap(jax.vmap(outer_product, (None, 0)), (0, None))(vectors_1, vectors_2).shape

(20, 30, 3, 3)

## 4. Random numbers
When we use 'random' numbers in machine learning, they typically aren't truely random. Instead, we are using a [pseudo random number generator](https://en.wikipedia.org/wiki/Pseudorandom_number_generator) (**PRNGs**) to create (deterministic) sequences of outputs that the right statistical properties. These PRNGs typically have a *seed*, which determines what sequence is going to be generated, and some internal state. In e.g. numpy and pytorch, you can set this seed, and the (underlying) PRNG that is used then automatically updates its state every time you draw a random number.

Now, as we have discussed above, JAX doesn't play too nicely with global states. Instead, we have to keep track of the PRNG state, typically refered to as a **key**, ourselves. If we pass the same key to the same random function twice, we get the same result twice:

In [17]:
from jax import random
key = random.key(123)
print(key)

print(f"\n{random.normal(key, shape=(3,))=}")
print(f"{random.normal(key, shape=(3,))=}\n")
print(key)

Array((), dtype=key<fry>) overlaying:
[  0 123]

random.normal(key, shape=(3,))=Array([-0.1470326,  0.5524756,  1.648498 ], dtype=float32)
random.normal(key, shape=(3,))=Array([-0.1470326,  0.5524756,  1.648498 ], dtype=float32)

Array((), dtype=key<fry>) overlaying:
[  0 123]


In order to get new pseudo random numbers, we need to split the key:

In [18]:
key, subkey = random.split(key)
print(key)
print(subkey)
print(f"\n{random.normal(subkey, shape=(3,))=}")

Array((), dtype=key<fry>) overlaying:
[1896456402   17229315]
Array((), dtype=key<fry>) overlaying:
[4081828428 1707601653]

random.normal(subkey, shape=(3,))=Array([-0.56996626, -0.6440589 ,  0.28660855], dtype=float32)


When doing this, **make sure not to re-use used keys!** A good habit is to only ever use subkey to pass to functions, and use key to keep track of your random state. So:

In [19]:
key, subkey = random.split(key)
print(random.normal(subkey))
key, subkey = random.split(key)
print(random.normal(subkey))

1.1124845
-0.70996124


Note that key and subkey are just `jax.Array`s, so we can e.g. `vmap` over subkey if need be:

In [20]:
def my_complicated_random_function(key):
    # imagine that this is some complicated code that is hard to manually write in a vectorized way
    return random.normal(key)

key, subkey = random.split(key)

subkey_array = random.split(subkey, 5)  # if we wanted say a grid instead, we could instead do random.split(subkey, (5, 5)) where (5, 5) would be the shape of the grid of subkeys
jax.vmap(my_complicated_random_function)(subkey_array)

Array([ 1.0297086 ,  2.3845856 , -1.8143567 , -0.80407333, -0.8641458 ],      dtype=float32)

### Exercise 5:
Look at the documentation of [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan) and [`jax.lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html#jax.lax.select) to write a function that performs the following random walk by scanning over an array of prng keys:

$$
x^0 = (0, 0)\\
x^i = \begin{cases}x^{i-1}+\epsilon^i&\text{with probability }\frac{1}{2}\\(x^{i-1}_1, x^{i-1}_0) + \epsilon^i&\text{with probability }\frac{1}{2}\end{cases}\\
\text{where}\\
\epsilon^i \sim \mathcal{N}(\underline{0}, I)
$$

Then use [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html) to simulate 10 random walks of length 20 in parallel.

In [21]:
def random_flip(x:jax.Array, key:jax.Array)->jax.Array:
    """ 
    Flips the entries in x randomly with probability 1/2
    :parameter x: (2,)-array of floats
    :parameter key: prng key for deciding whether to flip or not
    :return: array of same shape and type as x
    """
    random_float = random.uniform(key)
    return jax.lax.select(random_float<.5, x, x[::-1])

def step_function(x_old:jax.Array, key:jax.Array)->tuple[jax.Array, jax.Array]:
    """ 
    Step function for the jax.lax.scan that simulates the random walk
    :parameter x_old:  (2,)-array of floats representing x^{i-1}
    :parameter key: prng key for all the randomness in this step
    :return: (x_new, x_new)  where x_new has the same shape as x_old and represents x^i
    """  # why do we want to return x_new twice?
    key_1, key_2 = random.split(key)  # Don't forget to split the keys!
    epsilon = random.normal(key_2, (2,))
    x_new = random_flip(x_old, key_1) + epsilon
    return x_new, x_new

def simulate_random_walk(key:jax.Array, length:int)->jax.Array:
    """ 
    Simulate one random walk.
    :parameter key: prng key for all randomness in this random walk
    :parameter length: length of the random walk
    :return: (length, 2)-array of floats representing the random walk
    """
    keys = random.split(key, length-1) # split the key
    x_0 = jnp.zeros((2,))
    _, random_walk = jax.lax.scan(
        step_function, 
        x_0, 
        keys
    )
    return jnp.concatenate([x_0[None, :], random_walk])

key, subkey = random.split(key)
subkeys = random.split(subkey, 10)
result = jax.vmap(simulate_random_walk, in_axes=(0, None))(subkeys, 20)
print(result.shape)

(10, 20, 2)


On a side-note, although there are good reasons for JAX to take this approach to PRNGs, it can be a bit cumbersome to do this by hand all the time, especially when you're e.g. writing neural network architectures, where each layer will need a prng key for initialization and you might not know by heart beforehand how many keys you will need. I've written a package with convenience utilities called `common_jax_utils` that includes a generator `key_generator` that you can use in such situations.

I would recommend against using it in any code that needs to be jitted or vmapped, although, as long as you write all of your functions in a way that they do not have side-effects (**so don't write functions that expect a key_generator as an argument**), you should be fine.

In [22]:
import common_jax_utils as cju
key, subkey = random.split(key)
key_gen = cju.key_generator(subkey)

print(f"{random.normal(next(key_gen))=}")
print(f"{random.normal(next(key_gen))=}")
print(f"{random.normal(next(key_gen))=}")

random.normal(next(key_gen))=Array(1.1852474, dtype=float32)
random.normal(next(key_gen))=Array(0.54125774, dtype=float32)
random.normal(next(key_gen))=Array(-0.83737814, dtype=float32)


## 5. Control flow and function transformations
### 5.1 Jitting / vmapping
Sometimes, we need to use control flow, such as `if`-`else` statements or `for` loops. This doesn't always play nice with `jax.jit` and `jax.vmap`. In some of the above exercises, we already saw some `jax.lax` control flow alternatives such as `jax.lax.select` and `jax.lax.scan`. The question then remains: when should we use these control flow functions from `jax.lax`, and when is it okay to just use python control statements?

https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow provides a good, elaborate explanation of this, and you should read this. Nonetheless, I'd like to provide a quick rule of thumb with a rationale behind it here.

Basically, `jax.jit` keeps track of the shape and dtype of arguments to a jitted function. When a combination of argument shapes/dtypes is encountered that hasn't been encountered before, the function gets compiled *for this specific combination* and the compiled code gets stored, so that when this combination is encountered again, the compiled code can be used again and no re-compilation is necessary. If you play around with this, you'll find that the first time you use a jitted function on a specific array shape, it takes *much* longer than each next time you use it on the same shape.

This brings us to the **first rule of thumb** when it comes to control flow and jitting functions: *when the control flow depends **only on the shapes/dtypes** of arguments, you can safely use python control flow statements*. Some examples are:
`if len(input_array.shape) == 3:...` or `if input_array.shape[3] % 2 == 0:...` or `if input_array.dtype == jnp.int16:...`.
In principle the same holds for e.g. loops, but *if you use long loops (i.e. with many iterations) within a jitted function, the compilation step will likely take very long!*

The first step towards jit-compiling a function is called tracing. The function is run on *tracer* objects in order to record the sequence of operations specified by the function. These tracer objects encode the shape and dtype of what they are representing, **but not the actual values**. That means that if you use regular Python control flow (like `if` and `else`, or `while`) based on the values, that you will either get errors, or wrong results. 

This brings us to the **second rule of thumb** when it comes to control flow and jitting functions: *when the control flow depends on the **values** of arguments, you should use [control flow primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators).* For example `jax.lax.select(r<.5, x_i, x_i[::-1])`. 


Another useful resource on this, that you should definitely read, is https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html


As a final **TL;DR on jitting and control flow**:
* values depending on values **requires control flow operators**
* values depending on shapes is fine (you can use python control flow statements)
* shapes depending on shapes is fine, (as long as you don't use jax to make intermediate computations, so `jnp.zeros((np.prod(x.shape),))` is fine but `jnp.zeros((jnp.prod(x.shape)))` wil cause problems when jitted)
* shapes depending on values is **not fine** (although in some cases you can get around this by indicating to `jax.jit` that some arguments should not be traced by marking them as *static*).





In [23]:
# example making tracing explicit
@jax.jit  # decorator syntax, see e.g https://pythonbasics.org/decorators/
def f(x, y):
    print("Running f():")
    print(f"  x = {x}")
    print(f"  y = {y}")
    result = jnp.dot(x + 1, y + 1)
    print(f"  result = {result}")
    return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>


Array([9.449684 , 6.446216 , 7.7138443], dtype=float32)

On a side-note, if you want to print values for e.g. debugging purposes, just using the Python `print` function will not give you the desired result when working with jitted functions. Read the material in https://jax.readthedocs.io/en/latest/debugging/index.html and in https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html for more on the topic of **debugging transformed functions**. 

### Exercise 6

Complete the following python functions in a way that they can be correctly jit-compiled and then test the jit-compiled versions:

In [24]:

def f1(x:jax.Array):
    """ 
    :parameter x: jax.Array (of any shape)
    :return: an array with elements y_i = x_i**2 if x_i <0 else x_i **3
    """
    return jnp.where(x<0, x**2, x**3)

def f2(a: jax.Array, b: jax.Array):
    """ 
    :parameter a: jax.Array of shape (n_0, ..., n_k)
    :parameter b: jax.Array of shape (m_0, ..., m_k)
    :return: a jax.Array c of shape (max(n_0, m_0), ..., max(n_k, m_k))
        with c[i_0..., i_k] = aa[i_0, ..., i_k] + bb[i_0, ..., i_k]
        where aa[i_0, ..., i_k] = a[i_0, ..., i_k] if i_0<n_0, ..., i_k < n_k else max(n_0, ..., n_k)
        and b[i_0, ..., i_k] = b[i_0, ..., i_k] if i_0<m_0, ..., i_k < m_k else max(m_0, ..., m_k)
    """
    # your code goes here
    size = [max(n, m) for n, m in zip(a.shape, b.shape)]

    aa = jnp.pad(a, pad_width=[(0, s-n) for s, n in zip(size, a.shape)], mode='constant', constant_values=max(a.shape))
    bb = jnp.pad(b, pad_width=[(0, s-m) for s, m in zip(size, b.shape)], mode='constant', constant_values=max(b.shape))
    
    return aa+bb

print(f1(jnp.array([-1., 1., -2., 2.])))
print(f2(jnp.array([[0., 1.]]), jnp.array([[0.], [1.]])))

[1. 1. 4. 8.]
[[0. 3.]
 [3. 4.]]


### 5.2 Control flow and grad
Mostly there is no problem here (until we want to also `jit` the `grad` of a function). When using `lax.while_loop` or `lax.fori_loop`, automatic differentiation is restricted to [forward mode differentiation](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobian-vector-products-jvps-aka-forward-mode-autodiff) (see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#summary). But you're unlikely to run into this during this course.

One salient side-note about control-flow and automatic differentiation, is that if you have a function that is only partly defined, such as `jnp.log`, and you want to make it safe using `jnp.where`, you should put the `jnp.where` **inside** the `jnp.log` so as to prevent *NaN*s from occuring: see https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where

## 6. Pytrees and related utilities
Before we really get into defining and training Neural Networks in JAX, there is one more very important concept in JAX: pytrees. Before we get to what they really are, you can informally think of pytrees as any container (of containers of...) of `jax.Array`s. This really isn't quite accurate, but for now it's good enough.

Why are these relevant? Well, remember that because our functions have to be pure, any state that we might want to have needs to be passed around between them? That state will often not be just at single array. Think of e.g. all the weights and biases of a Neural Network, and think of the momentum for all those weights and biases that some optimizers need to keep track of. For an MLP it might be reasonable to just put all of them into a tuple, but when we work with more complicated networks, it might be helpful for the structure in which we store all of these weights and biases to resemble the architecture of the network.

Fortunately, many jax functions, such as `jax.jit`, `jax.grad`, `jax.vmap`, and `jax.lax.scan` play really nicely with these pytrees. And for functions that don't, there are utilities such as `jax.tree.map` and `jax.tree_util.tree_map_with_path`. Let's look at some examples:

In [25]:
from common_jax_utils.debug_utils import summary

def some_function(a:jax.Array, b:tuple[jax.Array, jax.Array]):
    """ 
    :parameter a: array of shape (in_channels,)
    :parameter b: tuple of two arrays
        first array of shape (out_channels, in_channels)
        second array of shape (out_channels,)
    :return: array of shape (out_channels,)
    """
    return (b[0]@a)/(1+jnp.abs(b[1]))

a = random.normal(next(key_gen), (10, 3))
b = [
    random.normal(next(key_gen), (2, 3)),
    random.normal(next(key_gen), (2,))
]

print(f"{summary(jax.vmap(some_function, in_axes=(0, None))(a, b))=}") # vmap over the 0-axis of a and over none of the axes in any element in b

b_alt = (
    random.normal(next(key_gen), (2, 3)),
    random.normal(next(key_gen), (2, 10))
)

print(f"{summary(jax.vmap(some_function, in_axes=(0, (None, 1)))(a, b_alt))=}") # vmap over the 0-axis of a and over none of the axes in the first element of b_alt and over the 1-axis of the second element of b_alt

def another_function(a:jax.Array, b:tuple[jax.Array, jax.Array]):
    """ 
    :parameter a: array of shape (batch, in_channels)
    :parameter b: tuple of two arrays
        first array of shape (out_channels, in_channels)
        second array of shape (out_channels,)
    :return: array of shape (out_channels,)
    """
    return jnp.mean(
        jnp.square(
            jax.vmap(some_function, in_axes=(0, None))(a, b)
            )
        )

print(f"{summary(jax.value_and_grad(another_function, argnums=1)(a, b))=}")  # value and gradient w.r.t b, where gradient has the same shape as b, including the same pytree structure
summary(
    jax.vmap(jax.value_and_grad(another_function, argnums=1), in_axes=(None, (None, 1)))(a, b_alt)  # can you explain why every array in the result has 10 for its 0-axis?
)

summary(jax.vmap(some_function, in_axes=(0, None))(a, b))='array(10, 2)'
summary(jax.vmap(some_function, in_axes=(0, (None, 1)))(a, b_alt))='array(10, 2)'
summary(jax.value_and_grad(another_function, argnums=1)(a, b))="('array()', ['array(2, 3)', 'array(2,)'])"


"('array(10,)', ('array(10, 2, 3)', 'array(10, 2)'))"

### Exercise 7
Complete the following code:

In [26]:
# let's look at these concepts in action by writing a simple implementation of an mlp and its training loop
from typing import Sequence, Union  # for Sequence: think list or tuple

def simple_mlp(x:jax.Array, weights:Union[Sequence[jax.Array], jax.Array], biases:Union[Sequence[jax.Array], jax.Array]):
    """simple_mlp 
    Apply an MLP with weights and biases dicated by the weights and biases parameters, and ReLU activations, to the datapoint x.
    The final layer of the MLP is a linear one.

    :parameter x: jax.Array of shape (in_channels,)
    :parameter weights: Either a jax.Array of shape (out_channels, in_channels), 
        or a Sequence of jax.Arrays, the first one of shape (hidden_channels_0, in_channels) 
        and the last one of shape (out_channels, hidden_channels_n)
    :parameter biases: Either a jax.Array of shape (out_channels,),
        or a Sequence of jax.Arrays, the first one of shape (hidden_channels_0,)
        and the last one of shape (out_channels,)
    :raises ValueError: in case the number of weights arrays isn't equal to the numer of biases arrays
    :return: MLP applied to x
    """
    # first do some case handling for the different input types (jax.Array vs Sequence[jax.Array])
    if not isinstance(weights, Sequence):
        weights = [weights]
    if not isinstance(biases, Sequence):
        biases = [biases]

    # then check if we need to raise a ValueError
    if len(weights) != len(biases):
        raise ValueError(f"weights and biases should have equal length. Got {len(weights)=} but {len(biases)=}.")

    # finally implement the logic of the MLP
    # hint for the ReLU activations you can use jax.nn.relu
    h = x 
    for w, b in zip(weights[:-1], biases[:-1]):
        h = jax.nn.relu(w@h+b)
    
    w, b = weights[-1], biases[-1]
    return w@h + b


weights = (random.normal(next(key_gen), (16, 3)), random.normal(next(key_gen), (16, 16)), random.normal(next(key_gen), (1, 16)))  # NB this is really bad initialization
biases = (random.normal(next(key_gen), (16,)), random.normal(next(key_gen), (16,)), random.normal(next(key_gen), (1,)))
batch = random.normal(next(key_gen), (10, 3))

# uncomment to test your code
print(f"{cju.debug_utils.summary(jax.vmap(simple_mlp, in_axes=(0, None, None))(batch, weights, biases))=}")  # apply simple_mlp to a batch of inputs

def calculate_loss(xs:jax.Array, ys:jax.Array, weights:tuple[jax.Array, ...], biases:tuple[jax.Array, ...]):
    """
    apply simple_mlp on xs and compute the mean square error
    :parameter xs: jax.Array with input data
    :parameter ys: target values to be approximated by the mlp
    :parameter weights: weigths of the MLP
    :parameter biases: biases of the MLP
    :return: scalar representing the mean squared error
    """
    predictions = jax.vmap(simple_mlp, in_axes=(0, None, None))(xs, weights, biases)
    return jnp.mean(jnp.square(ys-predictions))

def simple_train_step(
        weights_and_biases: tuple[Sequence[jax.Array], Sequence[jax.Array]], 
        data: tuple[jax.Array, jax.Array]
        ) -> tuple[tuple[Sequence[jax.Array], Sequence[jax.Array]], jax.Array]:
    """ 
    Perform a train step with a fixed learning rate of .01
    The signature of this function (what kind of arguments it takes and what it returns) makes it so that it can be used with jax.lax.scan
    :parameter weights_and_biases: 2-tuple with as its first element the weights of an MLP and as its second element the biases of an MLP (both compatible with)
    :parameter data: 2-tuple with as its first element a jax.Array representing a batch of input data, and as its second element a jax.Array containing the corresponding labels
    :return: 2-tuple, the first element of which contains a new value for weights_and_biases, and the second element of which contains the computed loss
    """
    xs, ys = data
    weights, biases = weights_and_biases
    # hint: use jax.value_and_grad to get the loss and gradients, and use jax.tree.map for getting the updated weights and biases
    loss, grads = jax.value_and_grad(
        calculate_loss, 
        argnums=(2, 3)  # weights, biases
    )(xs, ys, weights, biases)
    new_weights, new_biases = jax.tree.map(
        lambda w, g: w-.01*g,
        (weights, biases),
        grads
    )
    return (new_weights, new_biases), loss

ys = random.normal(next(key_gen), (10, 1))

training_data = random.normal(next(key_gen), shape=(20, 100, 3))  # 20 batches of size 100
training_labels = jnp.sin(training_data)

# uncomment the following to test your implementation
print(f"{cju.debug_utils.summary(simple_train_step((weights, biases), (batch, ys)))=}")

final_weights_and_biases, losses = jax.lax.scan(simple_train_step, init=(weights, biases), xs=(training_data, training_labels))
print(cju.debug_utils.summary(final_weights_and_biases))
print(losses)


cju.debug_utils.summary(jax.vmap(simple_mlp, in_axes=(0, None, None))(batch, weights, biases))='array(10, 1)'
cju.debug_utils.summary(simple_train_step((weights, biases), (batch, ys)))="((('array(16, 3)', 'array(16, 16)', 'array(1, 16)'), ('array(16,)', 'array(16,)', 'array(1,)')), 'array()')"
(('array(16, 3)', 'array(16, 16)', 'array(1, 16)'), ('array(16,)', 'array(16,)', 'array(1,)'))
[303.88705   837.9702     43.880123    5.9955125   4.6102185   3.1268833
   2.637646    2.569746    2.2857032   1.6449605   1.8573601   1.9932003
   1.6272165   1.5321568   1.71918     1.4262049   1.386418    1.213356
   1.2622874   1.2384579]


As mentioned, thinking of Pytrees just as compositions of Python containers and `jax.Array`s isn't quite right. The containers needn't be just standard Python containers, and the leafs of the tree needn't just be `jax.Array`s. Instead, what counts as a container for a pytree is any object whose type is registered in JAX's pytree container registry. And any object whose type is not in that registry is considered a leaf of a pytree. So, the definition given by the [JAX documentation](https://jax.readthedocs.io/en/latest/pytrees.html#what-is-a-pytree) is:
1. any object whose type is *not* in the pytree container registry is considered a *leaf* pytree;
1. any object whose type is in the pytree container registry, and which contains pytrees, is considered a pytree.

[If you want, you can write and register your own pytree container types](https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees). But likely you won't need to, as any subclass of `equinox.Module`, our next topic of interest, is automatically registered as a pytree container. Another source of pytrees that you will encounter is `optax` with its optimizer states and gradient updates.

Utilities for working with pytrees can be found in [`jax.tree`](https://jax.readthedocs.io/en/latest/jax.tree.html), [`jax.tree_util`](https://jax.readthedocs.io/en/latest/jax.tree_util.html#module-jax.tree_util), [`common_jax_utils.tree_utils`](https://github.com/SimonKoop/common_jax_utils/blob/main/src/common_jax_utils/tree_utils.py), [`optax.tree_utils`](https://optax.readthedocs.io/en/latest/api/utilities.html#tree), and [`equinox`](https://docs.kidger.site/equinox/api/manipulation/).

Equinox is particularly useful when some of the leafs in your pytree are not actually `jax.Array`s.

# [Equinox](https://docs.kidger.site/equinox/) tutorial

Unlinke Pytorch and Tensorflow, JAX itself isn't a deeplearning library. There are various Neural Network libraries built ontop of JAX such as Flas, Haiku, and Equinox. In this course, we'll be using Equinox because it is conceptually simple (and frankly because I had a bunch of code for an old research project lying around that can make things significantly easier for you, and that code was written using Equinox). Importantly, Equinox makes it easy to do object oriented programming very similarly to Pytorch, but resulting in pytrees, so that the result is compatible with anything else in the JAX ecosystem.

## 1. Equinox Modules
Writing (neural network) models in Equinox is done by subclassing [`equinox.Module`](https://docs.kidger.site/equinox/api/module/module/). The resulting code looks quite similar to Pytorch code (where writing models happens through subclassing `torch.nn.Module`), but there are some key differences:
* `equinox.Module`s are **static** (remember, to jit things, our functions need to be pure, i.e. without side-effects)
* `equinox.Module`s are pytrees (they are registered as pytree containers automatically)
* `equinox.Module`s are [*dataclasses*](https://docs.python.org/3/library/dataclasses.html) and all of their attributes need to be registered as *fields*. This also means they come with a default `__init__` method.
* `equinox.Module`s will typically require a PRNG key for initializing weights and biases with random numbers
* the forward pass of your model can be defined in its `__call__` method (or you can just use any other method you want).

One particularly nice property of `equinox.Module`s, is that because they are pytrees, and because everything in JAX plays well with pytrees, you can have an `equinox.Module` that returns another `equinox.Module` in its forward pass, and everything will just work nicely the way you would want it to. We'll see this in the next practical session where we get more accustomed to the course code-base.

An example:

In [27]:
import equinox as eqx
from typing import Optional
import warnings
import math

# let's first re-define simple_mlp from above as a eqx.Module
# eqx.nn actually provides an MLP class, and moreover, it provides layers such as nn.Linear that you should ideally use for defining neural networks
# but this does provide a good, very basic example for now
class SimpleMLP(eqx.Module):
    weights: Sequence[jax.Array]  # we need to register the attributes as fields
    biases: Sequence[jax.Array]  # their values can only be set during initialization of the model

    def __init__(self, in_size:int, out_size:int, hidden_size:Union[int, Sequence[int]], key:jax.Array, depth:Optional[int]=None):
        """ 
        :parameter in_size: number of input features
        :parameter out_size: number of output features
        :parameter hidden_size: either an integer or a sequence of integers, the number of hidden units in each layer
        :parameter key: jax.Array, random key for initialization
        :parameter depth: number of hidden layers (ignored if hidden_size is a sequence of integers)
        """
        key_gen = cju.key_generator(key)
        if isinstance(hidden_size, int) and depth is not None:
            hidden_size = depth*(hidden_size,)
        elif isinstance(hidden_size, int):  # depth is None
            raise ValueError("If hidden size is an integer, depth should be provided do determine the number of hidden layers")
        elif depth is not None: # hidden size is not an integer
            warnings.warn(f"The value provided for depth ({depth}) will be ignored because hidden_size is provided as a sequence ({hidden_size})")

        output_sizes = tuple(hidden_size) + (out_size,)
        input_sizes = (in_size, ) + tuple(hidden_size)

        self.weights = []
        self.biases = []
        for out_size, in_size in zip(output_sizes, input_sizes):
            lim = 1/math.sqrt(in_size)
            self.weights.append(random.uniform(
                next(key_gen), 
                (out_size, in_size), 
                minval=-lim,
                maxval=lim
                ))
            self.biases.append(random.uniform(
                next(key_gen),
                (out_size,),
                minval=-lim,
                maxval=lim
            ))
    
    def __call__(self, x:jax.Array)->jax.Array:
        """ 
        Forward pass
        :parameter x: input data, jax.Array of shape (in_size,)
        :return: output of the MLP, jax.Array of shape (out_size,)
        """
        h = x
        for w, b in zip(self.weights[:-1], self.biases[:-1]):
            h = jax.nn.relu(w @ h + b)
        w, b = self.weights[-1], self.biases[-1]
        return w @ h + b
    
example_mlp = SimpleMLP(3, 1, (16, 17, 16), key=next(key_gen))
print(example_mlp)

x = random.normal(next(key_gen), (100, 3))
print(f"{summary(jax.vmap(example_mlp)(x))=}")





SimpleMLP(
  weights=[f32[16,3], f32[17,16], f32[16,17], f32[1,16]],
  biases=[f32[16], f32[17], f32[16], f32[1]]
)
summary(jax.vmap(example_mlp)(x))='array(100, 1)'


For another example, let's now write a model that actually uses the compositional nature of `eqx.Module`s and uses some of the pre-defined layers from `eqx.nn`:

In [28]:
class FCResBlock(eqx.Module):
    norm_layer_0: eqx.Module
    norm_layer_1: eqx.Module
    linear_layer_0: eqx.Module
    linear_layer_1: eqx.Module

    def __init__(self, feature_size, key):
        """
        :parameter feature_size: number of features
        :parameter key: jax.Array, random key for initialization
        """
        key_gen = cju.key_generator(key)
        self.norm_layer_0 = eqx.nn.LayerNorm(  # we are using LayerNorm out of convenience here
            shape=(feature_size,)
        )
        self.norm_layer_1 = eqx.nn.LayerNorm(  # it's possible to use BatchNorm, but due to the stateful nature of BatchNorm, that would make this example slightly more involved
            shape=(feature_size,)
        )
        self.linear_layer_0 = eqx.nn.Linear(
            feature_size, feature_size,
            key=next(key_gen)
        )
        self.linear_layer_1 = eqx.nn.Linear(
            feature_size, feature_size,
            key=next(key_gen)
        )
    
    def __call__(self, x:jax.Array):
        """
        Forward pass
        :parameter x: input data, jax.Array of shape (feature_size,)
        :return: output of the residual block, jax.Array of shape (feature_size,)
        """
        h = self.norm_layer_0(x)
        h = jax.nn.relu(h)
        h = self.linear_layer_0(h)
        h = self.norm_layer_1(h)
        h = jax.nn.relu(h)
        h = self.linear_layer_1(h)
        return x + h
    

class FCResnet(eqx.Module):
    input_layer: eqx.nn.Linear
    res_blocks: list[FCResBlock]
    output_layer: eqx.nn.Linear

    def __init__(self, in_features:int, out_features:int, hidden_features:int, num_resblocks:int, key:jax.Array):
        """
        :parameter in_features: number of input features
        :parameter out_features: number of output features
        :parameter hidden_features: number of features in the hidden layers
        :parameter num_resblocks: number of residual blocks
        :parameter key: jax.Array, random key for initialization
        """
        key_gen = cju.key_generator(key)
        self.input_layer = eqx.nn.Linear(in_features, hidden_features, key=next(key_gen))
        self.res_blocks = [
            FCResBlock(hidden_features, key=next(key_gen))
            for _ in range(num_resblocks)
        ]
        self.output_layer = eqx.nn.Linear(hidden_features, out_features, key=next(key_gen))

    def __call__(self, x):
        """
        Forward pass
        :parameter x: input data, jax.Array of shape (in_features,)
        :return: output of the MLP, jax.Array of shape (out_features,)
        """
        h = self.input_layer(x)
        for block in self.res_blocks:
            h = block(h)
        return self.output_layer(h)
    
my_resnet = FCResnet(3, 1, 32, 2, key=next(key_gen))
print(my_resnet)
print(f"{summary(jax.vmap(my_resnet)(x))=}")

FCResnet(
  input_layer=Linear(
    weight=f32[32,3],
    bias=f32[32],
    in_features=3,
    out_features=32,
    use_bias=True
  ),
  res_blocks=[
    FCResBlock(
      norm_layer_0=LayerNorm(
        shape=(32,),
        eps=1e-05,
        use_weight=True,
        use_bias=True,
        weight=f32[32],
        bias=f32[32]
      ),
      norm_layer_1=LayerNorm(
        shape=(32,),
        eps=1e-05,
        use_weight=True,
        use_bias=True,
        weight=f32[32],
        bias=f32[32]
      ),
      linear_layer_0=Linear(
        weight=f32[32,32],
        bias=f32[32],
        in_features=32,
        out_features=32,
        use_bias=True
      ),
      linear_layer_1=Linear(
        weight=f32[32,32],
        bias=f32[32],
        in_features=32,
        out_features=32,
        use_bias=True
      )
    ),
    FCResBlock(
      norm_layer_0=LayerNorm(
        shape=(32,),
        eps=1e-05,
        use_weight=True,
        use_bias=True,
        weight=f32[32],
        bi

### Exercise 8
Write a ResNet class that uses [convolutional blocks](https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.Conv) and [batch normalization](https://docs.kidger.site/equinox/api/nn/normalisation/#equinox.nn.BatchNorm). Test it on some random data.

In [29]:
# Your code goes here
class ResBlock(eqx.Module):
    norm_layer_0: eqx.nn.BatchNorm
    norm_layer_1: eqx.nn.BatchNorm
    conv_layer_0: eqx.Module
    conv_layer_1: eqx.Module

    def __init__(self, feature_size:int, kernel_size:int, num_spatial_dims:int, key:jax.Array):
        """
        :parameter feature_size: number of features
        :parameter kernal_size: size of the convolutional kernel
        :num_spatial_dims: how many dimensions to convolve over. E.g. 1 leads to Conv1d, 2 to Conv2d etc.
        :parameter key: jax.Array, random key for initialization
        """
        key_gen = cju.key_generator(key)
        self.norm_layer_0 = eqx.nn.BatchNorm(
            input_size=feature_size,
            axis_name='batch'
        )
        self.norm_layer_1 = eqx.nn.BatchNorm(
            input_size=feature_size,
            axis_name='batch'
        )
        self.conv_layer_0 = eqx.nn.Conv(
            num_spatial_dims,
            feature_size, feature_size,
            kernel_size,
            padding='same',
            key=next(key_gen)
        )
        self.conv_layer_1 = eqx.nn.Conv(
            num_spatial_dims,
            feature_size, feature_size,
            kernel_size,
            padding='same',
            key=next(key_gen)
        )
    
    def __call__(self, x:jax.Array, state:eqx.nn.State)->tuple[jax.Array, eqx.nn.State]:
        """
        Forward pass
        :parameter x: input data, jax.Array of shape (feature_size,)
        :parameter state: eqx.nn.State containing the statistics for batch normalization
        :return: a 2-tuple, the first element of which is the output of the residual block, 
            and the second element is the updated state 
        """
        h, state = self.norm_layer_0(x, state)
        h = jax.nn.relu(h)
        h = self.conv_layer_0(h)
        h, state = self.norm_layer_1(h, state)
        h = jax.nn.relu(h)
        h = self.conv_layer_1(h)
        return x + h, state
    

class ResNet(eqx.Module):
    input_layer: eqx.nn.Linear
    res_blocks: list[FCResBlock]
    output_layer: eqx.nn.Linear
    pooling_dims: tuple[int]

    def __init__(self, in_features:int, out_features:int, hidden_features:int, num_resblocks:int, kernel_size:int, num_spatial_dims:int, key:jax.Array):
        """
        :parameter in_features: number of input features
        :parameter out_features: number of output features
        :parameter hidden_features: number of features in the hidden layers
        :parameter num_resblocks: number of residual blocks
        :parameter kernel_size: size of the convolutional kernels
        :parameter num_spatial_dims: number of spatial dimensions (e.g 1 for using Conv1d, 2 for Conv2d, etc.)
        :parameter key: jax.Array, random key for initialization
        """
        key_gen = cju.key_generator(key)
        self.input_layer = eqx.nn.Conv(num_spatial_dims, in_features, hidden_features, kernel_size, padding='same', key=next(key_gen))
        self.res_blocks = [
            ResBlock(hidden_features, kernel_size, num_spatial_dims, key=next(key_gen))
            for _ in range(num_resblocks)
        ]
        self.output_layer = eqx.nn.Linear(hidden_features, out_features, key=next(key_gen))

        # channels are on axis 0, the rest are spatial/temporal channels over which to pool before using the linear layer
        self.pooling_dims = tuple(range(1, num_spatial_dims+1))  

        # NB channels are on axis 0 because we vmap over the batch dimension, so we write our code for a single data point instead of for a batch
        

    def __call__(self, x:jax.Array, state:eqx.nn.State)->tuple[jax.Array, eqx.nn.State]:
        """
        Forward pass
        :parameter x: input data, jax.Array of shape (in_features,)
        :parameter state: eqx.nn.State containing the statistics for batch normalization
        :return: 2-tuple, the first element of which is the output of the residual network, i.e. an (out_features,) shaped jax.Array, 
            and the second element is the updated state.
        """
        h = self.input_layer(x)
        for block in self.res_blocks:
            h, state = block(h, state)
        h_averaged = h.mean(axis=self.pooling_dims)
        return self.output_layer(h_averaged), state
    

example_model, state = eqx.nn.make_with_state(ResNet)(
    in_features=3,
    out_features=9,
    hidden_features=128,
    num_resblocks=3,
    kernel_size=3,
    num_spatial_dims=1,
    key=next(key_gen)
)

In [30]:
example_data = random.normal(next(key_gen), (20, 3, 32))

print(summary(jax.vmap(example_model, in_axes=(0, None), out_axes=(0, None), axis_name='batch')(example_data, state)))  # Don't forget the `out_axes=(0, None)`!


('array(20, 9)', State(
  0x7b04d61000d0='array()',
  0x7b04d61000f0=('array(128,)', 'array(128,)'),
  0x7b04d6100110='array()',
  0x7b04d6100130=('array(128,)', 'array(128,)'),
  0x7b04d6100150='array()',
  0x7b04d6100170=('array(128,)', 'array(128,)'),
  0x7b04d6100190='array()',
  0x7b04d61001b0=('array(128,)', 'array(128,)'),
  0x7b04d61001d0='array()',
  0x7b04d61001f0=('array(128,)', 'array(128,)'),
  0x7b04d6100210='array()',
  0x7b04d6100230=('array(128,)', 'array(128,)')
))


## 2. Filtered transformations
When the only leafs of our `eqx.Module`s are `jax.Array`s, we can use JAX's built in function transformations like `vmap`, `grad`, and `jit` without any problems, because, as we saw before, those play nice with pytrees. You can see this in the following code block:

In [31]:
def loss(network, x, y):
    return jnp.mean(jnp.square(jax.vmap(network)(x)-y))

jitted_value_and_grad_loss = jax.jit(jax.value_and_grad(loss))
jitted_value_and_grad_loss(example_mlp, x, random.normal(next(key_gen), (100, 1)))  # note that the grad has the same structure as the network, so it's a SimpleMLP itself!
# (albeit not one that we would want to apply to data)

(Array(0.93261236, dtype=float32),
 SimpleMLP(
   weights=[f32[16,3], f32[17,16], f32[16,17], f32[1,16]],
   biases=[f32[16], f32[17], f32[16], f32[1]]
 ))

However, for many models, we might want to store more information in the model than just arrays containing weights and biases. E.G. maybe we want to have some freedom in which activation function is used, and we want to store it as an attribute. This can cause some problems if we just use `jax.grad` or `jax.jit`:

In [32]:
from typing import Callable

class NNLayer(eqx.Module):
    weights: jax.Array
    biases: jax.Array
    activation_function: Callable

    def __init__(self, in_features:int, out_features:int, key:jax.Array, activation_function:Callable=jax.nn.relu):
        """
        :parameter in_features: number of input features
        :parameter out_features: number of output features
        :parameter key: jax.Array, random key for initialization
        :parameter activation_function: activation function to be used
        """
        lim = 1/math.sqrt(in_features)
        self.weights = random.uniform(
            key, 
            (out_features, in_features), 
            minval=-lim,
            maxval=lim
            )
        self.biases = random.uniform(
            key,
            (out_features,),
            minval=-lim,
            maxval=lim
        )
        self.activation_function = activation_function
    
    def __call__(self, x:jax.Array, *, key:Optional[jax.Array]=None):
        """
        Forward pass
        :parameter x: input data, jax.Array of shape (in_features,)
        :parameter key: ignored, provided for compatibility with the Equinox API (so we can use eqx.nn.Sequential later on)
        :return: output of the layer, jax.Array of shape (out_features,)
        """
        return self.activation_function(self.weights@x+self.biases)

my_layer = NNLayer(3, 1, key=next(key_gen))
# uncomment the following line to try it out and get the TypeError:
# jitted_value_and_grad_loss(my_layer, x, random.normal(next(key_gen), (100, 1)))  # TypeError: Cannot interpret value of type <class 'jax._src.custom_derivatives.custom_jvp'> as an abstract array; it does not have a dtype attribute

One way to deal with this, would be to manually indicate whicharguments should be treated statically. Doing so would however be rather inconvenient. Fortunately, equinox provides tools to deal with this instead: [filtered versions of many jax transformations](https://docs.kidger.site/equinox/api/transformations/). For example:

In [33]:
filtered_jvag_loss = eqx.filter_jit(eqx.filter_value_and_grad(loss))
filtered_jvag_loss(my_layer, x, random.normal(next(key_gen), (100, 1)))  
# question: my_layer works on 3-vectors, but x has shape (100, 3), why shouldn't we put jax.vmap(my_layer) instead of my_layer in filtered_jvag_loss?

(Array(0.9786879, dtype=float32),
 NNLayer(weights=f32[1,3], biases=f32[1], activation_function=None))

# [Optax](https://optax.readthedocs.io/en/latest/index.html) tutorial
Earlier in this notebook, we trained a neural network (our `simple_mlp`) by applyin a gradient update sort of manually using `jax.tree.map`:
```python
new_weights, new_biases = jax.tree.map(
    lambda w, g: w-.01*g,  # gradient descent with learning rate 0.01
    (weights, biases),
    grads
)
```
In general however, we might want to modify the gradients slightly, e.g. by clipping them, or we might want to use optimization techniques such as momentum. 

In Pytorch, we have `Optimizer` objects that track state like momentum, and apply the correct updates to the model. In JAX however, we want operations to have no side-effects so that we can JIT them. That means we'll have to keep track of any state (such as momentum) explicitly, and we will have to combine our model with any updates into a new model since the model can't be changed (as that would be a side-effect).

Nonetheless, we don't have to do all of the work of transforming the gradients in combination with a state into updates to our network and a new state manually every time. Optax is an optimization library for JAX that implements many commonly used optimizers, such as the Adam optimizer.

The first core concept of Optax is that of the `GradientTransformation`. A `GradientTransformation` object contains an `update` method that takes gradients, a state, and (optionally) the current parameters that are to be updated, and transforms them into updates for the model and a new state. One example of such a transformation is gradient clipping ([`optax.clip`](https://optax.readthedocs.io/en/latest/api/transformations.html#optax.clip)): it takes a pytree of gradients and a state (and optionally the parameters), and returns a clipped version of the pytree of gradients and the same state.

Another example is ['optax.adam'](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adam): the resulting `GradientTransformation` takes a pytree of gradients, an *appropriate* state (a.o. containing a $1^{\text{st}}$ moment pytree and a $2^{\text{nd}}$ moment pytree), and optionally a pytree of parameters, and computes new moments for the state and returns updates based on those new moments and the new state. 

These gradient transformations are composable, so that we can combine the transformation from `optax.clip` with that from `optax.adam` into a `GradientTransformation` that first clips the gradients and then applies the updaterules from Adam.

Moreover, they don't only implement their update rule (in their `update` method), but they also implement an `init` method that takes a pytree of model parameters (e.g. an `eqx.Module`) and returns an appropriate optimizer state (an instance of `OptState`, which makes it a pytree as well).

**In summary**, to use optimizers from Optax, what we need to do is
1. define a model, say `model` (e.g. an `eqx.Module`)
1. create a `GradientTransform`, say `optimizer` (e.g. by calling `optax.adam(learning_rate=.01)`)
1. create an appropriate `OptState` by setting `opt_state = optimizer.init(model)`

and then during our training loop, at each step we
1. compute the gradients of the loss (and the loss itself) using `jax.value_and_grad` or `eqx.filter_value_and_grad`
1. we combine the gradients and the existing optimizer state into updates and a new optimizer state (think `updates, opt_state = optimizer.update(grads, opt_state, model)`)
1. we combine our existing model and the updates into a new model

This last step can be done either by using `optax.apply_updates` or [`eqx.apply_updates`](https://docs.kidger.site/equinox/api/manipulation/#equinox.apply_updates) (or by using `jax.tree.map` to apply the updates manually). When working with `eqx.Module`s, it's typically best to use `eqx.apply_updates` for basically the same reason as why we want to use `eqx.filter_grad` and `eqx.filter_jit`.

In [34]:
# let's train example_mlp from earlier 
import optax 


optimizer = optax.adam(.01)
initial_state = optimizer.init(eqx.filter(example_mlp, eqx.is_array))
# the filter isn't really necessary for example_mlp, but it is for any model that contains
# non-array attributes, such as the NNLayer class above

def loss_func(model, xs, ys):
    predictions = jax.vmap(model)(xs)
    return jnp.mean(jnp.square(predictions-ys))

#@eqx.filter_jit  # see text after this cell
def train_step(model_and_state, data):
    model, opt_state = model_and_state
    xs, ys = data
    loss, grads = eqx.filter_value_and_grad(loss_func)(model, xs, ys)
    updates, new_opt_state = optimizer.update(grads, opt_state, model)
    new_model = eqx.apply_updates(model, updates)
    return (new_model, new_opt_state), loss


# now we train the model
example_mlp_trained, losses = jax.lax.scan(train_step, init=(example_mlp, initial_state), xs=(training_data, training_labels))
print(losses)


[0.43557435 0.4525172  0.40632796 0.43212205 0.443473   0.3904306
 0.37801644 0.37055156 0.34726283 0.3081857  0.35607505 0.31400898
 0.33122024 0.2920717  0.29022804 0.32155767 0.3180481  0.29130617
 0.2987887  0.28051388]


In [35]:
training_labels.shape

(20, 100, 3)

A quick note about the commented-out `@eqx.filter_jit` above the definition of `train_step` in the above example: in this example we use `jax.lax.scan` for the training loop. As mentioned in the "Note" in the [documentation of `jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan), `scan()` compiles the function that is scanned, so it is unnecessary to manually jit `train_step` in this case.

However, in practice, we often want to do more than just train a model during our training loop: we might want to give live updates about the loss and about various metrics to ["weights and biases" (wandb)](https://wandb.ai), and we might want to run a validation loop after every `N` steps, and maybe store some intermediate results, etc. This is much easier when we just run a Python `for` loop, calling the `train_step` function at every iteration. 

You may wonder: do we lose anything by running a Python `for` loop instead of using `jax.lax.scan`? Afterall, Python loops are slow, right? In all honesty, I don't know for sure. You may want to experiment with this. If I had to give an educated guess however, I would say: as long as your model is large enough, you shouldn't miss anything. Or more precisely: as long as a single train step takes longer for the GPU to compute than that the python interpreter takes to interpret the (hopefully hand full of) lines in the training loop, [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html#async-dispatch) should mean that there is no real overhead of the python interpreter.

So then in the usual case, where your train loop is just a Python `for` loop, **what functions should you jit?** Basically, only the `train_step` (and maybe e.g. a `validation_step` if you have implemented it). *As a rule of thumb, you want to jit the largest possible function for the best result.* So **don't** jit your model and your loss function etc. separately, passing a jitted model to the jitted `value_and_grad` of a jitted loss function. But instead **only jit the function that combines all of them** (i.e. the `train_step`).

### Exercise 9
Define a model using `equinox.nn.Sequential` and using the either `NNLayer`s defined earlier in this notebook or your own adaptation thereof. Define a `train_step` function, and train your model using [`optax.amsgrad`](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.amsgrad) for 20 000 steps on the data provided in the next cell, and print the loss every 400 steps. After every 2 000 steps, compute the loss of your model on the validation data provided.

In [36]:
training_data = random.normal(next(key_gen), shape=(40000, 3))
training_labels = jnp.sin(training_data[:, 0]) - jnp.cos(training_data[:, 1]) + jnp.sin(training_data[:, 2] - 0.1*training_data[:, 0] + 0.1*training_data[:, 1])
training_labels = training_labels[:, None]

validation_data = random.normal(next(key_gen), shape=(2000, 3))
validation_labels = jnp.sin(validation_data[:, 0]) - jnp.cos(validation_data[:, 1]) + jnp.sin(validation_data[:, 2] - 0.1*validation_data[:, 0] + 0.1*validation_data[:, 1])
validation_labels = validation_labels[:, None]

# your code goes here

# first we define our model, optimizer, and initial state
my_model = eqx.nn.Sequential([
    NNLayer(3, 256, next(key_gen), jax.nn.celu),
    NNLayer(256, 256, next(key_gen), jax.nn.celu),
    NNLayer(256, 256, next(key_gen), jax.nn.celu),
    NNLayer(256, 1, next(key_gen))
])

optimizer = optax.amsgrad(2e-6)

optimizer_state = optimizer.init(eqx.filter(my_model, eqx.is_array))

# then we define our train step and validation step
@eqx.filter_jit
def train_step(model, xs, ys, opt_state):
    loss, grads = eqx.filter_value_and_grad(loss_func)(model, xs, ys)
    updates, new_opt_state = optimizer.update(grads, opt_state, model)
    new_model = eqx.apply_updates(model, updates)
    return new_model, new_opt_state, loss

@eqx.filter_jit
def validation_step(model, xs, ys):
    loss = loss_func(model, xs, ys)
    return loss

# we define an iterator that just keeps iterating over batches of training data (shuffeling them each time we've had the whole data set)
def my_iterator(data:jax.Array, labels:jax.Array, batch_size:int, key:jax.Array):
    """ 
    :parameter data: data to iterate over
    :parameter labels: corresponding labels
    :parameter batch_size: size of the batches that are to be created
    :parameter key: jax prng key

    :yields: tuples of two `jax.Array`s, the first of which is a batch of data, and the second of which is the corresponding batch of labels
    """
    key_gen = cju.key_generator(key)
    num_batches = data.shape[0]//batch_size
    num_data_points_per_epoch = num_batches*batch_size

    while True:
        index_permutation = random.permutation(next(key_gen), data.shape[0])  # we to re-shuffle the data every time we loop over it
        data_batches = data[index_permutation[:num_data_points_per_epoch]].reshape((num_batches, batch_size, )+data.shape[1:])
        label_batches = labels[index_permutation[:num_data_points_per_epoch]].reshape((num_batches, batch_size, )+labels.shape[1:])
        yield from zip(data_batches, label_batches)

# define the batch size and number of training steps
batch_size = 100
num_steps = 20_000
num_val_batches = validation_data.shape[0]//batch_size

# and now for the training loop
train_losses = []
for step_index, (train_batch, train_labels) in zip(
    range(1, num_steps+1), # start counting at 1
    my_iterator(training_data, training_labels, batch_size, next(key_gen))  # it's okay that this is infinite, because the range isn't
    ):
    my_model, optimizer_state, loss = train_step(my_model, train_batch, train_labels, optimizer_state)
    train_losses.append(loss)

    if step_index % 400 == 0:  # here it helps that we started counting at 1
        print(f"Step {step_index}: Average loss over past 400 steps is {np.mean(train_losses):.6f}, most recent loss was {loss:.6f}")
        train_losses = []
    
    if step_index % 2_000 == 0:
        print("\nStarting validation loop")
        validation_losses = []
        for index in range(num_val_batches):
            val_batch = validation_data[index*batch_size : (index+1)*batch_size]
            val_labels = validation_labels[index*batch_size : (index+1)*batch_size]
            loss = validation_step(my_model, val_batch, val_labels)
            validation_losses.append(loss)
        print(f"Validation loss: {np.mean(validation_losses):.6f}\n")


Step 400: Average loss over past 400 steps is 1.308192, most recent loss was 1.143559
Step 800: Average loss over past 400 steps is 1.288161, most recent loss was 1.004046
Step 1200: Average loss over past 400 steps is 1.277691, most recent loss was 1.197715
Step 1600: Average loss over past 400 steps is 1.270541, most recent loss was 1.186588
Step 2000: Average loss over past 400 steps is 1.264957, most recent loss was 1.117419

Starting validation loop
Validation loss: 1.276634

Step 2400: Average loss over past 400 steps is 1.260207, most recent loss was 0.988892
Step 2800: Average loss over past 400 steps is 1.255895, most recent loss was 1.213949
Step 3200: Average loss over past 400 steps is 1.251733, most recent loss was 1.019403
Step 3600: Average loss over past 400 steps is 1.247568, most recent loss was 0.966704
Step 4000: Average loss over past 400 steps is 1.243258, most recent loss was 1.208685

Starting validation loop
Validation loss: 1.256798

Step 4400: Average loss ov

# List of additional resources
Here is a (far from comprehensive) list of additional material to look at if you get stuck with anything JAX related:
* https://docs.kidger.site/equinox/faq/  answers to questions you'll likely at some point have (e.g. "How to mark arrays as non-trainable?")
* https://docs.kidger.site/equinox/api/serialisation/  for storing your Equinox models
* https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html for more automatic differentiation than just `jax.grad`
* https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html for information on problems you're likely to run into
* https://jax.readthedocs.io/en/latest/faq.html for more answers to questions you might have at some point.