# **Intro to JAX for ML**

<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" width="60%" />


<a href="https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/main/prac1/prac_0_intro_to_jax_using_ml.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

© Deep Learning Indaba 2022. Apache License 2.0.

**Authors:** Kale-ab Tessera

**Introduction:** 

JAX is a python package for writing composable numerical transformations [[1]](https://jax.readthedocs.io/en/latest/index.html). It leverages [Autograd](https://github.com/hips/autograd) and [XLA](https://www.tensorflow.org/xla) (Accelerated Linear Algebra), to achieve high-performance numerical computing (particularly relevant in machine learning). It provides functionality such as automatic differentiation (`Grad`), parallelization (`pmap`), vectorization (`vmap`), just-in-time compilation (`JIT`), and more.  

JAX is **not** a replacement for Pytorch or Tensorflow, but a lower-level library commonly used with higher-level neural network libraries such as [Haiku](https://github.com/deepmind/dm-haiku) or [Flax](https://github.com/google/flax).  

**Topics:** 

Content: <font color='orange'>`Numerical Computing`</font>  
Level: <font color='grey'>`Beginner`</font>


**Aims/Learning Objectives:**

- Learn the basics of JAX and its similiarities and differences with numpy.
- Learn how to use JAX transforms - `jit`, `grad`, `vmap` and `pmap`.  
- Learn how to build simple classifiers using JAX. 

**Prerequisites:**

- Basic knowledge of [NumPy](https://github.com/numpy/numpy).
- Basic knowledge of [functional programming](https://en.wikipedia.org/wiki/Functional_programming). 

**Outline:** 

>[Part 1 - Basics of JAX](#scrollTo=Enx0WUr8tIPf)

>>[1.1 From NumPy ➡ Jax - Beginner](#scrollTo=-ZUp8i37dFbU)

>>>[JAX and NumPy - Similarities  🤝](#scrollTo=CbOEYsWQ6tHv)

>>>[JAX and NumPy - Differences ❌](#scrollTo=lg4__l4A7yqc)


>>[1.2 Acceleration in JAX 🚀](#scrollTo=TSj972IWxTo2)

>>>[JAX is backend Agnostic - Beginner](#scrollTo=_bQ9QqT-yKbs)

>>>[JAX Transformations](#scrollTo=JM_08mXEBRIK)

>>>>[Basic JAX Transformations - JIT and GRAD - Beginner](#scrollTo=cOGuGWtLmP7n)

>>>>[More Advanced Transforms - VMAP and PMAP - Intermediate, Advanced](#scrollTo=tvBzh8wiGuLf)

>>[Section Quiz](#scrollTo=WILOYJH4gCnD)

>[Part 2 - From Linear to Non-Linear Regression - WIP](#scrollTo=aB0503xgmSFh)

>>[Linear Regression](#scrollTo=XrWSN-zaWAhJ)

>[Conclusion](#scrollTo=fV3YG7QOZD-B)

>>[Feedback](#scrollTo=o1ndpYE50BpG)


**Before you start:**

For this practical, you will need to use a GPU to speed up training. To do this, go to the "Runtime" menu in Colab, select "Change runtime type" and then in the popup menu, choose "GPU" in the "Hardware accelerator" box.



## Installation and Imports

In [None]:
## Install and import anything required. Capture hides the output from the cell. 
#@title Install and import required packages. (Run Cell)

import os 

# https://stackoverflow.com/questions/68340858/in-google-colab-is-there-a-programing-way-to-check-which-runtime-like-gpu-or-tpu
if int(os.environ["COLAB_GPU"]) > 0:
  print("a GPU is connected.")
elif "COLAB_TPU_ADDR" in os.environ and os.environ["COLAB_TPU_ADDR"]:
  print("A TPU is connected.")
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()
else:
  print("Only CPU accelerator is connected.")
  # x8 cpu devices - number of (emulated) host devices
  os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap

import matplotlib.pyplot as plt
import numpy as np

In [None]:
#@title Helper Functions. (Run Cell)
import copy
from typing import Dict 
def plot_performance(data:Dict, title: str):
  runs = list(data.keys())
  time = list(data.values())
  
  # creating the bar plot
  plt.bar(runs, time, width = 0.35)
  
  plt.xlabel("Implementation")
  plt.ylabel("Average time taken (in s)")
  plt.title(title)
  plt.show()

  best_perf_key = min(data, key=data.get)
  all_runs_key = copy.copy(runs)

  # all_runs_key_except_best
  all_runs_key.remove(best_perf_key)

  for k in all_runs_key:
    print(f"{best_perf_key} was {round((data[k]/data[best_perf_key]),2)} times faster than {k} !!!")

In [None]:
#@title Check the device you are using (Run Cell)
print(f"Num devices: {jax.device_count()}")
print(f" Devices: {jax.devices()}")

# Part 1 - Basics of JAX

## 1.1 From NumPy ➡ Jax - <font color='blue'>`Beginner`</font>

 

### JAX and NumPy - Similarities  🤝

The main similiarity between JAX and NumPy is that they share a similiar interface and often times, JAX and NumPy arrays can be used interchanbly. 

Let's plot the sine functions using numpy.

In [None]:
# 100 linearly spaced numbers from -np.pi to np.pi
x = np.linspace(-np.pi,np.pi,100)

# the function, which is y = sin(x) here
y = np.sin(x)

# plot the functions
plt.plot(x,y, 'b', label='y=sin(x)')

plt.legend(loc='upper left')

# show the plot
plt.show()

Now using jax.numpy - `jnp` 

(We already imported this in the first cell as follows- `import jax.numpy as jnp`)

In [None]:
# 100 linearly spaced numbers from -np.pi to np.pi
x = jnp.linspace(-np.pi,np.pi,100)

# the function, which is y = sin(x) here
y = jnp.sin(x)

# plot the functions
plt.plot(x,y, 'b', label='y=sin(x)')

plt.legend(loc='upper left')

# show the plot
plt.show()

**Code Task:** Can you plot the cosine function using `jnp`?

In [None]:
#Plot Cosine using jnp. (UPDATE ME)

# 100 linearly spaced numbers
# UPDATE ME
x = ...

# UPDATE ME
y = ...  

if (y == ... or x == ...):
  raise Exception("Update ME!")

# plot the functions
plt.plot(x,y, 'b', label='y=cos(x)')

plt.legend(loc='upper left')

# show the plot
plt.show()

In [None]:
# @title Answer to code task (Try not to peek until you've given it a good try!') 
# 100 linearly spaced numbers
x = jnp.linspace(-np.pi,np.pi,100)

y = jnp.cos(x)  

# plot the functions
plt.plot(x,y, 'b', label='y=cos(x)')

plt.legend(loc='upper left')

# show the plot
plt.show()

### JAX and NumPy - Differences ❌

Although JAX and NumPy have some similiarities, they do have some important differences:
- Jax arrays are **immutable** (they can't be modified after they are created).
- The way they handle **randomness**.

#### JAX arrays are immutable, while NumPy arrays are not.

JAX and NumPy arrays are often interchangeable, **but** Jax arrays are **immutable** (they can't be modified after they are created). 

Let's see this in practice by adding a number to the beginning of an array. 

In [None]:
# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)

In [None]:
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10

So it fails! We can't mutate a JAX array once it has been created. To update JAX arrays, we need to use [helper functions](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) that return an updated copy of the JAX array. 

Instead of doing this `x[idx] = y`, we need to do this `x = x.at[idx].set(y)`. 

In [None]:
x = jnp.arange(10)
x = x.at[0].set(10)
print(x)

#### Randomness in NumPy vs JAX 

JAX is more explicit in Pseudo Random Number Generation (PRNG) than NumPy and other libraries (such as tensorflow or pytorch). [PRNG](https://en.wikipedia.org/wiki/Pseudorandom_number_generator) is the process of algorithmically generating a sequence of numbers, which *approximate* the properties of a sequence of random numbers.  

Let's see the differences in how JAX and NumPy generate random numberss.

##### In Numpy, PRNG is based on a global `state`.

Let's set the initial seed.

In [None]:
# Set random seed
np.random.seed(42)
prng_state = np.random.get_state()

In [None]:
#@title Helper function to compare prng keys (Run Cell)
def is_prng_state_the_same(prng_1,prng_2):
  """Helper function to compare two prng keys."""
  # concat all elements in prng tuple 
  list_prng_data_equal = [(a == b) for a, b in zip(prng_1,prng_2)]
  # stack all elements together
  list_prng_data_equal = np.hstack(list_prng_data_equal)
  # check if all elements are the same
  is_prng_equal=all(list_prng_data_equal)
  return is_prng_equal

Let's take a few samples from a Gaussian (normal) Distribution and check if prng keys change.

In [None]:
print(f"sample 1 = {np.random.normal()} Did prng state change: {not is_prng_state_the_same(prng_state,np.random.get_state())}")
prng_state = np.random.get_state()
print(f"sample 2 = {np.random.normal()} Did prng state change: {not is_prng_state_the_same(prng_state,np.random.get_state())}")
prng_state = np.random.get_state()
print(f"sample 3 = {np.random.normal()} Did prng state change: {not is_prng_state_the_same(prng_state,np.random.get_state())}")

Numpy's global random state is updated everytime a random number is generated, so sample 1 != sample 2 != sample 3. 

Having the state automatically updated, makes it difficult to handle randomness in a **reproducible** way across threads, processes and devices. 

##### In JAX, PRNG is more explicit.

In JAX, for each random number generation, you need to explicitly pass in a random key/state.

Passing the same state/key results in the same number being generated. This is generally undesirable.

In [None]:
from jax import random
key = random.PRNGKey(42)
print(f"sample 1 = {random.normal(key)}")
print(f"sample 2 = {random.normal(key)}")
print(f"sample 3 = {random.normal(key)}")

To generate different and indepedent samples, you need to manually **split** the keys. 

In [None]:
from jax import random
key = random.PRNGKey(42)
print(f"sample 1 = {random.normal(key)}")

# We split the key -> key and subkey
key, subkey = random.split(key)

# We use the subkey immediately and keep the key for future splits. It doesn't really matter which key we keep and which one we use immediately. 
print(f"sample 2 = {random.normal(subkey)}")

# We split the new key -> key and subkey
key, subkey = random.split(key)
print(f"sample 3 = {random.normal(subkey)}")

By using JAX, we can more easily reproduce random number generation in parallel across threads, processes or even devices by explicitly passing and keeping track of the prng key (without relying on a global state that automatically gets updated). For more details on PRNG in JAX, you can read more [here](https://github.com/google/jax/blob/main/docs/design_notes/prng.md). 

## 1.2 Acceleration in JAX 🚀 

JAX leverages Autograd and XLA for accelerating numerical computation. The use of Autograd allows for automatic differentiation (`grad`), while XLA allows JAX to run on multiple accelerators/backends and run transforms like `jit`, `vmap` and `pmap`.  

### JAX is backend Agnostic - <font color='blue'>`Beginner`</font>

Using JAX, you can run the same code on different backends/AI accelerators (e.g. CPU/GPU/TPU), **with no changes in code** (no more `.to(device)` - from frameworks like PyTorch). This means we can easily run linear algebra operations directly on gpu/tpu.

**Multiplying Matrices**

Dot products are a common operation in numerical computing and a central part of modern deep learning. They are defined over [vectors](https://en.wikipedia.org/wiki/Coordinate_vector), which can loosely be thought of as a list of multiple scalers (single values). 

Formally, given two vectors $\boldsymbol{x}$,$\boldsymbol{y}$ $\in R^n$, their dot product is defined as:

<center>$\boldsymbol{x}^{\top} \boldsymbol{y}=\sum_{i=1}^{n} x_{i} y_{i}$</center>

Dot Product in NumPy (will run on cpu)

In [None]:
size=1000
x = np.random.normal(size=(size, size))
y = np.random.normal(size=(size, size))
numpy_time = %timeit -o -n 10 a_np = np.dot(y,x.T)

Dot Product using JAX (will run on current runtime - e.g. GPU).

In [None]:
size=1000
key = jax.random.PRNGKey(42)
x = jax.random.normal(key,shape=(size, size))
y = jax.random.normal(key,shape=(size, size))
jax_time = %timeit -o -n 10 jnp.dot(y, x.T).block_until_ready()

When timing JAX functions, we use `.block_until_ready()` because JAX uses [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html#async-dispatch). This means JAX doesn't wait for the operation to complete before returning control to your code. To fairly compute the time taken for JAX operations, we therefore block until the operation is done.

How much faster was the dot product in JAX (Using GPU)?

In [None]:
np_average_time=np.mean(numpy_time.all_runs)
jax_average_time=np.mean(jax_time.all_runs)
data = {'numpy':np_average_time, 'jax':jax_average_time}

plot_performance(data,title="Average time taken per framework to run dot product")

If you are running on accelerator, you should see a considerable performance benefit of using JAX, without making any changes to your code! 

### JAX Transformations 

JAX transforms (e.g. jit, grad, vmap, pmap) first convert python functions into an intermediate language called jaxpr. Transforms are then applied to this jaxpr representation.

JAX generates jaxpr, in a process known as **tracing**. During tracing, function inputs are wrapped by a tracer object and then JAX records all operations (including regular python code) that occur during the function call. These recorded operations are used to reconstruct the function. Any python side-effects are not recording during tracing. For more on tracing and jaxpr, you can read [here](https://jax.readthedocs.io/en/latest/jaxpr.html).



#### Basic JAX Transformations - `JIT` and `GRAD` - <font color='blue'>`Beginner`</font>

In this section, we will explore two basic JAX transforms: 
- JIT (Just-in-time compilation) - compiles and caches JAX Python functions so that they can be run efficiently on XLA - `speed up functions`.
- Grad - Automatically compute gradients - `automatic differentiation`.

##### JIT

Jax dispatches operations to accelerators one at a time. If we have repeated operations, we can use `jit` to compile the function the first time it is called, then subsequent calls will be cached. 

Let's compile [ReLU (Rectified Linear Unit)](https://arxiv.org/abs/1803.08375), a popular activation function in deep learning. 

ReLU is defined as follows:
<center>$f(x)=max(0,x)$</center>

It can be visualized as follows:

<center>
<img src="https://machinelearningmastery.com/wp-content/uploads/2018/10/Line-Plot-of-Rectified-Linear-Activation-for-Negative-and-Positive-Inputs.png" width="35%" />
</center>,

where $x$ is the input to the function and $y$ is output of ReLU.


$$f(x)=\max (0, x)=\left\{\begin{array}{l}x_{i} \text { if } x_{i}>0 \\ 0 \text { if } x_{i}<=0\end{array}\right.$$

**Code Task:** Complete the ReLU implementation below.

In [None]:
#Implement ReLU.
def relu(x):
  if x > 0:
    return
    # TODO Implement me! 
  else:
    return
    # TODO Implement me! 

In [None]:
#@title Run to test your ReLU function.

def plot_relu(relu_function):
  max_int = 5
  # Generete 100 evenly spaced points from -max_int to max_int 
  x = np.linspace(-max_int,max_int,1000)
  y = np.array([relu_function(xi) for xi in x])
  plt.plot(x, y,label='ReLU')
  plt.legend(loc="upper left")
  plt.xticks(np.arange(min(x), max(x)+1, 1))
  plt.show()

def check_relu_function(relu_function):
  # Generete 100 evenly spaced points from -100 to -1
  x = np.linspace(-100,-1,100)
  y = np.array([relu_function(xi) for xi in x])
  assert (y == 0).all()

  # Check if x == 0
  x = 0
  y = relu_function(x)
  assert y == 0

  # Generete 100 evenly spaced points from 0 to 100
  x = np.linspace(0,100,100)
  y = np.array([relu_function(xi) for xi in x])
  assert np.allclose(x, y)

  print("Your ReLU function is correct!")

check_relu_function(relu)
plot_relu(relu)

In [None]:
# @title Answer to code task (Try not to peek until you've given it a good try!') 
def relu(x):
  if x > 0:
    return x
  else:
    return 0

check_relu_function(relu)
plot_relu(relu)

Let's try to `jit` this function to speed up compilation and try to call it.

In [None]:
relu_jit = jax.jit(relu)

# Gen 1000000 random numbers and pass them to relu
num_random_numbers=1000000
x = jax.random.normal(key, (num_random_numbers,))

try:# Should raise an error. 
  relu_jit(x) 
except Exception as e:
  print("Exception {}".format(e))

**Why does this fail?**

As mentioned above, JAX transforms first convert python functions into an intermediate language called jaxpr. Jaxpr only captures what is executed on the parameters given to it during tracing, so this means during conditional calls, jaxpr only considers the branch taken. 

When jit-compiling a function, we want to compile and cache a version of the function that can handle multiple different arguement types (so we don't have recompile for each function evaluation). For example, when we compile a function on an array `jnp.array([1., 2., 3.], jnp.float32)`, we would likely also want to used the compiled function for `jnp.array([4., 5., 6.], jnp.float32)`. 

To achieve this, JAX traces your code based on abstract values. The default abstraction level is a ShapedArray - array that has a fixed size and dtype, for example, if we trace a function using `ShapedArray((3,), jnp.float32)`,  it can be reused for any concrete array of size 3, and float32 dtype. 

This does come with some challenges. Tracing that relies on concrete values become tricky and sometimes results in `ConcretizationTypeError` as in the relu function above. Furtermore, when tracing function with conditional statements ("if ..."), JAX doesn't know which branch to take when tracing and so tracing can't occur.



To solve this, we have two options:
- Use static arguements to make sure JAX traces on a concrete value level - not ideal, if you need to retrace a lot. Example - bottom of this [section](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit).
- Use builtin JAX condition flow primitives such as [`lax.cond`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.cond.html) or [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html).  

**Code Task** : Let's convert our ReLU function above to use [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) (you can also use `jnp.maximum`, if you prefer.) 

In [None]:
# Implement ReLU using jnp.where.
def relu(x):
  # TODO Implement ME! 
  return

In [None]:
# Check ReLU function
check_relu_function(relu)

In [None]:
# @title Answer to code task (Try not to peek until you've given it a good try!') 
def relu(x):
  # TODO Implement ME! 
  return jnp.maximum(x,0)

check_relu_function(relu)


Now let's see the performance benefit of using jit!

In [None]:
relu_jit = jax.jit(relu)
key = jax.random.PRNGKey(42)

num_random_numbers=1000000
x = jax.random.normal(key, (num_random_numbers,))


jax_time = %timeit -o -n 10 relu(x).block_until_ready()

# Warm up/Compile - first run 
relu_jit(x).block_until_ready()
jax_jit_time = %timeit -o -n 10 relu_jit(x).block_until_ready()

# Let's plot the performance difference
jax_avg_time=np.mean(jax_time.all_runs)
jax_jit_avg_time=np.mean(jax_jit_time.all_runs)
data = {'JAX (no jit)':jax_avg_time, 'JAX (with jit)':jax_jit_avg_time}

plot_performance(data,title="Average time taken for ReLU function")

##### Grad

`grad` is used to automatically compute the gradient of a function in JAX. It can be applied to Python and NumPy functions, which means you can differentiate through loops, branches, recursion and closures.  

`grad` takes in a function `f` and returns a function. If `f` is a mathematical function $f$, then `grad(f)` correspondes to $f'(x)$, with `grad(f(x))` corresponding to $\Delta{f(x)}$.


Let's take a simple function $f(x)=6x^4-9x+4$

In [None]:
f = lambda x: 6*x**4 - 9*x + 4

We can compute the gradient of this function - $\Delta{f(x)}$ and evaluate it at $x=3$.

In [None]:
dfdx = grad(f)
dfdx_3 = dfdx(3.0)

**Math Task**: Can you calculate $f'(2)$ by hand?

In [None]:
answer =  0#@param {type:"integer"}

dfdx_2 = dfdx(2.0)

assert answer == dfdx_2, "Incorrect answer, hint ..."

print("Nice, you got the correct answer!")

We can also chain `grad` to calculate higher order deratives. 

We can calculate $f'''(x)$ as follows:

In [None]:
d3dx = grad(grad(grad(f)))

**Math Task**: How about $f'''(2)$ by hand?

In [None]:
answer =  0#@param {type:"integer"}

d3dx_2 = d3dx(2.0)

assert answer == d3dx_2, "Incorrect answer, hint ..."

print("Nice, you got the correct answer!")

Another useful method is `value_and_grad`, where we can get the value ($f(x)$) and gradient ($f'(x)$). 

In [None]:
from jax import value_and_grad
f_x,dy_dx = value_and_grad(f)(2.0)
print(f"f(x): {f_x} f′(x): {dy_dx} ")

**Group Task:** Chat with neighbour/think about how JAX's automatic differentiation compares to other libraries such as Pytorch or Tensorflow. 

#### More Advanced Transforms - `VMAP` and `PMAP` - <font color='orange'>`Intermediate`</font>, <font color='green'>`Advanced`</font>

JAX also provides transforms that allow you automatically vectorize (`vmap`) and parallelize (`pmap`) your code. 

##### VMAP - <font color='orange'>`Intermediate`</font>

VMAP (Vectorizing map) automatically vectorizes your python functions. 

Let's define a simple function that calculate the min and max of an input.

In [None]:
def min_max(x):
  return jnp.array([jnp.min(x), jnp.max(x)])

We can apply this function to the vector - `[0, 1, 2, 3, 4]` and get the min and max values.

In [None]:
x = jnp.arange(5)
min_max(x)

What about if we want to apply this to a batch/list of vectors (i.e. calculate the min and max independently across multiple batches)? 

*We will see in part 2 why this relevant to ML - hint - batched gradient descent!*

Let's create our batch - 3 vectors of size 5.

In [None]:
batched_x = np.arange(15).reshape((3, 5))
print(batched_x)

In [None]:
min_max(batched_x)

**Question**: What do you think would be the result if we passed batch_x into `min_max`?

In [None]:
batch_min_max_output = [0,14] #@param ["[[0,4],[5,9],[10,14]]", "[[0,10],[1,11],[2,12],[3,13],[4,14]]", "[0,14]"] {type:"raw"}

assert (batch_min_max_output == np.array(min_max(batched_x))).all(), "Incorrect answer, hint ..."

print("Nice, you got the correct answer!")

So the above is not what we want. The `min` and `max` is applied across the entire batch, when we want the min and max per vector. 

We can also manually batch this by `jnp.stack` and a for loop, as follows:

In [None]:
@jit
def manual_batch_min_max_loop(batched_x):
  return jnp.stack([min_max(x) for x in batched_x])

print(manual_batch_min_max_loop(batched_x))

Or, just natively updating the `axis` in `jnp.min` and `jnp.max`. 

In [None]:
@jit
def manual_batch_min_max_axis(batched_x):
  return jnp.array([jnp.min(batched_x,axis=1), jnp.max(batched_x,axis=1)]).T

print(manual_batch_min_max_axis(batched_x))

These approaches both work, but we need to change our function to work with batches. We can't just run the same code across a batch of data.

There is where `vmap` becomes really useful! Using `vmap` we can write a function once, as if it is working on a single element, and then use `vmap` to automatically vectorize it! 

In [None]:
# define our vmap function using our original single vector function
@jit
def min_max_vmap(batched_x):
  return vmap(min_max)(batched_x)

# Run it on a single vecor
## We add extra dimention in a single vector, shape changes from (5,) to (1,5), which makes the vmapping possible
x_with_leading_dim = jax.numpy.expand_dims(x,axis=0)
print(f"Single vector: {min_max_vmap(x_with_leading_dim)}")

# Run it on batch of vectors
print(f"Batch/list of vector:{min_max_vmap(batched_x)}")

So this is really conveniet, but what about performance? 

In [None]:
batched_x = np.arange(50000).reshape((500, 100))

# Trace the functions with first call
manual_batch_min_max_loop(batched_x).block_until_ready()
manual_batch_min_max_axis(batched_x).block_until_ready()
min_max_vmap(batched_x).block_until_ready()

min_max_forloop_time = %timeit -o -n 10 manual_batch_min_max_loop(batched_x).block_until_ready()
min_max_axis_time = %timeit -o -n 10 manual_batch_min_max_axis(batched_x).block_until_ready()
min_max_vmap_time = %timeit -o -n 10 min_max_vmap(batched_x).block_until_ready()

So `vmap` should be similiar in performance to manually vectorized code (if everything is implemented well), and much better than naively vectorized code (i.e. for loops). 

##### PMAP - <font color='green'>`Advanced`</font>

**For this subsection, please ensure that colab is using a `TPU` runtime.**

Another JAX transform is `pmap`. `pmap` transforms a function written for one device, to a function that can run in parallel, across many devices. 

**Difference between `vmap` and `pmap`**:

So both `pmap` and `vmap` transform a function to work over an array, but they differ in implementation. `vmap` adds an extra batch dimension to all the operations in a function, while `pmap` replicates the function and executes each replica on their own XLA device in parallel.

Let's try and `pmap` a batch of dot products.

Here is an illustration of how we would typically do this sequentially: 

[Source](https://www.assemblyai.com/blog/why-you-should-or-shouldnt-be-using-jax-in-2022/)

In [None]:
#@title Illustration of Sequential Dot Product (Run me)
from IPython.display import HTML

HTML('<iframe width="560" height="315" src="https://www.assemblyai.com/blog/content/media/2022/02/not_parallel-2.mp4" frameborder="0" allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>')



Here is code implementation of this:

In [None]:
# Let's generate a batch of size 8, each with a matrix of size (5000, 6000)
keys = jax.random.split(jax.random.PRNGKey(0), 8)
mats = jnp.stack([jax.random.normal(key, (5000, 6000)) for key in keys])

def dot_product_sequential():
  @jit
  def avg_dot_prod(mats):
    result = []
    # Loop through batch and compute dp
    for mat in mats:
        result.append(jnp.dot(mat, mat.T))
    return jnp.stack(result)
    

  avg_dot_prod(mats).block_until_ready()

run_sequential = %timeit -o -n 5 dot_product_sequential()

Here is an illustration of how we would do this in parallel 

[Source](https://www.assemblyai.com/blog/why-you-should-or-shouldnt-be-using-jax-in-2022/)

In [None]:
#@title Illustration of Parallel Dot Product (Run me)
from IPython.display import HTML

HTML('<iframe width="560" height="315" src="https://www.assemblyai.com/blog/content/media/2022/02/parallelized.mp4" frameborder="0" allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>')



Here is code implementation of batched dot products:

First we will create `8` large random matrices (one for each available tpu devices ~ colab tpu's has 8 available [devices](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm).

In [None]:
keys = jax.random.split(jax.random.PRNGKey(0), 8)
mats = pmap(lambda key: jax.random.normal(key, (5000, 6000)))(keys)

The leading dimention here needs to equal the amount of available devices (since we are sending a batch to each device).

In [None]:
print(mats.shape)

Using `pmap` to generate the batches ensures these batches are of type `ShardedDeviceArray`. This is similiar to an ndarray, except each batch/shared is stored in the memory of multiple devices, so they can be used in subsequent `pmap` operations without moving data around between devices (gpu/tpu) and hosts (cpu). 

In [None]:
print(type(mats))

In [None]:
def dot_product_parallel():
  
  # Run a local matmul on each device in parallel (no data transfer)
  result = pmap(lambda x: jnp.dot(x, x.T))(mats).block_until_ready()  # result.shape is (8, 5000, 5000)

run_parallel = %timeit -o -n  5 dot_product_parallel()

It is simple as that! Our dot product now runs in parallel across available devices (cpu, gpus or tpus). As we have more cores/devices, this code will automatically scale. 

In [None]:
# Let's plot the performance difference
jax_parallel_time=np.mean(run_parallel.all_runs)
jax_seq_time=np.mean(run_sequential.all_runs)


data = {'JAX (seq)':jax_seq_time, 'JAX (parallel)':jax_parallel_time}

plot_performance(data,title="Average time taken for Seq vs Parallel Dot Product")

We showed an example of using `pmap` for *pure* parallelism, where there is no communication between devices. JAX also has various operations for communication across distributed devices ( more on this [here](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#communication-between-devices).)

## Section Quiz 

Optional end of section quiz. Below is an example of an assessment.

In [None]:
#@title Generate Quiz Form. (Run Cell)
from IPython.display import HTML
HTML(
"""
<iframe 
	src="https://forms.gle/zbJoTSz3nfYq1VrY6",
  width="80%" 
	height="1200px" >
	Loading...
</iframe>
"""
)

# Part 2 - From Linear to Non-Linear Regression - `WIP`

Now that we know some basics of JAX, we can build some simple models!

Parts of this section are adapted from [Deepmind's Regression Tutorial](https://github.com/deepmind/educational/blob/master/colabs/summer_schools/intro_to_regression.ipynb). 

## Linear Regression

In regression, we aim to find a function $f$ that maps inputs $x$ ($x \in R^D$) to corresponding outputs - $f(x) \in R$ [(source)](https://mml-book.github.io/). 

Put simply, we are trying to learn the relationship between our inputs and outputs.  

Let's build a simple dataset, with 5 elements. Each element has a single input and single output.  

In [None]:
x_data_list = [1, 2, 3, 4, 5]
y_data_list = [3, 2, 3, 1, 0]

We can plot this dataset

In [None]:
def plot_basic_data(parameters_list=None, title="Observed data"):
  xlim = [-1, 7]
  fig, ax = plt.subplots()
  
  if parameters_list is not None:
    x_pred = np.linspace(xlim[0], xlim[1], 100)
    for parameters in parameters_list:
      y_pred = parameters[0] + parameters[1] * x_pred
      ax.plot(x_pred, y_pred, ':', color=[1, 0.7, 0.6])

    parameters = parameters_list[-1]
    y_pred = parameters[0] + parameters[1] * x_pred
    ax.plot(x_pred, y_pred, "-", color=[1, 0, 0], lw=2)

  ax.plot(x_data_list, y_data_list, "ob")
  ax.set(xlabel="Input x", ylabel="Output y",
         title=title,
         xlim=xlim, ylim=[-2, 5])
  ax.grid()

plot_basic_data()

Let's say we would to like to predict these $y$ (outputs) values given the $x$ (inputs). 

We can start modelling this by using a simple linear function: 
<center> 
$f(x) = \color{red}{w} x + \color{red}{b}$
</center>,

where $x$ is our input and  $\color{red}{b}$ and $\color{red}{w}$ are our model parameters.

Usually, we learn the model parameters, but let's try to find these parameters by hand!



**Code Task:** 
1. Move the two sliders below to set $\color{red}{b}$ and $\color{red}{w}$, and press "Run cell" on the code cell below. 
2. Is your $f(x)$ close to the blue data points? Can you find a better fit?
3. Repeat 1-2 until convergence :D 

In [None]:
parameters_list = [] # Used to track which parameters were tried. 

In [None]:
b = 5.14 #@param {type:"slider", min:-1, max:8, step:0.01}
w = -1.18 #@param {type:"slider", min:-3, max:3, step:0.01}
print("Plotting line", w, "* x +", b)
parameters = [b, w]
parameters_list.append(parameters)
plot_basic_data(parameters_list,
                title="Observed data and my first predictions")

**Weights and Bias**

What was the input of the function when you changed $\color{red}{b}$ and $\color{red}{w}$?

- $\color{red}{w}$ is our weights. This represents the slope of our function and determines the influence of the features $x$.
- $\color{red}{b}$ is our bias (also called the *intercept*). This shifts the line, without changing the slope.

**You're a born optimizer!**

Let's plot the optimizationt trajectory you took.

In [None]:
fig, ax = plt.subplots()
opt = {"head_width": 0.2, "head_length": 0.2,
       "length_includes_head": True, "color": "r"}
if parameters_list is not None:
  b_old = parameters_list[0][0]
  w_old = parameters_list[0][1]
  for i in range(1, len(parameters_list)):
    b_next = parameters_list[i][0]
    w_next = parameters_list[i][1]
    ax.arrow(b_old, w_old, b_next - b_old, w_next - w_old, **opt)
    b_old, w_old = b_next, w_next

  ax.scatter(b_old, w_old, s=200, marker="o", color="y")
  bs = [parameters[0] for parameters in parameters_list]
  ws =  [parameters[1] for parameters in parameters_list]
  ax.scatter(bs, ws, s=40, marker='o', color='k')

ax.set(xlabel="Bias b", ylabel="Weight w",
       title="My sequence of b\'s and w\'s",
       xlim=[-1, 8], ylim=[-3, 3])
plt.show()

**Group Task**:

*How did your neighbour do?*
- Did they change $\color{red}{b}$ and $\color{red}{w}$ with big steps or small steps each time?
- Did they start with small steps, and then progressed to bigger steps? Or the other way round? What about you?
- Did the magnitude of your previous steps influence your next choice? Why? Or why not?
- Did you all converge to roughly the same endpoint for $\color{red}{b}$ and $\color{red}{w}$, or did your sequences end up in different places?

In [None]:
from matplotlib import cm

def l1_loss(b, w):
  loss = 0 * b
  for x, y in zip(x_data_list, y_data_list):
    f = w * x + b
    loss += np.abs(f - y)
  return loss / len(x_data_list)

bs, ws = np.linspace(-1, 8, num=25), np.linspace(-3, 3, num=25)
b_grid, w_grid = np.meshgrid(bs, ws)
loss_grid = l1_loss(b_grid, w_grid)

def plot_loss(parameters_list, title, show_stops=False):
  fig, ax = plt.subplots(1, 2, figsize=(18, 8),
                         subplot_kw={"projection": "3d"})
  ax[0].view_init(10, -30)
  ax[1].view_init(30, -30)

  if parameters_list is not None:
    b_old = parameters_list[0][0]
    w_old = parameters_list[0][1]
    loss_old = l1_loss(b_old, w_old)
    ls = [loss_old]

    for i in range(1, len(parameters_list)):
      b_next = parameters_list[i][0]
      w_next = parameters_list[i][1]
      loss_next = l1_loss(b_next, w_next)
      ls.append(loss_next)

      ax[0].plot([b_old, b_next], [w_old, w_next], [loss_old, loss_next],
                color="red", alpha=0.8, lw=2)
      ax[1].plot([b_old, b_next], [w_old, w_next], [loss_old, loss_next],
                color="red", alpha=0.8, lw=2)
      b_old, w_old, loss_old = b_next, w_next, loss_next

    if show_stops:
      ax[0].scatter(b_old, w_old, loss_old, s=100, marker="o", color="y")
      ax[1].scatter(b_old, w_old, loss_old, s=100, marker="o", color="y")
      bs = [parameters[0] for parameters in parameters_list]
      ws = [parameters[1] for parameters in parameters_list]
      ax[0].scatter(bs, ws, ls, s=40, marker="o", color="k")
      ax[1].scatter(bs, ws, ls, s=40, marker="o", color="k")
    else:
      ax[0].scatter(b_old, w_old, loss_old, s=40, marker='o', color='k')
      ax[1].scatter(b_old, w_old, loss_old, s=40, marker='o', color='k')

  ax[0].plot_surface(b_grid, w_grid, loss_grid, cmap=cm.coolwarm,
                     linewidth=0, alpha=0.4, antialiased=False)
  ax[1].plot_surface(b_grid, w_grid, loss_grid, cmap=cm.coolwarm,
                     linewidth=0, alpha=0.4, antialiased=False)
  ax[0].set(xlabel="Bias b", ylabel="Weight w", zlabel="Loss", title=title)
  ax[1].set(xlabel="Bias b", ylabel="Weight w", zlabel="Loss", title=title)
  plt.show()

plot_loss(parameters_list,
          "An example loss function and my sequence of b\'s and w\'s",
          show_stops=True)

# Conclusion
**Summary:**
- JAX combines Autograd and XLA to perform accelerated numerical computations. These computations are achieved using transforms such as `jit`,`grad`,`vmap` and `pmap`. 


**Next Steps:** 

[Next steps for people who have completed the prac, like optional reading (e.g. blogs, papers, courses, youtube videos). This could also link to other pracs.]

**Appendix:** 

[Anything (probably math heavy stuff) we don't have space for in the main practical sections.]

**References:** 
1. Various JAX [docs](https://jax.readthedocs.io/en/latest/) - specifically [quickstart](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html), [common gotchas](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html), [jitting](
https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html#), [random numbers](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html) and [pmap](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html?highlight=pmap#). 

2. http://matpalm.com/blog/ymxb_pod_slice/
3. https://roberttlange.github.io/posts/2020/03/blog-post-10/
4. [Machine Learning with JAX - From Zero to Hero | Tutorial #1](https://www.youtube.com/watch?v=SstuvS-tVc0). 


For other practicals from the Deep Learning Indaba, please visit [here](https://github.com/deep-learning-indaba/indaba-pracs-2022).

## Feedback

Please provide feedback that we can use to improve our practicals in the future.

In [None]:
#@title Generate Feedback Form. (Run Cell)
from IPython.display import HTML
HTML(
"""
<iframe 
	src="https://forms.gle/bvLLPX74LMGrFefo9",
  width="80%" 
	height="1200px" >
	Loading...
</iframe>
"""
)

<img src="https://baobab.deeplearningindaba.com/static/media/indaba-logo-dark.d5a6196d.png" width="50%" />