## **JAX**

JAX is a library for array-oriented numerical computation, with automatic differentiation and JIT compilation to enable high-performance machine learning research

1. JAX provide a unified NumPy-like interface to computations that run on CPU, GPU or TPI, in local or distributed settings,
2. JAX features built-in Jut-in-Time (JIT) compilation, and open source machine learning compiler ecosystem.
3. JAX functions support efficient evalution of gradients via its automatic differentiation transformations.
4. JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs

In [9]:
import jax
import jax.numpy as jnp

With the above import, we can immediately start using JAX in a similar manner to NumPy

In [10]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x))

[0.        1.05      2.1       3.1499999 4.2      ]


JAX works great for many numerical and scientific programs, but only if they are written with certain constraints, as explained in [tutorial_n.ipynb](#add_link_when_done)

### **Just-in-time compilation with `jax.jit()`**

JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one). However, in the above code, JAX is dispatching kernels to the chip one operation at a time. If we have a sequence of operations, we can use the `jax.jit()` function to compile this sequence of operations together using XLA.


We can use python's `%timeit` to quickly benchmark our `selu` function, using `block_until_ready()` to account for JAX's dynamic dispatch. See [tutorial_async](#add_it_too) for more

In [11]:
from jax import random

key = random.key(135)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()

786 μs ± 29.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


We can speed the execution time for this function with `jax.jit()` transformation, which will `jit-compile` the first time `selu` is called and it will be cached forever

In [12]:
from jax import jit

selu_jit = jit(selu)
_ = selu_jit(x)  # warmup
%timeit selu_jit(x).block_until_ready()

221 μs ± 12 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


This is just the execution time on CPU, the same code can be run on GPU, or TPU, typically for even greater speedup

### Taking derivatives with `jax.grad()`

In addition to transforming functions via JIT compilation, JAX also provides other transformation. One such transformation is `jax.grad()`, which performs **automatic differentiation**

Automatic Differentiation (AD) is a technique to compute derivatives (graadients) of functions automatically and exactly, using the chain rule. It's not symbolic (like `sympy`), not numerical (like finite differences), it is a programmatic way of doing calculus.

Let's say you have a function:

In [13]:
def f(x):
    return x**2 + 3*x + 2

and you want to compute $\frac{df}{dx}$

With AD, you don't have to manually compute that, JAX (or maybe PyTorch) will do it automatically and exactly.

#### Types of AD

1. Forward Mode AD:
    - Computes derivatives along with the function as it runs
    - Efficient for function with few inputs and many outputs
2. Reverse Model AD (backpropagation):
    - Runs the function forward, and then compute derivatives backward.
    - Efficient when you have many inputs and one output (like in Neural nets)

In [14]:
grad_f = jax.grad(f)
print(grad_f(2.0))  # Should print 7.0, which is the derivative of f at x=2

7.0


**JAX** does automatic differentiation using a system of tracing and transformation of Python functions.

*JAX doesn't just run your function, it records the operations and applies the chain rule step-by-step using a method called Reverse mode AD*

### **Step 1: Tracing**

When you call `jax.grad(f)`, JAX doesn't immediately compute the gradient, instead:

- It traces the function, meaning runs the function once with a special object called **Tracer** instead of a regular number.
- The **Tracer** object records every operation like `x**2`, `+` etc, in a computation graph


So our function becomes a chain of elementary operations:

$ x \rightarrow x^2 \rightarrow 3x \rightarrow x^2 + 3x \rightarrow x^2 + 3x + 2 $

### **Step 2: Chain Rule (Reverse Model)**


JAX walks backward through this computation graph and applies the chain rule to compute how each operation contributed to the final result. For the function

$$ f(x) = x^2 + 3x + 2 $$

The derivative is:

$$ f'(x) = 2x + 3 $$


But JAX doesn't *differentiate algebraically* like sympy, it builds and walks a graph like this:


```
          x = 2.0
           |
     +-----+------+
     |            |
   x**2         3*x
     |            |
     +-----+------+
           |
          +2
           |
         Output
```
Backward pass:


$$
\frac{df}{dx} = \frac{∂ \text{output}}{∂x²} * \frac{∂x²}{∂x} + \frac{∂ \text{output}}{∂(3x)} * \frac{∂(3x)}{∂x} $$
$$
       = 1 * 2x + 1 * 3 = 2x + 3$$


You can actually see the intermediate representation (or the JAX's computational graph)

In [15]:
from jax import make_jaxpr
print(make_jaxpr(f)(2.0))


{ [34;1mlambda [39;22m; a[35m:f32[][39m. [34;1mlet
    [39;22mb[35m:f32[][39m = integer_pow[y=2] a
    c[35m:f32[][39m = mul 3.0:f32[] a
    d[35m:f32[][39m = add b c
    e[35m:f32[][39m = add d 2.0:f32[]
  [34;1min [39;22m(e,) }


Let's see more examples:

In [16]:
from jax import grad

def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))


x_small = jnp.arange(3.)

derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))  # Should print the derivative of sum_logistic at x_small


[0.25       0.19661194 0.10499357]


Let's verify using finite differences:

In [19]:
def first_finite_difference(f, x, h=1e-3):
    return jnp.array([(f(x + h * v) - f(x - h * v)) / (2 * h) for v in jnp.eye(len(x))])

print(first_finite_difference(sum_logistic, x_small))  # Should be close to the derivative computed by JAX

[0.24998187 0.1965761  0.10502338]


We can also `jit` compile it:

In [None]:
print(grad(jit(grad(jit(grad(jit(grad(sum_logistic)))))))(2.0))  

0.0207841


So the `grad()` and `jit()` can be composed and mixed arbitrarily

Beyong scalar-valued function, the `jax.jacobian()` transformation can be used to compute the fulll Jacobian matrix of vector-valued functions:

In [22]:
from jax import jacobian

print(jacobian(jnp.exp)(x_small))

[[1.        0.        0.       ]
 [0.        2.7182817 0.       ]
 [0.        0.        7.389056 ]]


Jacobian is a matrix of all partial derivatives of a function with multiple inputs and outputs, 

If we have a function: 
$$ f : \mathbb{R}^n \rightarrow \mathbb{R}^m$$

then the jacobian is an $m \times n$ matrix:

$$ J_{ij} = \frac{\partial f_i}{ \partial x_j} $$

It tells you how much each output depends on each input


In [23]:
def f(x): 
    return jnp.array([
        x[0] + 2 * x[1],
        x[0]**2 + x[1]**2,
    ])

The Jacobian matrix is: 
$$ 
 J = \begin{bmatrix}
\frac{\partial f_0}{\partial x_0} & \frac{\partial f_0}{\partial x_1} \\
\frac{\partial f_1}{\partial x_0} & \frac{\partial f_1}{\partial x_1}
\end{bmatrix}
= 
\begin{bmatrix}
1 & 2x_0 \\
2 & 2x_1
\end{bmatrix} $$


In [24]:
x = jnp.array([1.0, 2.0])
jacobian_f = jacobian(f)
print(jacobian_f(x))  # Should print the Jacobian matrix of f at x

[[1. 2.]
 [2. 4.]]
