Skip to content

Fix MIOpen race condition in multi-GPU CompileSolution#777

Closed
phambinhfin wants to merge 1 commit intorocm-jaxlib-v0.9.1from
fix/miopen-race-condition-compile-solution
Closed

Fix MIOpen race condition in multi-GPU CompileSolution#777
phambinhfin wants to merge 1 commit intorocm-jaxlib-v0.9.1from
fix/miopen-race-condition-compile-solution

Conversation

@phambinhfin
Copy link
Copy Markdown

Summary

When testing JAX 0.9.1 with MaxText Gemma3-4B multimodal on 8x MI300X, segmentation fault happens during XLA compilation in MIOpen's miopen::kernels() internal function.

Debug VLOG instrumentation was added to CompileSolution entry/exit points to trace thread and device info. The log shows a race condition may happen — all 8 GPUs enter MIOpen CompileSolution simultaneously on different threads, each holding their own per-GPU mutex:

WITHOUT fix (concurrent access):

11:48:18.213  ENTERING CompileSolution device 3  ┐
11:48:18.219  ENTERING CompileSolution device 1  │
11:48:18.219  ENTERING CompileSolution device 2  │ 8 threads
11:48:18.219  ENTERING CompileSolution device 6  │ overlapping
11:48:18.219  ENTERING CompileSolution device 4  │ inside MIOpen
11:48:18.219  ENTERING CompileSolution device 5  │
11:48:18.219  ENTERING CompileSolution device 0  │
11:48:18.219  ENTERING CompileSolution device 7  ┘
11:48:18.243  EXITING  CompileSolution device 3

WITH fix (serialized access):

12:34:46.330  ENTERING CompileSolution device 4
12:34:46.365  EXITING  CompileSolution device 4
12:34:46.365  ENTERING CompileSolution device 7
12:34:46.405  EXITING  CompileSolution device 7
...one at a time, zero overlap...

The fix adds a process-wide static absl::Mutex in rocm_dnn.cc that serializes MIOpen Find/CompileSolution/CompileFusionPlan calls across all GPUs. Only affects compilation time, not runtime execution.

After fix, segmentation fault does not happen. Training completes all steps with correct loss convergence. No performance regression on steady-state throughput (~8.7 TFLOP/s/device, ~132 tokens/s/device).

Additional changes:

  • Use StreamExecutorAddressAllocator instead of BFC allocator for MIOpen autotuning scratch memory to avoid BFC OOM during autotuning
  • Replace LOG(FATAL) with absl::ResourceExhaustedError on scratch allocation failure for graceful fallback
  • Add fallback to default config when autotuning fails
  • Move do_not_autotune_ check before MIOpen calls

Need to discuss more with XLA engineers about this solution.

Test plan

  • MaxText Gemma3-4B multimodal on 8x MI300X (gfx942) — 5 training steps, no segfault, correct loss convergence
  • VLOG instrumentation confirms serialized access with fix vs concurrent access without
  • Steady-state throughput unchanged (~8.7 TFLOP/s/device)

When testing JAX 0.9.1 with MaxText Gemma3-4B multimodal on 8x MI300X,
segmentation fault happens during XLA compilation in MIOpen's
miopen::kernels() internal function.

Debug VLOG instrumentation was added to CompileSolution entry/exit
points to trace thread and device info. The log shows a race condition
may happen - all 8 GPUs enter MIOpen CompileSolution simultaneously
on different threads, each holding their own per-GPU mutex:

WITHOUT fix (concurrent access):
  11:48:18.213  ENTERING CompileSolution device 3
  11:48:18.219  ENTERING CompileSolution device 1
  11:48:18.219  ENTERING CompileSolution device 2
  11:48:18.219  ENTERING CompileSolution device 6
  11:48:18.219  ENTERING CompileSolution device 4
  11:48:18.219  ENTERING CompileSolution device 5
  11:48:18.219  ENTERING CompileSolution device 0
  11:48:18.219  ENTERING CompileSolution device 7
  (8 threads overlapping inside MIOpen)
  11:48:18.243  EXITING  CompileSolution device 3

WITH fix (serialized access):
  12:34:46.330  ENTERING CompileSolution device 4
  12:34:46.365  EXITING  CompileSolution device 4
  12:34:46.365  ENTERING CompileSolution device 7
  12:34:46.405  EXITING  CompileSolution device 7
  ...one at a time, zero overlap...

The fix adds a process-wide static absl::Mutex in rocm_dnn.cc that
serializes MIOpen Find/CompileSolution/CompileFusionPlan calls across
all GPUs. Only affects compilation time, not runtime execution.

After fix, segmentation fault does not happen. Training completes
all steps with correct loss convergence. No performance regression
on steady-state throughput (~8.7 TFLOP/s/device, ~132 tokens/s/device).

Additional changes:
- Use StreamExecutorAddressAllocator instead of BFC allocator for
  MIOpen autotuning scratch memory to avoid BFC OOM during autotuning
- Replace LOG(FATAL) with absl::ResourceExhaustedError on scratch
  allocation failure for graceful fallback
- Add fallback to default config when autotuning fails
- Move do_not_autotune_ check before MIOpen calls

Need to discuss more with XLA engineers about this solution.
@phambinhfin phambinhfin closed this Apr 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant