From 50b748bb1ec752dbf819af323131542794d7e46f Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Wed, 30 Jul 2025 17:28:44 -0700 Subject: [PATCH 1/2] [rocm7.0_internal_testing] mx fpx: Skip test_blockwise_mxfloatx_error_messages Issue is tracked by SWDEV-535267 Signed-off-by: Jagadish Krishnamoorthy --- test/test_matmul_cuda.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 19a09ec96e5be..6bdb4fdc6e0bc 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -1382,6 +1382,7 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, sqnr = compute_error(C_ref, C) assert sqnr.item() > approx_match_sqnr_target + @skipIfRocm @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @parametrize("recipe", ["mxfp8", "nvfp4"]) def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None: @@ -1615,6 +1616,7 @@ def test_blockwise_mxfp8_compile(self) -> None: ) torch.testing.assert_close(C, C_ref, atol=0, rtol=0) + @skipIfRocm @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg) def test_blockwise_nvfp4_compile(self) -> None: From 1639e123d22a21a1645541f939ce0196c68c7f68 Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Fri, 1 Aug 2025 12:37:35 -0700 Subject: [PATCH 2/2] [rocm7.0_internal_testing] fp8: skip rowwise scaling tests fp8 rowwise scaling is not supported on ROCm 7.0, works on mainline. Skip the test for now. Tracked by https://ontrack-internal.amd.com/browse/SWDEV-532820 Signed-off-by: Jagadish Krishnamoorthy --- test/test_matmul_cuda.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 6bdb4fdc6e0bc..4f7b028bcad3f 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -46,6 +46,7 @@ parametrize, run_tests, skipIfRocm, + skipIfRocmArch, skipIfRocmVersionLessThan, TEST_CUDA, TEST_WITH_ROCM, @@ -907,6 +908,7 @@ def test_float8_scale_fast_accum(self, device) -> None: out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) self.assertEqual(out_fp8, out_fp8_s) + @skipIfRocmArch("gfx950") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific") @parametrize("use_fast_accum", [True, False]) @@ -1012,6 +1014,7 @@ def test_float8_error_messages(self, device) -> None: out_dtype=torch.bfloat16, ) + @skipIfRocmArch("gfx950") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific") @parametrize("base_dtype", [torch.bfloat16])