### A4.2.2. JAX Just-in-Time Compilation

> *`jax.jit` traces a Python function into a jaxpr, lowers it to StableHLO, and compiles it via XLA on first call ‚Äî subsequent calls with the same signature dispatch directly to the cached executable.*

**Explanation:**

**`jax.jit`** is JAX's primary compilation mechanism. Unlike TF's graph mode, it operates on pure functions with a functional tracing model.

**JIT Mechanics:**

1. **First call** ‚Äî JAX traces with abstract values (shape + dtype), producing a jaxpr.
2. **Lowering** ‚Äî jaxpr ‚Üí StableHLO ‚Üí XLA HLO.
3. **Compilation** ‚Äî XLA compiles HLO to machine code.
4. **Caching** ‚Äî executable cached by `(function, input_shapes, input_dtypes, static_args)`.
5. **Subsequent calls** ‚Äî dispatch to cached executable (no Python overhead).

**`static_argnums` and `static_argnames`:**

Arguments marked static become part of the cache key as **values** (not just shapes). Useful for:
- Configuration flags that change computation structure.
- Integer arguments used in Python control flow.

Each distinct static value triggers a separate compilation.

**`donate_argnums`:**

Tells XLA that an input buffer will not be read after the call, allowing the compiler to reuse its memory for the output (zero-copy).

**Debugging:**

| Tool | Purpose |
|------|--------|
| `jax.make_jaxpr(f)(x)` | Inspect the traced jaxpr |
| `jax.jit(f).lower(x).as_text()` | Inspect StableHLO IR |
| `jax.jit(f).lower(x).compile().as_text()` | Inspect compiled HLO |
| `JAX_LOG_COMPILES=1` | Log every compilation event |

**Example:**

```python
@jax.jit
def loss_fn(params, x, y):
    pred = model(params, x)
    return jnp.mean((pred - y) ** 2)
```

In [None]:
from dataclasses import dataclass, field
import time


@dataclass
class CacheKey:
    function_name: str
    input_shapes: tuple
    input_dtypes: tuple
    static_values: tuple = ()


@dataclass
class JITCache:
    entries: dict = field(default_factory=dict)
    compile_count: int = 0

    def get_or_compile(self, cache_key):
        key = (cache_key.function_name, cache_key.input_shapes,
               cache_key.input_dtypes, cache_key.static_values)
        if key in self.entries:
            return self.entries[key], False
        self.compile_count += 1
        executable = f"executable_{self.compile_count}"
        self.entries[key] = executable
        return executable, True


cache = JITCache()

calls = [
    CacheKey("loss_fn", ((128, 784), (784, 256)), ("f32", "f32")),
    CacheKey("loss_fn", ((128, 784), (784, 256)), ("f32", "f32")),
    CacheKey("loss_fn", ((64, 784), (784, 256)), ("f32", "f32")),
    CacheKey("loss_fn", ((128, 784), (784, 256)), ("f32", "f32")),
    CacheKey("predict", ((128, 784), (784, 10)), ("f32", "f32"), (True,)),
    CacheKey("predict", ((128, 784), (784, 10)), ("f32", "f32"), (False,)),
]

print("JIT compilation trace:")
for call_index, key in enumerate(calls):
    executable, compiled = cache.get_or_compile(key)
    status = "COMPILED" if compiled else "CACHED"
    static_info = f", static={key.static_values}" if key.static_values else ""
    print(f"  call {call_index}: {key.function_name}(shapes={key.input_shapes}{static_info}) ‚Üí {status} ‚Üí {executable}")

print(f"\nTotal compilations: {cache.compile_count}")
print(f"Cache entries: {len(cache.entries)}")
print(f"Cache hits: {len(calls) - cache.compile_count}")

**References:**

[üìò JAX Documentation. *Just In Time Compilation with JAX.*](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)

[üìò JAX Documentation. *Stateful Computations in JAX.*](https://jax.readthedocs.io/en/latest/jax-101/07-state.html)

---

[‚¨ÖÔ∏è Previous: TensorFlow Graph Compilation](./01_tensorflow_graph_compilation.ipynb) | [Next: Custom Calls ‚û°Ô∏è](./03_custom_calls.ipynb)