# Intro to JAX

Author: Adrian Price-Whelan (CCA, 2025)

[Live notebook on Colab](https://colab.research.google.com/github/JAXtronomy/tutorials/blob/main/tutorials/Intro-to-JAX.ipynb)

[Here is the slide deck](../../_static/Intro-to-JAX-for-astronomers.pdf)

In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

%xmode minimal

ModuleNotFoundError: No module named 'jax'

## JIT Compilation

Just-in-Time (JIT) compilation enables getting low-level language performance from Python code, on the fly. This can dramatically speed up some numerical computations, especially where loops are required or where the code is not easily vectorized. 

### Example: Computing the gravitational potential of a set of particles

In [2]:
def func(x, a):
    return jnp.log(jnp.sum(x**2, axis=1) + a**2)

In [3]:
x = np.random.normal(size=(1_000_000, 3))

NameError: name 'np' is not defined

In [4]:
# without JIT:
%timeit func(x, 1.0).block_until_ready()

NameError: name 'x' is not defined

In [5]:
func_jit = jax.jit(func)

# this first call compiles the function
func_jit(x, 1.0).block_until_ready()

# with JIT:
%timeit func_jit(x, 1.0).block_until_ready()

NameError: name 'jax' is not defined

This example is a little contrived since the non-JITted version is already vectorized, but it serves to illustrate the JIT compilation process. JIT is most useful when you need to run a function call many times either in a loop or some other iterative process. For example, optimization, MCMC sampling, orbit integration, etc.

BTW: Other packages like Numba and PyTorch also have JIT compilation, but JAX's JIT pairs with XLA to produce highly optimized and hardware-specific machine code.

A brief aside about `.block_until_ready()`: JAX uses lazy evaluation, so it doesn't actually run the code until you need the values. This is great for performance, because JAX will build up a graph of function calls, compile it and send it to XLA for further optimization. But it can be confusing when debugging -- if you don't do this, it can seem like your code is running absurdly fast. `.block_until_ready()` forces JAX to run the code and wait for it to finish before continuing. Compare these:

---

## Vectorization with `vmap`

If your function can be expressed using standard `numpy` operations, writing JAX code with `jax.numpy` to accept arrays as inputs will automatically vectorize the function. However, we often write functions that are not easily vectorized - for example, functions that contain conditional statements or loops. In these cases, `jax.vmap` provides an efficient way to apply a function across batches of inputs, often avoiding explicit (external) loops.

For example, imagine you have a bunch of stellar spectra for a set of stars, and you want to compute the depth of some absorption line that appears at a different location (in pixels) for each star because of doppler shifts. You could write a function that computes the depth of the line for a single spectrum, and then use `jax.vmap` to apply it to all the stars in your dataset. 

### Example: Find the depth of an absorption line in a set of spectra

First we simulate some "spectra":

In [6]:
N = 128  # number of spectra to make

rng = np.random.default_rng(123)
pix = jnp.arange(100)

true_ctr = rng.uniform(40, 60, size=(N,))
true_depth = rng.uniform(0.1, 0.5, size=(N,))


def make_spectrum(pix, ctr, depth, scale):
    return 1 - depth * jnp.exp(-0.5 * (pix - ctr) ** 2 / scale**2)


# we'll use vmap to simulate the spectra too ^_^
spectra = jax.vmap(
    make_spectrum,
    in_axes=(None, 0, 0, None),
    out_axes=0,
)(pix, true_ctr, true_depth, 10.0).block_until_ready()

NameError: name 'np' is not defined

In [7]:
plt.figure(figsize=(6, 4))
_ = plt.plot(spectra.T, "-", alpha=0.1, color="k")

NameError: name 'plt' is not defined

Now we have a set of spectra, and we want to find the depth of the absorption line near pixel ~50. We can write a function that computes the depth of the line for a single spectrum, and then use `jax.vmap` to apply it to all the spectra in our dataset:

In [8]:
def find_depth(spectrum):
    # find the minimum pixel value:
    min_pix = jnp.argmin(spectrum)

    # locally fit a parabola to the spectrum to find the inter-pixel minimum
    y = jax.lax.dynamic_slice(spectrum, (min_pix - 1,), (3,))
    x = jnp.arange(-1, 2)
    A = jnp.vander(x, 3, increasing=True)  # creates a matrix with columns [1, x, x**2]
    coeffs = jnp.linalg.lstsq(A, y, rcond=None)[0]

    return 1 - coeffs[0]


find_depth_vmap = jax.vmap(find_depth, in_axes=0)

NameError: name 'jax' is not defined

In [9]:
%%time
for spectrum in spectra:
    find_depth(spectrum).block_until_ready()

CPU times: user 5 μs, sys: 0 ns, total: 5 μs
Wall time: 9.06 μs


NameError: name 'spectra' is not defined

In [10]:
%%time
find_depth_vmap(spectra).block_until_ready()

CPU times: user 5 μs, sys: 0 ns, total: 5 μs
Wall time: 8.11 μs


NameError: name 'find_depth_vmap' is not defined

In [11]:
depths = find_depth_vmap(spectra).block_until_ready()
plt.scatter(depths, true_depth)
plt.xlabel("Measured depth")
plt.ylabel("True depth")

NameError: name 'find_depth_vmap' is not defined

---

## Automatic Differentiation

Automatic differentiation (autodiff or autograd) allows you to compute exact gradients of functions. This is generally most useful for simplifying optimization problems (and enabling faster or novel optimization methods).

### Example: Evaluating the gradient of a function

In [12]:
def some_func(x, A, P):
    return A * jnp.cos(2 * jnp.pi * x / P)

In [13]:
some_func_grad = jax.grad(some_func, argnums=0)

xgrid = jnp.linspace(0, 10, 4096)
some_func_grad(xgrid, 1.0, 1.0)

NameError: name 'jax' is not defined

Huh, what happened? Why did this error? `jax.grad` only works on scalar outputs. If we try to pass in an array, the output of `some_func` will be an array. Instead of the above, if we want to compute the gradient for all elements in an array, we can use `jax.vmap` to vectorize the `grad`'ed function:

In [14]:
some_func_grad_vmap = jax.vmap(jax.grad(some_func, argnums=0), in_axes=(0, None, None))

NameError: name 'jax' is not defined

In [15]:
plt.figure(figsize=(6, 4))
plt.plot(xgrid, some_func(xgrid, 1.0, 1.0), "-", label="f(x)")
plt.plot(xgrid, some_func_grad_vmap(xgrid, 1.0, 1.0), "-", label="df/dx")
plt.legend(loc="lower right", fontsize=20)

NameError: name 'plt' is not defined

---

## Pytrees

Pytrees are JAX’s structured data containers. Think of these as dictionaries, where the data may be scalar or array valued and could be arranged hierarchically. JAX functions often take Pytrees as input and output, and they are designed to be compatible with JAX's JIT compilation and autodiff features. Pytrees are one of my favorite features of JAX!

### Example: Computing the gradient of a function with respect to a set of parameters


In [16]:
def model(x, a, b):
    return a * x + b


def objective(params, x, y, y_err):
    model_y = model(x, params["a"], params["b"])
    return jnp.sum((y - model_y) ** 2 / y_err**2)  # chi2

In [17]:
rng = np.random.default_rng(42)
x = rng.uniform(0, 10, size=16)
y = model(x, a=8.67, b=5.309)
y_err = rng.uniform(0.1, 5.0, size=len(x))
y += rng.normal(0, y_err, size=len(x))

plt.figure(figsize=(6, 4))
plt.errorbar(x, y, yerr=y_err, fmt="o", label="data")

NameError: name 'np' is not defined

In [18]:
objecive_grad = jax.grad(objective, argnums=0)

NameError: name 'jax' is not defined

In [19]:
objecive_grad({"a": 1.0, "b": 1.0}, x, y, y_err)

NameError: name 'objecive_grad' is not defined

### Example: Using `vmap` over a pytree of parameters

You can use `jax.vmap` to efficiently evaluate a function for a batch of pytree keys. Suppose we have arrays of parameter sets and we want to evaluate the function for each set of parameters:

In [20]:
params_arr = {
    "a": jnp.linspace(0, 10, 5),
    "b": jnp.linspace(-5, 5, 5),
}

objective_vmap = jax.vmap(objective, in_axes=({"a": 0, "b": 0}, None, None, None))
chi2_arr = objective_vmap(params_arr, x, y, y_err).block_until_ready()
chi2_arr

NameError: name 'jnp' is not defined

## Sharp Bits

### Control flow

In general, use `jax.lax.cond` instead of `if` statements, and `jax.lax.scan` instead of `for` loops. This is because JAX uses XLA to compile the code, and XLA needs to know the shape of the data at compile time. 

In [21]:
@jax.jit
def func_if1(x):
    if x > 0:
        return x**3
    else:
        return x**2

NameError: name 'jax' is not defined

In [22]:
func_if1(10.0)

NameError: name 'func_if1' is not defined

In [23]:
@jax.jit
def func_if2(x):
    return jax.lax.cond(x > 0, lambda: x**3, lambda: x**2)

NameError: name 'jax' is not defined

In [24]:
func_if2(10.0)

NameError: name 'func_if2' is not defined

## Teaser: Using physical units in JAX code

`astropy.units` has the `Quantity` object for handling data with associated physical units. However, this is an explicit subclass of `numpy.ndarray`, so JAX does not work natively with Astropy units. Instead, in the meantime, we are building a new package called `unxt` for handling JAX tracers with physical units. This package is still in development, but it is ready for use:

In [25]:
import unxt as u
import quaxed

ModuleNotFoundError: No module named 'unxt'

In [26]:
pos = u.Quantity(jnp.arange(10.0), "kpc")
time = u.Quantity(5.3, "Gyr")

result = pos / time
result.unit

NameError: name 'u' is not defined

In [27]:
def compute_potential(r, GM, a):
    # Gravitational potential for a Hernquist model:
    return -GM / (r + a)

In [28]:
dPhi_dr = quaxed.grad(compute_potential)

NameError: name 'quaxed' is not defined

In [29]:
r = u.Quantity(1.0, "kpc")
GM = u.Quantity(1.0, "kpc^3 / Gyr^2")
a = u.Quantity(1.0, "kpc")
compute_potential(r, GM, a).block_until_ready()

NameError: name 'u' is not defined

In [30]:
dPhi_dr(r, GM, a).block_until_ready()

NameError: name 'dPhi_dr' is not defined