# Introduction to JAX

At its core, JAX provides a way to transform functions using NumPy-like syntax to get automatic differentiation, GPU/TPU acceleration, and much more.

## JAX and Numpy

Using JAX feels very much like using NumPy. The majority of your NumPy knowledge is directly transferable.

In [1]:
# basic examples: array creation, reshaping, and slicing in both JAX and NumPy as side-by-side code snippets.
import jax
import numpy as np
import jax.numpy as jnp

# Examples of array definition:
# Numpy
numpy_array = np.array([[1, 2], [3, 4]])
numpy_ones = np.ones(10)

# JAX
jax_array = jnp.array([[1, 2], [3, 4]])
jax_ones = jax.numpy.ones(10)


2023-10-31 12:09:24.015324: W external/xla/xla/service/gpu/nvptx_compiler.cc:673] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.3.52). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


However, JAX arrays are diffenrent in several important ways: 

- **Immutability:** Once an array is created, its contents cannot be changed, the same way python `tuple` works. 
- **Device Memory:** JAX arrays can live in device memory (like on a GPU), while NumPy arrays live in host memory.
- **Lazy Evaluation:** Unlike NumPy which evaluates operations immediately, JAX can be lazy in its evaluation, waiting to execute operations until the result is actually needed. This is especially true when using some of JAX's transformations. (Hard to showcase with a simple example).
- **Advanced Features:** JAX arrays operations can be auto-differentiated and parallelized (we touch on this later).

## Immutability

In [2]:
# Array immutability

# Item assignment works for numpy arrays.
numpy_array[1, 1] = 0

# While in JAX it raises an TypeError.
try:
    jax_array[1, 1] = 0
except TypeError as err:
    print(f"Exeption raised: {err}")



Exeption raised: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html


## Device Memory

This is how you can display the device in which the array is on, and send it to another device.

In [3]:
# Display the device the array is on
print("Array device:", jax_array.device_buffer.device())

# List all available devices
print("List available devices", jax.devices())

# Send array to a desired device
jax_array = jax.device_put(jax_array, jax.devices("cpu")[0])
print("Array device:", jax_array.device_buffer.device())


Array device: cuda:0
List available devices [cuda(id=0)]
Array device: TFRT_CPU_0


## GPU-acceleration 
As a simple use case, we can see that on large matrices, JAX's GPU acceleration can be noticeably faster.

In [4]:
# Create large matrices
large_numpy_matrix = np.random.rand(5000, 5000)
large_jax_matrix = jnp.array(large_numpy_matrix)

# Time matrix multiplication in NumPy
%timeit np.dot(large_numpy_matrix, large_numpy_matrix)

# Time matrix multiplication in JAX
%timeit jnp.dot(large_jax_matrix, large_jax_matrix).block_until_ready() # The block_until_ready ensures we wait for the result



370 ms ± 18.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
19 ms ± 847 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Just-In-Time Compilation with `jit`

### What is JIT Compilation?

Just-In-Time (JIT) Compilation is a technique where the code is compiled during execution rather than before execution starts. This allows certain optimizations to be made that can't be achieved with traditional ahead-of-time compilation.

For those of you familiar with `numba`, JAX's `jit` serves a similar purpose: it takes a Python function and optimizes it for faster execution. However, while `numba` is designed primarily for speeding up CPU-bound Python code, JAX's `jit` can target both CPU and accelerators like GPUs and TPUs.

### Basic Usage

Here's how to use `jit` in JAX:


In [5]:

def slow_function(x):
    return jnp.sin(x) * jnp.cos(x)

compiled_function = jax.jit(slow_function)

# Now, fast_function will run the optimized version
result = compiled_function(jnp.array([1.0, 2.0, 3.0]))
print(result)

# Alternatively you can use decorators
@jax.jit
def decorated_function(x):
    return jnp.sin(x) * jnp.cos(x)

print(f"JIT compiled with decorators: {decorated_function}")

[ 0.4546487  -0.37840125 -0.13970774]
JIT compiled with decorators: <PjitFunction of <function decorated_function at 0x7f4c3947c670>>


### Tracers
When you jit a function, JAX introduces what are known as "tracers" into the function. Tracers can be thought of as symbolic representations of your input, and they help JAX figure out how to optimize your function.

In [6]:
from jax import make_jaxpr # visualize the series of operations (expressed as a JAX expression or "jaxpr")

# Let's see the jax expression of our function
print(make_jaxpr(slow_function)(jnp.array([1.0, 2.0, 3.0])))

{ lambda ; a:f32[3]. let
    b:f32[3] = sin a
    c:f32[3] = cos a
    d:f32[3] = mul b c
  in (d,) }


The output is a JAX expression, a symbolic representation of your function's operations. This is what JAX uses internally to optimize your code.

### Debugging JIT-compiled Code

Certain functions, such as `jax.lax.scan`, undergo JIT compilation by default. This can make debugging challenging since, during the process, you'll be able to inspect tracers rather than the computed values. To facilitate debugging, you can deactivate the jit compilation as follows:  

In [7]:

@jax.jit
def jitted_function(x):
    print("Inside jitted function:", x)

# Call the function with some data
# The print inside the function will be executed only once, at compilation time
print("JIT enabled")
x = jnp.array([1.0, 2.0, 3.0])
for i in range(3):
    print(f"Iteration {i+1}:")
    jitted_function(x)

print("\nJIT disabled")
with jax.disable_jit():
    for i in range(3):
        print(f"Iteration {i+1}:")
        jitted_function(x)



JIT enabled
Iteration 1:
Inside jitted function: Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
Iteration 2:
Iteration 3:

JIT disabled
Iteration 1:
Inside jitted function: [1. 2. 3.]
Iteration 2:
Inside jitted function: [1. 2. 3.]
Iteration 3:
Inside jitted function: [1. 2. 3.]


## Autodifferentiation with JAX
Autodifferentiation, often referred to as "autograd" or simply "AD", is the ability to compute gradients of functions automatically. JAX provides this capability seamlessly. Let's take a look at how this works.

### Toy Example 1: Linear Regression
We can start off with a standard example, fitting a linear regression by gradient descent on a Mean Squared Error loss.

Let's consider a simple linear regression model: $f(x)=mx+b$, and $y=f(x) + \text{noise}$, where we'll optimize for the slope $m$ and the intercept $b$. We'll use the Mean Squared Error as our loss function.

First, let's generate some sample data:

In [8]:
import matplotlib.pylab as plt
# Generating some sample data
true_m = 2.5
true_b = -1.0
x_data = jnp.linspace(-10, 10, 1000)
y_data = true_m * x_data + true_b + jax.random.normal(jax.random.PRNGKey(123), shape=x_data.shape) * 2.0  # Adding some noise


Let's define the MSE

In [15]:
def linear_model(x, m, b):
    return m * x + b

def mse_loss(params, x, y):
    m, b = params
    y_pred = linear_model(x, m, b)
    return np.mean((y_pred - y)**2)


We can now compute the gradient via AD, by passing the lossn through the `jax.grad` function,

In [16]:
gradient = jax.grad(mse_loss, argnums=0)  # 0 refers to the first argument of mse_loss, i.e., the params tuple
gradient(jnp.array([0.0, 0.0]), x_data,y_data )

Array([-166.60902  ,    1.7543185], dtype=float32)

One can perform gradien decent or any other gradient based optimization on the loss. 

There are many excellent libraries for optimization in the `jax` framwork, for example [`JAXopt`](https://jaxopt.github.io/stable/) for general purpose optimizaiton or [`flax`](https://flax.readthedocs.io/en/latest/) which is specific for neural networks. 

In [11]:
import jaxopt

# Gradient Descent
learning_rate = 0.01
params = jnp.array([0.0, 0.0])  # Initial values for m and b

for i in range(100):  # 100 gradient descent updates
    grad_params = gradient(params, x_data, y_data)
    params -= learning_rate * grad_params
    if i > 90:
        print(f"Iteration {i+1}: m = {params[0]:.4f}, b = {params[1]:.4f}, Loss = {mse_loss(params, x_data, y_data):.4f}")
print(f"\nOptimization results,  m: {params[0]}, b: {params[1]}")
      
# Gradient Descent via JAXopt
solver = jaxopt.GradientDescent(fun=mse_loss, tol=10**-6)
params, state = solver.run(init_params=(1., 1.), x=x_data, y=y_data)
print(f"JAXopt results,        m: {params[0]}, b: {params[1]}")

Iteration 92: m = 2.4941, b = -0.7404, Loss = 4.1020
Iteration 93: m = 2.4941, b = -0.7432, Loss = 4.1012
Iteration 94: m = 2.4941, b = -0.7458, Loss = 4.1005
Iteration 95: m = 2.4941, b = -0.7485, Loss = 4.0998
Iteration 96: m = 2.4941, b = -0.7510, Loss = 4.0992
Iteration 97: m = 2.4941, b = -0.7536, Loss = 4.0985
Iteration 98: m = 2.4941, b = -0.7560, Loss = 4.0979
Iteration 99: m = 2.4941, b = -0.7585, Loss = 4.0974
Iteration 100: m = 2.4941, b = -0.7608, Loss = 4.0968

Optimization results,  m: 2.4941420555114746, b: -0.7608301043510437
JAXopt results,        m: 2.4940860271453857, b: -0.8771581649780273


## Toy Example 2: Regression with a Two-layer Neural Network
Now, let's add a two feed-forward layer with a ReLU activation. 

This will transform our model from a simple linear regression to a slighty complex neural network. The beauty of autodiff, is that this can be handled in the same way.

In [17]:
def two_layer_nn_model(x, w1, b1, w2, b2, w_out, b_out):
    hidden1 = jax.nn.relu(jnp.dot(x, w1) + b1)  # First ReLU activation
    hidden2 = jax.nn.relu(jnp.dot(hidden1, w2) + b2)  # Second ReLU activation
    return jnp.dot(hidden2, w_out) + b_out

def two_layer_nn_mse_loss(params, x, y):
    w1, b1, w2, b2, w_out, b_out = params
    y_pred = two_layer_nn_model(x, w1, b1, w2, b2, w_out, b_out)
    return jnp.mean((y_pred - y)**2)

gradient_two_layer_nn = jax.grad(two_layer_nn_mse_loss, argnums=0)


learning_rate = 0.1
params_two_layer_nn = (jnp.zeros((1, 10)), jnp.zeros(10), jnp.zeros((10, 10)), jnp.zeros(10), jnp.zeros(10), 0.0)  # Initialize weights and biases

for i in range(100):
    grad_params_two_layer_nn = gradient_two_layer_nn(params_two_layer_nn, x_data[:, None], y_data)  # x_data[:, None] reshapes x_data for matrix multiplication
    params_two_layer_nn = tuple(param - learning_rate * grad_param for param, grad_param in zip(params_two_layer_nn, grad_params_two_layer_nn))
    if i < 10:
        print(f"Iteration {i+1}: Loss = {two_layer_nn_mse_loss(params_two_layer_nn, x_data[:, None], y_data):.4f}")


Iteration 1: Loss = 212.3490
Iteration 2: Loss = 212.1717
Iteration 3: Loss = 212.0583
Iteration 4: Loss = 211.9857
Iteration 5: Loss = 211.9392
Iteration 6: Loss = 211.9094
Iteration 7: Loss = 211.8904
Iteration 8: Loss = 211.8782
Iteration 9: Loss = 211.8704
Iteration 10: Loss = 211.8654


## Vectorization and Parallelization in JAX: `vmap` and `pmap`

Parallelization is the process of performing multiple computations simultaneously. JAX offers two primary mechanisms for parallelizing code:

1. Vectorization with **`vmap`**: This is a way to transparently turn a function that operates on single data points into one that operates on batches of data points.
2. Parallelization across devices with **`pmap`**: It allows you to distribute computations over multiple accelerators.

### Vectorization with vmap

Suppose you have a function that computes the square of a number:

In [18]:
def square(x):
    return x * x


If you want to compute the square of a batch of numbers, one way is to use a loop:

In [19]:
xs = jnp.array([1.0, 2.0, 3.0, 4.0])
squared = jnp.array([square(x) for x in xs])


With **`vmap`**, you can vectorize the square function so it can process the entire batch at once:

In [20]:
batched_square = jax.vmap(square)
squared = batched_square(xs)

And parallelizaiton can be done on mutlidimensional arrays, specfing the input and output dimension

In [21]:
# Vector doct product
vector_vector = lambda x, y: jnp.vdot(x, y)  #  ([a], [a]) -> []

# Using vmap, upgrade vv to handle matrix-vector multiplication. 
# We parallelize across the rows of the matrix (axis 0) and broadcast the vector.
matrix_vector = jax.vmap(vector_vector, in_axes=(0, None), out_axes=0)      #  ([b,a], [a]) -> [b]      (b is the mapped axis)

# Apply vmap again to upgrade mv to handle matrix-matrix multiplication. 
# This time we parallelize across the columns of the second matrix (axis 1).
matrix_matrix = jax.vmap(matrix_vector, in_axes=(None, 1), out_axes=1)      #  ([b,a], [a,c]) -> [b,c]  (c is the mapped axis)

print("Compare outputs:\nvmap:\n", matrix_matrix(jax_array, jax_array), "\nregular dot product:\n", jnp.dot(jax_array, jax_array))

# The out_axes argument specifies where to store the parallelized axis in the output. 
# By setting it to 0 in the next function, we effectively transpose the output of the matrix-matrix multiplication.
matrix_matrix_transpose = jax.vmap(matrix_vector, in_axes=(None, 1), out_axes=0)      #  ([b,a], [a,c]) -> [c,b]  (c is the mapped axis)
print("\n\nTranspose outputs:\n", matrix_matrix_transpose(jax_array, jax_array))


Compare outputs:
vmap:
 [[ 7 10]
 [15 22]] 
regular dot product:
 [[ 7 10]
 [15 22]]


Transpose outputs:
 [[ 7 15]
 [10 22]]


### Parallelization across devices with **`pmap`**
`pmap` (parallel map) is a parallelized version of `vmap` and is meant to distribute computations across multiple devices, typically multiple GPUs or TPUs. While vmap vectorizes computations, pmap physically runs them in parallel across devices. Here's a way to build on top of the matrix multiplication example to introduce pmap.
1. **Setup**
Before using pmap, you need to check the available devices:

In [22]:
# assume that you have multiple gpus
devices = jax.devices()
print(devices)

[cuda(id=0)]


**NOTE** You should take the result with a grain of salt. In real-world scenarios, GPU computation is often magnitudes faster than CPU for tasks like matrix multiplication. Plus, the data transfer time between the host (CPU) and device (GPU) can be a significant overhead. So, while it's an interesting demonstration, this setup might not be practically efficient.

In [23]:
# Create two arrays with leading dimension 2, intended for distribution across two devices.
# Each sub-array (of shape (10, 10)) will be sent to a separate device.
x_sharded = jax.random.normal(jax.random.PRNGKey(123), shape=(2, 10, 10))
y_sharded = jax.random.normal(jax.random.PRNGKey(456), shape=(2, 10, 10))  # Note: Using a different seed for variety.

# Define a parallelized matrix multiplication operation using `pmap`.
# Note: We are specifying devices to make sure the operation runs in parallel across the intended devices.
# Here, we duplicate the devices list, assuming `devices` originally has one device for CPU and one for GPU.
pmap_multiply = jax.pmap(matrix_matrix, devices=devices*2)

# Execute the parallel matrix multiplication. Each pair of sub-arrays from x_sharded and y_sharded 
# will be multiplied on a separate device.
result = pmap_multiply(x_sharded, y_sharded)

# Explore the result buffers, which provides details about the shape of the result 
# and the device on which each computation was performed.
for buffer in result.device_buffers:
    print("Shape:", buffer.shape)
    print("Device:", buffer.device_buffer.device())

# Print the shape of the `result` to demonstrate that, to the end user,
# the result appears as a standard array, even though it was computed in a parallelized manner.
print("Result Shape:", result.shape)

Shape: (10, 10)
Device: cuda:0
Shape: (10, 10)
Device: cuda:0
Result Shape: (2, 10, 10)


### Benefits and When to Use

1. **Performance Improvement:** Parallelizing code, especially on accelerators like GPUs, can lead to significant speedups. Operations like matrix multiplications, which can be heavily parallelized, benefit greatly.

2. **Memory Efficiency:** For large-scale computations, distributing data and computations across multiple devices can help in tackling memory limitations of a single device.

However, one should be aware of the communication overhead between devices. If the devices need to communicate frequently or exchange large amounts of data, the benefits of parallelization might be diminished.

## Pytrees

A "pytree" is a tree of Python containers (e.g., lists, tuples, dicts) that can contain numerical arrays (like `numpy.ndarray` or `jax.numpy.DeviceArray`) as leaves. In the context of JAX, pytrees are useful because many JAX operations that take array arguments also support pytree arguments, allowing you to build and manipulate more complex structures.
For a more comprehensive intro see the [official documentation](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html), which is very good.

### Basic Pytree Examples

Here are some examples of pytrees:



In [19]:
# A single array
leaf = jnp.array([1.0, 2.0, 3.0])

# A tuple of arrays
tree = (jnp.array([1.0, 2.0, 3.0]), jnp.array([4.0, 5.0]))

# A dictionary of arrays
tree = {"a": jnp.array([1.0, 2.0, 3.0]), "b": jnp.array([4.0, 5.0])}

### Tree Utilities

JAX provides utilities to manipulate pytrees:

- **tree_map**: applies a function to each leaf in a pytree or multiple pytrees.
- **tree_reduce**: aggregates values across all the leaves of a pytree.

Many others can be found in the `jax.tree_util` module.

In [24]:

tree = {"a": jnp.array([1.0, 2.0]), "b": (jnp.array([3.0, 4.0]), jnp.array([5.0]))}
result = jax.tree_map(lambda x: x * 2, tree)
print(result)


{'a': Array([2., 4.], dtype=float32), 'b': (Array([6., 8.], dtype=float32), Array([10.], dtype=float32))}


### Why are Pytrees Useful in JAX?

When defining models or optimization routines in JAX, it's common to have parameters structured in nested containers, especially when using neural networks libraries  like Flax or Haiku. Pytrees allow for easy manipulation and transformation of these structures without having to manually manage each individual array.

Understanding how to work with pytrees helps in:

- Building more complex models with structured parameters.
- Writing generic JAX code that works seamlessly with these complex structures.
- Utilizing JAX transformations like grad on functions that accept or return these structures.

### Pytrees and Neural Networks

When you're working with neural networks, the weights and biases are typically stored as arrays. As the network grows in complexity, managing these arrays individually can become cumbersome. Pytrees provide a more structured way to handle these parameters, making the code more readable and easier to maintain.

### Two-Layer Neural Network Revisited with Pytrees

First, we define our network:


In [25]:
def init_network(input_dim, hidden_dim, output_dim):
    params = {
        "layer1": {
            "weights": jax.random.normal(jax.random.PRNGKey(123), shape=(input_dim, hidden_dim)),
            "biases": jnp.zeros(hidden_dim)
        },
        "layer2": {
            "weights": jax.random.normal(jax.random.PRNGKey(246), shape=(hidden_dim, output_dim)),
            "biases": jnp.zeros(output_dim)
        }
    }
    return params

Notice how we're using nested dictionaries to represent the layers and their respective weights and biases.

Next, our forward pass:

In [26]:
def forward(params, x):
    # First layer
    z1 = jnp.dot(x, params["layer1"]["weights"]) + params["layer1"]["biases"]
    a1 = jax.nn.relu(z1)
    
    # Second layer
    z2 = jnp.dot(a1, params["layer2"]["weights"]) + params["layer2"]["biases"]
    a2 = jax.nn.relu(z2)
    
    return a2

### Advantages of Pytrees and Tree Operations

1. **Structured Representation**: Pytrees provide a clear and organized representation of model parameters, especially for models with multiple layers or components.

2. **Simplified Parameter Updates**: With the tree utilities provided by JAX, operations like updating model parameters after a gradient descent step become more streamlined.

3. **Scalability**: As your model grows, pytrees keep the code maintainable. Whether you have two layers or twenty, the structure remains consistent.

4. **Generic Code**: You can write more general-purpose functions that operate over model parameters, without having to know the exact structure or layout of the model. For example, you can write a generic training loop for any model, as long as the parameters are structured as pytrees.

5. **Compatible with JAX Transformations**: Pytrees work seamlessly with JAX functions like `grad`, `jit`, and others. When you compute the gradient of a function with respect to a pytree of parameters, JAX returns the gradients in the same pytree structure.

### Example: Gradient Descent Update with Pytrees

Using JAX's tree utilities, updating the parameters of the network after computing gradients becomes straightforward:


In [28]:
# Mean squared error loss
def loss(params, x, y_true):
    y_pred = forward(params, x)
    return jnp.mean((y_pred - y_true) ** 2)

# Compute gradients of the loss with respect to the parameters
grad_fn = jax.grad(loss)

def gradient_descent_step(params, x, y_true, learning_rate=0.01):
    gradients = grad_fn(params, x, y_true)
    return jax.tree_map(lambda p, g: p - learning_rate * g, params, gradients)

Here, `tree_map` is used to subtract the computed gradients from the current parameters, effectively implementing a gradient descent update. The code remains clean and general, regardless of the specific structure of the `params` pytree.

In summary, pytrees and tree operations make handling the parameters of neural networks (or any machine learning model) in JAX more organized, flexible, and scalable.

In [29]:
def print_leaf_shapes(d, path=[]):
    for k, v in d.items():
        new_path = path + [k]
        if isinstance(v, dict):
            print_leaf_shapes(v, new_path)
        else:
            # Assuming the leaf is a JAX array; replace with appropriate check if not
            print(f"{' -> '.join(new_path)}: {v.shape}")


update = gradient_descent_step(init_network(1, 10, 1), x_data[:, None], y_data[:, None])

# returns a tree with the same srtucture of the input parameters.
print_leaf_shapes(update)

layer1 -> biases: (10,)
layer1 -> weights: (1, 10)
layer2 -> biases: (1,)
layer2 -> weights: (10, 1)


Here, `tree_multimap` is used to subtract the computed gradients from the current parameters, effectively implementing a gradient descent update. The code remains clean and general, regardless of the specific structure of the `params` pytree.

In summary, pytrees and tree operations make handling the parameters of neural networks (or any machine learning model) in JAX more organized, flexible, and scalable.


## Resources

- [**Awesome JAX**](https://github.com/n2cholas/awesome-jax) a curated list of `JAX`
