Skip to content

Comments

Add has_side_effect parameter to FfiCallable and jax_callable#1240

Merged
shi-eric merged 1 commit intoNVIDIA:mainfrom
btaba:has_side_effect_ffi
Feb 20, 2026
Merged

Add has_side_effect parameter to FfiCallable and jax_callable#1240
shi-eric merged 1 commit intoNVIDIA:mainfrom
btaba:has_side_effect_ffi

Conversation

@btaba
Copy link
Contributor

@btaba btaba commented Feb 20, 2026

When using JAX's pmap, XLA may eliminate FFI calls as dead code if their outputs are not consumed by downstream operations. This adds a has_side_effect parameter through the FFI chain (FfiCallable.init, FfiCallable.call, jax_callable) that is forwarded to jax.ffi.ffi_call, telling XLA to always execute the call regardless of whether its outputs are used.

Description

Before your PR is "Ready for review"

  • All commits are signed-off to indicate that your contribution adheres to the Developer Certificate of Origin requirements
  • Necessary tests have been added
  • Documentation is up-to-date
  • Auto-generated files modified by compiling Warp and building the documentation have been updated (e.g. __init__.pyi, docs/api_reference/, docs/language_reference/)
  • Code passes formatting and linting checks with pre-commit run -a

Summary by CodeRabbit

  • New Features
    • Added support for specifying side-effect behavior in FFI integrations. Users can now configure whether FFI calls have side effects during initialization.

When using JAX's pmap, XLA may eliminate FFI calls as dead code if
their outputs are not consumed by downstream operations. This adds a
has_side_effect parameter through the FFI chain (FfiCallable.__init__,
FfiCallable.__call__, jax_callable) that is forwarded to
jax.ffi.ffi_call, telling XLA to always execute the call regardless
of whether its outputs are used.

This fixes a bug where BVH refit calls were being skipped under pmap
in MuJoCo MJX rendering.

Signed-off-by: Baruch Tabanpour <btaba@google.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 20, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link

coderabbitai bot commented Feb 20, 2026

📝 Walkthrough

Walkthrough

A new has_side_effect flag is introduced to the JAX Warp FFI integration. The FfiCallable class now accepts this parameter (defaulting to false), stores it as an instance attribute, and forwards it during FFI calls. The jax_callable function's signature is updated to include this parameter and propagate it when constructing FfiCallable instances.

Changes

Cohort / File(s) Summary
JAX FFI Side Effect Flag
warp/_src/jax_experimental/ffi.py
Added has_side_effect parameter (default False) to FfiCallable constructor and jax_callable function. Parameter is stored as instance attribute and forwarded to underlying ffi_call during execution.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change—adding a has_side_effect parameter to FfiCallable and jax_callable—which directly matches the pull request's primary objective.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@greptile-apps
Copy link

greptile-apps bot commented Feb 20, 2026

Greptile Summary

Added has_side_effect parameter (default False) to FfiCallable and jax_callable to prevent XLA from eliminating FFI calls as dead code when outputs are unused under JAX's pmap.

  • Parameter added to FfiCallable.__init__ (line 499) and stored as instance variable
  • Parameter added to jax_callable function signature (line 1459) with default value False
  • Parameter correctly forwarded to jax.ffi.ffi_call (line 680)
  • Fixes bug where BVH refit calls were skipped under pmap in MuJoCo MJX rendering

The implementation is clean, minimal, and follows JAX FFI conventions. The default value of False maintains backward compatibility.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The change is a straightforward parameter addition that follows established patterns. The parameter is properly threaded through the call chain, has a safe default value (False) for backward compatibility, and directly addresses a documented bug. No breaking changes or complex logic introduced.
  • No files require special attention

Important Files Changed

Filename Overview
warp/_src/jax_experimental/ffi.py Added has_side_effect parameter with default False to FfiCallable.__init__, jax_callable, and forwarded to jax.ffi.ffi_call to prevent XLA dead code elimination

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[User calls jax_callable with has_side_effect=True] --> B[FfiCallable.__init__ stores has_side_effect]
    B --> C[FfiCallable.__call__ invoked]
    C --> D[jax.ffi.ffi_call called with has_side_effect parameter]
    D --> E{XLA Compiler}
    E -->|has_side_effect=False| F[May eliminate as dead code if outputs unused]
    E -->|has_side_effect=True| G[Always executes FFI call]
    G --> H[Prevents elimination under pmap/vmap]
Loading

Last reviewed commit: 7ab2076

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
warp/_src/jax_experimental/ffi.py (2)

1513-1542: ⚠️ Potential issue | 🟠 Major

has_side_effect is absent from the jax_callable cache key — the flag is silently dropped on cache hits.

The lookup key (lines 1514–1521) does not include has_side_effect. As a result, if jax_callable is first called with the default has_side_effect=False and then called again with the same func and has_side_effect=True, the second call returns the cached FfiCallable that still has has_side_effect=False. The new value is never applied — directly defeating the purpose of this PR. The else branch (lines 1540–1542) already demonstrates the pattern of reconciling a changed parameter (graph_cache_max), but has_side_effect receives no equivalent update there either.

🐛 Proposed fix — add has_side_effect to cache key and update on cache hit
     key = (
         func,
         num_outputs,
         graph_mode,
         vmap_method,
         hashable_output_dims,
         module_preload_mode,
+        has_side_effect,
     )

     with _FFI_REGISTRY_LOCK:
         callable = _FFI_CALLABLE_REGISTRY.get(key)
         if callable is None:
             callable = FfiCallable(
                 func,
                 num_outputs,
                 graph_mode,
                 vmap_method,
                 output_dims,
                 in_out_argnames,
                 stage_in_argnames,
                 stage_out_argnames,
                 graph_cache_max,
                 module_preload_mode,
                 has_side_effect,
             )
             _FFI_CALLABLE_REGISTRY[key] = callable
         else:
             # make sure we're using the latest graph cache max
             callable.graph_cache_max = graph_cache_max
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@warp/_src/jax_experimental/ffi.py` around lines 1513 - 1542, The cache key
for jax_callable is missing the has_side_effect flag so cached FfiCallable
instances can return with the wrong side-effect setting; include has_side_effect
in the tuple assigned to key (alongside func, num_outputs, graph_mode,
vmap_method, hashable_output_dims, module_preload_mode) and, in the cache-hit
branch where callable.graph_cache_max is reconciled, also set
callable.has_side_effect = has_side_effect to update the existing FfiCallable;
this touches the key variable, the _FFI_CALLABLE_REGISTRY lookup, and the
FfiCallable instance on cache hit.

1459-1499: ⚠️ Potential issue | 🟡 Minor

Missing docstring entry for has_side_effect in jax_callable.

The new parameter is not documented in the Args section.

📝 Proposed addition to docstring Args section
         module_preload_mode: Specify the devices where the module should be preloaded.
+        has_side_effect: If ``True``, instructs XLA to always execute the FFI call even
+            when its outputs are not consumed, preventing dead-code elimination under
+            ``jax.pmap`` or similar transforms.

As per coding guidelines: "Follow Google-style docstrings with Warp-specific guidelines: document __init__ parameters in class docstring."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@warp/_src/jax_experimental/ffi.py` around lines 1459 - 1499, Add a docstring
entry for the has_side_effect parameter of jax_callable: state that
has_side_effect: bool = False indicates whether the provided Python callback
performs side effects (e.g., I/O, host state mutation) and therefore should not
be treated as a pure function by JAX; document the type and default, and briefly
note the behavioral impact (when True, the callback will be treated as having
side effects which can affect JAX transformations, caching and graph capture
semantics). Ensure this line is placed in the Args section alongside the other
parameters for jax_callable.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@warp/_src/jax_experimental/ffi.py`:
- Around line 1513-1542: The cache key for jax_callable is missing the
has_side_effect flag so cached FfiCallable instances can return with the wrong
side-effect setting; include has_side_effect in the tuple assigned to key
(alongside func, num_outputs, graph_mode, vmap_method, hashable_output_dims,
module_preload_mode) and, in the cache-hit branch where callable.graph_cache_max
is reconciled, also set callable.has_side_effect = has_side_effect to update the
existing FfiCallable; this touches the key variable, the _FFI_CALLABLE_REGISTRY
lookup, and the FfiCallable instance on cache hit.
- Around line 1459-1499: Add a docstring entry for the has_side_effect parameter
of jax_callable: state that has_side_effect: bool = False indicates whether the
provided Python callback performs side effects (e.g., I/O, host state mutation)
and therefore should not be treated as a pure function by JAX; document the type
and default, and briefly note the behavioral impact (when True, the callback
will be treated as having side effects which can affect JAX transformations,
caching and graph capture semantics). Ensure this line is placed in the Args
section alongside the other parameters for jax_callable.

@shi-eric shi-eric requested a review from nvlukasz February 20, 2026 01:49
@nvlukasz nvlukasz added this to the 1.12.0 milestone Feb 20, 2026
@shi-eric shi-eric merged commit 7ab2076 into NVIDIA:main Feb 20, 2026
3 checks passed
@shi-eric
Copy link
Contributor

Thanks @btaba, @nvlukasz reviewed this and it's now merged. This will be included in the Warp 1.12 release, which is planned for release on or around March 6.

pull bot pushed a commit to Stars1233/warp-python that referenced this pull request Feb 20, 2026
Add has_side_effect argument to jax_kernel and jax_callable (NVIDIAGH-1240)

See merge request omniverse/warp!2028
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants