# an intro to jax

jax is a high performance python library for ML research. It includes an updated version of autograd, a jit compiler for GPUs and TPUs, and is apparently really good at vectorising things. Its not an official Google product, but it is maintained by them so its probably not gonna go anywhere soon. Generally, you can use a numpy-esque syntax to write your code, so it should be quite familiar. 

we are gonna use two references for this .ipynb:
        
 * https://github.com/google/jax
        
 * https://colinraffel.com/blog/you-don-t-know-jax.html

In [1]:
import random
import itertools

In [3]:
import jax
# referring to numpy as np is so last year, now np as jax
import jax.numpy as np

In [4]:
# Current convention is to import original numpy as "onp"
import numpy as onp

we're gonna make an example (just as in one of the above links) to make an xor gate

 * 0 0 -> 0
 
 * 0 1 -> 1
 
 * 1 0 -> 1
 
 * 1 1 -> 0

whats interesting, is we can build up an entire NN using just jax syntax (well, i suppose you could do that with numpy to?)

In [6]:
# Sigmoid nonlinearity
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

In [7]:
# Computes our network's output
def net(params, x):
    w1, b1, w2, b2 = params
    # so this is a hidden layer
    # with a tanh activation
    hidden = np.tanh(np.dot(w1, x) + b1)
    # we then have our final layer
    # we are using our sigmoid here
    return sigmoid(np.dot(w2, hidden) + b2)

In [8]:
# Cross-entropy loss
# standard stuff
def loss(params, x, y):
    out = net(params, x)
    cross_entropy = -y * np.log(out) - (1 - y)*np.log(1 - out)
    return cross_entropy

In [10]:
# Utility function for testing whether the net produces the correct
# output for all possible inputs
def test_all_inputs(inputs, params):
    predictions = [int(net(params, inp) > 0.5) for inp in inputs]
    for inp, out in zip(inputs, predictions):
        print(inp, '->', out)
    # im guessing bitwise_xor aint in jax yet
    return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])

In [23]:
# this creates some ran numbers to fill our weights with

# this sets our network as having a hidden layer of 2 nodes
def initial_params():
    return [
        onp.random.randn(3, 2),  # w1
        onp.random.randn(3),  # b1
        onp.random.randn(3),  # w2
        onp.random.randn(),  #b2
    ]

## finding a gradient

In [13]:
# need to be able to find the gradient over our loss:
loss_function = jax.grad(loss)

In [14]:
# Stochastic gradient descent learning rate
learning_rate = 1.
# All possible inputs
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])



In [15]:
# Initialize parameters randomly
params = initial_params()

In [17]:
print(params)

[array([[ 0.63373614, -0.96423199],
       [ 0.59448257,  0.0934048 ],
       [-0.59233837,  1.82237451]]), array([ 0.37087304, -1.55432476, -1.78767255]), array([0.94221594, 1.16798729, 1.49574142]), -0.5188286175826627]


In [21]:
# as far as i can tell, this is an infinite loop
# weird, but hey! lets follow along with the 
# tutorial for now:
for n in itertools.count():
    # Grab a single random input
    x = inputs[onp.random.choice(inputs.shape[0])]
    # Compute the target output
    y = onp.bitwise_xor(*x)
    # Get the gradient of the loss for this input/output pair
    grads = loss_function(params, x, y)
    # Update parameters via gradient descent
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    # Every 100 iterations, check whether we've solved XOR
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
Iteration 100
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


look, we solved it!

## jit

jax in nice cause you can jit compile code for gpus and tpus. Obviously, my mac does have those, but lets have a look anyway

In [27]:
# Time the original gradient function

# recall, this includes the non-jit'd grad calc
%timeit loss_function(params, x, y)

16.4 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [28]:
# Run once to trigger JIT compilation
loss_grad = jax.jit(jax.grad(loss))
loss_function(params, x, y)
%timeit loss_grad(params, x, y)

445 µs ± 54.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


oh shit thats a big speedup

In [30]:
# retrain:

params = initial_params()

for n in itertools.count():
    x = inputs[onp.random.choice(inputs.shape[0])]
    y = onp.bitwise_xor(*x)
    grads = loss_grad(params, x, y)
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
Iteration 100
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


## vectorising

right now our network takes in 1 thing at a time. Thats not good, we want lots. Well, jax can vectorise things really easily

In [32]:
loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0), out_axes=0))

in_axes=(None, 0, 0), out_axes=0

this is specifying where to parallelise over. We are not parallelising the first arguements (which for us in params), but we are doing it over the 0th dimension of the 2nd and 3rd (x and y)

out_axis talks about the functions output, we are saying to parallelise over the loss gradients (the sole output)

In [35]:
params = initial_params()

batch_size = 100

for n in itertools.count():
    # Generate a batch of inputs
    # see, we are now using an entire batch!
    x = inputs[onp.random.choice(inputs.shape[0], size=batch_size)]
    y = onp.bitwise_xor(x[:, 0], x[:, 1])
    # The call to loss_grad remains the same!
    grads = loss_grad(params, x, y)
    # Note that we now need to average gradients over the batch
    params = [param - learning_rate * np.mean(grad, axis=0)
              for param, grad in zip(params, grads)]
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 100
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


## conclusions

it seems you get parallelising and jit-ing for "free". However, I'm not sure the use cases over TF or PyTorch - maybe if you want easy control over a gpu or tpu? I would be interested in a speedtest between this and other libraries, but maybe thats not the point, if we arent running on a g/tpu. I wonder how Autograd compares to older versions? Anyway, its got a nice numpy-y interface. 