

<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" width="300" height="300" align="center"/>

I hope you have been liking the tutorials so far. This is the fourth tutorial in this series and today, we will be starting with **JAX**. If you haven't looked at the previous tutorials, I highly suggest going through them once. Here are the links:

1. [TF_JAX_Tutorials - Part 1](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part1)
2. [TF_JAX_Tutorials - Part 2](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part2)
3. [TF_JAX_Tutorials - Part 3](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part3)

# What is JAX anyway?

**JAX** is a framework that is specifically suited for Machine Learning Research (more on this later). A few points about JAX:
1. It's just like `numpy` but uses a compiler (XLA) to compile native Numpy code, and runs on acceleartors (GPU/TPU)
2. For automatic differentiation, JAX uses `Autograd`. It automatically differentiate native Python and Numpy code.
3. JAX is used to express numerical programs as compositions but with certain constraints e.g. JAX transformation and compilation are designed to work only on Python functions that are functionally pure. A function is pure if it always returns the same value when invoked with same arguments, and the function has no-side affect e.g. chaning the state of a non-local variables
4. In terms of syntax, JAX is very very similar to numpy but there are subtle differences that you should be aware of. We will look into it in a bit.

Let's take a few examples to see JAX in-play!

In [1]:
import time

import jax
import numpy as np
import jax.numpy as jnp
from jax import random

%config IPCompleter.use_jedi = False

In [2]:
# We will create two arrays, one with numpy and other with jax
# to check the common things and the differences

array_numpy = np.arange(10, dtype=np.int32)
array_jax = jnp.arange(10, dtype=jnp.int32)

print("Array created using numpy: ", array_numpy)
print("Array created using JAX: ", array_jax)

Array created using numpy:  [0 1 2 3 4 5 6 7 8 9]
Array created using JAX:  [0 1 2 3 4 5 6 7 8 9]


In [3]:
# What types of array are these?
print(f"array_numpy is of type : {type(array_numpy)}")
print(f"array_jax is of type : {type(array_jax)}")

array_numpy is of type : <class 'numpy.ndarray'>
array_jax is of type : <class 'jaxlib.xla_extension.DeviceArray'>


So, `array_numpy` is an object of **`ndarray`** while `array_jax` is an object of **`DeviceArray`**. Before discussing anything else, let's dive into **`DeviceArray`** and see what makes **DeviceArray** so special.


# DeviceArray

Following are the points that you should know about **`DeviceArray`**:
1. It is the core underlying JAX array object, similar to `ndarray` but with subtle differences (more on this in the examples below)
2. Unlike `ndarray`, `DeviceArray` is backed by a memory buffer on a single device (CPU/GPU/TPU)
3. It is **device-agnostic** i.e. JAX doesn't need to track the device on which the array is present, and can avoid data transfers
4. Because it is device agnostic, this makes it easy to run the same JAX code on CPU, GPU, or TPU with no code changes
5. `DeviceArray` is **lazy** i.e. the value of a JAX `DeviceArray` isn't immediately available and is only pulled when requested.
6. Even though `DeviceArray` is lazy, you can still do operations like inspecting the shape or type of a DeviceArray without waiting for the computation that produced it to complete. We can even pass it to another JAX computation. (The examples will make it more clear)

The two properties **lazy evaluation**, and being **device-agnostic** give **`DeviceArray`** a huge advantage. You will see this in the future tutorials as we dive deeper and deeper into complex things like model building, optimization, etc.

# Numpy vs JAX-numpy

`jax numpy` is very very similar to `numpy` in terms of API. *Most of the operations* that you do in numpy are also available in jax numpy with similar semantics. I am just listing down a few operations to showcase this but there are many more. Please check the [docs](https://jax.readthedocs.io/en/latest/jax.numpy.html) to see the list of functions that are available.

**Note:** Not all Numpy functions are implemented in JAX numpy (..yet)

In [4]:
# Find the max element. Similarly you can find `min` as well
print(f"Maximum element in ndarray: {array_numpy.max()}")
print(f"Maximum element in DeviceArray: {array_jax.max()}")

Maximum element in ndarray: 9
Maximum element in DeviceArray: 9


In [5]:
# Reshaping
print("Original shape of ndarray: ", array_numpy.shape)
print("Original shape of DeviceArray: ", array_jax.shape)

array_numpy = array_numpy.reshape(-1, 1)
array_jax = array_jax.reshape(-1, 1)

print("\nNew shape of ndarray: ", array_numpy.shape)
print("New shape of DeviceArray: ", array_jax.shape)

Original shape of ndarray:  (10,)
Original shape of DeviceArray:  (10,)

New shape of ndarray:  (10, 1)
New shape of DeviceArray:  (10, 1)


In [6]:
# Absoulte pairwise difference
print("Absoulte pairwise difference in ndarray")
print(np.abs(array_numpy - array_numpy.T))

print("\nAbsoulte pairwise difference in DeviceArray")
print(jnp.abs(array_jax - array_jax.T))

# Are they equal?
print("\nAre all the values same?", end=" ")
print(jnp.alltrue(np.abs(array_numpy - array_numpy.T) == jnp.abs(array_jax - array_jax.T)))

Absoulte pairwise difference in ndarray
[[0 1 2 3 4 5 6 7 8 9]
 [1 0 1 2 3 4 5 6 7 8]
 [2 1 0 1 2 3 4 5 6 7]
 [3 2 1 0 1 2 3 4 5 6]
 [4 3 2 1 0 1 2 3 4 5]
 [5 4 3 2 1 0 1 2 3 4]
 [6 5 4 3 2 1 0 1 2 3]
 [7 6 5 4 3 2 1 0 1 2]
 [8 7 6 5 4 3 2 1 0 1]
 [9 8 7 6 5 4 3 2 1 0]]

Absoulte pairwise difference in DeviceArray
[[0 1 2 3 4 5 6 7 8 9]
 [1 0 1 2 3 4 5 6 7 8]
 [2 1 0 1 2 3 4 5 6 7]
 [3 2 1 0 1 2 3 4 5 6]
 [4 3 2 1 0 1 2 3 4 5]
 [5 4 3 2 1 0 1 2 3 4]
 [6 5 4 3 2 1 0 1 2 3]
 [7 6 5 4 3 2 1 0 1 2]
 [8 7 6 5 4 3 2 1 0 1]
 [9 8 7 6 5 4 3 2 1 0]]

Are all the values same? True


In [7]:
# Matrix multiplication
print("Matrix multiplication of ndarray")
print(np.dot(array_numpy, array_numpy.T))

print("\nMatrix multiplication of DeviceArray")
print(jnp.dot(array_jax, array_jax.T))

Matrix multiplication of ndarray
[[ 0  0  0  0  0  0  0  0  0  0]
 [ 0  1  2  3  4  5  6  7  8  9]
 [ 0  2  4  6  8 10 12 14 16 18]
 [ 0  3  6  9 12 15 18 21 24 27]
 [ 0  4  8 12 16 20 24 28 32 36]
 [ 0  5 10 15 20 25 30 35 40 45]
 [ 0  6 12 18 24 30 36 42 48 54]
 [ 0  7 14 21 28 35 42 49 56 63]
 [ 0  8 16 24 32 40 48 56 64 72]
 [ 0  9 18 27 36 45 54 63 72 81]]

Matrix multiplication of DeviceArray
[[ 0  0  0  0  0  0  0  0  0  0]
 [ 0  1  2  3  4  5  6  7  8  9]
 [ 0  2  4  6  8 10 12 14 16 18]
 [ 0  3  6  9 12 15 18 21 24 27]
 [ 0  4  8 12 16 20 24 28 32 36]
 [ 0  5 10 15 20 25 30 35 40 45]
 [ 0  6 12 18 24 30 36 42 48 54]
 [ 0  7 14 21 28 35 42 49 56 63]
 [ 0  8 16 24 32 40 48 56 64 72]
 [ 0  9 18 27 36 45 54 63 72 81]]


Now, let's take a look at some of the things that you can do in Numpy but not in Jax-numpy and vice-versa

# Immutability

JAX arrays are **immutable**, just like [**TensorFlow tensors**](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part1). Meaning, JAX arrays don't support `item assignment` as you do in `ndarray`. Let's take an example!

In [8]:
array1 = np.arange(5, dtype=np.int32)
array2 = jnp.arange(5, dtype=jnp.int32)

print("Original ndarray: ", array1)
print("Original DeviceArray: ", array2)

# Item assignment
array1[4] = 10
print("\nModified ndarray: ", array1)
print("\nTrying to modify DeviceArray-> ", end=" ")

try:
    array2[4] = 10
    print("Modified DeviceArray: ", array2)
except Exception as ex:
    print(type(ex).__name__, ex)

Original ndarray:  [0 1 2 3 4]
Original DeviceArray:  [0 1 2 3 4]

Modified ndarray:  [ 0  1  2  3 10]

Trying to modify DeviceArray->  TypeError '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?


This situation is exactly the same as we have with TensorFlow Tensors. Similar to `tf.tensor_scatter_nd_update` in TensorFlow, we have [Indexed update operators](https://jax.readthedocs.io/en/latest/jax.ops.html#indexed-update-operators)( earlier there used to be [**jax.ops.index_update(..)**](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_update.html#jax.ops.index_update) but it's deprecated now). The syntax is pretty simple e.g. `DeviceArray.at[idx].op(val)`. This doesn't modify the original array though, it returns a new array with the elements updated as specified


A question that naturally comes to mind? **Why immutability?** The thing is that JAX relies on **pure functions**. Allowing item-assignment or in-place updates is the opposite of that philosophy.

But then why TF Tensors are immutable as it doesn't need pure functions? If you are doing any optimization on a DAG, it is highly advisable to avoid things that change the state of an op used in the computation to avoid any side effects.

In [9]:
# Modifying DeviceArray elements at specific index/indices
array2_modified = array2.at[4].set(10)

# Equivalent => array2_modified = jax.ops.index_update(array2, 4, 10)

print("Original DeviceArray: ", array2)
print("Modified DeviceArray: ", array2_modified)

Original DeviceArray:  [0 1 2 3 4]
Modified DeviceArray:  [ 0  1  2  3 10]


In [10]:
# Of course, updates come in many forms!
# Of course, updates come in many forms!
print(array2.at[4].add(6))
print(array2.at[4].max(20))
print(array2.at[4].min(-1))

# Equivalent but depecated. Just to showcase the similarity to tf scatter_nd_update
print("\nEquivalent but deprecatd")
print(jax.ops.index_add(array2, 4, 6))
print(jax.ops.index_max(array2, 4, 20))
print(jax.ops.index_min(array2, 4, -1))

[ 0  1  2  3 10]
[ 0  1  2  3 20]
[ 0  1  2  3 -1]

Equivalent but deprecatd
[ 0  1  2  3 10]
[ 0  1  2  3 20]
[ 0  1  2  3 -1]


# Asynchronous dispatch

One of the biggest differences between `ndarrays` and `DeviceArrays` is in their execution and their availability. JAX uses asynchronous dispatch to hide Python overheads. Let's take an example to understand what it means.

In [11]:
# Create two random arrays sampled from a uniform distribution
array1 = np.random.uniform(size=(8000, 8000)).astype(np.float32)
array2 = jax.random.uniform(jax.random.PRNGKey(0), (8000, 8000), dtype=jnp.float32) # More on PRNGKey later!
print("Shape of ndarray: ", array1.shape)
print("Shape of DeviceArray: ", array2.shape)

Shape of ndarray:  (8000, 8000)
Shape of DeviceArray:  (8000, 8000)


Now, let's do some computation on each array to see what happens and how much time does each computation take

In [12]:
# Dot product on ndarray
start_time = time.time()
res = np.dot(array1, array1)
print(f"Time taken by dot product op on ndarrays: {time.time()-start_time:.2f} seconds")

# Dot product on DeviceArray
start_time = time.time()
res = jnp.dot(array2, array2)
print(f"Time taken by dot product op on DeviceArrays: {time.time()-start_time:.2f} seconds")

Time taken by dot product op on ndarrays: 7.94 seconds
Time taken by dot product op on DeviceArrays: 0.02 seconds


Wow! Seems that the `DeviceArray` computation finished in no time. This is where you should remember this:
1. Unlike the result of `ndarray`, the result of the computation done on DeviceArray isn't available yet. This is a **future** value that will be available on the accelerator 
2. You can retrieve the value of this computation by **printing** it or by converting it into a plain old numpy `ndarray`
3. The above timing for DeviceArray is the time taken to **dispatch** the work, not the time taken for actual computation
4. Asynchronous dispatch is useful since it allows Python code to “run ahead” of an accelerator device, keeping Python code out of the critical path. Provided the Python code enqueues work on the device faster than it can be executed, and that Python code does not actually need to inspect the output of a computation on the host, then a Python program can enqueue arbitrary amounts of work and avoid having the accelerator wait.
5. To measure the true cost of any such operation:
     - Either convert it to plain numpy ndarray (not preferred)
     - Use `block_until_ready()` to wait for the computation that produced it to complete (preferred way for benchmarking)
     
Let's take a look at the above two methods again to measure the correct computation time

In [13]:
# First we will time it by converting the computation results to ndarray
%time np.asarray(jnp.dot(array2, array2))

CPU times: user 1min 6s, sys: 459 ms, total: 1min 7s
Wall time: 17.5 s


array([[1973.7642, 1957.4628, 1977.2909, ..., 1968.4293, 1975.3844,
        1988.7894],
       [2025.475 , 2023.9645, 2015.592 , ..., 2023.733 , 2002.3163,
        2028.4009],
       [2010.4509, 1999.3922, 2015.3254, ..., 2001.3368, 2002.6456,
        1999.4705],
       ...,
       [1990.0709, 1980.6545, 2004.953 , ..., 2000.7068, 1989.4515,
        1998.1526],
       [2019.7246, 2013.85  , 2037.707 , ..., 2013.0159, 2011.6285,
        2014.5178],
       [2010.0378, 1999.1147, 2012.6888, ..., 2006.3755, 2002.0842,
        2011.8866]], dtype=float32)

In [14]:
# Now let's time it using the blocking method
%time jnp.dot(array2, array2).block_until_ready()

CPU times: user 33.8 s, sys: 415 ms, total: 34.3 s
Wall time: 8.65 s


DeviceArray([[1973.7642, 1957.4628, 1977.2909, ..., 1968.4293, 1975.3844,
              1988.7894],
             [2025.475 , 2023.9645, 2015.592 , ..., 2023.733 , 2002.3163,
              2028.4009],
             [2010.4509, 1999.3922, 2015.3254, ..., 2001.3368, 2002.6456,
              1999.4705],
             ...,
             [1990.0709, 1980.6545, 2004.953 , ..., 2000.7068, 1989.4515,
              1998.1526],
             [2019.7246, 2013.85  , 2037.707 , ..., 2013.0159, 2011.6285,
              2014.5178],
             [2010.0378, 1999.1147, 2012.6888, ..., 2006.3755, 2002.0842,
              2011.8866]], dtype=float32)

# Types promotion

This is another aspect to keep in mind. `dtype` promotion in JAX is less aggressve as compared to numpy. A few things:
1. JAX always prefers the precision of the JAX value when promoting a Python scalar
2. JAX always prefers the type of the floating-point or complex type when promoting an integer or boolean type against floating or complex type
3. JAX uses floating point promotion rules that are more suited to modern accelerator devices like GPUs/TPUs

Let's take an example to see these in action

In [15]:
print("Types promotion in numpy =>", end=" ")
print((np.int8(32) + 4).dtype)

print("Types promtoion in JAX =>", end=" ")
print((jnp.int8(32) + 4).dtype)

Types promotion in numpy => int64
Types promtoion in JAX => int8


In [16]:
array1 = np.random.randint(5, size=(2), dtype=np.int32)
print("Implicit numpy casting gives: ", (array1 + 5.0).dtype)

# Check the difference in semantics of the above function in JAX
array2 = jax.random.randint(jax.random.PRNGKey(0),
                            minval=0,
                            maxval=5,
                            shape=[2],
                            dtype=jnp.int32
                           )
print("Implicit JAX casting gives: ", (array2 + 5.0).dtype)

Implicit numpy casting gives:  float64
Implicit JAX casting gives:  float32


# Automatic Differentiation

Automatic Differentiation is one of my favorite topics to cover. It is beautiful and demands a full tutorial. Similar to how we covered AD in depth in [TensorFlow tutorial](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part3), we will cover it in-depth here in future tutorials. Here we will take a look at a simple example to see how tightly it is integrated into JAX.

In [17]:
def squared(x):
    return x**2

x = 4.0
y = squared(x)

dydx = jax.grad(squared)
print("First order gradients of y wrt x: ", dydx(x))
print("Second order gradients of y wrt x: ", jax.grad(dydx)(x))

First order gradients of y wrt x:  8.0
Second order gradients of y wrt x:  2.0


That's it for part 1! We will be looking at other things, especially the **`Automatic Differentiation`** in the next tutorial!<br>


**References**:
1. https://jax.readthedocs.io/en/latest/
2. https://colinraffel.com/blog/you-don-t-know-jax.html
3. https://flax.readthedocs.io/en/latest/notebooks/jax_for_the_impatient.html


Let me know in the comments if you have any suggestions/queries