diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index e8ff44fd40986..dead0bda1c0c3 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -46,6 +46,7 @@ parametrize, run_tests, skipIfRocm, + skipIfRocmVersionAndArch, skipIfRocmVersionLessThan, TEST_CUDA, TEST_WITH_ROCM, @@ -1197,6 +1198,7 @@ def test_float8_scale_fast_accum(self, device) -> None: out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) self.assertEqual(out_fp8, out_fp8_s) + @skipIfRocmVersionAndArch((7, 1), "gfx950") @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific") @@ -1304,6 +1306,7 @@ def test_float8_error_messages(self, device) -> None: out_dtype=torch.bfloat16, ) + @skipIfRocmVersionAndArch((7, 1), "gfx950") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific") @parametrize("base_dtype", [torch.bfloat16]) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 052a968d51e22..58398f5287000 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -2024,6 +2024,23 @@ def wrap_fn(self, *args, **kwargs): return wrap_fn return dec_fn +def skipIfRocmVersionAndArch(version=None, arch=None): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if TEST_WITH_ROCM: + rocm_version = str(torch.version.hip) + rocm_version = rocm_version.split("-")[0] # ignore git sha + rocm_version_tuple = tuple(int(x) for x in rocm_version.split(".")) + if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version): + prop = torch.cuda.get_device_properties(0) + if prop.gcnArchName.split(":")[0] in arch: + reason = f"ROCm {version} and {arch} combination not supported" + raise unittest.SkipTest(reason) + return fn(self, *args, **kwargs) + return wrap_fn + return dec_fn + def skipIfNotMiopenSuggestNHWC(fn): @wraps(fn) def wrapper(*args, **kwargs):