Skip to content

[TRITON] Fix unit tests on gfx950 - part 2#2491

Merged
brunomazzottiamd merged 7 commits intomainfrom
bmazzott/fix-gluon-mfma-layout-on-gfx950
Apr 7, 2026
Merged

[TRITON] Fix unit tests on gfx950 - part 2#2491
brunomazzottiamd merged 7 commits intomainfrom
bmazzott/fix-gluon-mfma-layout-on-gfx950

Conversation

@brunomazzottiamd
Copy link
Copy Markdown
Contributor

@brunomazzottiamd brunomazzottiamd commented Mar 26, 2026

Motivation

Triton test suite isn't passing on gfx950. This PR fixes test_gemm_afp4wfp4.py and test_gemm_a8w8.py, slightly improving the situation.

Technical Details

Fix test_gemm_afp4wfp4.py

Triton commit de2ba3946b changed AMDMFMALayout.instr_shape from a 2-element [M, N] to a 3-element [M, N, K] list. Extend the previously 2-element [32, 32] to [32, 32, 64]. K=64 is the K dimension of the mfma_scale_f32_32x32x64_f8f6f4 hardware instruction used for FP4 on gfx950.

Restrict AFP4/WFP4 AOT tests to Triton 3.5. Avoid using prebuilt AOT kernels on newer Triton versions where the metadata format is incompatible.

Fix test_gemm_a8w8.py

Fix _gemm_a8w8_kernel

Same instr_shape API break (Triton de2ba3946b). The kernel uses mfma_scaled for FP8 and plain mfma for INT8, which target different hardware instructions with different K dimensions:

  • FP8 mfma_scale_f32_16x16x128_f8f6f4 (K=128, K_WIDTH=32)
  • INT8 mfma_i32_16x16x64_i8 (K=64, K_WIDTH=16)

SwizzledSharedLayout.vec is updated to match K_WIDTH per data type specialization.

Fix _gemm_a8w8_preshuffled_kernel

The linear_nk layout and its reshape - permute - reshape - trans unshuffle sequence were designed for K=32 / K_WIDTH=16, so applying K=128 breaks the layout conversion. Since mfma_scaled was already invoked with a_scale=None and b_scale=None (per-tensor scale applied to the accumulator separately), replace it with plain mfma, targeting the unscaled mfma_f32_16x16x32_fp8_fp8 (K=32) that the preshuffled layout was built for.

Fix test_gemm_a8w8.py

Relax absolute tolerance from 0.02 to 0.03 to accommodate the preshuffled FP8 path (unscaled dot + software accumulator scale).

Fix ff_a16w16_fused_ungated.py

The k-loop staggers each N-block's start position by k_cyclic_offset = pid_n % cdiv(K, BLOCK_SIZE_K) to reducetl.atomic_add contention on y_ptrs. The y_mask K-boundary check incorrectly used the raw loop counter k (always starting at 0) instead of k_cyclic_offset (the actual K position). When the cyclic offset is non-zero, k understates the real offset, producing a wrong mask and corrupting partial sums near the K boundary.

Compatibility fixes for older Gluon API (Triton < 3.6.0)

This PR also implements compatibility for old Gluon API, supporting Gluon of a Triton compiler older than version 3.6.

Test Plan

Run respective tests on gfx950:

for t in \
op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py \
op_tests/triton_tests/gemm/basic/test_gemm_a8w8.py \
op_tests/triton_tests/test_pa_decode_gluon.py \
; do echo "${t}"; pytest -q --no-header "${t}" | tail -1; done

The tests should pass on latest Triton TOT and Triton 3.5.0 (< 3.6.0).

Test Result

test_gemm_afp4wfp4.py, test_gemm_a8w8.py and test_pa_decode_gluon.py pass on gfx950.

TOT Triton - all test cases:

op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py
10656 passed, 3680 skipped in 570.54s (0:09:30)
op_tests/triton_tests/gemm/basic/test_gemm_a8w8.py
9382 passed, 1856 skipped in 1024.83s (0:17:04)
op_tests/triton_tests/test_pa_decode_gluon.py
4 passed, 1 warning in 618.54s (0:10:18)

Triton 3.5.0 - only test cases of Gluon kernels, to check compatibility with older API:

op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py
3584 passed, 3584 skipped in 223.31s (0:03:43)
op_tests/triton_tests/gemm/basic/test_gemm_a8w8.py
6414 passed, 1272 skipped in 689.21s (0::1::29)
op_tests/triton_tests/test_pa_decode_gluon.py
4 passed, 1 warning in 618.00s (0:10:18)

Execution of Gluon kernels only was achieved thought the following patch:

```diff
diff --git a/op_tests/triton_tests/gemm/basic/test_gemm_a8w8.py b/op_tests/triton_tests/gemm/basic/test_gemm_a8w8.py
index 10d8f99d8..2010a2ec0 100644
--- a/op_tests/triton_tests/gemm/basic/test_gemm_a8w8.py
+++ b/op_tests/triton_tests/gemm/basic/test_gemm_a8w8.py
@@ -175,7 +175,6 @@ def generate_gemm_a8w8_inputs(
 @pytest.mark.parametrize(
     "impl",
     [
-        "triton",
         "gluon",
         "gluon_shuffle",
     ],
diff --git a/op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py b/op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py
index 8a4953811..671f516ef 100644
--- a/op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py
+++ b/op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py
@@ -241,7 +241,7 @@ def run_triton(
     [True, False],
 )
 @pytest.mark.parametrize("skip_reduce", [True, False])
-@pytest.mark.parametrize("impl", ["triton", "gluon"])
+@pytest.mark.parametrize("impl", ["gluon"])
 def test_gemm_afp4_wfp4(
     M: int,
     N: int,

Submission Checklist

@brunomazzottiamd brunomazzottiamd self-assigned this Mar 26, 2026
@brunomazzottiamd brunomazzottiamd requested a review from a team March 26, 2026 20:09
@brunomazzottiamd brunomazzottiamd added bug Something isn't working triton ci:all labels Mar 26, 2026
@github-actions

This comment was marked as spam.

@azaidy

This comment was marked as resolved.

@brunomazzottiamd

This comment was marked as resolved.

@brunomazzottiamd brunomazzottiamd force-pushed the bmazzott/fix-gluon-mfma-layout-on-gfx950 branch from 1fa2e5a to 3564334 Compare March 27, 2026 18:06
@brunomazzottiamd

This comment was marked as outdated.

@brunomazzottiamd brunomazzottiamd force-pushed the bmazzott/fix-gluon-mfma-layout-on-gfx950 branch 4 times, most recently from e4fae56 to 67bde7c Compare March 30, 2026 19:22
@brunomazzottiamd

This comment was marked as resolved.

@gyohuangxin

This comment was marked as resolved.

@brunomazzottiamd

This comment was marked as resolved.

@brunomazzottiamd brunomazzottiamd force-pushed the bmazzott/fix-gluon-mfma-layout-on-gfx950 branch from 32f1746 to d2035ae Compare March 31, 2026 16:44
@brunomazzottiamd

This comment was marked as outdated.

@brunomazzottiamd brunomazzottiamd force-pushed the bmazzott/fix-gluon-mfma-layout-on-gfx950 branch from 7fec184 to 9a16391 Compare April 1, 2026 13:41
@brunomazzottiamd

This comment was marked as outdated.

azaidy

This comment was marked as outdated.

@azaidy

This comment was marked as outdated.

@brunomazzottiamd brunomazzottiamd force-pushed the bmazzott/fix-gluon-mfma-layout-on-gfx950 branch from 9a16391 to 6cd4cc0 Compare April 1, 2026 18:36
@brunomazzottiamd

This comment was marked as outdated.

@brunomazzottiamd brunomazzottiamd force-pushed the bmazzott/fix-gluon-mfma-layout-on-gfx950 branch from 6cd4cc0 to 4bd8467 Compare April 1, 2026 20:41
@brunomazzottiamd brunomazzottiamd requested a review from azaidy April 1, 2026 20:41
@brunomazzottiamd brunomazzottiamd force-pushed the bmazzott/fix-gluon-mfma-layout-on-gfx950 branch from df7ac2c to a5c73f9 Compare April 2, 2026 12:57
@brunomazzottiamd

This comment was marked as resolved.

@brunomazzottiamd

This comment was marked as resolved.

@gyohuangxin

This comment was marked as resolved.

@brunomazzottiamd

This comment was marked as resolved.

@brunomazzottiamd

This comment was marked as resolved.

@brunomazzottiamd

This comment was marked as resolved.

@brunomazzottiamd brunomazzottiamd force-pushed the bmazzott/fix-gluon-mfma-layout-on-gfx950 branch 3 times, most recently from 8986535 to dd453ba Compare April 7, 2026 14:39
Triton commit de2ba3946b ("[AMD] Refactor mfma layout") changed
`AMDMFMALayout.instr_shape` from a 2-element `[M, N]` to a 3-element
`[M, N, K]` list. Extend the previously 2-element `[32, 32]` to
`[32, 32, 64]`. K=64 is the K dimension of the
`mfma_scale_f32_32x32x64_f8f6f4` hardware instruction used for FP4
on `gfx950`.
* Fix `_gemm_a8w8_kernel`:
  Same `instr_shape` API break (Triton de2ba3946b). The kernel uses
  `mfma_scaled` for FP8 and plain `mfma` for INT8, which target
  different hardware instructions with different K dimensions:
    - FP8 `mfma_scale_f32_16x16x128_f8f6f4` (K=128, K_WIDTH=32)
    - INT8 `mfma_i32_16x16x64_i8` (K=64, K_WIDTH=16)
  `SwizzledSharedLayout.vec` is updated to match K_WIDTH per data type
  specialisation.

* Fix `_gemm_a8w8_preshuffled_kernel`:
  The `linear_nk` layout and its `reshape - permute - reshape - trans`
  unshuffle sequence were designed for K=32 / K_WIDTH=16, so applying
  K=128 breaks the layout conversion. Since `mfma_scaled` was already
  invoked with `a_scale=None` and `b_scale=None` (per-tensor scale
  applied to the accumulator separately), replace it with plain `mfma`,
  targeting the unscaled `mfma_f32_16x16x32_fp8_fp8` (K=32) that the
  preshuffled layout was built for.

* Fix `test_gemm_a8w8.py`:
  Relax absolute tolerance from 0.02 to 0.03 to accommodate the
  preshuffled FP8 path (unscaled dot + software accumulator scale).
This aspect should be also used by other Gluon kernels, namely
`gemm_afp4wfp4.py` and `gemm_a8w8.py`.
* Restrict AFP4/WFP4 AOT tests to Triton 3.5. Avoid using prebuilt AOT kernels
on newer Triton versions where the metadata format is incompatible.
* Support Gluon API for Triton compiler older than 3.6.
* Conditionally skip some cases of `test_gemm_a8w8.py::test_gemm_splitk` on
  Triton 3.5. Ragged FP8 split-K lowering fails in Triton 3.5.
The k-loop staggers each N-block's start position by
`k_cyclic_offset = pid_n % cdiv(K, BLOCK_SIZE_K)` to reduce`tl.atomic_add`
contention on `y_ptrs`. The `y_mask` K-boundary check incorrectly used the raw
loop counter `k` (always starting at 0) instead of `k_cyclic_offset` (the actual
K position). When the cyclic offset is non-zero, `k` understates the real
offset, producing a wrong mask and corrupting partial sums near the K boundary.
Replace `k` with `k_cyclic_offset`, consistent with the analogous bound already
used in the `w2` load mask.
Copy link
Copy Markdown
Contributor

@azaidy azaidy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@brunomazzottiamd
Copy link
Copy Markdown
Contributor Author

The only test failures are the expected ones, i.e. test_mha_backward on Triton Tests (MI35X) / Shard 0:

=========================== short test summary info ============================
FAILED op_tests/triton_tests/attention/test_mha.py::test_mha_backward[True-False-0.0-True-128-8-32-512-512-1]
FAILED op_tests/triton_tests/attention/test_mha.py::test_mha_backward[True-False-0.0-True-128-8-32-512-512-4]
FAILED op_tests/triton_tests/attention/test_mha.py::test_mha_backward[True-False-0.0-True-128-8-32-1024-1024-1]
FAILED op_tests/triton_tests/attention/test_mha.py::test_mha_backward[True-False-0.0-True-128-8-32-1024-1024-4]
FAILED op_tests/triton_tests/attention/test_mha.py::test_mha_backward[True-False-0.0-True-128-8-32-2048-2048-1]
FAILED op_tests/triton_tests/attention/test_mha.py::test_mha_backward[True-False-0.0-True-128-8-32-2048-2048-4]
FAILED op_tests/triton_tests/attention/test_mha.py::test_mha_backward[True-False-0.0-True-128-8-64-512-512-1]
FAILED op_tests/triton_tests/attention/test_mha.py::test_mha_backward[True-False-0.0-True-128-8-64-512-512-4]
FAILED op_tests/triton_tests/attention/test_mha.py::test_mha_backward[True-False-0.0-True-128-8-64-1024-1024-1]
FAILED op_tests/triton_tests/attention/test_mha.py::test_mha_backward[True-False-0.0-True-128-8-64-1024-1024-4]
FAILED op_tests/triton_tests/attention/test_mha.py::test_mha_backward[True-False-0.0-True-128-8-64-2048-2048-1]
FAILED op_tests/triton_tests/attention/test_mha.py::test_mha_backward[True-False-0.0-True-128-8-64-2048-2048-4]
==== 12 failed, 5269 passed, 2016 skipped, 6 warnings in 3920.96s (1:05:20) ====

These failures aren't in the scope of this PR. We're good to merge, everything else passed.

FYI: @gyohuangxin

@brunomazzottiamd brunomazzottiamd merged commit 957c1aa into main Apr 7, 2026
65 of 69 checks passed
@brunomazzottiamd brunomazzottiamd deleted the bmazzott/fix-gluon-mfma-layout-on-gfx950 branch April 7, 2026 20:57
yzhou103 pushed a commit that referenced this pull request Apr 8, 2026
* Fix `test_gemm_afp4wfp4.py`

Triton commit de2ba3946b ("[AMD] Refactor mfma layout") changed
`AMDMFMALayout.instr_shape` from a 2-element `[M, N]` to a 3-element
`[M, N, K]` list. Extend the previously 2-element `[32, 32]` to
`[32, 32, 64]`. K=64 is the K dimension of the
`mfma_scale_f32_32x32x64_f8f6f4` hardware instruction used for FP4
on `gfx950`.

* Fix `test_gemm_a8w8.py`

* Fix `_gemm_a8w8_kernel`:
  Same `instr_shape` API break (Triton de2ba3946b). The kernel uses
  `mfma_scaled` for FP8 and plain `mfma` for INT8, which target
  different hardware instructions with different K dimensions:
    - FP8 `mfma_scale_f32_16x16x128_f8f6f4` (K=128, K_WIDTH=32)
    - INT8 `mfma_i32_16x16x64_i8` (K=64, K_WIDTH=16)
  `SwizzledSharedLayout.vec` is updated to match K_WIDTH per data type
  specialisation.

* Fix `_gemm_a8w8_preshuffled_kernel`:
  The `linear_nk` layout and its `reshape - permute - reshape - trans`
  unshuffle sequence were designed for K=32 / K_WIDTH=16, so applying
  K=128 breaks the layout conversion. Since `mfma_scaled` was already
  invoked with `a_scale=None` and `b_scale=None` (per-tensor scale
  applied to the accumulator separately), replace it with plain `mfma`,
  targeting the unscaled `mfma_f32_16x16x32_fp8_fp8` (K=32) that the
  preshuffled layout was built for.

* Fix `test_gemm_a8w8.py`:
  Relax absolute tolerance from 0.02 to 0.03 to accommodate the
  preshuffled FP8 path (unscaled dot + software accumulator scale).

* Refactor Triton version detection logic out of `pa_decode_gluon.py`

This aspect should be also used by other Gluon kernels, namely
`gemm_afp4wfp4.py` and `gemm_a8w8.py`.

* Fix `test_gemm_afp4wfp4.py`

* Restrict AFP4/WFP4 AOT tests to Triton 3.5. Avoid using prebuilt AOT kernels
on newer Triton versions where the metadata format is incompatible.

* Implement compatibility for old Gluon API

* Support Gluon API for Triton compiler older than 3.6.
* Conditionally skip some cases of `test_gemm_a8w8.py::test_gemm_splitk` on
  Triton 3.5. Ragged FP8 split-K lowering fails in Triton 3.5.

* Fix `ff_a16w16_fused_ungated.py`

The k-loop staggers each N-block's start position by
`k_cyclic_offset = pid_n % cdiv(K, BLOCK_SIZE_K)` to reduce`tl.atomic_add`
contention on `y_ptrs`. The `y_mask` K-boundary check incorrectly used the raw
loop counter `k` (always starting at 0) instead of `k_cyclic_offset` (the actual
K position). When the cyclic offset is non-zero, `k` understates the real
offset, producing a wrong mask and corrupting partial sums near the K boundary.
Replace `k` with `k_cyclic_offset`, consistent with the analogous bound already
used in the `w2` load mask.

* Set RNG seed in `test_pa_decode.py`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants