### A4.1.3. JAX Compilation

> *JAX traces Python functions into a functional IR (jaxpr), lowers it to StableHLO, and compiles it via XLA ‚Äî making JIT compilation, automatic differentiation, and vectorization composable transformations.*

**Explanation:**

JAX treats numerical Python functions as data: it **traces** them to extract a functional intermediate representation, then **lowers** and **compiles** that IR.

**JAX Compilation Pipeline:**

```
Python function ‚Üí tracing ‚Üí jaxpr ‚Üí StableHLO ‚Üí XLA HLO ‚Üí LLVM/PTX ‚Üí executable
```

**Tracing:**

JAX calls the function with **abstract tracer values** (carrying shape and dtype but no data). Each JAX primitive encountered is recorded into a **jaxpr** (JAX expression) ‚Äî a flat, functional, SSA-like IR.

**Key Transformations (composable):**

| Transform | Purpose |
|-----------|--------|
| `jax.jit` | JIT-compile via XLA |
| `jax.grad` | Automatic differentiation (reverse-mode) |
| `jax.vmap` | Automatic vectorization (batch dimension) |
| `jax.pmap` | Parallel map across devices |

These compose: `jax.jit(jax.vmap(jax.grad(f)))` is valid.

**Caching and Recompilation:**

- Compiled executables are cached by **(function identity, input shapes, input dtypes)**.
- New shapes trigger recompilation.
- `jax.jit` with `static_argnums` treats certain args as compile-time constants.

**Lowering Stages:**

1. `jax.make_jaxpr(f)(x)` ‚Äî inspect the jaxpr.
2. `jax.jit(f).lower(x).as_text()` ‚Äî inspect the StableHLO.
3. `jax.jit(f).lower(x).compile()` ‚Äî get the compiled executable.

**Example:**

```python
@jax.jit
def predict(params, x):
    return jax.nn.relu(x @ params['w'] + params['b'])
```

First call traces ‚Üí lowers to StableHLO `dot_general + add + max(0, .)` ‚Üí XLA fuses into one kernel. Subsequent calls with same shapes reuse the cached executable.

In [None]:
from dataclasses import dataclass, field


@dataclass
class AbstractValue:
    shape: tuple[int, ...]
    dtype: str


@dataclass
class JaxprEquation:
    primitive: str
    inputs: list[str]
    output: str


@dataclass
class Jaxpr:
    input_vars: list[str]
    equations: list[JaxprEquation] = field(default_factory=list)
    output_var: str = ""


jaxpr = Jaxpr(
    input_vars=["%x", "%w", "%b"],
    equations=[
        JaxprEquation("dot_general", ["%x", "%w"], "%mm"),
        JaxprEquation("add", ["%mm", "%b"], "%biased"),
        JaxprEquation("max", ["%biased", "%zero"], "%out"),
    ],
    output_var="%out",
)

print("Jaxpr:")
print(f"  inputs: {', '.join(jaxpr.input_vars)}")
for equation in jaxpr.equations:
    print(f"  {equation.output} = {equation.primitive}({', '.join(equation.inputs)})")
print(f"  output: {jaxpr.output_var}")

compilation_stages = [
    ("Python function", "def predict(params, x): ..."),
    ("Tracing", "abstract values ‚Üí record primitives"),
    ("Jaxpr", "flat functional SSA IR"),
    ("StableHLO", "dot_general + add + maximum"),
    ("XLA HLO", "fusion, layout, buffer assignment"),
    ("Machine code", "LLVM IR ‚Üí x86 / PTX"),
]

print("\nJAX compilation pipeline:")
for stage_name, detail in compilation_stages:
    print(f"  {stage_name} ‚Üí {detail}")

cache_key_a = ("predict", ((128, 784), "f32"), ((784, 256), "f32"), ((256,), "f32"))
cache_key_b = ("predict", ((64, 784), "f32"), ((784, 256), "f32"), ((256,), "f32"))

print(f"\nCache key (batch=128): {cache_key_a}")
print(f"Cache key (batch=64):  {cache_key_b}")
print(f"Same key? {cache_key_a == cache_key_b} ‚Üí recompilation needed for different batch size")

**References:**

[üìò Bradbury, J. et al. *JAX: Composable transformations of Python+NumPy programs.*](https://github.com/google/jax)

[üìò JAX Documentation. *How JAX primitives work.*](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_Primitives_Work.html)

---

[‚¨ÖÔ∏è Previous: TensorFlow XLA](./02_tensorflow_xla.ipynb) | [Next: TensorFlow Graph Compilation ‚û°Ô∏è](../02_Framework_Integration/01_tensorflow_graph_compilation.ipynb)