Why learn jax:</br>
It is extremely fast
How is jax faster than pytorch?</br>
Because jax uses just in time compilation compared to pytorchs dynamic graph construction.

1. Jax functions should not have side effects ==> they are not allowed to affect any variable outside of their namespaces.
2. Jax compiles functions based on anticipated shapes of all arrays/tensors in the function. This can be problematic when shapes or progrmam flow within the function depends on the values of the tensors (e.g y = x[x>=2]. Here the shape of y is determined by how many values in x are greater than 2.)

More resources:
1. Jax 101 
2. Jax the sharp bits.
3. Jax for impatient
4. Flax basics.

In [6]:
# import standard libraris
import os
import math
import numpy as np  
import time

# imports for plotting.
import matplotlib.pyplot as plt 
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # for exports
from matplotlib.colors import to_rgba
import seaborn as sns
sns.set()

## progress bar
from tqdm.auto import tqdm

  set_matplotlib_formats('svg', 'pdf') # for exports


## Jax as NumPy on accelerators.

In [7]:
import jax
import jax.numpy as jnp
print('Using jax version', jax.__version__)

Using jax version 0.4.35


In [8]:
a = jnp.zeros((2, 3), dtype= jnp.float32)
print(a)

[[0. 0. 0.]
 [0. 0. 0.]]


In [9]:
b = jnp.arange(6)
print(b)

[0 1 2 3 4 5]


In [10]:
b.__class__

jaxlib.xla_extension.ArrayImpl

Unlike numpy arrays, jax arrays can execute the same code on different backends like cpus, gpus or tpus. Hence, device arays represent an array on one of the backends. we can check the device of an array using

In [11]:
b.device

CpuDevice(id=0)

It would have been on gpu if there is one! But inorder to make the array on cpu we can use:

In [12]:
b_cpu = jax.device_get(b) 
print(b_cpu.__class__)

<class 'numpy.ndarray'>


To explicitly push the Numpy array to GPU use jax.device_put()

In [13]:
b_gpu = jax.device_put(b_cpu)
print(b_gpu.__class__)

<class 'jaxlib.xla_extension.ArrayImpl'>


Jax handles the device clash itself. 

In [14]:
b_cpu + b_gpu 

Array([ 0,  2,  4,  6,  8, 10], dtype=int32)

Print all the available devices

In [15]:
jax.devices()

[CpuDevice(id=0)]

## Immutable tensors

Numpy arrays are mutable meaning the value of array elemets can be changed in place. example: b[0] = 2. Device arrays are immutable. No inplace operations are possible (jax requires programs to be pure functions). We use b.at[0].set(1) which returns a new array  

In [16]:
b_new = b.at[0].set(1)
print('original array', b)
print('new array', b_new)

original array [0 1 2 3 4 5]
new array [1 1 2 3 4 5]


## Pseudo Random numbers in Jax

To avoid automatic change of seed outside the namespace, in jax we need to pass the seed to generate a random number.

In [17]:
rng = jax.random.PRNGKey(42)

In [18]:
## not an ideal way of generating random numbers
jax_random_number_1 = jax.random.normal(rng)
jax_random_number_2 = jax.random.normal(rng)
print('jax_random_number_1', jax_random_number_1)
print('jax_random_number_2', jax_random_number_2)


jax_random_number_1 -0.18471177
jax_random_number_2 -0.18471177


That produced same random number because of PRNGKey. One can always get new keys by splitting the old one

In [19]:
np.random.seed(42) # set the seed
np_random_number_1 = np.random.normal()
np_random_number_2 = np.random.normal()

print('numpy_random_number_1: ',np_random_number_1) 
print('numpy_random_number_2: ',np_random_number_2) 



numpy_random_number_1:  0.4967141530112327
numpy_random_number_2:  -0.13826430117118466


Two different numbers since the key (state of the random number generator) changes. Not allowed in jax --> side-effects

In [20]:
rng, subkey1, subkey2 = jax.random.split(rng, num= 3)
jax_random_number_1 = jax.random.normal(subkey1)
jax_random_number_2 = jax.random.normal(subkey2)
print("jax_random_number_1 ", jax_random_number_1)
print("jax_random_number_2 ", jax_random_number_2)

jax_random_number_1  0.107961535
jax_random_number_2  -1.2226542


Function transformations with Jaxpr.
1. Always write the code in the form of pure functions => no side effects.
2. Jaxpr: Jax transforms a given function to a small and well behaved intermediate form. We can check which operations are performed on which array and what shapes the arrays are.
3. Then based on the the jaxpr, jax then interprets the function with transformation-specific interpretation rules: automatic differentiation or compiling a function in XLA to efficiently use the accelerator.

Lets use the below function to understand the jaxpr transformation.
$$y = 1/|x| \sum_i \left[(x_i + 2)^2 +3\right]$$

In [21]:
def simple_graph(x):
    x = x + 2
    x = x**2
    x = x + 3
    y = x.mean()
    return y

inp = jnp.arange(3, dtype= jnp.float32)
print("input: ", inp)
print("output: ", simple_graph(inp))

input:  [0. 1. 2.]
output:  12.666667


In [22]:
jax.make_jaxpr(simple_graph)(inp)

{ lambda ; a:f32[3]. let
    b:f32[3] = add a 2.0
    c:f32[3] = integer_pow[y=2] b
    d:f32[3] = add c 3.0
    e:f32[] = reduce_sum[axes=(0,)] d
    f:f32[] = div e 3.0
  in (f,) }

Lets look at the jaxpr representation of a function with side-effects

In [23]:
global_list = []

def norm(x):
    global_list.append(x) # making changes to a global variable --> not a pure function.
    x = x**2
    n = x.sum()
    n = jnp.sqrt(n)
    return n

jax.make_jaxpr(norm)(inp)


{ lambda ; a:f32[3]. let
    b:f32[3] = integer_pow[y=2] a
    c:f32[] = reduce_sum[axes=(0,)] b
    d:f32[] = sqrt c
  in (d,) }

Jaxpr ignored the operation with side effects!

## Automatic differentiation

Jax takes a function and gives us another function which directly computes the gradients for it.

In [24]:
# jax.grad
grad_function = jax.grad(simple_graph)
gradients = grad_function(inp)
print('Gradients', gradients)

Gradients [1.3333334 2.        2.6666667]


In [25]:
jax.make_jaxpr(grad_function)(inp)

{ lambda ; a:f32[3]. let
    b:f32[3] = add a 2.0
    c:f32[3] = integer_pow[y=2] b
    d:f32[3] = integer_pow[y=1] b
    e:f32[3] = mul 2.0 d
    f:f32[3] = add c 3.0
    g:f32[] = reduce_sum[axes=(0,)] f
    _:f32[] = div g 3.0
    h:f32[] = div 1.0 3.0
    i:f32[3] = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] h
    j:f32[3] = mul i e
  in (j,) }

In [26]:
# both value and gradient of a function for an input.
val_grad_funct = jax.value_and_grad(simple_graph)
val_grad_funct(inp)

(Array(12.666667, dtype=float32),
 Array([1.3333334, 2.       , 2.6666667], dtype=float32))

# Pytree

Pytrees help us in dealing with nested structures.

In [28]:
# example:
import jax
import jax.numpy as jnp

example_tree = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1':2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]) # itself is a leaf
]

# print how many leaves the pytrees have
for pytree in example_tree:
    # this 'jax.tree.leaves()' method extracts the flattened leaves from pytree
    leaves = jax.tree.leaves(pytree)
    print(f'{repr(pytree): <45} has length {len(leaves)} leaves: {leaves}')


[1, 'a', <object object at 0x117e4ba30>]      has length 3 leaves: [1, 'a', <object object at 0x117e4ba30>]
(1, (2, 3), ())                               has length 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has length 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has length 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32)                 has length 1 leaves: [Array([1, 2, 3], dtype=int32)]


1. lists, tuples, and dicts are considered as pytree.
2. Any object that is not in the pytree registry is treated as leaf.
3. pytree registry can be extended to user defined contained  classes.


# Common Pytree functions.

1. jax.tree.map

In [29]:
list_of_lists = [
    [1, 2, 3],
    [1,2],
    [1, 2, 3, 4]
]

jax.tree.map(lambda x: x*2, list_of_lists)

[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

Jax.tree.map also allows mapping a N-ary function over multiple arguments. 

In [30]:
another_list_of_lists  = list_of_lists
jax.tree.map(lambda x, y: x + y, list_of_lists, another_list_of_lists)

[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

When using the multiple arguments with jax.tree.map, the structure of the inputs must match. 

## jax.tree.map with ML model parameters.

pytrees can be useful in training a MLP.


In [31]:
import numpy as np
def init_mlp_params(layer_widths):
    params = []
    for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
        params.append(
            dict(
                weights = np.random.normal(size= (n_in, n_out))*np.sqrt(2/n_in),
                biases = np.ones(shape= (n_out,))
            )
        )

    return params

params = init_mlp_params([1, 128, 128, 1])


In [32]:
# use jax.tree.map to check the shapes of the inital parameters

jax.tree.map(lambda x: x.shape, params)

[{'biases': (128,), 'weights': (1, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (1,), 'weights': (128, 1)}]

In [None]:
# define a function for training the mlp model:
# define forward pass
def forward(params, x):
    *hidden, last = params
    for layer in hidden:
        x = jax.nn.relu(x@layer['weights'] + layer['biases'])
        return x@last['weights'] +last['biases']

# define the loss function:
