<a href="https://colab.research.google.com/github/probml/pyprobml/blob/master/book1/intro/jax_intro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# JAX <a class="anchor" id="jax"></a>

[JAX](https://github.com/google/jax) is a  version of Numpy that runs fast on CPU, GPU and TPU.
In addition to having a fast backend, JAX supports
several useful python-level program transformations:

* [vmap](#vmap), vectorized map operator for automatic vectorization or batching.
* [autograd](#AD), for automatic differentiation.
* [jit](#jit), just in time compiler for speeding up your code (even on a CPU!).

We illustrate these below.

More details can be found at the 
* [official JAX quickstart page](https://github.com/google/jax#quickstart-colab-in-the-cloud).
* [official JAX quickstart colab](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/quickstart.ipynb#scrollTo=SY8mDvEvCGqk)




In [1]:
# Standard Python libraries
from __future__ import absolute_import, division, print_function, unicode_literals

import os
import time
import numpy as np
np.set_printoptions(precision=3)
import glob
import matplotlib.pyplot as plt
import PIL
import imageio

from IPython import display
%matplotlib inline

import sklearn


In [2]:

# Load JAX
import jax
import jax.numpy as np
import numpy as onp # original numpy
from jax import grad, hessian, jit, vmap
print("jax version {}".format(jax.__version__))


jax version 0.2.7


In [3]:
# Check if GPU is available
!nvidia-smi


Thu Dec 31 05:41:36 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.27.04    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P8     9W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:

# Check if JAX is using GPU
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))

jax backend gpu


# Vmap <a class="anchor" id="vmap"></a>


To illustrate vmap, consider a binary logistic regression model.

In [None]:
def sigmoid(x): return 0.5 * (np.tanh(x / 2.) + 1)

def predict_single(w, x):
    return sigmoid(np.dot(w, x)) # <(D) , (D)> = (1) # inner product
  
def predict_batch(w, X):
    return sigmoid(np.dot(X, w)) # (N,D) * (D,1) = (N,1) # matrix-vector multiply


D = 2
N = 3
onp.random.seed(42)
w = onp.random.randn(D)
X = onp.random.randn(N, D)
y = onp.random.randint(0, 2, N)

# We can apply predict_batch to a matrix of data, but we cannot apply predict_single in this way
# because the order of the arguments to np.dot is incorrect.

p1 = predict_batch(w, X)
try:
    p2 = predict_single(w, X)
except:
    print('cannot apply to batch')

cannot apply to batch


To avoid having to think about batch shape, it is often easier to write a function that works on single
input vectors. We can then apply this in a loop.

In [None]:
p3 = [predict_single(w, x) for x in X]
assert np.allclose(p1, p3)

Unfortunately, mapping down a list is slow.
Fortunately, JAX provides `vmap`, which has the same effect, but can be parallelized.

We first apply the `predict_single` function to its first arugment, w, to get a function that only
depends on x. We then vectorize this, and map the resulting modified function along rows (dimension 0)
of the data matrix.

In [None]:
from functools import partial

predict_single_w = partial(predict_single, w)
predict_batch_w = vmap(predict_single_w)
p4 = predict_batch_w(X)
p5 = vmap(predict_single, in_axes=(None, 0))(w, X)

assert np.allclose(p1, p4)
assert np.allclose(p1, p5)


# Looping constructs

Since JAX is functional, it cannot mutate loop counters. So for and while loops need special constructs, as we illustrate below.

## For loops.

The semantics of the for loop function in JAX is as follows:
```
def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val
```
We see that ```val``` is used to accumulate the results across iterations.

Below is an example.

In [24]:
# sum from 1 to N = N*(N+1)/2
N = 10
s = 0
for i in range(1,N+1):
  s += i
print(s)
expected = int(N*(N+1)/2)
assert s==expected

def body_fun(i, val):
  return i + val
s2 = jax.lax.fori_loop(1, N+1, body_fun, 0)
assert s2==expected

s3 = jax.lax.fori_loop(1, N+1, lambda i,val: i+val, 0)
assert s3==expected

55


## While loops

Here is the semantics of the JAX while loop


```
def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val
```

Below is an example.

In [35]:
N =10
s = 0
i = 0
while (i <= N):
  s += i
  i += 1
print(s)
expected = int(N*(N+1)/2)
assert s==expected  

init_val = (0,0)
def cond_fun(val):
  s,i = val
  return i<=N
def body_fun(val):
  s,i = val
  s += i
  i += 1
  return (s,i)
val = jax.lax.while_loop(cond_fun, body_fun, init_val)
s2 = val[0]
print(s2)
assert s2==expected 

55
55


# Autograd <a class="anchor" id="AD"></a>

In this section, we illustrate automatic differentation using JAX.



## Simple convex functions

In [5]:
from jax import grad, hessian, jacfwd, jacrev, vmap, jit

Linear function: multi-input, scalar output.

$$
\begin{align}
f(x; a) &= a^T x\\
\nabla_x f(x;a) &= a
\end{align}
$$

In [14]:
# We construct a single output linear function.
# In this case, the Jacobian and gradient are the same.
Din = 3; Dout = 1;
onp.random.seed(42)
a = onp.random.randn(Dout, Din)
def fun1d(x):
    return np.dot(a, x)[0]
x = onp.random.randn(Din)
g = grad(fun1d)(x)
assert np.allclose(g, a)
J = jacrev(fun1d)(x)
assert np.allclose(J, g)

Linear function: multi-input, multi-output.

$$
\begin{align}
f(x;A) &= A x \\
\nabla_x f(x;A) &= A
\end{align}
$$

In [15]:
# We construct a multi-output linear function.
# We check forward and reverse mode give same Jacobians.
Din = 3; Dout = 4;
A = onp.random.randn(Dout, Din)
def fun(x):
    return np.dot(A, x)
x = onp.random.randn(Din)
Jf = jacfwd(fun)(x)
Jr = jacrev(fun)(x)
assert np.allclose(Jf, Jr)
assert np.allclose(Jf, A)

Quadratic form.

$$
\begin{align}
f(x;A) &= x^T A x \\
\nabla_x f(x;A) &= (A+A^T) x \\
\nabla^2 x^2 f(x;A) &= A + A^T
\end{align}
$$

In [16]:

D = 4
A = onp.random.randn(D, D)
x = onp.random.randn(D)
quadfun = lambda x: np.dot(x, np.dot(A, x))

J = jacfwd(quadfun)(x)
assert np.allclose(J, np.dot(A+A.T, x))

H1 = hessian(quadfun)(x)
assert np.allclose(H1, A+A.T)

def my_hessian(fun):
  return jacfwd(jacrev(fun))
H2 = my_hessian(quadfun)(x)
assert np.allclose(H1, H2)

Chain rule applied to sigmoid function.

$$
\begin{align}
f(x;w) &=\sigma(w^T x) \\
\nabla_w f(x;w) &= \sigma'(w^T x) x \\
\sigma(a) &= \sigma(a) * (1-\sigma(a)) 
\end{align}
$$

In [17]:


onp.random.seed(42)
D = 5
w = onp.random.randn(D)
x = onp.random.randn(D)
y = 0 

def sigmoid(x): return 0.5 * (np.tanh(x / 2.) + 1)
def mu(w): return sigmoid(np.dot(w,x))
def deriv_mu(w): return mu(w) * (1-mu(w)) * x
deriv_mu_jax =  grad(mu)
assert np.allclose(deriv_mu(w), deriv_mu_jax(w))



## Binary logistic regression

In [None]:


# negative log likelihood
def loss(weights, inputs, targets):
    preds = predict_batch(weights, inputs)
    logprobs = np.log(preds) * targets + np.log(1 - preds) * (1 - targets)
    return -np.sum(logprobs)

print(loss(w, X, y))

# Gradient function
grad_fun = grad(loss)

# Gradient of each example in the batch - 2 different ways
grad_fun_w = partial(grad_fun, w)
grads = vmap(grad_fun_w)(X,y)
print(grads)
assert grads.shape == (N,D)

grads2 = vmap(grad_fun, in_axes=(None, 0, 0))(w, X, y) 
assert np.allclose(grads, grads2)

# Gradient for entire batch
grad_sum = np.sum(grads, axis=0)
assert grad_sum.shape == (D,)
print(grad_sum)

1.7016186
[[-0.306 -0.719]
 [-0.112 -0.112]
 [-0.532 -0.258]]
[-0.95 -1.09]


In [None]:
# Textbook implementation of gradient
def NLL_grad(weights, batch):
    X, y = batch
    N = X.shape[0]
    mu = predict_batch(weights, X)
    g = np.sum(np.dot(np.diag(mu - y), X), axis=0)
    return g

grad_sum_batch = NLL_grad(w, (X,y))
print(grad_sum_batch)
assert np.allclose(grad_sum, grad_sum_batch)

[-0.95 -1.09]


In [None]:
# We can also compute Hessians, as we illustrate below.
from jax import hessian

hessian_fun = hessian(loss)

# Hessian on one example
H0 = hessian_fun(w, X[0,:], y[0])
print('Hessian(example 0)\n{}'.format(H0))

# Hessian for batch
Hbatch = vmap(hessian_fun, in_axes=(None, 0, 0))(w, X, y) 
print('Hbatch shape {}'.format(Hbatch.shape))

Hbatch_sum = np.sum(Hbatch, axis=0)
print('Hbatch sum\n {}'.format(Hbatch_sum))

Hessian(example 0)
[[0.105 0.246]
 [0.246 0.578]]
Hbatch shape (3, 2, 2)
Hbatch sum
 [[0.675 0.53 ]
 [0.53  0.723]]


In [None]:
# Textbook implementation of Hessian

def NLL_hessian(weights, batch):
  X, y = batch
  mu = predict_batch(weights, X)
  S = np.diag(mu * (1-mu))
  H = np.dot(np.dot(X.T, S), X)
  return H

H2 = NLL_hessian(w, (X,y) )
assert np.allclose(Hbatch_sum, H2, atol=1e-2)


# JIT (just in time compilation) <a class="anchor" id="JIT"></a>

In this section, we illustrate how to use the Jax JIT compiler to make code go faster (even on a CPU). However, it does not work on arbitrary Python code, as we explain below.




In [None]:
grad_fun_jit = jit(grad_fun) # speedup gradient function
grads_jit = vmap(partial(grad_fun_jit, w))(X,y)
assert np.allclose(grads, grads_jit)


In [None]:
# We can apply JIT to non ML applications as well.

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = np.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x)  
%timeit -n10 -r3 slow_f(x)  

The slowest run took 63.95 times longer than the fastest. This could mean that an intermediate result is being cached.
10 loops, best of 3: 234 µs per loop
The slowest run took 4.80 times longer than the fastest. This could mean that an intermediate result is being cached.
10 loops, best of 3: 5.04 ms per loop


We can also add the `%jit` decorator in front of a function.

Note that JIT compilation requires that the control flow through the function  can be determined by the shape (but not concrete value) of its inputs. The function below violates this, since when x<3, it takes one branch, whereas when x>= 3, it takes the other.

In [None]:
@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

# This will fail!
try:
  print(f(2))
except Exception as e:
  print("ERROR:", e)
  


ERROR: Abstract tracer value encountered where concrete value is expected.

The problem arose with the `bool` function. 

While tracing the function f at <ipython-input-43-3da05647f18e>:1, this concrete value was not available in Python because it depends on the value of the arguments to f at <ipython-input-43-3da05647f18e>:1 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>


We can fix this by telling JAX to trace the control flow through the function using concrete values of some of its arguments. JAX will then compile different versions, depending on the input values. See below for an example.


In [None]:
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

f2 = jit(f, static_argnums=(0,))

print(f2(5))

-20


Unfortunately, the static argnum method fails with vmap, which passes in different inputs.

In [None]:

xs = np.arange(5)
try:
  ys = vmap(f)(xs)
  print('used vmap')
except:
  ys = np.array([f(x) for x in xs])
  print('did not use vmap')
print(ys)

print(np.sum(ys))

did not use vmap
[  0.   3.  12. -12. -16.]
-13.0


There are a few other subtleties. If your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside jit'd functions:

In [None]:
def f(x):
  print(x)
  y = 2 * x
  print(y)
  return y
y1 = f(2)
print(y1)

@jit
def f(x):
  print(x)
  y = 2 * x
  print(y)
  return y
y2 = f(2)
print(y2)

2
4
4
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
4


# A few differences from Numpy

Below we list a few items where Jax differs from Numpy.
See also the official [list of common gotchas](https://colab.research.google.com/github/google/jax/blob/master/notebooks/Common_Gotchas_in_JAX.ipynb).

## Random number generation

The API for Jax is basically identical to Numpy, except for pseudo random number
generation (PRNG).
This is because Jax does not maintain any global state, i.e., it is purely functional.
This design "provides reproducible results invariant to compilation boundaries and backends,
while also maximizing performance by enabling vectorized generation and parallelization across random calls"
(to quote [the official page](https://github.com/google/jax#a-brief-tour)).
                              
Thus, whenever we do anything stochastic, we need to give it a fresh RNG key. We can do this by splitting the existing key into pieces. We can do this indefinitely, as shown below.

In [None]:
import jax.random as random

key = random.PRNGKey(0)
print(random.normal(key, shape=(3,)))  # [ 1.81608593 -0.48262325  0.33988902]
print(random.normal(key, shape=(3,)))  # [ 1.81608593 -0.48262325  0.33988902]  ## identical results

# To make a new key, we split the current key into two pieces.
key, subkey = random.split(key)
print(random.normal(subkey, shape=(3,)))  # [ 1.1378783  -1.22095478 -0.59153646]

# We can continue to split off new pieces from the global key.
key, subkey = random.split(key)
print(random.normal(subkey, shape=(3,)))  # [-0.06607265  0.16676566  1.17800343]

# We can always use original numpy if we like (although this may interfere with the deterministic behavior of jax)
onp.random.seed(42)
print(onp.random.randn(3))

[ 1.816 -0.483  0.34 ]
[ 1.816 -0.483  0.34 ]
[ 1.138 -1.221 -0.592]
[-0.066  0.167  1.178]
[ 0.497 -0.138  0.648]


## Implicitly casting lists to vectors

You cannot treat a list of numbers as a vector. Instead you must explicitly create the vector using the np.array() constructor.


In [None]:
# You cannot treat a list of numbers as a vector. 
try:
  S = np.diag([1.0, 2.0, 3.0])
except:
  print('must convert indices to np.array')

must convert indices to np.array


In [None]:
# Instead you should explicitly construct the vector.

S = np.diag(np.array([1.0, 2.0, 3.0]))

## Mutation of arrays 

Since JAX is functional, you cannot mutate arrays in place,
since this makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program.
Instead, JAX offers the functional update functions: `index_update`, `index_add`, `index_min`, `index_max`, and the `index` helper. These are illustrated below. 

Note: If the input values of `index_update` aren't reused, jit-compiled code will perform these operations in-place, rather than making a copy. 
    

In [None]:
# You cannot assign directly to elements of an array.

jax_array = np.zeros((3,3), dtype=np.float32)

# In place update of JAX's array will yield an error!
try:
  jax_array[1, :] = 1.0
except:
  print('must use index_update')

must use index_update


In [None]:
from jax.ops import index, index_add, index_update

jax_array = np.zeros((3, 3))
print("original array:")
print(jax_array)

new_jax_array = index_update(jax_array, index[1, :], 1.)

new_jax_array2 = index_add(new_jax_array, index[:, 2], 7.)
print("new array post update")
print(new_jax_array)

print("new array post add")
print(new_jax_array2)

print("old array unchanged:")
print(jax_array)

original array:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
new array post update
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]
new array post add
[[0. 0. 7.]
 [1. 1. 8.]
 [0. 0. 7.]]
old array unchanged:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


# JAX neural net libraries

JAX is a purely functional library, which differs from Tensorflow and
Pytorch, which are stateful. The main advantages of functional programming
are that  we can safely transform the code, and/or run it in parallel, without worrying about
global state changing behind the scenes. The main disadvantage is that code (especially DNNs) can be harder to write.
To simplify the task, various DNN libraries have been designed, as we list below. In this book, we use Flax.

|Name|Description|
|----|----|
|[Stax](https://github.com/google/jax/blob/master/jax/experimental/stax.py)|Barebones DNN DSL|
|[Flax](https://github.com/google/flax)|DNN library for creating and training models|
|[Haiku](https://github.com/deepmind/dm-haiku)|DNN library for creating models|
|[Trax](https://github.com/google/trax)|DNN library, focus on sequence models|
|[Objax](https://github.com/google/objax)|DNN framework, similar to PyTorch, not compatible with other JAX libraries|


# Other JAX  libraries

There are many other useful JAX libraries, most of which are purely functional, and therefore compose nicely.

|Name|Description|
|----|----|
|[NumPyro](https://github.com/pyro-ppl/numpyro)|Library for (deep) probabilistic modeling|
|[Optax](https://github.com/deepmind/optax)|Library for defining gradient-based optimizers|
|[RLax](https://github.com/deepmind/rlax)|Library for reinforcement learning|
|[Chex](https://github.com/deepmind/chex)|Library for debugging and developing reliable JAX code|