Skip to content

Fix MJX-Warp FFI Multi-GPU Deadlock#1181

Closed
Adityakk9031 wants to merge 1 commit intoNVIDIA:mainfrom
Adityakk9031:main
Closed

Fix MJX-Warp FFI Multi-GPU Deadlock#1181
Adityakk9031 wants to merge 1 commit intoNVIDIA:mainfrom
Adityakk9031:main

Conversation

@Adityakk9031
Copy link
Contributor

@Adityakk9031 Adityakk9031 commented Jan 16, 2026

Problem
In multi-GPU environments using jax.pmap, identical virtual addresses can be assigned to buffers on different devices. The MJX-Warp FFI graph cache used:

capture_key = hash((call_id, *buffer_pointers))
This led to cache collisions where a CUDA graph captured on one device was incorrectly retrieved and launched on another device, causing silent deadlocks after ~10M training steps.

Root Cause
Memory addresses are only unique per-device, not globally:

Device 0 allocates buffer at 0x7f1234567890
Device 1 allocates different buffer at same virtual address 0x7f1234567890
Both hash to identical cache key → cross-device graph launch → deadlock
Solution
Include device_ordinal in the cache key hash:

device_ordinal = get_device_ordinal_from_callframe(call_frame.contents)
capture_key = hash((device_ordinal, call_id, *buffer_pointers))
This ensures CUDA graphs are cached uniquely per device, preventing cross-device launch conflicts.

Changes
File:
warp/_src/jax_experimental/ffi.py

Line 684 - Added: Extract device ordinal from call frame
Line 690 - Modified: Include device_ordinal in cache key hash

Summary by CodeRabbit

  • Bug Fixes
    • CUDA GPU graph capture now differentiates captures by device, preventing cross-device cache mix-ups and improving correctness and reuse when running on multiple GPUs.

✏️ Tip: You can customize this high-level summary in your review settings.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 16, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link

coderabbitai bot commented Jan 16, 2026

📝 Walkthrough

Walkthrough

Includes the CUDA device ordinal from the JAX call frame into Warp's graph-capture key so captures are differentiated per device; the key now hashes (device_ordinal, call_id, input/output data) affecting capture cache matching.

Changes

Cohort / File(s) Summary
Graph Capture Device Specificity
warp/_src/jax_experimental/ffi.py
Reads device ordinal from the call frame and adds it to the graph-capture key; capture key now hashes device-specific metadata to differentiate captures and influence cache hits/reuse.
sequenceDiagram
  participant Client
  participant JAX
  participant FFI as "FFI (warp/_src/jax_experimental/ffi.py)"
  participant CUDA as "CUDA Graph Capture"
  participant Cache

  Client->>JAX: invoke kernel (call_id, inputs)
  JAX->>FFI: forward call frame + inputs
  FFI->>FFI: extract device_ordinal from call frame
  FFI->>Cache: lookup key hash(device_ordinal, call_id, inputs/outputs)
  alt cache hit
    Cache-->>FFI: cached graph
    FFI->>CUDA: launch cached graph
  else cache miss
    FFI->>CUDA: capture new graph (includes device context)
    CUDA-->>Cache: store graph under key
    FFI->>CUDA: launch captured graph
  end
  CUDA-->>Client: results
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main change: fixing a multi-GPU deadlock issue in the MJX-Warp FFI graph cache by including device ordinals in the cache key.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

🧹 Recent nitpick comments
warp/_src/jax_experimental/ffi.py (1)

684-685: Correct fix for the multi-GPU deadlock.

Extracting the device ordinal early and using it in the cache key properly ensures CUDA graphs are cached and retrieved per-device, preventing cross-device launch conflicts.

However, note that line 788 now performs a redundant call to get_device_ordinal_from_callframe since device_ordinal is already set here. Consider removing the duplicate assignment at line 788 to avoid the extra FFI call overhead.

♻️ Suggested cleanup to remove redundant call
@@ -785,7 +785,6 @@ class FfiCallable:
                         # early out
                         return

-                device_ordinal = get_device_ordinal_from_callframe(call_frame.contents)
                 device = wp.get_cuda_device(device_ordinal)
                 stream = wp.Stream(device, cuda_stream=cuda_stream)

📜 Recent review details

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7c20457 and 3162a10.

📒 Files selected for processing (1)
  • warp/_src/jax_experimental/ffi.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2026-01-13T17:29:11.184Z
Learnt from: shi-eric
Repo: NVIDIA/warp PR: 1171
File: warp/_src/builtins.py:7726-7735
Timestamp: 2026-01-13T17:29:11.184Z
Learning: In NVIDIA/warp (PR `#1171`), block_dim() may be called from user-defined Warp functions, while tid() may not; tid() is kernel-only. File context: warp/_src/builtins.py builtins registration.

Applied to files:

  • warp/_src/jax_experimental/ffi.py
🧬 Code graph analysis (1)
warp/_src/jax_experimental/ffi.py (1)
warp/_src/jax_experimental/xla_ffi.py (1)
  • get_device_ordinal_from_callframe (598-604)
🪛 Ruff (0.14.11)
warp/_src/jax_experimental/ffi.py

684-684: get_device_ordinal_from_callframe may be undefined, or defined from star imports

(F405)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Greptile Review
  • GitHub Check: Greptile Review
🔇 Additional comments (1)
warp/_src/jax_experimental/ffi.py (1)

686-703: LGTM! The cache key fix correctly prevents cross-device graph launches.

Including device_ordinal as the first element of the cache key tuple ensures that identical virtual addresses on different devices will hash to different keys, eliminating the root cause of the deadlock.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@greptile-apps
Copy link

greptile-apps bot commented Jan 16, 2026

Greptile Summary

  • Fixed critical multi-GPU deadlock in MJX-Warp FFI by including device ordinal in CUDA graph cache key to prevent cross-device cache collisions
  • Added device ordinal extraction from call frame on line 684 and modified cache key generation on line 690 to include device-specific information
  • This prevents identical virtual memory addresses across different GPU devices from causing incorrect CUDA graph retrieval and execution on wrong devices

Important Files Changed

Filename Overview
warp/_src/jax_experimental/ffi.py Fixed multi-GPU deadlock by adding device ordinal to CUDA graph cache key hash to ensure device-specific graph caching

Confidence score: 5/5

  • This PR is extremely safe to merge with minimal risk of causing production issues
  • Score reflects a simple, well-understood bug fix that addresses a critical deadlock issue with clear root cause analysis and targeted solution
  • No files require special attention as the change is surgical and addresses a specific multi-GPU caching problem

@greptile-apps
Copy link

greptile-apps bot commented Jan 16, 2026

Greptile found no issues!

From now on, if a review finishes and we haven't found any issues, we will not post anything, but you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

Signed-off-by: Aditya kumar singh <143548997+Adityakk9031@users.noreply.github.com>
@Adityakk9031
Copy link
Contributor Author

@btaba google-deepmind/mujoco#3017

@Adityakk9031
Copy link
Contributor Author

Adityakk9031 commented Jan 16, 2026

@shi-eric google-deepmind/mujoco#3017 check this

@shi-eric shi-eric requested a review from nvlukasz January 16, 2026 16:49
@shi-eric shi-eric added this to the 1.11.1 milestone Jan 16, 2026
Copy link
Contributor

@nvlukasz nvlukasz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, thanks! Will try to get this merged shortly.

@shi-eric
Copy link
Contributor

Thanks @Adityakk9031, this was merged in a9e071d

@shi-eric shi-eric closed this Jan 22, 2026
pull bot pushed a commit to Stars1233/warp-python that referenced this pull request Jan 22, 2026
Fix JAX FFI multi-gpu graph caching (NVIDIAGH-1181)

See merge request omniverse/warp!1926
@Adityakk9031
Copy link
Contributor Author

@btaba this pr is merged

@hartikainen
Copy link

Thanks for the fix @Adityakk9031. For what it's worth, as mentioned in google-deepmind/mujoco#2980 (comment), my original deadlock issue was not resolved by the change.

@Adityakk9031
Copy link
Contributor Author

@hartikainen it will taken some time team will do it later as of now u can do it locally this solution to fix it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants