You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Same as #5810
Skip tests in test_narrow_precision that use scaled/grouped mm
err msg Exception raised from runGemm at /opt/pytorch/nvfuser/cutlass/nvfp4_scaled_mm.cu:255
This PR consolidates test skip conditions for scaled/grouped matrix multiplication tests to only run on Blackwell GPUs with compute capability 10.0. Four test functions are updated to replace separate is_pre_blackwell() and microarchitecture_is_pre(12) decorators with a unified, more restrictive microarchitecture_is(10, 0) condition. The required microarchitecture_is import is added. This change ensures these GPU-intensive tests only execute on the supported hardware.
Confidence Score: 5/5
This PR is safe to merge with no concerns. The changes are straightforward test decorator updates with no functional code modifications.
Score of 5 reflects: (1) minimal, mechanical changes affecting only test skip conditions; (2) correct consolidation of hardware requirements from two decorators to one, making the intent clearer; (3) proper import addition; (4) changes applied consistently to all 4 affected test functions; (5) no logic errors or unintended side effects. The PR follows the pattern established in related PRs (skip scaled grouped mm test on unsupported arches #5816, skip test cutlass mxfp8_gemm on unsupported arches #5810) for test skipping on unsupported hardware.
No files require special attention
Important Files Changed
Filename
Overview
tests/python/direct/test_narrow_precision.py
Updated skip conditions for scaled/grouped MM tests from two separate decorators (is_pre_blackwell() + microarchitecture_is_pre(12)) to a single unified condition (microarchitecture_is(10, 0)), making requirements more specific and restrictive. Added microarchitecture_is import. Changes applied to 4 test functions, all correctly preserving test logic.
Sequence Diagram
sequenceDiagram
participant Test as Test Execution
participant Decorator as pytest.mark.skipif
participant GPU as GPU Device
participant Logic as Compute Capability Check
Test->>Decorator: Run test with decorator
Decorator->>GPU: Query device properties
GPU->>Logic: Return compute capability major.minor
Logic-->>Decorator: Check if (major == 10 && minor == 0)
alt Matches Blackwell 10.0
Decorator->>Test: Execute test
else Does Not Match
Decorator->>Test: Skip test (unsupported GPU)
end
The new import microarchitecture_is is added but the existing import is_pre_blackwell is still present in the file. Verify that is_pre_blackwell is no longer needed elsewhere in the file to avoid dead imports.
The skip conditions have been made more restrictive (from allowing blackwell and newer devices to only compute capability 10.0). Confirm this change aligns with the actual hardware requirements for scaled/grouped mm operations and doesn't inadvertently skip tests on newer architectures that should be supported.
@pytest.mark.skipif(notmicroarchitecture_is(10, 0), reason="Only supported on blackwell compute 10.0.")@pytest.mark.parametrize("config", [[128, 256, 512], [128, 256, 512]])@pytest.mark.parametrize("out_dtype", [torch.bfloat16])deftest_scaled_mm(
nvfuser_direct_test,
config,
out_dtype,
):
in_dtype=torch.float4_e2m1fn_x2quantization=nvfp4_quantizem, k, n=configmat1_ref=torch.randn((m, k), dtype=torch.float32, device="cuda")
mat2_ref=torch.randn((n, k), dtype=torch.float32, device="cuda")
mat1, scale1, global_sf1=quantization(mat1_ref)
mat2, scale2, global_sf2=quantization(mat2_ref)
alpha=1.0/ (global_sf1*global_sf2)
inputs= [
mat1,
mat2.t(),
linear_to_swizzled_128_4(scale1),
linear_to_swizzled_128_4(scale2),
alpha,
]
defnvfuser_fusion_id0(fd: FusionDefinition) ->None:
mat1=fd.define_tensor(
shape=[-1, -1], contiguity=True, dtype=DataType.Float4_e2m1fn, is_cpu=False
)
mat2=fd.define_tensor(
shape=[-1, -1],
contiguity=True,
dtype=DataType.Float4_e2m1fn,
is_cpu=False,
stride_order=[0, 1],
)
scale1=fd.define_tensor(
shape=[-1, -1], contiguity=True, dtype=DataType.Float8_e4m3fn, is_cpu=False
)
scale2=fd.define_tensor(
shape=[-1, -1], contiguity=True, dtype=DataType.Float8_e4m3fn, is_cpu=False
)
alpha=fd.define_tensor(
shape=[], contiguity=True, dtype=DataType.Float, is_cpu=False
)
out, _, _=fd.ops.scaled_mm(
mat1,
mat2,
scale1,
scale2,
alpha,
bias=None,
beta=None,
dtype=torch_dtype_to_nvfuser_dtype(out_dtype),
)
fd.add_output(out)
outputs, _=nvfuser_direct_test.exec_nvfuser(
nvfuser_fusion_id0, inputs, new_fusion_expected=None
)
ref_outputs= (
torch._scaled_mm(
mat1,
mat2.t(),
linear_to_swizzled_128_4(scale1),
linear_to_swizzled_128_4(scale2),
None,
None,
out_dtype,
)
*alpha
)
torch.testing.assert_close(outputs[0], ref_outputs, rtol=1e-1, atol=1e-2)
@pytest.mark.skipif(notmicroarchitecture_is(10, 0), reason="Only supported on blackwell compute 10.0.")@pytest.mark.parametrize("config", [[1024, 1024, 1024]])@pytest.mark.parametrize("out_dtype", [torch.bfloat16])deftest_scaled_mm_nv_quantized(
nvfuser_direct_test,
config,
out_dtype,
):
"""Test scaled_mm with on-the-fly quantization vs pre-quantized baseline. Compares nvfuser's nv_block_quantize (quantizing mat1 on-the-fly) against a baseline using pre-quantized inputs from Transformer Engine. """m, k, n=configmat1_ref=torch.testing.make_tensor((m, k), dtype=torch.float, device="cuda")
mat2_ref=torch.testing.make_tensor((n, k), dtype=torch.float, device="cuda")
# Quantize both matrices using Transformer Enginemat1_quantized, mat1_scale_inv, global_sf1=extract_te_nvfp4_metadata(mat1_ref)
mat2_quantized, mat2_scale_inv, global_sf2=extract_te_nvfp4_metadata(mat2_ref)
# Alpha compensates for both quantization scalesalpha=1.0/ (global_sf1*global_sf2)
# Prepare inputs for fusion with on-the-fly quantizationinputs_with_quantize= [
mat1_ref,
mat2_quantized.t(),
global_sf1,
linear_to_swizzled_128_4(mat2_scale_inv),
alpha,
]
# Fusion 1: Quantize mat1 on-the-fly using nv_block_quantizedeffusion_with_nv_block_quantize(fd: FusionDefinition) ->None:
"""Defines fusion that quantizes mat1 on-the-fly before scaled_mm."""mat1=fd.define_tensor(
shape=[-1, -1], contiguity=True, dtype=DataType.Float, is_cpu=False
)
mat2_fp4=fd.define_tensor(
shape=[-1, -1],
contiguity=True,
dtype=DataType.Float4_e2m1fn,
is_cpu=False,
stride_order=[0, 1],
)
global_scale=fd.define_tensor(
shape=[], contiguity=True, dtype=DataType.Float, is_cpu=False
)
scale2=fd.define_tensor(
shape=[-1, -1], contiguity=True, dtype=DataType.Float8_e4m3fn, is_cpu=False
)
alpha=fd.define_tensor(
shape=[], contiguity=True, dtype=DataType.Float, is_cpu=False
)
# Quantize mat1 on-the-flymat1_fp4, scale1=fd.ops.nv_block_quantize(mat1, global_scale, True, 16)
# Perform scaled matrix multiplicationout, _, _=fd.ops.scaled_mm(
mat1_fp4,
mat2_fp4,
scale1,
scale2,
alpha,
bias=None,
beta=None,
dtype=torch_dtype_to_nvfuser_dtype(out_dtype),
)
fd.add_output(out)
outputs, _=nvfuser_direct_test.exec_nvfuser(
fusion_with_nv_block_quantize, inputs_with_quantize
)
# Fusion 2: Baseline using pre-quantized inputsinputs_baseline= [
mat1_quantized,
mat2_quantized.t(),
linear_to_swizzled_128_4(mat1_scale_inv),
linear_to_swizzled_128_4(mat2_scale_inv),
alpha,
]
deffusion_baseline(fd: FusionDefinition) ->None:
"""Defines baseline fusion using pre-quantized inputs."""mat1_fp4=fd.define_tensor(
shape=[-1, -1], contiguity=True, dtype=DataType.Float4_e2m1fn, is_cpu=False
)
mat2_fp4=fd.define_tensor(
shape=[-1, -1],
contiguity=True,
dtype=DataType.Float4_e2m1fn,
is_cpu=False,
stride_order=[0, 1],
)
scale1=fd.define_tensor(
shape=[-1, -1], contiguity=True, dtype=DataType.Float8_e4m3fn, is_cpu=False
)
scale2=fd.define_tensor(
shape=[-1, -1], contiguity=True, dtype=DataType.Float8_e4m3fn, is_cpu=False
)
alpha=fd.define_tensor(
shape=[], contiguity=True, dtype=DataType.Float, is_cpu=False
)
out, _, _=fd.ops.scaled_mm(
mat1_fp4,
mat2_fp4,
scale1,
scale2,
alpha,
bias=None,
beta=None,
dtype=torch_dtype_to_nvfuser_dtype(out_dtype),
)
fd.add_output(out)
outputs_baseline, _=nvfuser_direct_test.exec_nvfuser(
fusion_baseline,
inputs_baseline,
new_fusion_expected=None,
)
torch.testing.assert_close(outputs[0], outputs_baseline[0], atol=1e-2, rtol=1e-2)
@pytest.mark.skipif(notmicroarchitecture_is(10, 0), reason="Only supported on blackwell compute 10.0.")@pytest.mark.parametrize("config", [[1024, 128, 256]])@pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]])@pytest.mark.parametrize("out_dtype", [torch.bfloat16])deftest_cutlass_nvfp4_grouped_mm(
nvfuser_direct_test,
config,
tokens_per_expert_neg_one,
out_dtype,
):
BLOCK_SIZE=16# k dimension is multiple of 128 to avoid paddingm, n, k=config# copy list and append tokens for last experttokens_per_expert=list(tokens_per_expert_neg_one)
tokens_per_expert.append(m-sum(tokens_per_expert))
g=len(tokens_per_expert)
mat1_ref=torch.testing.make_tensor((m, k), dtype=torch.float32, device="cuda:0")
# format is g, n, k instead of g, k, nmat2_ref=torch.testing.make_tensor(
(g, n, k), dtype=torch.float32, device="cuda:0"
)
offsets=torch.empty((g,), dtype=torch.int32, device="cuda:0")
blockscale_offsets=torch.empty((g,), dtype=torch.int32, device="cuda:0")
problem_sizes=torch.empty((g, 3), dtype=torch.int32, device="cuda:0")
# prepare quantization for mat2mat2_gs=torch.empty((g,), dtype=torch.float32, device="cuda:0")
scale2=torch.empty(
(g, n, k//BLOCK_SIZE), dtype=torch.float8_e4m3fn, device="cuda:0"
)
acc_tokens=0rounded_acc_tokens=0mat2_scaled=torch.empty(
(g, n, k//2), dtype=torch.float4_e2m1fn_x2, device="cuda:0"
)
foriinrange(g):
global_sf=FLOAT4_E2M1_MAX*FLOAT8_E4M3_MAX/mat2_ref[i].max()
offsets[i] =acc_tokensblockscale_offsets[i] =rounded_acc_tokensacc_tokens+=tokens_per_expert[i]
# Note: we technically don't need to round up, since k is perfectly sized.rounded_acc_tokens+=round_up(tokens_per_expert[i], 128)
problem_sizes[i][0] =tokens_per_expert[i]
problem_sizes[i][1] =nproblem_sizes[i][2] =kscaled_mat2_i, bs_mat2_i=pytorch_nvfp4_quantize(mat2_ref[i], global_sf)
mat2_gs[i] =1.0/global_sfmat2_scaled[i] =scaled_mat2_iscale2[i] =linear_to_swizzled_128_4(bs_mat2_i)
# prepare quantization for mat1# note: following sglang implementation, not computing global scaling factor for mat1# similarly, we don't need to apply mat1_gs to alphamat1_gs=torch.ones((g,), dtype=torch.float32, device="cuda:0")
mat1, scale1=activation_scale_to_nvfp4(
mat1_ref, mat1_gs, offsets, blockscale_offsets, BLOCK_SIZE
)
defnvfuser_fusion_id0(fd: FusionDefinition) ->None:
mat1=fd.define_tensor(
shape=[-1, -1],
contiguity=True,
dtype=DataType.Float4_e2m1fn,
is_cpu=False,
)
mat2=fd.define_tensor(
shape=[-1, -1, -1],
contiguity=True,
dtype=DataType.Float4_e2m1fn,
is_cpu=False,
stride_order=[2, 0, 1],
)
scale1=fd.define_tensor(
shape=[-1, -1],
contiguity=True,
dtype=DataType.Float8_e4m3fn,
is_cpu=False,
)
scale2=fd.define_tensor(
shape=[-1, -1, -1],
contiguity=True,
dtype=DataType.Float8_e4m3fn,
is_cpu=False,
)
alpha=fd.define_tensor(
shape=[-1], contiguity=True, dtype=DataType.Float, is_cpu=False
)
problem_sizes=fd.define_tensor(
shape=[-1, -1], contiguity=True, dtype=DataType.Int32, is_cpu=False
)
offsets=fd.define_tensor(
shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False
)
blockscale_offsets=fd.define_tensor(
shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False
)
out=fd.ops.cutlass_nvfp4_grouped_mm(
mat1,
mat2,
scale1,
scale2,
alpha,
problem_sizes,
offsets,
blockscale_offsets,
DataType.BFloat16,
)
fd.add_output(out)
inputs= [
mat1.view(torch.float4_e2m1fn_x2),
mat2_scaled.view(torch.float4_e2m1fn_x2).transpose(-1, -2),
scale1,
scale2,
mat2_gs,
problem_sizes,
offsets,
blockscale_offsets,
]
outputs, _=nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs)
o_decomposed_ref=torch.empty(m, n, dtype=torch.bfloat16, device="cuda:0")
foriinrange(g):
l=offsets[i]
l_sf=blockscale_offsets[i]
ifi==g-1:
r=melse:
r=offsets[i+1]
r_sf=round_up(tokens_per_expert[i], 128) +l_sf# For some reason I cannot feed mat2_gs[i] as alpha in the torch kernel.# This triggers a cublas invalid value error.o_decomposed_ref[l:r] = (
torch._scaled_mm(
mat1[l:r],
mat2_scaled[i].transpose(-1, -2),
scale1[l_sf:r_sf],
scale2[i],
None,
None,
torch.bfloat16,
)
*mat2_gs[i]
)
torch.testing.assert_close(o_decomposed_ref, outputs[0], atol=1e-2, rtol=1e-2)
@pytest.mark.skipif(is_pre_blackwell(), reason="Only supported on blackwell and newer devices.")@pytest.mark.skipif(notmicroarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0")@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float])deftest_fp4_vectorization(
nvfuser_direct_test,
dtype,
):
inputs= [
torch.ones(4, 8, dtype=dtype, device="cuda"),
torch.ones(4, dtype=dtype, device="cuda"),
]
defnvfuser_fusion_id0(fd: FusionDefinition) ->None:
T0=fd.from_pytorch(inputs[0])
T1=fd.from_pytorch(inputs[1])
T2=fd.ops.cast(T0, DataType.Float)
cast_T1=fd.ops.cast(T1, DataType.Float)
broadcast_T1=fd.ops.broadcast(cast_T1, [False, True])
T3=fd.ops.div(T2, broadcast_T1)
T4=fd.ops.cast(T3, DataType.Float4_e2m1fn)
T5=fd.ops.reshape(T4, [32])
fd.add_output(T5)
outputs, _=nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs)
ref_outputs=to_fp4(inputs[0].to(torch.float) /inputs[1].unsqueeze(-1)).reshape(
-1
)
torch.testing.assert_close(
outputs[0].view(dtype=torch.uint8),
ref_outputs.view(dtype=torch.uint8),
rtol=1e-1,
atol=1e-2,
)
# This is adopted from the decomposed version.# A few things I have to change in order to pass the test:# 1. inputs data needs to be changed from `torch.testing.make_tensor` to `torch.randn`;# 2. output errors are much more relaxed.@pytest.mark.skipif(notmicroarchitecture_is(10, 0), reason="Only supported on blackwell compute 10.0.")@pytest.mark.parametrize("config", [[1024, 128, 256]])@pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]])@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Same as #5810
Skip tests in
test_narrow_precisionthat use scaled/grouped mmerr msg
Exception raised from runGemm at /opt/pytorch/nvfuser/cutlass/nvfp4_scaled_mm.cu:255