In [None]:
import numpy as np
import jax
from jax import numpy as jnp

# JAX

JAX is a Python library that combines NumPy's familiar array operations with automatic differentiation and hardware acceleration (GPU/TPU). At its core, JAX provides composable function transformations that enable efficient gradient computation, vectorization, and just-in-time compilation. For scientific computing, JAX's key innovation is making gradient-based optimization as simple as calling `grad(f)` on any function `f`.

### Automatic Differentiation

Automatic differentiation (AD) computes derivatives by systematically applying the chain rule to elementary operations. Given a function $f: \mathbb{R}^n \to \mathbb{R}^m$ composed of operations as
$$
f = f_N \circ f_{N-1} \circ \cdots \circ f_1 \; ,
$$
AD efficiently evaluates the gradient $\nabla f(\mathbf{x})$ by accumulating partial derivatives through the chain rule as
$$
\frac{\partial f}{\partial x_i} = \frac{\partial f_N}{\partial f_{N-1}} \frac{\partial f_{N-1}}{\partial f_{N-2}} \cdots \frac{\partial f_{n+1}}{\partial f_n} \cdots \frac{\partial f_1}{\partial x_i} \; ,
$$
which is really just matrix multiplication of Jacobian matrices. Unlike numerical differentiation (finite differences) which suffers from truncation errors, or symbolic differentiation which can produce exponentially large expressions, AD computes exact derivatives at machine precision with computational cost proportional to evaluating $f$ itself.

Let's look at a simple example of using JAX for automatic differentiation. Consider the function
$$
f(x | A, b) = \sum _ i \frac{1}{1 + e^{-(A x + b)_i}} \; .
$$

In [2]:
def f(params, x):
    A, b = params
    y = A @ x + b
    return jax.nn.sigmoid(y).sum()

In [3]:
# Randomness is a little different in JAX, we will use numpy for this simple example
x = jnp.array(np.random.randn(3)) 
A = jnp.array(np.random.randn(10, 3))
b = jnp.array(np.random.randn(10))

f_val, f_grad = jax.value_and_grad(f)((A, b), x)

In [4]:
f_val

Array(6.7408266, dtype=float32)

In [5]:
f_grad # It is a tuple with gradients with respect to A and b !

(Array([[ 0.04165042,  0.03343025, -0.10853758],
        [ 0.08074013,  0.06480516, -0.21040215],
        [ 0.12888701,  0.10344972, -0.33586895],
        [ 0.11841699,  0.09504608, -0.30858496],
        [ 0.10518608,  0.08442643, -0.2741063 ],
        [ 0.02413005,  0.01936772, -0.06288093],
        [ 0.08724204,  0.07002385, -0.22734559],
        [ 0.07838579,  0.06291548, -0.20426694],
        [ 0.09324119,  0.074839  , -0.24297887],
        [ 0.1004806 ,  0.08064964, -0.2618442 ]], dtype=float32),
 Array([0.07803333, 0.15126908, 0.2414737 , 0.22185782, 0.1970693 ,
        0.04520838, 0.1634506 , 0.14685816, 0.17469019, 0.18825345],      dtype=float32))

JAX know how to propagate gradients through Python containers too!

## Automatic Vectorization with `vmap`

Let's see how `vmap` automatically vectorizes functions. Suppose we want to evaluate $f$ at multiple input points $\{x_1, x_2, \ldots, x_N\}$ simultaneously. Without `vmap`, we would need to manually loop or add batch dimensions to our function.

In [6]:
# Create a batch of inputs: shape (100, 3)
x_batch = jnp.array(np.random.randn(100, 3))

# Naive approach: loop over batch
results_loop = jnp.array([f((A, b), x_i) for x_i in x_batch])

In [7]:
# With vmap: automatically vectorize over the batch dimension
# in_axes=(None, 0) means: don't map over params, but map over axis 0 of x
f_batched = jax.vmap(f, in_axes=(None, 0))
results_vmap = f_batched((A, b), x_batch)

In [8]:
# Verify they produce the same results
print(f"Results match: {jnp.allclose(results_loop, results_vmap)}")
print(f"Output shape: {results_vmap.shape}")

Results match: True
Output shape: (100,)


The `in_axes` argument specifies which axes to map over: `None` means "don't vectorize this argument", while `0` means "map over axis 0". We can also compose transformations: `vmap(grad(f))` computes gradients for a batch of inputs in parallel!

In [9]:
# Example: batched gradients
grad_f_batched = jax.vmap(jax.grad(f, argnums=1), in_axes=(None, 0))
gradients_batch = grad_f_batched((A, b), x_batch)
print(f"Gradient shape: {gradients_batch.shape}  # One gradient per input")

Gradient shape: (100, 3)  # One gradient per input


## Just-In-Time Compilation with `jit`

JAX can compile functions to optimized machine code using `jit` (just-in-time compilation). This traces the function with abstract values, optimizes the computation graph, and compiles it to XLA (Accelerated Linear Algebra). The first call incurs compilation overhead, but subsequent calls are much faster.

In [38]:
# Let's define a more complex function to see the speedup
def dumb_computation(x):
    x = jnp.tanh(x @ x.T)
    x = jnp.exp(-x**2) @ x
    x = jnp.sin(x) + jnp.cos(x.T @ x)
    x = jax.nn.softmax(x, axis=-1)
    x = jnp.linalg.matrix_power(x, 3)
    return jnp.sum(x * jnp.log(jnp.abs(x) + 1e-8))

# Create test data
test_x = jnp.array(np.random.randn(50, 50))

In [39]:
out = dumb_computation(test_x)
out

Array(-190.08969, dtype=float32)

In [40]:
# Without jit: interpreted execution
%timeit -n 3 -r 3 dumb_computation(test_x).block_until_ready()

674 μs ± 230 μs per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [41]:
stupid_complex_computation_jit = jax.jit(dumb_computation)

In [42]:
out_jit = stupid_complex_computation_jit(test_x)
out_jit

Array(-190.0897, dtype=float32)

In [43]:
# With jit: compiled execution
%timeit -n 3 -r 3 stupid_complex_computation_jit(test_x).block_until_ready()

278 μs ± 50.8 μs per loop (mean ± std. dev. of 3 runs, 3 loops each)


Note the `.block_until_ready()` call—JAX uses asynchronous execution, so we need this to ensure timing is accurate. The compiled version is typically orders of magnitude faster! You can also use `jit` as a decorator: `@jax.jit` above a function definition.

In [44]:
# Combining transformations: jit a vmapped gradient computation
@jax.jit
def batched_gradients_optimized(params, x_batch):
    return jax.vmap(jax.grad(f, argnums=1), in_axes=(None, 0))(params, x_batch)

# This is now fully optimized: compiled + vectorized + differentiated
result = batched_gradients_optimized((A, b), x_batch)
print(f"Shape: {result.shape}")

Shape: (100, 3)


# Key Features of JAX for Scientific Computing

- **Functional transformations**: `grad()` for gradients, `vmap()` for automatic vectorization, `jit()` for compilation
- **Composability**: Transformations can be arbitrarily nested, e.g., `jit(vmap(grad(f)))` to compute batched gradients efficiently
- **NumPy compatibility**: Code often requires minimal changes from NumPy, using `jax.numpy` as a drop-in replacement
- **Hardware acceleration**: Transparent GPU/TPU execution without code modifications
- **Pseudorandom numbers**: Explicit PRNG state management ensures reproducibility in stochastic algorithms