Fix MJX-Warp FFI Multi-GPU Deadlock#1181
Conversation
📝 WalkthroughWalkthroughIncludes the CUDA device ordinal from the JAX call frame into Warp's graph-capture key so captures are differentiated per device; the key now hashes (device_ordinal, call_id, input/output data) affecting capture cache matching. Changes
sequenceDiagram
participant Client
participant JAX
participant FFI as "FFI (warp/_src/jax_experimental/ffi.py)"
participant CUDA as "CUDA Graph Capture"
participant Cache
Client->>JAX: invoke kernel (call_id, inputs)
JAX->>FFI: forward call frame + inputs
FFI->>FFI: extract device_ordinal from call frame
FFI->>Cache: lookup key hash(device_ordinal, call_id, inputs/outputs)
alt cache hit
Cache-->>FFI: cached graph
FFI->>CUDA: launch cached graph
else cache miss
FFI->>CUDA: capture new graph (includes device context)
CUDA-->>Cache: store graph under key
FFI->>CUDA: launch captured graph
end
CUDA-->>Client: results
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧹 Recent nitpick comments
📜 Recent review detailsConfiguration used: Path: .coderabbit.yml Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧠 Learnings (1)📚 Learning: 2026-01-13T17:29:11.184ZApplied to files:
🧬 Code graph analysis (1)warp/_src/jax_experimental/ffi.py (1)
🪛 Ruff (0.14.11)warp/_src/jax_experimental/ffi.py684-684: (F405) ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
🔇 Additional comments (1)
✏️ Tip: You can disable this entire section by setting Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Greptile Summary
Important Files Changed
Confidence score: 5/5
|
Greptile found no issues!From now on, if a review finishes and we haven't found any issues, we will not post anything, but you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
Signed-off-by: Aditya kumar singh <143548997+Adityakk9031@users.noreply.github.com>
|
@shi-eric google-deepmind/mujoco#3017 check this |
nvlukasz
left a comment
There was a problem hiding this comment.
Nice catch, thanks! Will try to get this merged shortly.
|
Thanks @Adityakk9031, this was merged in a9e071d |
Fix JAX FFI multi-gpu graph caching (NVIDIAGH-1181) See merge request omniverse/warp!1926
|
@btaba this pr is merged |
|
Thanks for the fix @Adityakk9031. For what it's worth, as mentioned in google-deepmind/mujoco#2980 (comment), my original deadlock issue was not resolved by the change. |
|
@hartikainen it will taken some time team will do it later as of now u can do it locally this solution to fix it |
Problem
In multi-GPU environments using jax.pmap, identical virtual addresses can be assigned to buffers on different devices. The MJX-Warp FFI graph cache used:
capture_key = hash((call_id, *buffer_pointers))
This led to cache collisions where a CUDA graph captured on one device was incorrectly retrieved and launched on another device, causing silent deadlocks after ~10M training steps.
Root Cause
Memory addresses are only unique per-device, not globally:
Device 0 allocates buffer at 0x7f1234567890
Device 1 allocates different buffer at same virtual address 0x7f1234567890
Both hash to identical cache key → cross-device graph launch → deadlock
Solution
Include device_ordinal in the cache key hash:
device_ordinal = get_device_ordinal_from_callframe(call_frame.contents)
capture_key = hash((device_ordinal, call_id, *buffer_pointers))
This ensures CUDA graphs are cached uniquely per device, preventing cross-device launch conflicts.
Changes
File:
warp/_src/jax_experimental/ffi.py
Line 684 - Added: Extract device ordinal from call frame
Line 690 - Modified: Include device_ordinal in cache key hash
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.