skip test cutlass mxfp8_gemm on unsupported arches#5810
Conversation
Description
|
| Relevant files | |||
|---|---|---|---|
| Enhancement |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Logic Equivalence
microarchitecture_is(10, 0) provides equivalent logic to the original condition compute_cap < (10, 0) or compute_cap >= (12, 0). The new condition should only allow compute capabilities 10.x and 11.x, excluding < 10.0 and >= 12.0. |
Greptile OverviewGreptile SummaryChangesThis PR refactors the architecture compatibility check in Before: compute_cap = torch.cuda.get_device_capability()
if compute_cap < (10, 0) or compute_cap >= (12, 0):
pytest.skip(...)After: if not microarchitecture_is(10, 0):
pytest.skip(...)Key Observations
ImpactIf MxFP8 GEMM operations were intended to work on architectures 10.1-10.3 (e.g., B300/GB300 at 10.3) or 11.x, this change will skip those tests. If the restriction to exactly 10.0 is intentional for safety/testing reasons, this aligns with the conservative approach used in Confidence Score: 3/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant Test as test_cutlass_mxfp8_gemm.py
participant DirectUtils as python.direct_utils
participant CUDA as torch.cuda
participant Pytest as pytest
Note over Test: Module load time
Test->>DirectUtils: import microarchitecture_is
Test->>DirectUtils: microarchitecture_is(10, 0)
DirectUtils->>CUDA: get_device_properties(current_device())
CUDA-->>DirectUtils: device properties (major, minor)
DirectUtils->>DirectUtils: Check: major == 10 AND minor == 0
alt Architecture is NOT exactly 10.0
DirectUtils-->>Test: False
Test->>Pytest: pytest.skip(allow_module_level=True)
Note over Test,Pytest: Module skipped - tests won't run
else Architecture is exactly 10.0
DirectUtils-->>Test: True
Note over Test: Continue loading module
Test->>DirectUtils: import linear_to_swizzled_128_4, swizzled_to_linear_128_4
Note over Test: Tests will run
end
|
|
|
||
| compute_cap = torch.cuda.get_device_capability() | ||
| if compute_cap < (10, 0) or compute_cap >= (12, 0): | ||
| if not microarchitecture_is(10, 0): |
There was a problem hiding this comment.
[P1] Logic change narrows supported architectures more than original
The original condition compute_cap < (10, 0) or compute_cap >= (12, 0) would skip tests outside the range [10.0, 12.0), meaning it would allow compute capabilities like 10.1, 10.2, 10.3, 11.0, 11.5, etc.
The new condition not microarchitecture_is(10, 0) only allows exactly 10.0 (major==10 AND minor==0), and skips everything else including 10.1, 10.2, 10.3, and all 11.x architectures.
If the intent was to maintain the original behavior while using the utility function, consider using a range check or adding support for additional architectures that were previously supported (10.1-10.3, 11.x). If restricting to exactly 10.0 is intentional, this is fine but represents a significant behavior change from the original code.
Greptile OverviewGreptile SummaryThis PR updates the architecture compatibility check for MxFp8 GEMM tests, restricting execution to compute capability 10.0 only. This follows the pattern established in PR #5757 for test_cutlass_gemm.py. Key Changes
Behavioral ChangeBefore: Test ran on any compute capability from 10.0 up to (but excluding) 12.0 This is a conservative approach that prevents potential failures on untested architectures, consistent with the test_cutlass_gemm.py pattern. Observations
Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant Test as test_mxfp8_gemm
participant Utils as python.direct_utils
participant PyTorch as torch.cuda
participant Pytest as pytest
Note over Test: Module Load Time
Test->>Utils: import microarchitecture_is
Test->>Utils: microarchitecture_is(10, 0)
Utils->>PyTorch: get_device_properties(current_device())
PyTorch-->>Utils: device_properties
Utils->>Utils: Check if major == 10 and minor == 0
alt Compute capability is exactly 10.0
Utils-->>Test: True
Note over Test: Test execution proceeds
else Compute capability is NOT 10.0
Utils-->>Test: False
Test->>Pytest: pytest.skip(allow_module_level=True)
Note over Test,Pytest: Entire test module skipped<br/>with reason message
end
Note over Test: Previous Behavior (before PR)
Note over Test: Used torch.cuda.get_device_capability()<br/>directly and checked range [10.0, 12.0)
|
| if compute_cap < (10, 0) or compute_cap >= (12, 0): | ||
| if not microarchitecture_is(10, 0): | ||
| pytest.skip( | ||
| reason="MxFp8 Requires compute capability 10.", |
There was a problem hiding this comment.
The error message "MxFp8 Requires compute capability 10." is misleading. The new logic using microarchitecture_is(10, 0) only runs on exactly compute capability 10.0, not all 10.x architectures.
Previous behavior: Allowed 10.0, 10.1, 10.2, 10.3, 11.x (anything from 10.0 up to but excluding 12.0)
New behavior: Only allows 10.0
Consider updating the message to be more specific, e.g., "MxFp8 requires compute capability 10.0. Other architectures have not been tested." This matches the pattern used in test_cutlass_gemm.py which has detailed comments about tested vs untested architectures.
| reason="MxFp8 Requires compute capability 10.", | |
| reason="MxFp8 requires compute capability 10.0. Other architectures have not been tested.", |
|
|
||
| compute_cap = torch.cuda.get_device_capability() | ||
| if compute_cap < (10, 0) or compute_cap >= (12, 0): | ||
| if not microarchitecture_is(10, 0): |
There was a problem hiding this comment.
Consider adding explanatory comments about which architectures are supported/tested, similar to test_cutlass_gemm.py (lines 12-16). This would help future maintainers understand why the test is restricted to compute capability 10.0 only.
Example from test_cutlass_gemm.py:
# GPU Compute Capability: https://developer.nvidia.com/cuda/gpus
# tested on blackwell compute 10.0 (B200 and GB200)
# doesn't support 12.0 (RTX PRO 6000 and RTX 50XX)
# Not tested on 10.3 (B300 and GB300)
# Not tested on 12.1 (DGX Spark)
if not microarchitecture_is(10, 0):This documentation helps clarify the intentional restriction and provides context for when the restriction might be relaxed in the future.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| import pytest | ||
| import torch | ||
| from nvfuser_direct import nvf_cutlass | ||
| from python.direct_utils import microarchitecture_is |
There was a problem hiding this comment.
Note: test_cutlass_nvfp4_gemm.py still uses the old pattern with torch.cuda.get_device_capability() and allows all architectures in the range [10.0, 12.0). For consistency, consider updating that test file as well if NVFP4 should also be restricted to only tested architectures.
Current state in test_cutlass_nvfp4_gemm.py:
compute_cap = torch.cuda.get_device_capability()
if compute_cap < (10, 0) or compute_cap >= (12, 0):
pytest.skip(
reason="Nvfp4 Requires compute capability 10.",
allow_module_level=True,
)If NVFP4 has the same testing limitations as MxFp8 and grouped_mm, it should follow the same pattern for maintainability and clarity.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
!build |
Same as #5810 err msg `Exception raised from runGemm at /opt/pytorch/nvfuser/cutlass/nvfp4_scaled_mm.cu:255`
Same as #5810 err msg `Exception raised from run_nvfp4_scaled_group_mm at /opt/pytorch/nvfuser/cutlass/nvfp4_scaled_group_mm.cu:518`
Same as #5810 Skip tests in `test_narrow_precision` that use scaled/grouped mm err msg `Exception raised from runGemm at /opt/pytorch/nvfuser/cutlass/nvfp4_scaled_mm.cu:255`
err msg on unsupported hardwares:
Exception raised from runGemm at /opt/pytorch/nvfuser/cutlass/mxfp8_scaled_mm.cu:262