Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions test/test_matmul_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
parametrize,
run_tests,
skipIfRocm,
skipIfRocmVersionAndArch,
skipIfRocmVersionLessThan,
TEST_CUDA,
TEST_WITH_ROCM,
Expand Down Expand Up @@ -1197,6 +1198,7 @@ def test_float8_scale_fast_accum(self, device) -> None:
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
self.assertEqual(out_fp8, out_fp8_s)

@skipIfRocmVersionAndArch((7, 1), "gfx950")
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
Expand Down Expand Up @@ -1304,6 +1306,7 @@ def test_float8_error_messages(self, device) -> None:
out_dtype=torch.bfloat16,
)

@skipIfRocmVersionAndArch((7, 1), "gfx950")
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
@parametrize("base_dtype", [torch.bfloat16])
Expand Down
17 changes: 17 additions & 0 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2024,6 +2024,23 @@ def wrap_fn(self, *args, **kwargs):
return wrap_fn
return dec_fn

def skipIfRocmVersionAndArch(version=None, arch=None):
def dec_fn(fn):
@wraps(fn)
def wrap_fn(self, *args, **kwargs):
if TEST_WITH_ROCM:
rocm_version = str(torch.version.hip)
rocm_version = rocm_version.split("-")[0] # ignore git sha
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
prop = torch.cuda.get_device_properties(0)
if prop.gcnArchName.split(":")[0] in arch:
reason = f"ROCm {version} and {arch} combination not supported"
raise unittest.SkipTest(reason)
return fn(self, *args, **kwargs)
return wrap_fn
return dec_fn

def skipIfNotMiopenSuggestNHWC(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
Expand Down