Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/test_matmul_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
parametrize,
run_tests,
skipIfRocm,
skipIfRocmArch,
skipIfRocmVersionAndArch,
skipIfRocmVersionLessThan,
TEST_CUDA,
TEST_WITH_ROCM,
Expand Down Expand Up @@ -908,7 +908,7 @@ def test_float8_scale_fast_accum(self, device) -> None:
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
self.assertEqual(out_fp8, out_fp8_s)

@skipIfRocmArch("gfx950")
@skipIfRocmVersionAndArch((7, 1), "gfx950")
Copy link

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new decorator expects an iterable for the arch parameter (based on the in operator usage), but a string is passed. This should be a list: @skipIfRocmVersionAndArch((7, 1), ["gfx950"])

Suggested change
@skipIfRocmVersionAndArch((7, 1), "gfx950")
@skipIfRocmVersionAndArch((7, 1), ["gfx950"])

Copilot uses AI. Check for mistakes.
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
@parametrize("use_fast_accum", [True, False])
Expand Down Expand Up @@ -1014,7 +1014,7 @@ def test_float8_error_messages(self, device) -> None:
out_dtype=torch.bfloat16,
)

@skipIfRocmArch("gfx950")
@skipIfRocmVersionAndArch((7, 1), "gfx950")
Copy link

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new decorator expects an iterable for the arch parameter (based on the in operator usage), but a string is passed. This should be a list: @skipIfRocmVersionAndArch((7, 1), ["gfx950"])

Suggested change
@skipIfRocmVersionAndArch((7, 1), "gfx950")
@skipIfRocmVersionAndArch((7, 1), ["gfx950"])

Copilot uses AI. Check for mistakes.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"in" operator can also used for string comparison !

@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
@parametrize("base_dtype", [torch.bfloat16])
Expand Down
17 changes: 17 additions & 0 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1987,6 +1987,23 @@ def wrap_fn(self, *args, **kwargs):
return wrap_fn
return dec_fn

def skipIfRocmVersionAndArch(version=None, arch=None):
def dec_fn(fn):
@wraps(fn)
def wrap_fn(self, *args, **kwargs):
if TEST_WITH_ROCM:
rocm_version = str(torch.version.hip)
rocm_version = rocm_version.split("-")[0] # ignore git sha
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
Copy link

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition rocm_version_tuple is None can never be true since rocm_version_tuple is always assigned a tuple value on line 1997. This check should be removed or the logic needs to be restructured.

Suggested change
if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
if version is None or rocm_version_tuple < tuple(version):

Copilot uses AI. Check for mistakes.
prop = torch.cuda.get_device_properties(0)
if prop.gcnArchName.split(":")[0] in arch:
Copy link

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will raise an exception if arch is None, but the function parameter allows None values. Add a null check: if arch and prop.gcnArchName.split(":")[0] in arch:

Suggested change
if prop.gcnArchName.split(":")[0] in arch:
if arch and prop.gcnArchName.split(":")[0] in arch:

Copilot uses AI. Check for mistakes.
reason = f"ROCm {version} and {arch} combination not supported"
raise unittest.SkipTest(reason)
return fn(self, *args, **kwargs)
return wrap_fn
return dec_fn

def skipIfNotMiopenSuggestNHWC(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
Expand Down