# JAX

**JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.**


In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

Make sure JAX is using the GPU. First, remember to turn on the GPU runtime in Colab: 

**Click on Runtime -> Click on "runtime type" -> Choose GPU**

then:

In [None]:
import jax
jax.config.update('jax_platform_name', 'gpu')

JAX provides a drop-in replacement for many of the essential NumPY functionalities. These replacements would run faster because they are parallelized and thus run efficiently on the GPU. Let's see an example:

## Multiplying Matrices

We'll be generating random data in the following examples.

In [None]:
key = random.PRNGKey(0)
key, subkey = random.split(key)
x = random.normal(subkey, (10,))
print(x)

[-0.38812608 -0.04487164 -2.0427258   0.07932311  0.33349916  0.7959976
 -1.4411978  -1.6929979  -0.37369204 -1.5401139 ]


**Note**: As you saw we used the `PRNGKey()` to generate our random variables. One big difference between NumPy and JAX is how you generate random numbers. In JAX:

> We split the PRNG to get usable subkeys every time we need a new pseudorandom number. We **propagate the key** and **use the new subkey** whenever we need new a random number.





For more details, see 
[Common Gotchas in JAX](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers)

In [None]:
size = 3000
key, subkey = random.split(key)
x = random.normal(subkey, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

The slowest run took 83.64 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 28.5 ms per loop


The same operation using NumPy takes 3 times longer:

In [None]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

10 loops, best of 5: 92.2 ms per loop


It is important to know where your data lives. When using JAX with GPU, the data lives on VRAM (the graphics card memory) and not RAM. If you want to transfer a JAX array between RAM and VRAM, yo use the `device_put(device=None)` function that by default commits (copies a "deep" copy of) your data to your default memory (here, VRAM because we chose to work on GPU before).   

In [None]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x) # this puts the data in jax.devices()[0] which for us, is the (first) GPU's VRAM. If you don't use GPU, then this will be in RAM.
%timeit jnp.dot(x, x.T).block_until_ready()

10 loops, best of 5: 24.3 ms per loop


The output of `~jax.device_put` still acts like an NDArray (a NumPy array), but it only copies values back to the CPU when they're needed for printing, plotting, saving to disk, branching, etc.

JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there are three main ones:

 - {func}`~jax.jit`, for speeding up your code
 - {func}`~jax.grad`, for taking derivatives
 - {func}`~jax.vmap`, for automatic vectorization or batching.

Let's go over these, one-by-one. We'll also end up composing these in interesting ways.

## Using `jax.jit` to speed up functions

You may have heard that python in "interpreted", i.e., it is run one line at a time. This makes it very convenient to use, especially with data science and machine learning workloads where rapid iteration is highly desirable. But this comes at a performance cost. That is why most performant ML libraries (like NumPy) have wrappers in lower-level compiled languages (like C). Just-in-time compilation, brings the best of both worlds:

"In computing, just-in-time (JIT) compilation (also dynamic translation or run-time compilations) is a way of executing computer code that involves compilation during execution of a program (at run time) rather than before execution...JIT compilation is a combination of the two traditional approaches to translation to machine code (ahead-of-time compilation (AOT), and interpretation)"&mdash;[JIT on Wikipedia](https://en.wikipedia.org/wiki/Just-in-time_compilation)

JIT compilation is now pretty common in numerical computing. For instance, the computing language Julia is also JIT compiled by-default. Besides providing a performant GPU-optimized NumPy alternative, JAX provides a convenient way to perform JIT compilation to speed up computations. 

For example, consider the Rectified Linear Unit function, or ReLU: $$\text{relu}(x) = \max\{x, 0\} = (x)^{+}$$

In [None]:
def relu(x):
    return jnp.where(x > 0, x, 0)

key, subkey = random.split(key)
x = random.normal(subkey, (1000000,))
%timeit relu(x).block_until_ready()

The slowest run took 143.12 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 5: 1.19 ms per loop


We can speed it up with `jit`, which will jit-compile the first time `relu` is called and will be cached thereafter.

In [None]:
relu_jit = jit(relu)
%timeit relu_jit(x).block_until_ready()

The slowest run took 660.66 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 5: 130 µs per loop


That is, we gained a 10 times speedup!

## Taking derivatives with `jax.grad`

Perhaps the most important functionality that JAX provides us is the automatic differentiation (AD).

Consider a sum of logits:
$$ f(\mathbf{x}) = \sum_i \sigma(x_i) = \sum_i  \frac{1}{1 + e^{-x_i}}$$

We want to calculate $\nabla_\mathbf{x}f$. Let's derive it analytically first:

$$
\frac{\mathrm{\partial}}{\mathrm{\partial} x_i} f(\mathbf{x})=\frac{e^{x} \cdot\left(1+e^{x_i}\right)-e^{x_i} \cdot e^{x_i}}{\left(1+e^{x_i}\right)^{2}}=\frac{e^{x_i}}{\left(1+e^{x_i}\right)^{2}}= \sigma(x_i)\left(1- \sigma(x_i)\right)
$$



In [None]:
def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

So, provided you use JAX primitives, JAX automatically stores the gradients and values in the computation graph defined by your function. 

In [None]:
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661194 0.10499357]


Let's verify with our analytically derived expression:

In [None]:
def grad_sum_logistic(x):
    return jnp.exp(x) / jnp.power(1.0 + jnp.exp(x), 2)

grad_sum_logistic(x_small)

DeviceArray([0.25      , 0.19661194, 0.10499358], dtype=float32)

We can go further, and take the Hessian as well. But there is a technical issue: gradient is a vector, and although in notation we write the Hessian as a gradient of the gradient, this is not technically true, since only scalers have gradients.  In fact, when we take the Hessian, we take a gradient of every element of the gradient. 

One solution is to add an additional vector $\mathbf{v}$ and do a dot product with the gradient, so that we get a scaler again. Now we can again use the `grad` function: 

In [None]:
def hessian_fn(x, v):
    intermediate_fn = lambda x : jnp.vdot(derivative_fn(x), v)
    derivative_intermediate_fn = grad(intermediate_fn)
    return derivative_intermediate_fn(x)

Now if we let $\mathbf{v}$  be the unit vectors of this 3-dimensional space, we can recover the Hessian of our function:

In [None]:
print(hessian_fn(x_small, jnp.array([1, 0, 0])))
print(hessian_fn(x_small, jnp.array([0, 1, 0])))
print(hessian_fn(x_small, jnp.array([0, 0, 1])))

[-0. -0. -0.]
[-0.         -0.09085774 -0.        ]
[-0.         -0.         -0.07996248]


But this is cumbersome. We can do this in one go by recognizing that if we stack the $\mathbf{v}$'s above together, they make an identity matrix of dimensions 3. Using the `stack` function and Python's list comprehension:

In [None]:
def hessian_fn_stacked(x):
  return jnp.stack([hessian_fn(x_small, unit_vector) for unit_vector in jnp.identity(3)])

hessian_fn_stacked(x_small)

DeviceArray([[-0.        , -0.        , -0.        ],
             [-0.        , -0.09085774, -0.        ],
             [-0.        , -0.        , -0.07996248]], dtype=float32)

This is an example of a function that could really use **vectorization**. Fortunately, JAX provides a very convenient way to vectorize functions (aka maps):

In [None]:
def hessian_fn_vmap(x):
    vectorized_hessian_fn = vmap(lambda v: hessian_fn(x, v))
    return vectorized_hessian_fn(jnp.identity(3))

hessian_fn_vmap(x_small)

DeviceArray([[-0.        , -0.        , -0.        ],
             [-0.        , -0.09085774, -0.        ],
             [-0.        , -0.        , -0.07996248]], dtype=float32)

We could have gotten our Hessian in a much more compact way:

In [None]:
vmap(lambda v: hessian_fn(x_small, v))(jnp.identity(3))

DeviceArray([[-0.        , -0.        , -0.        ],
             [-0.        , -0.09085774, -0.        ],
             [-0.        , -0.        , -0.07996248]], dtype=float32)

## Taking Jacobians with `jax.jacfwd` and `jax.jacbwd`

We can write an more elegant (and performant) code if we used JAX's Jacobian function which extends taking gradient to vector-valued functions:

If we start with a function $f: \mathbb{R}^n \rightarrow \mathbb{R}^m$, then at a point $x \in \mathbb{R}^n$ we expect to get the shapes

$f(x) \in \mathbb{R}^m$, the value of $f$ at $x$,

$\partial f(x) \in \mathbb{R}^{m \times n}$, the Jacobian matrix at $x$,

$\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}$, the Hessian at $x$.

In our example above, $m=1$ and $n=3$.

JAX has two implementation of the Jacobian, `jax.jacfwd` and `jax.jacrev`, which correspond to [forward-mode and backward-mode](https://en.wikipedia.org/wiki/Automatic_differentiation) automatic differentiation, respectively. 

In [None]:
print(jax.jacfwd(jax.grad(sum_logistic))(x_small))

[[-0.         -0.         -0.        ]
 [-0.         -0.09085774 -0.        ]
 [-0.         -0.         -0.07996248]]


Since the gradient is the Jacobian of a scaler ($m=1$) , we could have used `jacrev` or `jacfwd` to the same effect. Also we can freely compose these. So all these forms are equivalent:

In [None]:
print(jax.jacfwd(jax.jacrev(sum_logistic))(x_small))
print(jax.jacrev(jax.jacfwd(sum_logistic))(x_small))
print(jax.jacfwd(jax.jacfwd(sum_logistic))(x_small))
print(jax.jacrev(jax.jacrev(sum_logistic))(x_small))

[[-0.         -0.         -0.        ]
 [-0.         -0.09085774 -0.        ]
 [-0.         -0.         -0.07996248]]
[[-0.         -0.         -0.        ]
 [-0.         -0.09085774 -0.        ]
 [-0.         -0.         -0.07996248]]
[[ 0.          0.          0.        ]
 [ 0.         -0.09085774  0.        ]
 [ 0.          0.         -0.07996248]]
[[-0.         -0.         -0.        ]
 [-0.         -0.09085774 -0.        ]
 [-0.         -0.         -0.07996248]]


The only difference then, is performance; which is an important concern, espeically in neural networks.

In [None]:
%timeit jax.jacfwd(jax.jacrev(sum_logistic))(x_small)

100 loops, best of 5: 15.1 ms per loop


In [None]:
%timeit jax.jacrev(jax.jacfwd(sum_logistic))(x_small)

100 loops, best of 5: 15.5 ms per loop


You can read in more detail about the details of how these are different in the JAX's [autodiff cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) but all we need for now is:

> `jacfwd` is more efficient for “tall” Jacobian matrices, while `jacrev` is more efficient for “wide” Jacobian matrices. For matrices that are near-square, `jacfwd` probably has an edge over `jacrev`.

Since the gradient is a 'tall' matrix, we use `jacfwd`:

In [None]:
%timeit jax.jacfwd(jax.grad(sum_logistic))(x_small)

100 loops, best of 5: 11.6 ms per loop


We can make the code more performant by compiling it just in time (*jitting* it):

In [None]:
hessian_jitted  = jit(jax.jacfwd(jit(jax.grad(sum_logistic))))
hessian_jitted(x_small)

DeviceArray([[-0.        , -0.        , -0.        ],
             [-0.        , -0.09085774, -0.        ],
             [-0.        , -0.        , -0.07996248]], dtype=float32)

In [None]:
%timeit hessian_jitted(x_small)

The slowest run took 5.72 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 5: 111 µs per loop


Let's compare this to our slowest implementation:

In [None]:
%%timeit
hessian_fn(x_small, jnp.array([1, 0, 0]))
hessian_fn(x_small, jnp.array([0, 1, 0]))
hessian_fn(x_small, jnp.array([0, 0, 1]))

10 loops, best of 5: 47.1 ms per loop


That's nearly 500 times faster!

It's important to note that in machine learning, writing performant code is not a nicety; it's what enables training models on millions of samples in days rather than years. 

## Auto-vectorization with `jax.vmap`

Vectorization basically means to apply the same function (*map*) to a vector of values all at once, instead of one value at a time (sequentially). 

This is an important concept in parallelized processing of data which allows us to gain significant speedups on massively parallelized hardware such as GPUs; or distributed clusters. 

Python provides us with a simple `map` that applies a function onto a iterator (such as a list comprehension) or a generator:

In [None]:
map(lambda x: x*2, range(1, 10))

<map at 0x7f2835a828d0>

For example, imagine we want to multiply matrix of weights `w` with `samples`. We can do this reasonably fast with NumPy or JAX equivalent function:

In [None]:
key, *subkeys = random.split(key, 3) # split the key into 3 keys, we propagate one and use the other two
w = random.normal(subkeys[0], (150, 100))
samples = random.normal(subkeys[1], (10, 100))

%timeit jnp.dot(w, samples.T)

The slowest run took 200.61 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 5: 414 µs per loop


In [None]:
def apply_matrix(v):
    return jnp.dot(w, v)

In [None]:
def naively_batched_apply_matrix(v_batched):
    return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(samples).block_until_ready()

Naively batched
The slowest run took 33.59 times longer than the fastest. This could mean that an intermediate result is being cached.
100 loops, best of 5: 5.7 ms per loop


In [None]:
@jit
def batched_apply_matrix(v_batched):
    return jnp.dot(v_batched, w.T)

print('Manually batched')
%timeit batched_apply_matrix(samples).block_until_ready()

Manually batched
The slowest run took 698.96 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 5: 116 µs per loop


In [None]:
@jit
def vmap_batched_apply_matrix(v_batched):
    return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(samples).block_until_ready()

Auto-vectorized with vmap
The slowest run took 642.32 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 5: 135 µs per loop


# K-Means in JAX

Consider the following randomly generated dataset:

In [None]:
key, *subkeys = random.split(key, 4)
points = jnp.concatenate([
    jax.random.normal(subkeys[0], (400, 2)) + jnp.array([4, 0]),
    jax.random.normal(subkeys[1], (200, 2)) + jnp.array([.5, 1]),
    jax.random.normal(subkeys[2], (200, 2)) + jnp.array([-.5, -1]),
])
points

DeviceArray([[ 4.8073673 ,  0.69552976],
             [ 2.8321383 ,  0.50319654],
             [ 6.1076746 ,  0.8052495 ],
             ...,
             [ 1.2582104 , -1.5633142 ],
             [ 1.091379  ,  0.04237545],
             [ 0.5922272 ,  0.6641729 ]], dtype=float32)

#### Part 1: Initialize your centroids and distortions:


In [None]:
num_clusters = 4
key, subkey = random.split(key)
# TODO

### Part 2: Find new assignment and calculate new distortions

**Hint:** Use `jnp.argmin`, `jax.vmap` and `jnp.linalg.norm`

In [None]:
def update_assignment(samples, centroids):
    # TODO
    return assignment, distortions

### Part 3: Find new centroids

**Hint 1:** Use `jax.vmap`. 

**Hint 2:** What does this line do? 

```
points_assignments[:, jnp.newaxis] == cluster_id[jnp.newaxis, jnp.newaxis]
```
**Hint 3:** You can use `jnp.where`


In [None]:
# TODO

### Part 4: Iterate!

First wrap Part 2 and Part 3 in a function: 

In [None]:
def improve_centroids(values, k):
    centroids, distortions, _ = values
    # TODO
    return new_centroids, new_distortions.mean(), jnp.mean(distortions)

# let's test it
improve_centroids((initial_centroids, initial_distortion, jnp.inf), 4)

Use `jax.lax.while_loop`. It has [several benefits](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html) over the normal while loop.

We use `partial` to provide the last argument of the above function (`k`) and allow the `values` tupple to be iterated on in a consistant manner using the `while_loop` function. 

In [None]:
from functools import partial

In [None]:
thresh=1e-5
centroids, distortion, _ = jax.lax.while_loop(
        lambda values: (values[2] - values[1]) > thresh,
        partial(improve_centroids, k = num_clusters),
        (initial_centroids, initial_distortion, jnp.inf),
)
centroids

### Part 5: Visualize the clusters

In [None]:
import matplotlib.pyplot as plt

In [None]:
# Perform the final assignment
final_assignments, _ = update_assignment(points, centroids)

# TODO