Skip to content

Cherry-pick openxla/xla#44428: [ROCm] packed bf16 atomic add for scatter/segment_sum#987

Closed
magaonka-amd wants to merge 1 commit into
rocm-jaxlib-v0.10.2from
cherrypick-44428-to-v0.10.2
Closed

Cherry-pick openxla/xla#44428: [ROCm] packed bf16 atomic add for scatter/segment_sum#987
magaonka-amd wants to merge 1 commit into
rocm-jaxlib-v0.10.2from
cherrypick-44428-to-v0.10.2

Conversation

@magaonka-amd

Copy link
Copy Markdown

Motivation

Requested in #985 review: cherry-pick openxla#44428 onto rocm-jaxlib-v0.10.2 (missing from the JAX 0.10.2 pinned XLA base 5a9e73cb; PR merged upstream Jun 17).

Commit (cherry-picked with -x)

Files changed

  • xla/backends/gpu/codegen/emitters/transforms/atomic_rmw_utils.cc
  • xla/backends/gpu/codegen/emitters/transforms/tests/lower_tensors.mlir

Test Plan

  • ROCm jaxlib build on rocm-jaxlib-v0.10.2; release-validation CI.

…ent_sum by matchin…

Imported from GitHub PR openxla#44428

…g FloatNormalization conversions.

📝 Summary of Changes
Make atomic-RMW matcher (GetAtomicModifierParameters) to look through the extf → addf(f32) → truncf body that FloatNormalization emits for bf16, recovering the narrow bf16 modifier so scatter-add lowers to packed atomicrmw fadd <2 x bf16> (global_atomic_pk_add_bf16) instead of a CAS loop. GpuFloatSupport/FloatNormalization are unchanged; targets without a native bf16 atomic still fall back to CAS.

🎯 Justification
bf16 segment_sum/scatter-add result in slow CAS loop on MI300/MI350 despite the HW having a packed bf16 atomic, making bf16 ~7x slower than f16.

🚀 Kind of Contribution
Please remove what does not apply: ⚡️ Performance Improvement,
🧪 Tests

📊 Benchmark (for Performance Improvements)
Please measure and include speedups for one of the public HLOs in
`compiler/xla/tools/benchmarks/hlo/`.

🧪 Unit Tests:
Added direct_atomic_rmw_fadd_bf16_widened + a gfx942 CHECK-GFX942-MI300 RUN line to lower_tensors.mlir, asserting the packed atomicrmw fadd <2 x bf16> with no CAS. All 9 RUN-line prefixes pass.

🧪 Execution Tests:
What execution tests were added? For example, a new optimization should be
tested with an end-to-end execution test triggering the optimization and
asserting correctness. Please provide test cases running with at most 2 GPUs.

Copybara import of the project:

--
edcb06b by Zoran Jovanovic <zjovanov@amd.com>:

[ROCm] Emit packed bf16 atomic add for scatter/segment_sum by matching FloatNormalization conversions.

Merging this change closes openxla#44428

COPYBARA_INTEGRATE_REVIEW=openxla#44428 from ROCm:rocm-bf16-atomic-scatter edcb06b
PiperOrigin-RevId: 933630040

(cherry picked from commit 97544f7)
@magaonka-amd

Copy link
Copy Markdown
Author

Superseded by #993, which combines all four ROCm 0.10.2 cherry-pick PRs into a single PR against rocm-jaxlib-v0.10.2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants