# JAX

https://jax.readthedocs.io/en/latest/index.html

JAX is a newer machine learning and general accelerated computation framework, built by Google on top of XLA.  Jax takes a lot of the really nice things from Tensorflow/Pytorch and makes them easy to do:
- `numpy` interface, including a lot of `scipy` functions
- JIT Compilation with XLA for great performance.
- Automatic differentiation
- Easy GPU / Other device compilation

On top of that, JAX extends into new directions that are really powerful:
- Automatic vectorization
- Function composition, including for vectorization/gradient calculation
- Forward and Reverse mode AD
    - The JAX autodiff cookbook is one of the very best references on the web to learn how AD is actually working, and figure out how to do new things. https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

JAX has some key differences from TF/Torch/Numpy, too:
- Pure functional.
   - In particular, random numbers work differently! (https://jax.readthedocs.io/en/latest/jax.random.html)
- Tracing and static variables can cause you unexpected issues if you aren't paying attention.

Read more about the "Sharp Bits" here: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html



In [1]:
# Let's get cifar10 again:
import tensorflow as tf
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
del tf

2022-09-28 20:12:23.209586: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-09-28 20:12:24.968744: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
import jax.numpy as numpy
import jax.random as random

from jax import tree_util
from jax import jit, grad, vmap

JAX uses `DeviceArrays` as portable data types:

In [3]:
x_train = numpy.asarray(x_train)
y_train = numpy.asarray(y_train)

In [4]:
x_train.device()

GpuDevice(id=0, process_index=0)

Device arrays are *immutable*, so updates in place are impossible:

In [5]:
# This is meant to throw an error!
x_train[:,:,:,0] = 5

TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [6]:
# Works: 
x_train_changed = x_train.at[:,:,:,0].add(5)

## Neural networks in Flax

Neural networks in Flax are great, especially if you need to do any sort of unusual gradient.  Here, I'll include a calculation of the gradients _per example_, which is not trivial to do in the other frameworks.

In [7]:
import optax
import flax.linen as nn
# Flax is just one package for neural networks in JAX.  Optax includes lots of optimzers and other utilities.

In [8]:

class ResidualBlock(nn.Module):
    # This isn't a callable class for real: it's a template for creating a function to do this.
    # It's designed to look just like tensorflow and pytorch.
    def setup(self):
        
        self.conv1  = nn.Conv(features=16, kernel_size=[3,3], padding="same")
        self.conv2  = nn.Conv(features=16, kernel_size=[3,3], padding="same")

    def __call__(self, x):
        out = self.conv1(x)
        out = nn.relu(out)
        out = self.conv2(out)
        
        x = x + out
        
        return nn.relu(x)



In [9]:
class MyModel(nn.Module):
    
    def setup(self):
        
        self.conv_init = nn.Conv(features=16, kernel_size=[1,1])
        
        self.res1 = ResidualBlock()
        
        self.res2 = ResidualBlock()
        
        # 10 filters for each class:
        self.conv_final = nn.Conv(features=10, kernel_size=[1,1])
        
                
    def __call__(self, inputs):
        
        x = self.conv_init(inputs)
        
        x = self.res1(x)
        
        x = self.res2(x)
        
        x = self.conv_final(x)
        x = nn.avg_pool(x, (32,32))
        return x.reshape((10))

In [10]:
# Model can be initialized as a class, but it's not callable:
model = MyModel()

In [11]:
#This should fail!
model(x_train[0:5000])

AttributeError: "MyModel" object has no attribute "conv_init"

To initialize the network, we need to provide a random number keep as well as example data.  The data could be random, but it defines the shapes of the weights:

In [12]:
key = random.PRNGKey(0)
print(key)

[0 0]


In [13]:
# Pass in one image as an example, get the network parameters back:
params = model.init(key, x_train[0])

In [14]:
# We can look at the parameters by using the tree_util component of jax:
print(type(params))
tree_util.tree_map(lambda x: x.shape, params)

<class 'flax.core.frozen_dict.FrozenDict'>


FrozenDict({
    params: {
        conv_final: {
            bias: (10,),
            kernel: (1, 1, 16, 10),
        },
        conv_init: {
            bias: (16,),
            kernel: (1, 1, 3, 16),
        },
        res1: {
            conv1: {
                bias: (16,),
                kernel: (3, 3, 16, 16),
            },
            conv2: {
                bias: (16,),
                kernel: (3, 3, 16, 16),
            },
        },
        res2: {
            conv1: {
                bias: (16,),
                kernel: (3, 3, 16, 16),
            },
            conv2: {
                bias: (16,),
                kernel: (3, 3, 16, 16),
            },
        },
    },
})

In [15]:
# Use the apply function to call the network:
results = model.apply(params, x_train[0])
print(results.shape)

(10,)


To apply to a whole batch, just use `vmap`:

In [16]:
model_fn = vmap(model.apply, in_axes=(None, 0))
results = model_fn(params, x_train[0:5000])

In [17]:
print(results.shape)

(5000, 10)


How is performance of the model?

In [18]:
%timeit model_fn(params, x_train[0:5000])

69.3 ms ± 732 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


We can make it faster with JIT:

In [19]:
model_fn = jit(model_fn)
%timeit model_fn(params, x_train[0:5000])

254 µs ± 44.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


Writing the function to compute the gradients is similarly easy:

In [20]:
from jax.nn import one_hot

This will work on the entire batch:

In [42]:
@jit 
def compute_loss(params, inputs, labels):
    labels_onehot = one_hot(labels, num_classes=10)
    logits = model_fn(params,inputs)
    return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()


We can get to the same result case-by-case + vmap, too:

In [150]:
@jit # You can also JIT directly
def compute_loss_single(params, inputs, labels):
    labels_onehot = one_hot(labels, num_classes=10)
    # Calling the original function!
    logits = model.apply(params, inputs)
    # Don't apply mean here, and we can get the jacobian:
    return optax.softmax_cross_entropy(logits, labels_onehot).reshape()

grad_fn_single = grad(compute_loss_single)

In [44]:
compute_loss_single(params, x_train[0], y_train[0])

DeviceArray(226.38292, dtype=float32)

In [45]:
compute_loss_2 = jit(vmap(compute_loss_single, in_axes=(None, 0,0)))

In [46]:
compute_loss_2(params, x_train[0:5000], y_train[0:5000] ).mean()

DeviceArray(173.78802, dtype=float32)

Here's the intrinsically vectorized version:

In [47]:
compute_loss(params, x_train[0:5000], y_train[0:5000]) 

DeviceArray(175.85536, dtype=float32)

Which is faster?  

In [48]:
%timeit compute_loss_2(params, x_train[0:5000], y_train[0:5000] ).mean()

9.81 ms ± 906 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [49]:
%timeit compute_loss(params, x_train[0:5000], y_train[0:5000]) 

9.98 ms ± 1.55 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


The two functions both come out approximately the same but surprisinly slow.

In [54]:
print(type(x_train))
print(type(y_train))
print(type(params))

<class 'jaxlib.xla_extension.DeviceArray'>
<class 'jaxlib.xla_extension.DeviceArray'>
<class 'flax.core.frozen_dict.FrozenDict'>


## Getting gradients

Computing gradients is as simple as calling `grad` on the loss function.

Note that like `jit` and `vmap`, grad operates on functions and not data.  And, it returns a function that takes the same inputs as whatever it operates on:

In [103]:

grad_fn = grad(compute_loss)

In [57]:
gradients = grad_fn(params, x_train[0:5000], y_train[0:5000])

In [58]:
print(gradients)

FrozenDict({
    params: {
        conv_final: {
            bias: DeviceArray([-0.10100005, -0.09199999, -0.10380004, -0.09719993,
                          0.89579934, -0.09760001, -0.10380004, -0.09719993,
                         -0.10359982, -0.09959987], dtype=float32),
            kernel: DeviceArray([[[[-4.75847149e+00, -4.33445406e+00, -4.89039516e+00,
                            -4.57946348e+00,  4.22055359e+01, -4.59829617e+00,
                            -4.89039516e+00, -4.57946348e+00, -4.88200617e+00,
                            -4.69253111e+00],
                           [-2.10035934e+01, -1.91320553e+01, -2.15859604e+01,
                            -2.02134438e+01,  1.86326111e+02, -2.02965679e+01,
                            -2.15859604e+01, -2.02134438e+01, -2.15829220e+01,
                            -2.07125072e+01],
                           [-2.63003986e-02, -2.39567850e-02, -2.70295125e-02,
                            -2.53108162e-02,  2.33239129e-01, -2.54150

In [59]:
%timeit grad_fn(params, x_train[0:5000], y_train[0:5000])

39.8 ms ± 8.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


That timing is comparable to tensorflow before `tf.function`.  But, we need to JIT here too:

In [60]:
grad_fn = jit(grad_fn)
%timeit grad_fn(params, x_train[0:5000], y_train[0:5000])

540 µs ± 64.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


That timing is again comparable to tensorflow + JIT!


## Reduced Precision in JAX

To use reduced precision in JAX is easy. Since everything is functional, you can typically cast your inputs to a lower precision and it will work:

In [110]:
x_train_reduced = x_train[0:5000].astype("bfloat16")
y_reduced = y_train[0:5000].astype("bfloat16")
params_reduced = tree_util.tree_map(lambda x : x.astype("bfloat16"), params)

In [111]:
compute_loss(params_reduced, x_train_reduced, y_reduced)

DeviceArray(176.05377, dtype=float32)

In [112]:
%timeit grad_fn(params_reduced, x_train_reduced, y_reduced)

41 ms ± 2.51 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
