-
-
Notifications
You must be signed in to change notification settings - Fork 55
Open
Description
From @jstac:
Adding @jax.jit to intermediate functions that are only called from within other jitted functions can actually hurt performance in several ways:
1. Prevents fusion: JAX's compiler (XLA) works best when it can see the entire computation graph at once. When you jit intermediate
functions separately, you create compilation boundaries that prevent the compiler from optimizing across those boundaries (like
fusing operations, eliminating intermediate arrays, etc.)
2. Multiple compilation overhead: Each @jax.jit decorator triggers a separate compilation, and when you call a jitted function from
another jitted function, JAX has to manage multiple compiled kernels instead of one optimized kernel.
3. Missed optimization opportunities: The XLA compiler can do things like:
- Fuse element-wise operations
- Eliminate temporary arrays
- Optimize memory layout
- Reorder operations for better cache usage
But only if it can see all the operations together in one compilation unit.
4. Dispatch overhead: Calling from one jitted function to another jitted function adds small dispatch costs that wouldn't exist if
everything was compiled together.
The rule of thumb: Only use @jax.jit on the "top-level" functions that users call directly. Let the inner helper functions be
compiled as part of the larger computation graph.
- Examine current code to see if they are following the best practice on using
@jax.jit.
Metadata
Metadata
Assignees
Labels
No labels