Add has_side_effect parameter to FfiCallable and jax_callable#1240
Add has_side_effect parameter to FfiCallable and jax_callable#1240shi-eric merged 1 commit intoNVIDIA:mainfrom
Conversation
When using JAX's pmap, XLA may eliminate FFI calls as dead code if their outputs are not consumed by downstream operations. This adds a has_side_effect parameter through the FFI chain (FfiCallable.__init__, FfiCallable.__call__, jax_callable) that is forwarded to jax.ffi.ffi_call, telling XLA to always execute the call regardless of whether its outputs are used. This fixes a bug where BVH refit calls were being skipped under pmap in MuJoCo MJX rendering. Signed-off-by: Baruch Tabanpour <btaba@google.com>
📝 WalkthroughWalkthroughA new Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Greptile SummaryAdded
The implementation is clean, minimal, and follows JAX FFI conventions. The default value of Confidence Score: 5/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[User calls jax_callable with has_side_effect=True] --> B[FfiCallable.__init__ stores has_side_effect]
B --> C[FfiCallable.__call__ invoked]
C --> D[jax.ffi.ffi_call called with has_side_effect parameter]
D --> E{XLA Compiler}
E -->|has_side_effect=False| F[May eliminate as dead code if outputs unused]
E -->|has_side_effect=True| G[Always executes FFI call]
G --> H[Prevents elimination under pmap/vmap]
Last reviewed commit: 7ab2076 |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
warp/_src/jax_experimental/ffi.py (2)
1513-1542:⚠️ Potential issue | 🟠 Major
has_side_effectis absent from thejax_callablecache key — the flag is silently dropped on cache hits.The lookup key (lines 1514–1521) does not include
has_side_effect. As a result, ifjax_callableis first called with the defaulthas_side_effect=Falseand then called again with the samefuncandhas_side_effect=True, the second call returns the cachedFfiCallablethat still hashas_side_effect=False. The new value is never applied — directly defeating the purpose of this PR. Theelsebranch (lines 1540–1542) already demonstrates the pattern of reconciling a changed parameter (graph_cache_max), buthas_side_effectreceives no equivalent update there either.🐛 Proposed fix — add
has_side_effectto cache key and update on cache hitkey = ( func, num_outputs, graph_mode, vmap_method, hashable_output_dims, module_preload_mode, + has_side_effect, ) with _FFI_REGISTRY_LOCK: callable = _FFI_CALLABLE_REGISTRY.get(key) if callable is None: callable = FfiCallable( func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames, stage_in_argnames, stage_out_argnames, graph_cache_max, module_preload_mode, has_side_effect, ) _FFI_CALLABLE_REGISTRY[key] = callable else: # make sure we're using the latest graph cache max callable.graph_cache_max = graph_cache_max🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@warp/_src/jax_experimental/ffi.py` around lines 1513 - 1542, The cache key for jax_callable is missing the has_side_effect flag so cached FfiCallable instances can return with the wrong side-effect setting; include has_side_effect in the tuple assigned to key (alongside func, num_outputs, graph_mode, vmap_method, hashable_output_dims, module_preload_mode) and, in the cache-hit branch where callable.graph_cache_max is reconciled, also set callable.has_side_effect = has_side_effect to update the existing FfiCallable; this touches the key variable, the _FFI_CALLABLE_REGISTRY lookup, and the FfiCallable instance on cache hit.
1459-1499:⚠️ Potential issue | 🟡 MinorMissing docstring entry for
has_side_effectinjax_callable.The new parameter is not documented in the
Argssection.📝 Proposed addition to docstring
Argssectionmodule_preload_mode: Specify the devices where the module should be preloaded. + has_side_effect: If ``True``, instructs XLA to always execute the FFI call even + when its outputs are not consumed, preventing dead-code elimination under + ``jax.pmap`` or similar transforms.As per coding guidelines: "Follow Google-style docstrings with Warp-specific guidelines: document
__init__parameters in class docstring."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@warp/_src/jax_experimental/ffi.py` around lines 1459 - 1499, Add a docstring entry for the has_side_effect parameter of jax_callable: state that has_side_effect: bool = False indicates whether the provided Python callback performs side effects (e.g., I/O, host state mutation) and therefore should not be treated as a pure function by JAX; document the type and default, and briefly note the behavioral impact (when True, the callback will be treated as having side effects which can affect JAX transformations, caching and graph capture semantics). Ensure this line is placed in the Args section alongside the other parameters for jax_callable.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@warp/_src/jax_experimental/ffi.py`:
- Around line 1513-1542: The cache key for jax_callable is missing the
has_side_effect flag so cached FfiCallable instances can return with the wrong
side-effect setting; include has_side_effect in the tuple assigned to key
(alongside func, num_outputs, graph_mode, vmap_method, hashable_output_dims,
module_preload_mode) and, in the cache-hit branch where callable.graph_cache_max
is reconciled, also set callable.has_side_effect = has_side_effect to update the
existing FfiCallable; this touches the key variable, the _FFI_CALLABLE_REGISTRY
lookup, and the FfiCallable instance on cache hit.
- Around line 1459-1499: Add a docstring entry for the has_side_effect parameter
of jax_callable: state that has_side_effect: bool = False indicates whether the
provided Python callback performs side effects (e.g., I/O, host state mutation)
and therefore should not be treated as a pure function by JAX; document the type
and default, and briefly note the behavioral impact (when True, the callback
will be treated as having side effects which can affect JAX transformations,
caching and graph capture semantics). Ensure this line is placed in the Args
section alongside the other parameters for jax_callable.
Add has_side_effect argument to jax_kernel and jax_callable (NVIDIAGH-1240) See merge request omniverse/warp!2028
When using JAX's pmap, XLA may eliminate FFI calls as dead code if their outputs are not consumed by downstream operations. This adds a has_side_effect parameter through the FFI chain (FfiCallable.init, FfiCallable.call, jax_callable) that is forwarded to jax.ffi.ffi_call, telling XLA to always execute the call regardless of whether its outputs are used.
Description
Before your PR is "Ready for review"
__init__.pyi,docs/api_reference/,docs/language_reference/)pre-commit run -aSummary by CodeRabbit