Fix MIOpen race condition in multi-GPU CompileSolution#777
Closed
phambinhfin wants to merge 1 commit intorocm-jaxlib-v0.9.1from
Closed
Fix MIOpen race condition in multi-GPU CompileSolution#777phambinhfin wants to merge 1 commit intorocm-jaxlib-v0.9.1from
phambinhfin wants to merge 1 commit intorocm-jaxlib-v0.9.1from
Conversation
When testing JAX 0.9.1 with MaxText Gemma3-4B multimodal on 8x MI300X, segmentation fault happens during XLA compilation in MIOpen's miopen::kernels() internal function. Debug VLOG instrumentation was added to CompileSolution entry/exit points to trace thread and device info. The log shows a race condition may happen - all 8 GPUs enter MIOpen CompileSolution simultaneously on different threads, each holding their own per-GPU mutex: WITHOUT fix (concurrent access): 11:48:18.213 ENTERING CompileSolution device 3 11:48:18.219 ENTERING CompileSolution device 1 11:48:18.219 ENTERING CompileSolution device 2 11:48:18.219 ENTERING CompileSolution device 6 11:48:18.219 ENTERING CompileSolution device 4 11:48:18.219 ENTERING CompileSolution device 5 11:48:18.219 ENTERING CompileSolution device 0 11:48:18.219 ENTERING CompileSolution device 7 (8 threads overlapping inside MIOpen) 11:48:18.243 EXITING CompileSolution device 3 WITH fix (serialized access): 12:34:46.330 ENTERING CompileSolution device 4 12:34:46.365 EXITING CompileSolution device 4 12:34:46.365 ENTERING CompileSolution device 7 12:34:46.405 EXITING CompileSolution device 7 ...one at a time, zero overlap... The fix adds a process-wide static absl::Mutex in rocm_dnn.cc that serializes MIOpen Find/CompileSolution/CompileFusionPlan calls across all GPUs. Only affects compilation time, not runtime execution. After fix, segmentation fault does not happen. Training completes all steps with correct loss convergence. No performance regression on steady-state throughput (~8.7 TFLOP/s/device, ~132 tokens/s/device). Additional changes: - Use StreamExecutorAddressAllocator instead of BFC allocator for MIOpen autotuning scratch memory to avoid BFC OOM during autotuning - Replace LOG(FATAL) with absl::ResourceExhaustedError on scratch allocation failure for graceful fallback - Add fallback to default config when autotuning fails - Move do_not_autotune_ check before MIOpen calls Need to discuss more with XLA engineers about this solution.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
When testing JAX 0.9.1 with MaxText Gemma3-4B multimodal on 8x MI300X, segmentation fault happens during XLA compilation in MIOpen's
miopen::kernels()internal function.Debug VLOG instrumentation was added to
CompileSolutionentry/exit points to trace thread and device info. The log shows a race condition may happen — all 8 GPUs enter MIOpenCompileSolutionsimultaneously on different threads, each holding their own per-GPU mutex:WITHOUT fix (concurrent access):
WITH fix (serialized access):
The fix adds a process-wide static
absl::Mutexinrocm_dnn.ccthat serializes MIOpen Find/CompileSolution/CompileFusionPlan calls across all GPUs. Only affects compilation time, not runtime execution.After fix, segmentation fault does not happen. Training completes all steps with correct loss convergence. No performance regression on steady-state throughput (~8.7 TFLOP/s/device, ~132 tokens/s/device).
Additional changes:
StreamExecutorAddressAllocatorinstead of BFC allocator for MIOpen autotuning scratch memory to avoid BFC OOM during autotuningLOG(FATAL)withabsl::ResourceExhaustedErroron scratch allocation failure for graceful fallbackdo_not_autotune_check before MIOpen callsNeed to discuss more with XLA engineers about this solution.
Test plan