Skip to content

[JAX] Remove @jax.jit on intermediate functions #719

@HumphreyYang

Description

@HumphreyYang

From @jstac:

Adding @jax.jit to intermediate functions that are only called from within other jitted functions can actually hurt performance in several ways:

  1. Prevents fusion: JAX's compiler (XLA) works best when it can see the entire computation graph at once. When you jit intermediate
   functions separately, you create compilation boundaries that prevent the compiler from optimizing across those boundaries (like
  fusing operations, eliminating intermediate arrays, etc.)
  2. Multiple compilation overhead: Each @jax.jit decorator triggers a separate compilation, and when you call a jitted function from
   another jitted function, JAX has to manage multiple compiled kernels instead of one optimized kernel.
  3. Missed optimization opportunities: The XLA compiler can do things like:
    - Fuse element-wise operations
    - Eliminate temporary arrays
    - Optimize memory layout
    - Reorder operations for better cache usage

  But only if it can see all the operations together in one compilation unit.
  4. Dispatch overhead: Calling from one jitted function to another jitted function adds small dispatch costs that wouldn't exist if
  everything was compiled together.

  The rule of thumb: Only use @jax.jit on the "top-level" functions that users call directly. Let the inner helper functions be
  compiled as part of the larger computation graph.
  • Examine current code to see if they are following the best practice on using @jax.jit.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions