From 30d22a1cf728bebb4f1226d1aaba94fc2e950480 Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Thu, 7 Aug 2025 16:10:01 -0700 Subject: [PATCH] [rocm7.0_internal_testing]fp8: optimize skip rowwise tests Skip based on ROCm version and gfx type. Signed-off-by: Jagadish Krishnamoorthy --- test/test_matmul_cuda.py | 6 +++--- torch/testing/_internal/common_utils.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 4f7b028bcad3f..e0f64c70c1c33 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -46,7 +46,7 @@ parametrize, run_tests, skipIfRocm, - skipIfRocmArch, + skipIfRocmVersionAndArch, skipIfRocmVersionLessThan, TEST_CUDA, TEST_WITH_ROCM, @@ -908,7 +908,7 @@ def test_float8_scale_fast_accum(self, device) -> None: out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) self.assertEqual(out_fp8, out_fp8_s) - @skipIfRocmArch("gfx950") + @skipIfRocmVersionAndArch((7, 1), "gfx950") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific") @parametrize("use_fast_accum", [True, False]) @@ -1014,7 +1014,7 @@ def test_float8_error_messages(self, device) -> None: out_dtype=torch.bfloat16, ) - @skipIfRocmArch("gfx950") + @skipIfRocmVersionAndArch((7, 1), "gfx950") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific") @parametrize("base_dtype", [torch.bfloat16]) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 1c73d74454cd2..10bcdff10b097 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1987,6 +1987,23 @@ def wrap_fn(self, *args, **kwargs): return wrap_fn return dec_fn +def skipIfRocmVersionAndArch(version=None, arch=None): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if TEST_WITH_ROCM: + rocm_version = str(torch.version.hip) + rocm_version = rocm_version.split("-")[0] # ignore git sha + rocm_version_tuple = tuple(int(x) for x in rocm_version.split(".")) + if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version): + prop = torch.cuda.get_device_properties(0) + if prop.gcnArchName.split(":")[0] in arch: + reason = f"ROCm {version} and {arch} combination not supported" + raise unittest.SkipTest(reason) + return fn(self, *args, **kwargs) + return wrap_fn + return dec_fn + def skipIfNotMiopenSuggestNHWC(fn): @wraps(fn) def wrapper(*args, **kwargs):