diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 19a09ec96e5be..4f7b028bcad3f 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -46,6 +46,7 @@ parametrize, run_tests, skipIfRocm, + skipIfRocmArch, skipIfRocmVersionLessThan, TEST_CUDA, TEST_WITH_ROCM, @@ -907,6 +908,7 @@ def test_float8_scale_fast_accum(self, device) -> None: out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) self.assertEqual(out_fp8, out_fp8_s) + @skipIfRocmArch("gfx950") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific") @parametrize("use_fast_accum", [True, False]) @@ -1012,6 +1014,7 @@ def test_float8_error_messages(self, device) -> None: out_dtype=torch.bfloat16, ) + @skipIfRocmArch("gfx950") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific") @parametrize("base_dtype", [torch.bfloat16]) @@ -1382,6 +1385,7 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, sqnr = compute_error(C_ref, C) assert sqnr.item() > approx_match_sqnr_target + @skipIfRocm @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @parametrize("recipe", ["mxfp8", "nvfp4"]) def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None: @@ -1615,6 +1619,7 @@ def test_blockwise_mxfp8_compile(self) -> None: ) torch.testing.assert_close(C, C_ref, atol=0, rtol=0) + @skipIfRocm @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg) def test_blockwise_nvfp4_compile(self) -> None: