Welcome to my JAX tutorial series. This is the first part of this series. As you notice, wherever you look at, you become encountered by something about JAX in machine learning domain, and naturally you become curious about what is really that JAX thing. I am here to explain and clarify your curiosity in the most detailed way. You can find the list of my tutorials below:


**JAX Tutorials:**

* [1. Introduction to JAX](https://www.kaggle.com/code/goktugguvercin/introduction-to-jax)
* [2. Gradients and Jacobians in JAX](https://www.kaggle.com/code/goktugguvercin/gradients-and-jacobians-in-jax)
* [3. Automatic Differentiation in JAX](https://www.kaggle.com/code/goktugguvercin/automatic-differentiation-in-jax)
* [4. Just-In-Time Compilation in JAX](https://www.kaggle.com/code/goktugguvercin/just-in-time-compilation-in-jax)

<div style="width:100%;text-align: center;"> 
<img align=middle src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" width="250" height="250">
</div>

In this tutorial, we will get familiar with fundamentals of JAX framework designed by Google-Brain team. It will be like general introduction to it. I guess that the things that you hear, first of all, about JAX is *GPU-Supported and Differentiable Numpy*, *Just-In-Time Compilation*, *Composable Function Transformations*, and *Automatic Vectorization*. We will deepen these titles one by one in this tutorial. In addition, I will mention about critical points that you are aware of while coding with JAX. I hope you like it.

# JAX as Improved Numpy

In the simplest terms, JAX is extended and upgraded version of Numpy with additional functionalities to have remarkable performance while doing research on machine learning domain. To give more specific details about JAX, it comes up with quite handy numpy module being accessed and used by $jax.numpy \,\, as \,\, jnp$. However, while JAX tries to make its own NumPy API as similar as possible to original one in order to provide nice usability environment that we experienced so far, with the help of primitive backbone [$jax.lax$](https://jax.readthedocs.io/en/latest/jax.lax.html#module-jax.lax), it aims to equip its own NumPy API with additional features. Even though all functions in original NumPy library are not implemented in JAX, most of them are ready to be used in it. To examine the available ones, you can look at the [documentation](https://jax.readthedocs.io/en/latest/jax.numpy.html).

In [1]:
import jax
import time
import numpy as np
import jax.numpy as jnp

from jax import grad
from jax import random

To make a small comparison between standard numpy that we have been using so far and the one introduced by Jax framework and to be familiar with JAX as upgraded NumPy, let's we look at some subtle differences and implementation details:

**1. Functional Programming Principle and Pure Functions:**

Jax prefers to follow the principles of functional programming, so it avoids all possible side-effects, one of which is to modify an array with in-place manner. In other words, while immutability is at the center of jax, standard numpy enables us to create mutable arrays. To cope with this issue, Jax provides completely pure update function $x.at[i].set(y)$, which actually creates the copy of data and perform data modification over it to preserve original data untouched. Since taking the transpose of an array and reshaping it tend to violate side effect rules in same aspect, Jax prefer to work with copies of original data for those operations too.

In [2]:
# ++
x1 = np.array([[1, 2], [3, 4]])  # standard numpy array
x2 = jnp.array([[1, 2], [3, 4]])  # jax numpy array

# Exception block to catch the error posed by in-place data manipulation
try:
    x2[1, 1] = 5
except Exception as ex:
    print(type(ex).__name__, ex)

# As you see, the attempt to change the data with pure modification actually creates the updated copy of data
print(x2.at[1, 1].set(5))

TypeError '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
[[1 2]
 [3 5]]


**2. GPU Support:**

Any numpy array created by JAX can be allocated to the accelerators, which are GPU and TPU. In other words, JAX introduces the combination of numpy's flexibility with GPU support. As you know, the devices (processing units) are split into three main categories which are CPUs, GPUs, and TPUs. Each of those three device groups is actually considered as a device backend by JAX, and by passing name of these three backends to the function [$jax.devices()$](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices) we can obtain the list of all devices under that backend, which are ready to be used with JAX.

Whenever a JAX array is instantiated, it is allocated in defult device, which is $jax.devices()[0]$, and this allocation is called as ***uncommited data***. All computations of uncommited data are performed on default device, and the result of those computations is also allocated on default device. However, we can transfer the data which were uncommitted on CPU to another device such as GPU in order to accelerate the operations. This can be achieved by passing one of the devices in device list to $jax.device\_put()$ function. At that point, you can ask yourself, as you just explained, whether we have to explicitly commit data to GPU. The answer to this question is **NO**. If you do not have any GPU and TPU in your computer, your default device naturally becomes CPU; however, the ones with those accelerators can configure their default device as GPU by setting the environment variable ***JAX_PLATFORM_NAME*** to "gpu", which eliminates explicitly data commitment process for every array instantiation:

In [3]:
A = jnp.arange(100, dtype=jnp.float32)
B = jnp.reshape(A, newshape=(10, 10))

# Explicit Data Commitment
# I have just CPU, the ones with GPU should set backend argument to "gpu" for acceleration
devices = jax.devices(backend="cpu")
print("All Devices: ", devices)  # list of all devices under CPU backend
B = jax.device_put(B, devices[0])

# block_until_ready() due to asynchronous dispatch, we will talk about it in further parts
%timeit jnp.dot(B, B.T).block_until_ready()

All Devices:  [CpuDevice(id=0)]
79.3 µs ± 355 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


**3. DeviceArray:**

Now that, we briefly covered data and computation placement on device in JAX just above. Let's we get familiar with how this data (JAX array) is created and placed on those devices.

A JAX array is composed of three main components:

* ***aval***: "abstract value associated to that array"

This value actually stores the meta information about the created JAX array. While it  records the shape and data type of the array, it also maintains a boolean value about whether it is weakly or strongly typed data. 

* ***device***: "optional sticky device"

In previous section, we mentioned about data commitment procedure. This data field tells us whether the data is specifically committed to particular device, if it is, which one of those device it is commited.

* ***device_buffer***: "the underlying buffer owning the on-device data"

The last but the most important data field which JAX array is composed of is device_buffer. When an array is instantiated, a buffer region on a default or requested device is allocated and the data is stored in that buffer. The first crucial thing about this buffer is that it becomes completely device-oriented. In other words, this buffer is characterized with the features of the device on which it is allocated, so by calling $device\_buffer.device()$ function, we can learn which device our data is placed on. At this point, you can think that second data member (device) would be sufficient to learn the device of the data, **why do I need to call $device()$ function of *device_buffer* ?** In fact, second data member (device component) of JAX arrays does tell us which device it is on only when it is explicitly commited to a device. In other words, the ones created as default and not transferred to other device by $device\_put()$ will have $None$ value for $device$ member.

As you see, the created JAX arrays are strongly correlated with which device they are placed on; they store important details about them. That's why dtype of JAX arrays is called as $DeviceArray$. DeviceArrays are ndarrays backed by a single device memory buffer, and if you have two DeviceArrays commited to two different devices, you are not allowed to perform any binary operations between them. To be more familiar with the structure of DeviceArray architecture, you can check Google's github repository [DeviceArray](https://github.com/google/jax/blob/8b1cb8a5366edfbffc8a9c193c0cb433c4756d93/jax/_src/device_array.py#L47).

In [4]:
device_array1 = jnp.array([[1, 2], [4, 3]])
device_array2 = jnp.array([[1.3, 5.6, 6.7], [4.12, 3.34, 2.2]])

device_array1 = jax.device_put(device_array1, jax.devices(backend="cpu")[0])

print("Commited DeviceArray: ", device_array1)
print("Which device it is placed on: ", device_array1.device_buffer.device())
print("Whether it is commited, if so where it is commited:", device_array1._device)
print("Meta information about this array: ", device_array1.aval)
print("Data type of this array: ", type(device_array1))

print("\nUncommited DeviceArray: ", device_array2)
print("Which device it is placed on: ", device_array2.device_buffer.device())
print("Whether it is commited, if so where it is commited: ", device_array2._device)
print("Meta information about this array: ", device_array2.aval)
print("Data type of this array: ", type(device_array2))

Commited DeviceArray:  [[1 2]
 [4 3]]
Which device it is placed on:  TFRT_CPU_0
Whether it is commited, if so where it is commited: TFRT_CPU_0
Meta information about this array:  ShapedArray(int32[2,2])
Data type of this array:  <class 'jaxlib.xla_extension.DeviceArray'>

Uncommited DeviceArray:  [[1.3  5.6  6.7 ]
 [4.12 3.34 2.2 ]]
Which device it is placed on:  TFRT_CPU_0
Whether it is commited, if so where it is commited:  None
Meta information about this array:  ShapedArray(float32[2,3])
Data type of this array:  <class 'jaxlib.xla_extension.DeviceArray'>


**4. Differentiable NumPy**

However, considering JAX as only GPU-supported Numpy under the functional programming principles is nothing but underestimating what JAX is capable of doing; it is absolutely more than those. Additionally, JAX can automatically differentiate the functions composed of not only python's own native characteristics like if-conditions, recursions, closures and iterations, but also jax.numpy codes. To accomplish this, it utilizes $ jax.grad()$ transformer, which takes the name of target function as input argument, and generates derivative function, that is why differentiability system in JAX is called as transformation of functions and programs. 

In [5]:
# native python function
def f(x):
    return x**3

# jax-numpy function
def g(x):
    return jnp.exp(x)

# What jax.grad() returns is a function; it generates a function from a function, so it is a transformer
grad_f = grad(f)  # 3 * x * x
grad_g = grad(g)  # exp(x)

print(grad_f(2.5)) # 3 * 2.5 * 2.5 = 18.75
print(grad_g(1.0)) # exp(1) = e

18.75
2.7182817


# Compilation in JAX:

XLA is a domain-oriented compiler designed to accelerate all linear algebra based operations in machine learning models; however, it also enables device memory to be much more efficiently used apart from the improvements in speed. It was released as a part of TensorFlow framework, but now it is also used with JAX. In other words, JAX utilizes [XLA](https://www.tensorflow.org/xla) engine to compile and run its own NumPy codes on the accelerators. To mention about the behavior of XLA briefly, it splits and compiles the machine learning model or just a small python function abounding with a couple of operations into smaller but unique pieces, called as XLA-optimized kernels. Without XLA optimization, each kernel represents just one operation, XLA tries to fuse them as much as possible into single GPU kernel.

At this point, you can think and say that this is not a new concept; it exists for a long time with Tensorflow. However, JAX, with the help of $jax.jit()$ transformer, becomes capable of doing **Just In-Time (JIT)** compilation, which is the concept of compiling the source code at run time just before executing it, to have efficient code execution in XLA. We will examine JIT structure of JAX in detail  on further notebooks.

# Asynchronous Dispatch in JAX:

Calling a function in Python is a little bit expensive compared with other programming languages, since Python opts for dynamic typing, and this poses great type inspection overhead. To get into more detail, arguments, variables and return value of python functions are determined at run-time; so function calls in python can be more costly. To alleviate these overheads, asynchronous dispatch is preferred to be used in JAX.

Whenever a function is invoked to perform some kinds of computations like dot product of two matrices or SVD of a matrix, it is waited to complete entire operation. JAX, with the help of [*asynchronous dispatch*](https://jax.readthedocs.io/en/latest/async_dispatch.html), gives up waiting for its entire completion, and directly returns a DeviceArray. At that point, you wonder how the result is computed and returned without waiting until the calculation is finished. In fact, returned JAX array is a future value that will be produced and allocated on a device; it is currently not available. You can think like it is an empty array but it consists of meta information about the structure of the result; hence, we are allowed to inspect the shape and type of it. Even we can use it in another computation, since the result of all computations will be dispatched asynchronously; the subsequent operations do not need exact result right away. However, if a request is created to access the exact result like printing it or converting to $numpy.ndarray$, then dispatch is synchronized, and it is waited for execution of operation to be completed. 

This concept is called asynchronous dispatch. To wrap up and make it a little bit more clear, if the result of the computation is not required on the host, execution of computations can be conducted asynchronously on another thread, which alleviates the excessive load on critical path of python program, that is main thread. This prevents the accelerators from being waited so much time, and thereby making it possible to enqueue the operations on the device faster than it can be execued. Hence, it is a quite useful concept. At that point, the function $block\_until\_ready()$, used above while mentioning gpu support in JAX, forces the program to wait for the computation to complete. In that way, it enables us to test the effect and role of asynchoronous dispatch:

In [6]:
# PRNGKey() will be mentioned on further notebooks
x = random.uniform(random.PRNGKey(0), (3000, 3000))
%time jnp.dot(x, x)  # approximately 582 us
%time jnp.dot(x, x).block_until_ready()  # approximately 22.9 ms

CPU times: user 1.7 s, sys: 40.9 ms, total: 1.74 s
Wall time: 454 ms
CPU times: user 1.7 s, sys: 30 ms, total: 1.73 s
Wall time: 435 ms


DeviceArray([[740.13153, 739.0283 , 743.0375 , ..., 740.0252 , 766.72534,
              756.9948 ],
             [754.3393 , 750.0729 , 747.5086 , ..., 749.75885, 771.0327 ,
              753.9398 ],
             [744.30945, 745.9499 , 753.9323 , ..., 752.6093 , 768.2983 ,
              753.47156],
             ...,
             [749.59576, 753.1669 , 757.54675, ..., 753.0712 , 781.12213,
              760.89844],
             [741.66565, 727.6113 , 742.857  , ..., 735.88513, 764.0407 ,
              747.057  ],
             [750.45935, 753.6105 , 755.4329 , ..., 757.1083 , 784.2975 ,
              761.195  ]], dtype=float32)

***Let's we analyze and make an assesment about the code above:***

When a square matrix of $ 9 \cdot 10^6 $ elements is multiplied by itself in two different manners. It seems like only 582 microseconds were required for entire process; however, it is not possible for such a severe operation with $ 27 \cdot 10^9 $ multiplications to be completed in such a short time. In fact, that $582$ microseconds is nothing but the time taken to dispatch the work to helper thread; it is not measuing the execution of entire matrix multiplication. At this point, we are encountered by how asynchronously dispatch works. On the other hand, in second case, the function $block\_until\_ready()$ enforces the program to wait until the computation is completed, which enables us to measure true cost of the operation, that is $22.9$ miliseconds.

***Asynchronously dispatch is not guaranteed:***

Even if no request comes from the host or no explicit methods to avoid the dispatch asynchronously like $block\_until\_ready()$ are used, asnchronously dispatch may not happen. There are a couple of reasons for this:

1. HLO cost analysis is performed over flop counts to determine whether the operation is cheap or not. Depending on the decision, small elementary computations on CPU specifically are executed on main thread, rather than asynchronously on another thread. Besides, any attempt to execute cheap computations asynchronously on CPU backend has the potential to hurt the performance, so it is not a preferred way.

2. The number of in-flight asyn. computations that can be carried out at once is limited, and its limitation is completely device-dependent. Exceeding this limitation with some sort of computations can completely fill the operation queue and it starts to overflow on main thread. As a result, main thread is blocked until the computations will finish.

Due to these reasons, you may obtain confusing and misleading timing results when you execute the code above on CPU backend. Hence, I strcitly advise the people with only CPU to try it on Google Colab to get meaningful and informative timing results.

If you want to have more detailed explanation about this topic, you can visit and examine my discussion on Github Page of Google-Jax: [Asynchronous Dispatch in JAX](https://github.com/google/jax/discussions/9895?sort=new)

# Automatic Vectorization in JAX:

Automatic vectorization in JAX enables mathematical calculations and operations like matrix multiplication, convolution, or transformation to be applicable to more than one item, that is batching. To get into more detail, JAX, in fact, allows us to generate batched version of a function designed for a mathematical operation with the help of a transformer $jax.vmap()$, which stands for vectorizing map. For example, we want to implement a simple neural network composed of just fully-connected (FC) layers. Assume that according to the architecture that we decided on, one of the layers in the network takes a vector of $512x1$ shape as input, and reduces its size by half. To realize this operation, we can easily implement the transformation by $512x256$ matrix over $x^TW$. **What about batching ? When more than one input are passed to the network, how do we going to batch this linear affine transformation ?**

We have two answers to that question:

1. Iterating over the samples (vectors) in the batch
2. Creating a system in which all samples will be multiplied by transformation (weight) matrix at the same time

First one is not recommended, since it is not efficient; it increases computational time enormously. Second method is the one that we want to opt for. If all samples (vectors) in the batch are concatenated from top to bottom to form a matrix, we can easily handle it by matrix-matrix product $XW$. This is actually what automatic vectorization does for us. Instead of trying to extend the mathematical operation from one sample to entire batch, we can directly utilize this transformer. Maybe you can ask that batching is not so difficult; why do we need to have a transformer that makes it for us ? To be honest, it is not as easy as expected for every operation. For example, you can have hard times to batch the convolution; the ones who did the homeworks of Stanford CS231 lectures know how sample indices are carefully organized in batched convolution. 

In [7]:
X = random.normal(random.PRNGKey(0), (100, 512))
W2 = random.normal(random.PRNGKey(0), (512, 256))

# assume that x is of 512x1 row vector
def dense_layer2(x):
    return jnp.dot(x.T, W2)

def manually_batched_dense_layer2(X):
    return jnp.dot(X, W2)

vmap_batched_dense_layer2 = jax.vmap(dense_layer2)

# botch functions compute its own result
# to confirm the truth of vmap batch function, let's compare all items in the result
# (100, 512) x (512, 256) = (100, 256) --> 25600 items
# If all values in both result are equivalent, sum of boolean values should be 25600
result1 = manually_batched_dense_layer2(X)
result2 = vmap_batched_dense_layer2(X)
jnp.sum(jnp.asarray(result1 == result2, dtype=jnp.int8))

DeviceArray(25600, dtype=int32)

Automatic vectorization can be actually composed with Just-In-Time (JIT) compilation for acceleration. JAX enables us to combine transformers, which makes everything in JAX much more handy, useful, and practical. We will investigate them in further notebooks.

# References

* https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
* https://jax.readthedocs.io/en/latest/jax.numpy.html
* https://jax.readthedocs.io/en/latest/faq.html
* https://en.wikipedia.org/wiki/Just-in-time_compilation
* https://www.tensorflow.org/xla
* https://jax.readthedocs.io/en/latest/async_dispatch.html#async-dispatch
* https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html