## What is JAX?

[JAX](https://github.com/google/jax) = [autograd](https://github.com/HIPS/autograd) + [XLA](https://www.tensorflow.org/xla) = Numpy + Autodiff + XLA(Accelerated Linear Algebra)

JAX is a Python library which augments numpy and Python code with function transformations which make it trivial to perform operations common in machine learning programs. Concretely, this makes it simple to write standard Python/numpy code and immediately be able to


*   Compute the derivative of a function via a successor to autograd (`jax.grad`)
*   Compile and run your numpy programs on GPUs and TPUs via XLA by default. Just-in-time compile your own Python functions into XLA-optimized kernels using a one-function API (`jax.jit`)
*   Automagically vectorize a function, so that e.g. you can process a “batch” of data in parallel (`jax.vmap/jax.pmap`)


**Disclaimer: I'm by no mean an expert in JAX.**
- "A good programmer is typically the one who is good at googling".




## Why JAX?
* Clean and unified API
* Powerful functional transformations
* Compatibility

In [1]:
import random
import itertools

import jax
import jax.numpy as np
# Current convention is to import original numpy as "onp"
import numpy as onp

In [2]:
def predict(params, inputs):
  for W, b in params:
    outputs = np.dot(inputs, W) + b
    inputs = np.tanh(outputs)
  return outputs

def logprob_fun(params, inputs, targets):
  preds = predict(params, inputs)
  return np.sum((preds - targets)**2)

grad_fun = jax.jit(jax.grad(logprob_fun))  # compiled gradient evaluation function
perex_grads = jax.jit(jax.vmap(grad_fun, in_axes=(None, 0, 0)))  # fast per-example grads

## Example (credit to Colin Raffel)

We will be learning the XOR function with a small neural network. The XOR function takes as input two binary numbers and outputs a binary number, like so:

In 1 | In 2 | Out
---- | ---- | ---
0    | 0    | 0   
0    | 1    | 1  
1    | 0    | 1   
1    | 1    | 0  

We'll use a neural network with a single hidden layer with 3 neurons and a hyperbolic tangent nonlinearity, trained with the cross-entropy loss via stochastic gradient descent. Let's implement this model and loss function. Note that the code is exactly as you'd write in standard `numpy`.

In [3]:
# Sigmoid nonlinearity
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# Computes our network's output
def net(params, x):
    w1, b1, w2, b2 = params
    hidden = np.tanh(np.dot(w1, x) + b1)
    return sigmoid(np.dot(w2, hidden) + b2)

# Cross-entropy loss
def loss(params, x, y):
    out = net(params, x)
    cross_entropy = -y * np.log(out) - (1 - y)*np.log(1 - out)
    return cross_entropy

# Utility function for testing whether the net produces the correct
# output for all possible inputs
def test_all_inputs(inputs, params):
    predictions = [int(net(params, inp) > 0.5) for inp in inputs]
    for inp, out in zip(inputs, predictions):
        print(inp, '->', out)
    return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])

There are some places where we want to use standard numpy rather than jax.numpy. One of those places is with parameter initialization. We'd like to initialize our parameters randomly before we train our network, which is not an operation for which we need derivatives or compilation. JAX uses its own jax.random library instead of numpy.random which provides better support for reproducibility (seeding) across different transformations. Since we don't need to transform the initialization of parameters in any way, it's simplest just to use standard numpy.random instead of jax.random here.

In [4]:
def initial_params():
    return [
        onp.random.randn(3, 2),  # w1
        onp.random.randn(3),  # b1
        onp.random.randn(3),  # w2
        onp.random.randn(),  #b2
    ]

## `jax.grad`

The first transformation we'll use is `jax.grad`. `jax.grad` takes a function and returns a new function which computes the gradient of the original function. By default, the gradient is taken with respect to the first argument; this can be controlled via the `argnums` argument to `jax.grad`. To use gradient descent, we want to be able to compute the gradient of our loss function with respect to our neural network's parameters. For this, we'll simply use `jax.grad(loss)` which will give us a function we can call to get these gradients.

In [5]:
loss_grad = jax.grad(loss)

# Stochastic gradient descent learning rate
learning_rate = 1.
# All possible inputs
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])

# Initialize parameters randomly
params = initial_params()

for n in itertools.count():
    # Grab a single random input
    x = inputs[onp.random.choice(inputs.shape[0])]
    # Compute the target output
    y = onp.bitwise_xor(*x)
    # Get the gradient of the loss for this input/output pair
    grads = loss_grad(params, x, y)
    # Update parameters via gradient descent
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    # Every 100 iterations, check whether we've solved XOR
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
Iteration 100
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


## `jax.jit`

While carefully-written `numpy` code can be reasonably performant, for modern machine learning we want our code to run as fast as possible. JAX provides a JIT (just-in-time) compiler which takes a standard Python/`numpy` function and compiles it to run efficiently on an accelerator. Compiling a function also avoids the overhead of the Python interpreter, which helps whether or not you're using an accelerator. In total, `jax.jit` can dramatically speed-up your code with essentially no coding overhead - you just ask JAX to compile the function for you. Even our tiny neural network can see a pretty dramatic speedup when using `jax.jit`:

**Warning**: not all functions can be jit'ed (see [here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT)).

In [6]:
# Time the original gradient function
%timeit loss_grad(params, x, y)
loss_grad = jax.jit(jax.grad(loss))
# Run once to trigger JIT compilation
loss_grad(params, x, y)
%timeit loss_grad(params, x, y)

2.27 ms ± 35.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
9.01 μs ± 44.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


## `jax.vmap`

An astute reader may have noticed that we have been training our neural network on a single example at a time. This is "true" stochastic gradient descent; in practice, when training modern machine learning models we perform "minibatch" gradient descent where we average the loss gradients over a mini-batch of examples at each step of gradient descent. JAX provides `jax.vmap`, which is a transformation which automatically "vectorizes" a function. What this means is that it allows you to compute the output of a function in parallel over some axis of the input. For us, this means we can apply the `jax.vmap` function transformation and immediately get a version of our loss function gradient which is amenable to using a minibatch of examples.

`jax.vmap` takes in additional arguments:
- `in_axes` is a tuple or integer which tells JAX over which axes the function's arguments should be parallelized. The tuple should have the same length as the number of arguments of the function being `vmap`'d, or should be an integer when there is only one argument. In our example, we'll use `(None, 0, 0)`, meaning "don't parallelize over the first argument (`params`), and parallelize over the first (zeroth) dimension of the second and third arguments (`x` and `y`)".
- `out_axes` is analogous to `in_axes`, except it specifies which axes of the function's output to parallelize over. In our case, we'll use `0`, meaning to parallelize over the first (zeroth) dimension of the function's sole output (the loss gradients).

Note that we will have to change the training code a little bit - we need to grab a batch of data instead of a single example at a time, and we need to average the gradients over the batch before applying them to update the parameters.

In [7]:
loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0), out_axes=0))

params = initial_params()

batch_size = 100

for n in itertools.count():
    # Generate a batch of inputs
    x = inputs[onp.random.choice(inputs.shape[0], size=batch_size)]
    y = onp.bitwise_xor(x[:, 0], x[:, 1])
    # The call to loss_grad remains the same!
    grads = loss_grad(params, x, y)
    # Note that we now need to average gradients over the batch
    params = [param - learning_rate * np.mean(grad, axis=0)
              for param, grad in zip(params, grads)]
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
Iteration 100
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


## SPMD programming with `jax.pmap`

For parallel programming of multiple accelerators, like multiple GPUs, use pmap. With pmap you write single-program multiple-data (SPMD) programs, including fast parallel collective communication operations. See MNIST [example](https://github.com/google/jax/blob/master/examples/spmd_mnist_classifier_fromscratch.py).

## Neural Network Libraries

There are several neural net libraries built on top of JAX. Depending what you're trying to do, you have several options:

- For toy functions and simple architectures (e.g. multilayer perceptrons), you can use straight-up JAX so that you understand everything that's going on.
- `Stax` is a very lightweight neural net package with easy-to-follow source code. It's good for implementing simpler architectures like CIFAR conv nets, and has the advantage that you can understand the whole control flow of the code.
- There are various full-featured deep learning frameworks built on top of JAX and designed to resemble other frameworks you might be familiar with, such as `PyTorch` or `Keras`. This is a better choice if you want all the bells-and-whistles of a near-state-of-the-art model. The main choices are [`Flax`](https://github.com/google/flax), [`Haiku`](https://github.com/deepmind/dm-haiku), and [`Objax`](https://github.com/google/objax), and the choice between them might come down to which ones already have a public implementation of something you need.


In [None]:
!pip install objax
import random

import numpy as np
import tensorflow as tf

import objax
from objax.zoo.wide_resnet import WideResNet

Collecting objax
[?25l  Downloading https://files.pythonhosted.org/packages/5a/e0/54503c60e3a04c23600dd9cf099c5adfda7c76d1240fd3bff4e71e3164b7/objax-1.2.0.tar.gz (41kB)
[K     |████████                        | 10kB 24.4MB/s eta 0:00:01[K     |███████████████▉                | 20kB 30.8MB/s eta 0:00:01[K     |███████████████████████▉        | 30kB 23.7MB/s eta 0:00:01[K     |███████████████████████████████▊| 40kB 18.6MB/s eta 0:00:01[K     |████████████████████████████████| 51kB 7.0MB/s 
Collecting parameterized
  Downloading https://files.pythonhosted.org/packages/31/13/fe468c8c7400a8eca204e6e160a29bf7dcd45a76e20f1c030f3eaa690d93/parameterized-0.8.1-py2.py3-none-any.whl
Building wheels for collected packages: objax
  Building wheel for objax (setup.py) ... [?25l[?25hdone
  Created wheel for objax: filename=objax-1.2.0-cp36-none-any.whl size=65294 sha256=c994aa3985a43f9432d7f7e1b6ef3b8d302956aa7177638db7b8351a7a7c1fd4
  Stored in directory: /root/.cache/pip/wheels/8e/97/e4

In [None]:
# Data
(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train.transpose(0, 3, 1, 2) / 255.0
X_test = X_test.transpose(0, 3, 1, 2) / 255.0

# Model
model = WideResNet(nin=3, nclass=10, depth=28, width=2)
opt = objax.optimizer.Adam(model.vars())

# Losses
@objax.Function.with_vars(model.vars())
def loss(x, label):
    logit = model(x, training=True)
    return objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()

gv = objax.GradValues(loss, model.vars())

@objax.Function.with_vars(model.vars() + opt.vars())
def train_op(x, y, lr):
    g, v = gv(x, y)
    opt(lr=lr, grads=g)
    return v


train_op = objax.Jit(train_op)
predict = objax.Jit(objax.nn.Sequential([
    objax.ForceArgs(model, training=False), objax.functional.softmax
]))


def augment(x):
    if random.random() < .5:
        x = x[:, :, :, ::-1]  # Flip the batch images about the horizontal axis
    # Pixel-shift all images in the batch by up to 4 pixels in any direction.
    x_pad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'reflect')
    rx, ry = np.random.randint(0, 8), np.random.randint(0, 8)
    x = x_pad[:, :, rx:rx + 32, ry:ry + 32]
    return x


# Training
# print(model.vars())
for epoch in range(30):
    # Train
    loss = []
    sel = np.arange(len(X_train))
    np.random.shuffle(sel)
    for it in range(0, X_train.shape[0], 64):
        loss.append(train_op(augment(X_train[sel[it:it + 64]]), Y_train[sel[it:it + 64]].flatten(),
                             4e-3 if epoch < 20 else 4e-4))

    # Eval
    test_predictions = [predict(x_batch).argmax(1) for x_batch in X_test.reshape((50, -1) + X_test.shape[1:])]
    accuracy = np.array(test_predictions).flatten() == Y_test.flatten()
    print(f'Epoch {epoch + 1:4d}  Loss {np.mean(loss):.2f}  Accuracy {100 * np.mean(accuracy):.2f}')

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Epoch    1  Loss 1.53  Accuracy 58.20
Epoch    2  Loss 1.06  Accuracy 63.95
Epoch    3  Loss 0.85  Accuracy 69.92
Epoch    4  Loss 0.70  Accuracy 74.81
Epoch    5  Loss 0.60  Accuracy 74.73
Epoch    6  Loss 0.53  Accuracy 80.52
Epoch    7  Loss 0.48  Accuracy 82.04
Epoch    8  Loss 0.44  Accuracy 82.75
Epoch    9  Loss 0.41  Accuracy 85.02
Epoch   10  Loss 0.38  Accuracy 85.09
Epoch   11  Loss 0.35  Accuracy 86.25
Epoch   12  Loss 0.32  Accuracy 87.10
Epoch   13  Loss 0.31  Accuracy 86.81
Epoch   14  Loss 0.29  Accuracy 84.76
Epoch   15  Loss 0.26  Accuracy 86.53
Epoch   16  Loss 0.25  Accuracy 88.06
Epoch   17  Loss 0.24  Accuracy 88.58
Epoch   18  Loss 0.22  Accuracy 88.91
Epoch   19  Loss 0.22  Accuracy 89.05
Epoch   20  Loss 0.21  Accuracy 88.96
Epoch   21  Loss 0.13  Accuracy 91.51
Epoch   22  Loss 0.11  Accuracy 91.67
Epoch   23  Loss 0.10  Accuracy 91.32
Epoch   24  Loss 0.09  Accuracy 91.81
Epoch   25