Skip to content

skip scaled/grouped mm related tests on unsupported gpus#5847

Merged
liqiangxl merged 1 commit intomainfrom
llu/skip_narrow_precision_on_some_arches
Jan 20, 2026
Merged

skip scaled/grouped mm related tests on unsupported gpus#5847
liqiangxl merged 1 commit intomainfrom
llu/skip_narrow_precision_on_some_arches

Conversation

@liqiangxl
Copy link
Copy Markdown
Collaborator

@liqiangxl liqiangxl commented Jan 20, 2026

Same as #5810
Skip tests in test_narrow_precision that use scaled/grouped mm
err msg Exception raised from runGemm at /opt/pytorch/nvfuser/cutlass/nvfp4_scaled_mm.cu:255

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jan 20, 2026

Greptile Summary

This PR consolidates test skip conditions for scaled/grouped matrix multiplication tests to only run on Blackwell GPUs with compute capability 10.0. Four test functions are updated to replace separate is_pre_blackwell() and microarchitecture_is_pre(12) decorators with a unified, more restrictive microarchitecture_is(10, 0) condition. The required microarchitecture_is import is added. This change ensures these GPU-intensive tests only execute on the supported hardware.

Confidence Score: 5/5

  • This PR is safe to merge with no concerns. The changes are straightforward test decorator updates with no functional code modifications.
  • Score of 5 reflects: (1) minimal, mechanical changes affecting only test skip conditions; (2) correct consolidation of hardware requirements from two decorators to one, making the intent clearer; (3) proper import addition; (4) changes applied consistently to all 4 affected test functions; (5) no logic errors or unintended side effects. The PR follows the pattern established in related PRs (skip scaled grouped mm test on unsupported arches #5816, skip test cutlass mxfp8_gemm on unsupported arches #5810) for test skipping on unsupported hardware.
  • No files require special attention

Important Files Changed

Filename Overview
tests/python/direct/test_narrow_precision.py Updated skip conditions for scaled/grouped MM tests from two separate decorators (is_pre_blackwell() + microarchitecture_is_pre(12)) to a single unified condition (microarchitecture_is(10, 0)), making requirements more specific and restrictive. Added microarchitecture_is import. Changes applied to 4 test functions, all correctly preserving test logic.

Sequence Diagram

sequenceDiagram
    participant Test as Test Execution
    participant Decorator as pytest.mark.skipif
    participant GPU as GPU Device
    participant Logic as Compute Capability Check

    Test->>Decorator: Run test with decorator
    Decorator->>GPU: Query device properties
    GPU->>Logic: Return compute capability major.minor
    Logic-->>Decorator: Check if (major == 10 && minor == 0)
    alt Matches Blackwell 10.0
        Decorator->>Test: Execute test
    else Does Not Match
        Decorator->>Test: Skip test (unsupported GPU)
    end

Loading

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jan 20, 2026

Auto-merge Status

❌ Internal CI is finished (pending)
✅ No failed checks
❌ PR is mergeable (blocked)
ℹ️ PR mergeable_state: blocked

Description

  • Update skip conditions for scaled/grouped mm tests to use precise compute capability check

  • Replace dual skip conditions with single microarchitecture_is(10, 0) check

  • Ensure tests only run on supported Blackwell compute 10.0 architectures

  • Import microarchitecture_is utility function for consistent architecture detection

Changes walkthrough

Relevant files
Bug fix
test_narrow_precision.py
Update GPU skip conditions for narrow precision tests       

tests/python/direct/test_narrow_precision.py

  • Add import for microarchitecture_is utility function
  • Update skip conditions for test_scaled_mm to use precise compute
    capability check
  • Update skip conditions for test_scaled_mm_nv_quantized with consistent
    architecture requirement
  • Update skip conditions for test_grouped_mm to match new architecture
    detection pattern
  • Update skip conditions for test_grouped_mm_nv_quantized with unified
    skip logic
  • +5/-13   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Import consistency

    The new import microarchitecture_is is added but the existing import is_pre_blackwell is still present in the file. Verify that is_pre_blackwell is no longer needed elsewhere in the file to avoid dead imports.

    microarchitecture_is,
    Test coverage validation

    The skip conditions have been made more restrictive (from allowing blackwell and newer devices to only compute capability 10.0). Confirm this change aligns with the actual hardware requirements for scaled/grouped mm operations and doesn't inadvertently skip tests on newer architectures that should be supported.

    @pytest.mark.skipif(
        not microarchitecture_is(10, 0), reason="Only supported on blackwell compute 10.0."
    )
    @pytest.mark.parametrize("config", [[128, 256, 512], [128, 256, 512]])
    @pytest.mark.parametrize("out_dtype", [torch.bfloat16])
    def test_scaled_mm(
        nvfuser_direct_test,
        config,
        out_dtype,
    ):
        in_dtype = torch.float4_e2m1fn_x2
        quantization = nvfp4_quantize
    
        m, k, n = config
        mat1_ref = torch.randn((m, k), dtype=torch.float32, device="cuda")
        mat2_ref = torch.randn((n, k), dtype=torch.float32, device="cuda")
    
        mat1, scale1, global_sf1 = quantization(mat1_ref)
        mat2, scale2, global_sf2 = quantization(mat2_ref)
        alpha = 1.0 / (global_sf1 * global_sf2)
    
        inputs = [
            mat1,
            mat2.t(),
            linear_to_swizzled_128_4(scale1),
            linear_to_swizzled_128_4(scale2),
            alpha,
        ]
    
        def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
            mat1 = fd.define_tensor(
                shape=[-1, -1], contiguity=True, dtype=DataType.Float4_e2m1fn, is_cpu=False
            )
            mat2 = fd.define_tensor(
                shape=[-1, -1],
                contiguity=True,
                dtype=DataType.Float4_e2m1fn,
                is_cpu=False,
                stride_order=[0, 1],
            )
            scale1 = fd.define_tensor(
                shape=[-1, -1], contiguity=True, dtype=DataType.Float8_e4m3fn, is_cpu=False
            )
            scale2 = fd.define_tensor(
                shape=[-1, -1], contiguity=True, dtype=DataType.Float8_e4m3fn, is_cpu=False
            )
            alpha = fd.define_tensor(
                shape=[], contiguity=True, dtype=DataType.Float, is_cpu=False
            )
            out, _, _ = fd.ops.scaled_mm(
                mat1,
                mat2,
                scale1,
                scale2,
                alpha,
                bias=None,
                beta=None,
                dtype=torch_dtype_to_nvfuser_dtype(out_dtype),
            )
            fd.add_output(out)
    
        outputs, _ = nvfuser_direct_test.exec_nvfuser(
            nvfuser_fusion_id0, inputs, new_fusion_expected=None
        )
    
        ref_outputs = (
            torch._scaled_mm(
                mat1,
                mat2.t(),
                linear_to_swizzled_128_4(scale1),
                linear_to_swizzled_128_4(scale2),
                None,
                None,
                out_dtype,
            )
            * alpha
        )
        torch.testing.assert_close(outputs[0], ref_outputs, rtol=1e-1, atol=1e-2)
    
    
    @pytest.mark.skipif(
        not microarchitecture_is(10, 0), reason="Only supported on blackwell compute 10.0."
    )
    @pytest.mark.parametrize("config", [[1024, 1024, 1024]])
    @pytest.mark.parametrize("out_dtype", [torch.bfloat16])
    def test_scaled_mm_nv_quantized(
        nvfuser_direct_test,
        config,
        out_dtype,
    ):
        """Test scaled_mm with on-the-fly quantization vs pre-quantized baseline.
    
        Compares nvfuser's nv_block_quantize (quantizing mat1 on-the-fly) against
        a baseline using pre-quantized inputs from Transformer Engine.
        """
        m, k, n = config
        mat1_ref = torch.testing.make_tensor((m, k), dtype=torch.float, device="cuda")
        mat2_ref = torch.testing.make_tensor((n, k), dtype=torch.float, device="cuda")
    
        # Quantize both matrices using Transformer Engine
        mat1_quantized, mat1_scale_inv, global_sf1 = extract_te_nvfp4_metadata(mat1_ref)
        mat2_quantized, mat2_scale_inv, global_sf2 = extract_te_nvfp4_metadata(mat2_ref)
    
        # Alpha compensates for both quantization scales
        alpha = 1.0 / (global_sf1 * global_sf2)
    
        # Prepare inputs for fusion with on-the-fly quantization
        inputs_with_quantize = [
            mat1_ref,
            mat2_quantized.t(),
            global_sf1,
            linear_to_swizzled_128_4(mat2_scale_inv),
            alpha,
        ]
    
        # Fusion 1: Quantize mat1 on-the-fly using nv_block_quantize
        def fusion_with_nv_block_quantize(fd: FusionDefinition) -> None:
            """Defines fusion that quantizes mat1 on-the-fly before scaled_mm."""
            mat1 = fd.define_tensor(
                shape=[-1, -1], contiguity=True, dtype=DataType.Float, is_cpu=False
            )
            mat2_fp4 = fd.define_tensor(
                shape=[-1, -1],
                contiguity=True,
                dtype=DataType.Float4_e2m1fn,
                is_cpu=False,
                stride_order=[0, 1],
            )
            global_scale = fd.define_tensor(
                shape=[], contiguity=True, dtype=DataType.Float, is_cpu=False
            )
            scale2 = fd.define_tensor(
                shape=[-1, -1], contiguity=True, dtype=DataType.Float8_e4m3fn, is_cpu=False
            )
            alpha = fd.define_tensor(
                shape=[], contiguity=True, dtype=DataType.Float, is_cpu=False
            )
    
            # Quantize mat1 on-the-fly
            mat1_fp4, scale1 = fd.ops.nv_block_quantize(mat1, global_scale, True, 16)
    
            # Perform scaled matrix multiplication
            out, _, _ = fd.ops.scaled_mm(
                mat1_fp4,
                mat2_fp4,
                scale1,
                scale2,
                alpha,
                bias=None,
                beta=None,
                dtype=torch_dtype_to_nvfuser_dtype(out_dtype),
            )
            fd.add_output(out)
    
        outputs, _ = nvfuser_direct_test.exec_nvfuser(
            fusion_with_nv_block_quantize, inputs_with_quantize
        )
    
        # Fusion 2: Baseline using pre-quantized inputs
        inputs_baseline = [
            mat1_quantized,
            mat2_quantized.t(),
            linear_to_swizzled_128_4(mat1_scale_inv),
            linear_to_swizzled_128_4(mat2_scale_inv),
            alpha,
        ]
    
        def fusion_baseline(fd: FusionDefinition) -> None:
            """Defines baseline fusion using pre-quantized inputs."""
            mat1_fp4 = fd.define_tensor(
                shape=[-1, -1], contiguity=True, dtype=DataType.Float4_e2m1fn, is_cpu=False
            )
            mat2_fp4 = fd.define_tensor(
                shape=[-1, -1],
                contiguity=True,
                dtype=DataType.Float4_e2m1fn,
                is_cpu=False,
                stride_order=[0, 1],
            )
            scale1 = fd.define_tensor(
                shape=[-1, -1], contiguity=True, dtype=DataType.Float8_e4m3fn, is_cpu=False
            )
            scale2 = fd.define_tensor(
                shape=[-1, -1], contiguity=True, dtype=DataType.Float8_e4m3fn, is_cpu=False
            )
            alpha = fd.define_tensor(
                shape=[], contiguity=True, dtype=DataType.Float, is_cpu=False
            )
    
            out, _, _ = fd.ops.scaled_mm(
                mat1_fp4,
                mat2_fp4,
                scale1,
                scale2,
                alpha,
                bias=None,
                beta=None,
                dtype=torch_dtype_to_nvfuser_dtype(out_dtype),
            )
            fd.add_output(out)
    
        outputs_baseline, _ = nvfuser_direct_test.exec_nvfuser(
            fusion_baseline,
            inputs_baseline,
            new_fusion_expected=None,
        )
    
        torch.testing.assert_close(outputs[0], outputs_baseline[0], atol=1e-2, rtol=1e-2)
    
    
    @pytest.mark.skipif(
        not microarchitecture_is(10, 0), reason="Only supported on blackwell compute 10.0."
    )
    @pytest.mark.parametrize("config", [[1024, 128, 256]])
    @pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]])
    @pytest.mark.parametrize("out_dtype", [torch.bfloat16])
    def test_cutlass_nvfp4_grouped_mm(
        nvfuser_direct_test,
        config,
        tokens_per_expert_neg_one,
        out_dtype,
    ):
        BLOCK_SIZE = 16
    
        # k dimension is multiple of 128 to avoid padding
        m, n, k = config
        # copy list and append tokens for last expert
        tokens_per_expert = list(tokens_per_expert_neg_one)
        tokens_per_expert.append(m - sum(tokens_per_expert))
        g = len(tokens_per_expert)
    
        mat1_ref = torch.testing.make_tensor((m, k), dtype=torch.float32, device="cuda:0")
        # format is g, n, k instead of g, k, n
        mat2_ref = torch.testing.make_tensor(
            (g, n, k), dtype=torch.float32, device="cuda:0"
        )
    
        offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0")
        blockscale_offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0")
        problem_sizes = torch.empty((g, 3), dtype=torch.int32, device="cuda:0")
    
        # prepare quantization for mat2
        mat2_gs = torch.empty((g,), dtype=torch.float32, device="cuda:0")
        scale2 = torch.empty(
            (g, n, k // BLOCK_SIZE), dtype=torch.float8_e4m3fn, device="cuda:0"
        )
    
        acc_tokens = 0
        rounded_acc_tokens = 0
        mat2_scaled = torch.empty(
            (g, n, k // 2), dtype=torch.float4_e2m1fn_x2, device="cuda:0"
        )
    
        for i in range(g):
            global_sf = FLOAT4_E2M1_MAX * FLOAT8_E4M3_MAX / mat2_ref[i].max()
            offsets[i] = acc_tokens
            blockscale_offsets[i] = rounded_acc_tokens
            acc_tokens += tokens_per_expert[i]
            # Note: we technically don't need to round up, since k is perfectly sized.
            rounded_acc_tokens += round_up(tokens_per_expert[i], 128)
    
            problem_sizes[i][0] = tokens_per_expert[i]
            problem_sizes[i][1] = n
            problem_sizes[i][2] = k
    
            scaled_mat2_i, bs_mat2_i = pytorch_nvfp4_quantize(mat2_ref[i], global_sf)
            mat2_gs[i] = 1.0 / global_sf
            mat2_scaled[i] = scaled_mat2_i
            scale2[i] = linear_to_swizzled_128_4(bs_mat2_i)
    
        # prepare quantization for mat1
        # note: following sglang implementation, not computing global scaling factor for mat1
        #       similarly, we don't need to apply mat1_gs to alpha
        mat1_gs = torch.ones((g,), dtype=torch.float32, device="cuda:0")
        mat1, scale1 = activation_scale_to_nvfp4(
            mat1_ref, mat1_gs, offsets, blockscale_offsets, BLOCK_SIZE
        )
    
        def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
            mat1 = fd.define_tensor(
                shape=[-1, -1],
                contiguity=True,
                dtype=DataType.Float4_e2m1fn,
                is_cpu=False,
            )
            mat2 = fd.define_tensor(
                shape=[-1, -1, -1],
                contiguity=True,
                dtype=DataType.Float4_e2m1fn,
                is_cpu=False,
                stride_order=[2, 0, 1],
            )
            scale1 = fd.define_tensor(
                shape=[-1, -1],
                contiguity=True,
                dtype=DataType.Float8_e4m3fn,
                is_cpu=False,
            )
            scale2 = fd.define_tensor(
                shape=[-1, -1, -1],
                contiguity=True,
                dtype=DataType.Float8_e4m3fn,
                is_cpu=False,
            )
            alpha = fd.define_tensor(
                shape=[-1], contiguity=True, dtype=DataType.Float, is_cpu=False
            )
            problem_sizes = fd.define_tensor(
                shape=[-1, -1], contiguity=True, dtype=DataType.Int32, is_cpu=False
            )
            offsets = fd.define_tensor(
                shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False
            )
            blockscale_offsets = fd.define_tensor(
                shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False
            )
            out = fd.ops.cutlass_nvfp4_grouped_mm(
                mat1,
                mat2,
                scale1,
                scale2,
                alpha,
                problem_sizes,
                offsets,
                blockscale_offsets,
                DataType.BFloat16,
            )
            fd.add_output(out)
    
        inputs = [
            mat1.view(torch.float4_e2m1fn_x2),
            mat2_scaled.view(torch.float4_e2m1fn_x2).transpose(-1, -2),
            scale1,
            scale2,
            mat2_gs,
            problem_sizes,
            offsets,
            blockscale_offsets,
        ]
    
        outputs, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs)
    
        o_decomposed_ref = torch.empty(m, n, dtype=torch.bfloat16, device="cuda:0")
        for i in range(g):
            l = offsets[i]
            l_sf = blockscale_offsets[i]
            if i == g - 1:
                r = m
            else:
                r = offsets[i + 1]
            r_sf = round_up(tokens_per_expert[i], 128) + l_sf
            # For some reason I cannot feed mat2_gs[i] as alpha in the torch kernel.
            # This triggers a cublas invalid value error.
            o_decomposed_ref[l:r] = (
                torch._scaled_mm(
                    mat1[l:r],
                    mat2_scaled[i].transpose(-1, -2),
                    scale1[l_sf:r_sf],
                    scale2[i],
                    None,
                    None,
                    torch.bfloat16,
                )
                * mat2_gs[i]
            )
    
        torch.testing.assert_close(o_decomposed_ref, outputs[0], atol=1e-2, rtol=1e-2)
    
    
    @pytest.mark.skipif(
        is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
    )
    @pytest.mark.skipif(
        not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0"
    )
    @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float])
    def test_fp4_vectorization(
        nvfuser_direct_test,
        dtype,
    ):
        inputs = [
            torch.ones(4, 8, dtype=dtype, device="cuda"),
            torch.ones(4, dtype=dtype, device="cuda"),
        ]
    
        def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
            T0 = fd.from_pytorch(inputs[0])
            T1 = fd.from_pytorch(inputs[1])
            T2 = fd.ops.cast(T0, DataType.Float)
            cast_T1 = fd.ops.cast(T1, DataType.Float)
            broadcast_T1 = fd.ops.broadcast(cast_T1, [False, True])
            T3 = fd.ops.div(T2, broadcast_T1)
            T4 = fd.ops.cast(T3, DataType.Float4_e2m1fn)
            T5 = fd.ops.reshape(T4, [32])
            fd.add_output(T5)
    
        outputs, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs)
    
        ref_outputs = to_fp4(inputs[0].to(torch.float) / inputs[1].unsqueeze(-1)).reshape(
            -1
        )
    
        torch.testing.assert_close(
            outputs[0].view(dtype=torch.uint8),
            ref_outputs.view(dtype=torch.uint8),
            rtol=1e-1,
            atol=1e-2,
        )
    
    
    # This is adopted from the decomposed version.
    # A few things I have to change in order to pass the test:
    #     1. inputs data needs to be changed from `torch.testing.make_tensor` to `torch.randn`;
    #     2. output errors are much more relaxed.
    @pytest.mark.skipif(
        not microarchitecture_is(10, 0), reason="Only supported on blackwell compute 10.0."
    )
    @pytest.mark.parametrize("config", [[1024, 128, 256]])
    @pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]])
    @pytest.mark.parametrize("out_dtype", [torch.bfloat16])

    @liqiangxl liqiangxl requested a review from protonu January 20, 2026 13:06
    Copy link
    Copy Markdown
    Collaborator

    @protonu protonu left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM

    @liqiangxl
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    @liqiangxl liqiangxl added the enable-auto-merge Auto-merge a PR when: 1) PR mergeable 2) Internal CI complete 3) No failures label Jan 20, 2026
    @liqiangxl liqiangxl merged commit 35cd8a7 into main Jan 20, 2026
    62 of 63 checks passed
    @liqiangxl liqiangxl deleted the llu/skip_narrow_precision_on_some_arches branch January 20, 2026 18:10
    @github-actions github-actions Bot removed the enable-auto-merge Auto-merge a PR when: 1) PR mergeable 2) Internal CI complete 3) No failures label Jan 20, 2026
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants