# Introduction to ML Using Jax

<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" width="60%" />


<a href="https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/prac1-intro-ml-using-jax/prac1/prac_0_intro_to_jax_using_ml.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

© Deep Learning Indaba 2022. Apache License 2.0.

**Authors:**

**Introduction:** 

JAX is a language/API used for writing effective numerical transformations [[1]](https://jax.readthedocs.io/en/latest/index.html). It leverages [Autograd](https://github.com/hips/autograd) and [XLA](https://www.tensorflow.org/xla) (Accelerated Linear Algebra), to provide useful functionality such as automatic differentiation (`Grad`), parallelization (`pmap`), vectorization (`vmap`), just-in-time compilation (`JIT`) and more. 

JAX is **not** a replacement for Pytorch or Tensorflow, but a lower-level library that is commonly used with higher-level neural network libraries such as [Haiku](https://github.com/deepmind/dm-haiku) or [Flax](https://github.com/google/flax).  

**Topics:** 

Content: <font color='orange'>`Supervised Learning`</font>  
Level: <font color='grey'>`Beginner`</font>


**Aims/Learning Objectives:**

- Learn the basics of JAX and how what differentiates it from numpy.
- Learn how to use JAX transforms - jit, grad, vmap and pmap.  
- Learn how build simple classifiers using jax. 

**Prerequisites:**

- Basic knowledge of numpy.

**Outline:** 

[Points that link to each section.]

**Before you start:**

For this practical, you will need to use a GPU to speed up training. To do this, go to the "Runtime" menu in Colab, select "Change runtime type" and then in the popup menu, choose "GPU" in the "Hardware accelerator" box.


## Installation and Imports

In [None]:
## Install and import anything required. Capture hides the output from the cell. 
#@title Install and import required packages. (Run Cell)
%%capture

# Section1
import numpy as np
import random

import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
import numpy as np

In [None]:
#@title Helper Functions. (Run Cell)
import copy
from typing import Dict 
def plot_performance(data:Dict, title: str):
  runs = list(data.keys())
  time = list(data.values())
  
  # creating the bar plot
  plt.bar(runs, time, width = 0.35)
  
  plt.xlabel("Framework")
  plt.ylabel("Average time taken (in s)")
  # plt.title("Average time taken per framework to run dot product")
  plt.title(title)
  plt.show()

  best_perf_key = min(data, key=data.get)
  all_runs_key = copy.copy(runs)

  # all_runs_key_except_best
  all_runs_key.remove(best_perf_key)

  for k in all_runs_key:
    print(f"{best_perf_key} was {round((data[k]/data[best_perf_key]),2)} times faster than {k} !!!")

# Part 1 - Basics of JAX

## 1.1 From Numpy ➡ Jax - <font color='blue'>`ALL`</font>

 

### Jax and Numpy have a similiar interface 🤝



Let's plot sin functions using numpy.

In [None]:
# 100 linearly spaced numbers
x = np.linspace(-np.pi,np.pi,100)

# the function, which is y = sin(x) here
y = np.sin(x)

# plot the functions
plt.plot(x,y, 'b', label='y=sin(x)')

plt.legend(loc='upper left')

# show the plot
plt.show()

Now using jax.numpy - `jnp` (`import jax.numpy as jnp`)

In [None]:
# 100 linearly spaced numbers
x = jnp.linspace(-np.pi,np.pi,100)

# the function, which is y = sin(x) here
y = jnp.sin(x)

# plot the functions
plt.plot(x,y, 'b', label='y=sin(x)')

plt.legend(loc='upper left')

# show the plot
plt.show()

**Code Task:** Can you plot cosine using `jnp`?

In [None]:
#@title Plot Cosine using jnp. (UPDATE ME)

# 100 linearly spaced numbers
x = jnp.linspace(-np.pi,np.pi,100)

# UPDATE ME
y = x  

if (y == x).all():
  raise Exception("Update ME!")

# plot the functions
plt.plot(x,y, 'b', label='y=cos(x)')

plt.legend(loc='upper left')

# show the plot
plt.show()

### JAX arrays are immutable.

JAX and Numpy arrays are often interchangeable, **but** Jax arrays are **immutable** (they can't be modified after they are created).

In [None]:
# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)

In [None]:
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10

Need to use [helper functions](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) that return an updated copy of jax array. 

`x[idx] = y` -> `x = x.at[idx].set(y)`

In [None]:
x = jnp.arange(10)
x = x.at[0].set(10)
print(x)

### Randomness in JAX.

JAX is more explicit in Pseudo Random Number Generation (PRNG) and doesn't rely on a global state. 

#### In Numpy, PRNG is based on a global `state`.

In [None]:
# Set global state 
np.random.seed(42)

Let's take a few samples from a Gaussian Distribution:

In [None]:
print(f"sample 1 = {np.random.normal()}")
print(f"sample 2 = {np.random.normal()}")
print(f"sample 3 = {np.random.normal()}")

Numpy's global random state is updated everytime a random number is generated, so sample 1 != sample 2 != sample 3. 

Having the state automatically updated, makes it difficult to handle randomness in a **reproducible** way across threads, processes and devices. 

#### In JAX, PRNG is more explicit.

For each random number generation, you need to explicitly pass in a random key/state.

Passing the same state/key results in the same number being generated. This is generally undesirable.

In [None]:
from jax import random
key = random.PRNGKey(42)
print(f"sample 1 = {random.normal(key)}")
print(f"sample 2 = {random.normal(key)}")
print(f"sample 3 = {random.normal(key)}")

To generate different and indepedent samples, you need to manually **split** the keys. 

In [None]:
from jax import random
key = random.PRNGKey(42)
print(f"sample 1 = {random.normal(key)}")

# We split the key -> key and subkey
key, subkey = random.split(key)

# We use the subkey immediately and keep the key for future splits. It doesn't really matter which key we keep and which one we use immediately. 
print(f"sample 2 = {random.normal(subkey)}")

# We split the new key -> key and subkey
key, subkey = random.split(key)
print(f"sample 3 = {random.normal(subkey)}")

If we keep track of the random key/state, we can reproduce JAX random number generation in parallel across threads, processes or even devices. For more details on PRNG in JAX, you can read more [here](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html). 

## 1.2 Acceleration in JAX 🚀 

### JAX is backend Agnostic - <font color='blue'>`ALL`</font>

Using JAX, you can run the same code on different backends/AI accelerators (e.g. CPU/GPU/TPU), **with no changes in code** (no more `.to(device)` - from frameworks like PyTorch). This means we can easily run linear algebra operations directly on gpu/tpu.

#### Speed Improvement - Multiplying Matrices

Dot products are a common operation in numerical computing and a central part of modern deep learning. They are defined over [vectors](https://en.wikipedia.org/wiki/Coordinate_vector), which can loosely be thought of as a list of multiple scalers (single values). 

Formally, given two vectors $\boldsymbol{x}$,$\boldsymbol{y}$ $\in R^n$, their dot product is defined as:

<center>$\boldsymbol{x}^{\top} \boldsymbol{y}=\sum_{i=1}^{n} x_{i} y_{i}$</center>

Dot Product in Numpy (will run on cpu)

In [None]:
size=1000
x = np.random.normal(size=(size, size))
y = np.random.normal(size=(size, size))
numpy_time = %timeit -o -n 10 a_np = np.dot(y,x.T)

Dot Product using JAX (will run on current runtime - e.g. GPU).

In [None]:
size=1000
key = jax.random.PRNGKey(42)
x = jax.random.normal(key,shape=(size, size))
y = jax.random.normal(key,shape=(size, size))
jax_time = %timeit -o -n 10 jnp.dot(y, x.T).block_until_ready()

How much faster was the dot product in JAX (Using GPU)?

In [None]:
np_average_time=np.mean(numpy_time.all_runs)
jax_average_time=np.mean(jax_time.all_runs)
data = {'numpy':np_average_time, 'jax':jax_average_time}

plot_performance(data,title="Average time taken per framework to run dot product")

### Basic JAX Transformations - `JIT` and `GRAD` - <font color='blue'>`ALL`</font>

JAX transforms first convert python functions into an intermediate language called jaxpr. Transforms (e.g. jit, grad, vmap) are then applied to this jaxpr representation.

JAX generates jaxpr, in a process known as *tracing*. During tracing, function inputs are wrapped by a tracer object and then JAX records all operations (including regular python code) that occur during the function call. These recorded operations are used to reconstruct the function. Any python side-effects are not recording during tracing.

In this section, we will explore two basic JAX transforms: 
- JIT (Just-in-time compilation) - compile JAX Python function so that they can be run efficiently on XLA - `speed up functions`.
- Grad - Automatically compute gradients - `automatic differentiation`.

#### JIT

Jax dispatches operations to accelerators one at a time. If we have repeated operations, we use `jit` to compile the function the first time it is called, then subsequent calls will be cached. 

Let's compile [ReLU (Rectified Linear Unit)](https://arxiv.org/abs/1803.08375), a popular activation in deep learning. 

ReLU is defined as follows:
<center>$f(x)=max(0,x)$</center>

It can be visualized as follows:

<center>
<img src="https://machinelearningmastery.com/wp-content/uploads/2018/10/Line-Plot-of-Rectified-Linear-Activation-for-Negative-and-Positive-Inputs.png" width="35%" />
</center>,

where $x$ is the input to the function and $y$ is output of ReLU.


$$f(x)=\max (0, x)=\left\{\begin{array}{l}x_{i} \text { if } x_{i}>0 \\ 0 \text { if } x_{i}<=0\end{array}\right.$$

**Code Task:** Complete the ReLU implementation below.

In [None]:
#@title Implement ReLU.
def relu(x):
  if x > 0:
    return
    # TODO Implement me! 
  else:
    return
    # TODO Implement me! 

In [None]:
#@title Run to test your ReLU function.

def plot_relu(relu_function):
  max_int = 5
  # Generete 100 evenly spaced points from -max_int to max_int 
  x = np.linspace(-max_int,max_int,1000)
  y = np.array([relu_function(xi) for xi in x])
  plt.plot(x, y,label='ReLU')
  plt.legend(loc="upper left")
  plt.xticks(np.arange(min(x), max(x)+1, 1))
  plt.show()

def check_relu_function(relu_function):
  # Generete 100 evenly spaced points from -100 to -1
  x = np.linspace(-100,-1,100)
  y = np.array([relu_function(xi) for xi in x])
  assert (y == 0).all()

  # Check if x == 0
  x = 0
  y = relu_function(x)
  assert y == 0

  # Generete 100 evenly spaced points from 0 to 100
  x = np.linspace(0,100,100)
  y = np.array([relu_function(xi) for xi in x])
  assert np.allclose(x, y)

  print("Your ReLU function is correct!")

check_relu_function(relu)
plot_relu(relu)

Let's try to `jit` this function to speed up compilation and try to call it.

In [None]:
relu_jit = jax.jit(relu)

num_random_numbers=1000000
x = jax.random.normal(key, (num_random_numbers,))

try:# Should raise an error. 
  relu_jit(x) 
except Exception as e:
  print("Exception {}".format(e))

**Why does this fail?**

As mentioned above, JAX transforms first convert python functions into an intermediate language called jaxpr. Jaxpr only captures what is executed on the parameters given to it during tracing, so this means during conditional calls, jaxpr only considers the branch taken. 

When jit-compiling a function, we want to compile and cache a version of the function that can handle multiple different arguement types (so we don't have recompile for each function evaluation). For example, when we compile a function on an array `jnp.array([1., 2., 3.], jnp.float32)`, we would likely also want to used the compiled function for `jnp.array([4., 5., 6.], jnp.float32)`. 

To achieve this, JAX traces your code based on abstract values. The default abstraction level is a ShapedArray - array that has a fixed size and dtype, for example, if we trace a function using `ShapedArray((3,), jnp.float32)`,  it can be reused for any concrete array of size 3, and float32 dtype. 

This does come with some challenges. Tracing that relies on concrete values become tricky and sometimes results in `ConcretizationTypeError` as in the relu function above. Furtermore, when tracing function with conditional statements ("if ..."), JAX doesn't which branch to take when tracing and so tracing can't occur.

To solve this, we have two options:
- Use static arguements to make sure jax traces on a concrete value level - not ideal, if you need to retrace a lot. Example - bottom of this [section](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit).
- Use buildin jax condition flow primitives such as [`lax.cond`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.cond.html) or [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html).  


**Code Task** : Let's convert our ReLU function above to use [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) (you can also use `jnp.maximum`, if you prefer.) 

In [None]:
######@title Implement ReLU using jnp.where.
def relu(x):
  # TODO Implement ME! 
  return

In [None]:
# Check ReLU function
check_relu_function(relu)

Now let's see the performance benefit of using jit!

In [None]:
relu_jit = jax.jit(relu)
key = jax.random.PRNGKey(42)

num_random_numbers=1000000
x = jax.random.normal(key, (num_random_numbers,))


jax_time = %timeit -o -n 10 relu(x).block_until_ready()

# Warm up/Compile - first run 
relu_jit(x).block_until_ready()
jax_jit_time = %timeit -o -n 10 relu_jit(x).block_until_ready()

# Let's plot the performance difference
jax_avg_time=np.mean(jax_time.all_runs)
jax_jit_avg_time=np.mean(jax_jit_time.all_runs)
data = {'JAX (no jit)':jax_avg_time, 'JAX (with jit)':jax_jit_avg_time}

plot_performance(data,title="Average time taken for ReLU function")

#### Grad

### 2.2 Advanced JAX Transforms - `VMAP` and `PMAP` - <font color='purple'>`Optional`</font> 

# Part 2 - From Linear to Non-Linear Regression

[Background/content for the section.]

## Conclusion
**Summary:**

[Summary of the main points/takeaways from the prac.]

**Next Steps:** 

[Next steps for people who have completed the prac, like optional reading (e.g. blogs, papers, courses, youtube videos). This could also link to other pracs.]

**Appendix:** 

[Anything (probably math heavy stuff) we don't have space for in the main practical sections.]

**References:** 

https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html#

For other practicals from the Deep Learning Indaba, please visit [here](https://github.com/deep-learning-indaba/indaba-pracs-2022).

## Feedback

Please provide feedback that we can use to improve our practicals in the future.

In [None]:
#@title Generate Feedback Form. (Run Cell)
from IPython.display import HTML
HTML(
"""
<iframe 
	src="https://forms.gle/bvLLPX74LMGrFefo9",
  width="80%" 
	height="1200px" >
	Loading...
</iframe>
"""
)

<img src="https://baobab.deeplearningindaba.com/static/media/indaba-logo-dark.d5a6196d.png" width="50%" />