[TRITON] Fix unit tests on gfx950 - part 2#2491
Merged
brunomazzottiamd merged 7 commits intomainfrom Apr 7, 2026
Merged
Conversation
This comment was marked as spam.
This comment was marked as spam.
11 tasks
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
1fa2e5a to
3564334
Compare
This comment was marked as outdated.
This comment was marked as outdated.
e4fae56 to
67bde7c
Compare
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
32f1746 to
d2035ae
Compare
This comment was marked as outdated.
This comment was marked as outdated.
7fec184 to
9a16391
Compare
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
9a16391 to
6cd4cc0
Compare
This comment was marked as outdated.
This comment was marked as outdated.
6cd4cc0 to
4bd8467
Compare
df7ac2c to
a5c73f9
Compare
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
a5c73f9 to
067cc79
Compare
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
8986535 to
dd453ba
Compare
Triton commit de2ba3946b ("[AMD] Refactor mfma layout") changed
`AMDMFMALayout.instr_shape` from a 2-element `[M, N]` to a 3-element
`[M, N, K]` list. Extend the previously 2-element `[32, 32]` to
`[32, 32, 64]`. K=64 is the K dimension of the
`mfma_scale_f32_32x32x64_f8f6f4` hardware instruction used for FP4
on `gfx950`.
* Fix `_gemm_a8w8_kernel`:
Same `instr_shape` API break (Triton de2ba3946b). The kernel uses
`mfma_scaled` for FP8 and plain `mfma` for INT8, which target
different hardware instructions with different K dimensions:
- FP8 `mfma_scale_f32_16x16x128_f8f6f4` (K=128, K_WIDTH=32)
- INT8 `mfma_i32_16x16x64_i8` (K=64, K_WIDTH=16)
`SwizzledSharedLayout.vec` is updated to match K_WIDTH per data type
specialisation.
* Fix `_gemm_a8w8_preshuffled_kernel`:
The `linear_nk` layout and its `reshape - permute - reshape - trans`
unshuffle sequence were designed for K=32 / K_WIDTH=16, so applying
K=128 breaks the layout conversion. Since `mfma_scaled` was already
invoked with `a_scale=None` and `b_scale=None` (per-tensor scale
applied to the accumulator separately), replace it with plain `mfma`,
targeting the unscaled `mfma_f32_16x16x32_fp8_fp8` (K=32) that the
preshuffled layout was built for.
* Fix `test_gemm_a8w8.py`:
Relax absolute tolerance from 0.02 to 0.03 to accommodate the
preshuffled FP8 path (unscaled dot + software accumulator scale).
This aspect should be also used by other Gluon kernels, namely `gemm_afp4wfp4.py` and `gemm_a8w8.py`.
* Restrict AFP4/WFP4 AOT tests to Triton 3.5. Avoid using prebuilt AOT kernels on newer Triton versions where the metadata format is incompatible.
* Support Gluon API for Triton compiler older than 3.6. * Conditionally skip some cases of `test_gemm_a8w8.py::test_gemm_splitk` on Triton 3.5. Ragged FP8 split-K lowering fails in Triton 3.5.
The k-loop staggers each N-block's start position by `k_cyclic_offset = pid_n % cdiv(K, BLOCK_SIZE_K)` to reduce`tl.atomic_add` contention on `y_ptrs`. The `y_mask` K-boundary check incorrectly used the raw loop counter `k` (always starting at 0) instead of `k_cyclic_offset` (the actual K position). When the cyclic offset is non-zero, `k` understates the real offset, producing a wrong mask and corrupting partial sums near the K boundary. Replace `k` with `k_cyclic_offset`, consistent with the analogous bound already used in the `w2` load mask.
Contributor
Author
|
The only test failures are the expected ones, i.e. These failures aren't in the scope of this PR. We're good to merge, everything else passed. FYI: @gyohuangxin |
yzhou103
pushed a commit
that referenced
this pull request
Apr 8, 2026
* Fix `test_gemm_afp4wfp4.py`
Triton commit de2ba3946b ("[AMD] Refactor mfma layout") changed
`AMDMFMALayout.instr_shape` from a 2-element `[M, N]` to a 3-element
`[M, N, K]` list. Extend the previously 2-element `[32, 32]` to
`[32, 32, 64]`. K=64 is the K dimension of the
`mfma_scale_f32_32x32x64_f8f6f4` hardware instruction used for FP4
on `gfx950`.
* Fix `test_gemm_a8w8.py`
* Fix `_gemm_a8w8_kernel`:
Same `instr_shape` API break (Triton de2ba3946b). The kernel uses
`mfma_scaled` for FP8 and plain `mfma` for INT8, which target
different hardware instructions with different K dimensions:
- FP8 `mfma_scale_f32_16x16x128_f8f6f4` (K=128, K_WIDTH=32)
- INT8 `mfma_i32_16x16x64_i8` (K=64, K_WIDTH=16)
`SwizzledSharedLayout.vec` is updated to match K_WIDTH per data type
specialisation.
* Fix `_gemm_a8w8_preshuffled_kernel`:
The `linear_nk` layout and its `reshape - permute - reshape - trans`
unshuffle sequence were designed for K=32 / K_WIDTH=16, so applying
K=128 breaks the layout conversion. Since `mfma_scaled` was already
invoked with `a_scale=None` and `b_scale=None` (per-tensor scale
applied to the accumulator separately), replace it with plain `mfma`,
targeting the unscaled `mfma_f32_16x16x32_fp8_fp8` (K=32) that the
preshuffled layout was built for.
* Fix `test_gemm_a8w8.py`:
Relax absolute tolerance from 0.02 to 0.03 to accommodate the
preshuffled FP8 path (unscaled dot + software accumulator scale).
* Refactor Triton version detection logic out of `pa_decode_gluon.py`
This aspect should be also used by other Gluon kernels, namely
`gemm_afp4wfp4.py` and `gemm_a8w8.py`.
* Fix `test_gemm_afp4wfp4.py`
* Restrict AFP4/WFP4 AOT tests to Triton 3.5. Avoid using prebuilt AOT kernels
on newer Triton versions where the metadata format is incompatible.
* Implement compatibility for old Gluon API
* Support Gluon API for Triton compiler older than 3.6.
* Conditionally skip some cases of `test_gemm_a8w8.py::test_gemm_splitk` on
Triton 3.5. Ragged FP8 split-K lowering fails in Triton 3.5.
* Fix `ff_a16w16_fused_ungated.py`
The k-loop staggers each N-block's start position by
`k_cyclic_offset = pid_n % cdiv(K, BLOCK_SIZE_K)` to reduce`tl.atomic_add`
contention on `y_ptrs`. The `y_mask` K-boundary check incorrectly used the raw
loop counter `k` (always starting at 0) instead of `k_cyclic_offset` (the actual
K position). When the cyclic offset is non-zero, `k` understates the real
offset, producing a wrong mask and corrupting partial sums near the K boundary.
Replace `k` with `k_cyclic_offset`, consistent with the analogous bound already
used in the `w2` load mask.
* Set RNG seed in `test_pa_decode.py`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Triton test suite isn't passing on
gfx950. This PR fixestest_gemm_afp4wfp4.pyandtest_gemm_a8w8.py, slightly improving the situation.Technical Details
Fix
test_gemm_afp4wfp4.pyTriton commit
de2ba3946bchangedAMDMFMALayout.instr_shapefrom a 2-element[M, N]to a 3-element[M, N, K]list. Extend the previously 2-element[32, 32]to[32, 32, 64]. K=64 is the K dimension of themfma_scale_f32_32x32x64_f8f6f4hardware instruction used for FP4 ongfx950.Restrict AFP4/WFP4 AOT tests to Triton 3.5. Avoid using prebuilt AOT kernels on newer Triton versions where the metadata format is incompatible.
Fix
test_gemm_a8w8.pyFix
_gemm_a8w8_kernelSame
instr_shapeAPI break (Tritonde2ba3946b). The kernel usesmfma_scaledfor FP8 and plainmfmafor INT8, which target different hardware instructions with different K dimensions:mfma_scale_f32_16x16x128_f8f6f4(K=128, K_WIDTH=32)mfma_i32_16x16x64_i8(K=64, K_WIDTH=16)SwizzledSharedLayout.vecis updated to match K_WIDTH per data type specialization.Fix
_gemm_a8w8_preshuffled_kernelThe
linear_nklayout and itsreshape - permute - reshape - transunshuffle sequence were designed for K=32 / K_WIDTH=16, so applying K=128 breaks the layout conversion. Sincemfma_scaledwas already invoked witha_scale=Noneandb_scale=None(per-tensor scale applied to the accumulator separately), replace it with plainmfma, targeting the unscaledmfma_f32_16x16x32_fp8_fp8(K=32) that the preshuffled layout was built for.Fix
test_gemm_a8w8.pyRelax absolute tolerance from 0.02 to 0.03 to accommodate the preshuffled FP8 path (unscaled dot + software accumulator scale).
Fix
ff_a16w16_fused_ungated.pyThe k-loop staggers each N-block's start position by
k_cyclic_offset = pid_n % cdiv(K, BLOCK_SIZE_K)to reducetl.atomic_addcontention ony_ptrs. They_maskK-boundary check incorrectly used the raw loop counterk(always starting at 0) instead ofk_cyclic_offset(the actual K position). When the cyclic offset is non-zero,kunderstates the real offset, producing a wrong mask and corrupting partial sums near the K boundary.Compatibility fixes for older Gluon API (Triton < 3.6.0)
This PR also implements compatibility for old Gluon API, supporting Gluon of a Triton compiler older than version 3.6.
Test Plan
Run respective tests on
gfx950:The tests should pass on latest Triton TOT and Triton 3.5.0 (< 3.6.0).
Test Result
test_gemm_afp4wfp4.py,test_gemm_a8w8.pyandtest_pa_decode_gluon.pypass ongfx950.TOT Triton - all test cases:
Triton 3.5.0 - only test cases of Gluon kernels, to check compatibility with older API:
Execution of Gluon kernels only was achieved thought the following patch:
Submission Checklist