Cherry-pick openxla/xla#44428: [ROCm] packed bf16 atomic add for scatter/segment_sum#987
Closed
magaonka-amd wants to merge 1 commit into
Closed
Cherry-pick openxla/xla#44428: [ROCm] packed bf16 atomic add for scatter/segment_sum#987magaonka-amd wants to merge 1 commit into
magaonka-amd wants to merge 1 commit into
Conversation
…ent_sum by matchin… Imported from GitHub PR openxla#44428 …g FloatNormalization conversions. 📝 Summary of Changes Make atomic-RMW matcher (GetAtomicModifierParameters) to look through the extf → addf(f32) → truncf body that FloatNormalization emits for bf16, recovering the narrow bf16 modifier so scatter-add lowers to packed atomicrmw fadd <2 x bf16> (global_atomic_pk_add_bf16) instead of a CAS loop. GpuFloatSupport/FloatNormalization are unchanged; targets without a native bf16 atomic still fall back to CAS. 🎯 Justification bf16 segment_sum/scatter-add result in slow CAS loop on MI300/MI350 despite the HW having a packed bf16 atomic, making bf16 ~7x slower than f16. 🚀 Kind of Contribution Please remove what does not apply: ⚡️ Performance Improvement, 🧪 Tests 📊 Benchmark (for Performance Improvements) Please measure and include speedups for one of the public HLOs in `compiler/xla/tools/benchmarks/hlo/`. 🧪 Unit Tests: Added direct_atomic_rmw_fadd_bf16_widened + a gfx942 CHECK-GFX942-MI300 RUN line to lower_tensors.mlir, asserting the packed atomicrmw fadd <2 x bf16> with no CAS. All 9 RUN-line prefixes pass. 🧪 Execution Tests: What execution tests were added? For example, a new optimization should be tested with an end-to-end execution test triggering the optimization and asserting correctness. Please provide test cases running with at most 2 GPUs. Copybara import of the project: -- edcb06b by Zoran Jovanovic <zjovanov@amd.com>: [ROCm] Emit packed bf16 atomic add for scatter/segment_sum by matching FloatNormalization conversions. Merging this change closes openxla#44428 COPYBARA_INTEGRATE_REVIEW=openxla#44428 from ROCm:rocm-bf16-atomic-scatter edcb06b PiperOrigin-RevId: 933630040 (cherry picked from commit 97544f7)
This was referenced Jun 24, 2026
Author
|
Superseded by #993, which combines all four ROCm 0.10.2 cherry-pick PRs into a single PR against rocm-jaxlib-v0.10.2. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Requested in #985 review: cherry-pick openxla#44428 onto
rocm-jaxlib-v0.10.2(missing from the JAX 0.10.2 pinned XLA base5a9e73cb; PR merged upstream Jun 17).Commit (cherry-picked with
-x)97544f7a9e— PR [ROCm] Emit packed bf16 atomic add for scatter/segment_sum by matchin… openxla/xla#44428: [ROCm] Emit packed bf16 atomic add for scatter/segment_sumFiles changed
xla/backends/gpu/codegen/emitters/transforms/atomic_rmw_utils.ccxla/backends/gpu/codegen/emitters/transforms/tests/lower_tensors.mlirTest Plan
rocm-jaxlib-v0.10.2; release-validation CI.