Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/python/direct/test_cutlass_mxfp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import pytest
import torch
from nvfuser_direct import nvf_cutlass
from python.direct_utils import microarchitecture_is
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: test_cutlass_nvfp4_gemm.py still uses the old pattern with torch.cuda.get_device_capability() and allows all architectures in the range [10.0, 12.0). For consistency, consider updating that test file as well if NVFP4 should also be restricted to only tested architectures.

Current state in test_cutlass_nvfp4_gemm.py:

compute_cap = torch.cuda.get_device_capability()
if compute_cap < (10, 0) or compute_cap >= (12, 0):
    pytest.skip(
        reason="Nvfp4 Requires compute capability 10.",
        allow_module_level=True,
    )

If NVFP4 has the same testing limitations as MxFp8 and grouped_mm, it should follow the same pattern for maintainability and clarity.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


compute_cap = torch.cuda.get_device_capability()
if compute_cap < (10, 0) or compute_cap >= (12, 0):
if not microarchitecture_is(10, 0):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] Logic change narrows supported architectures more than original

The original condition compute_cap < (10, 0) or compute_cap >= (12, 0) would skip tests outside the range [10.0, 12.0), meaning it would allow compute capabilities like 10.1, 10.2, 10.3, 11.0, 11.5, etc.

The new condition not microarchitecture_is(10, 0) only allows exactly 10.0 (major==10 AND minor==0), and skips everything else including 10.1, 10.2, 10.3, and all 11.x architectures.

If the intent was to maintain the original behavior while using the utility function, consider using a range check or adding support for additional architectures that were previously supported (10.1-10.3, 11.x). If restricting to exactly 10.0 is intentional, this is fine but represents a significant behavior change from the original code.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding explanatory comments about which architectures are supported/tested, similar to test_cutlass_gemm.py (lines 12-16). This would help future maintainers understand why the test is restricted to compute capability 10.0 only.

Example from test_cutlass_gemm.py:

# GPU Compute Capability: https://developer.nvidia.com/cuda/gpus
# tested on blackwell compute 10.0 (B200 and GB200)
# doesn't support 12.0 (RTX PRO 6000 and RTX 50XX)
# Not tested on 10.3 (B300 and GB300)
# Not tested on 12.1 (DGX Spark)
if not microarchitecture_is(10, 0):

This documentation helps clarify the intentional restriction and provides context for when the restriction might be relaxed in the future.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

pytest.skip(
reason="MxFp8 Requires compute capability 10.",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message "MxFp8 Requires compute capability 10." is misleading. The new logic using microarchitecture_is(10, 0) only runs on exactly compute capability 10.0, not all 10.x architectures.

Previous behavior: Allowed 10.0, 10.1, 10.2, 10.3, 11.x (anything from 10.0 up to but excluding 12.0)
New behavior: Only allows 10.0

Consider updating the message to be more specific, e.g., "MxFp8 requires compute capability 10.0. Other architectures have not been tested." This matches the pattern used in test_cutlass_gemm.py which has detailed comments about tested vs untested architectures.

Suggested change
reason="MxFp8 Requires compute capability 10.",
reason="MxFp8 requires compute capability 10.0. Other architectures have not been tested.",

allow_module_level=True,
Expand Down