diff --git a/modelopt/torch/kernels/quantization/gemm/__init__.py b/modelopt/torch/kernels/quantization/gemm/__init__.py index 39b07b4faa..70f729cffb 100644 --- a/modelopt/torch/kernels/quantization/gemm/__init__.py +++ b/modelopt/torch/kernels/quantization/gemm/__init__.py @@ -32,6 +32,7 @@ # fp4_kernel works on any CUDA GPU with triton from .fp4_kernel import * from .fp8_kernel import * + from .nvfp4_fp8_sweep import * # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) if torch.cuda.get_device_capability() >= (8, 9): diff --git a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py new file mode 100644 index 0000000000..49e4839a3c --- /dev/null +++ b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep. + +Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single +kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates +and emits the per-block ``best_amax`` directly. + +The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see +:func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on +the per-block scale is the identity, so the kernel can use +``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it +runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement). + +Tile shape (``BLOCKS_PER_PROGRAM``) and ``num_warps`` are autotuned per ``N_BLOCKS``. +""" + +import torch +import triton +import triton.language as tl + +from .nvfp4_quant import fp4_round_magnitude + +__all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"] + + +def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor: + """Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.""" + uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) + fp8_values = uint8_values.view(torch.float8_e4m3fn).float() + valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) + return fp8_values[valid_mask] / 448.0 + + +# Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300: +# BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms +# The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64 +# would underfill the SMs. +_FP8_SWEEP_AUTOTUNE_CONFIGS = [ + triton.Config({"BLOCKS_PER_PROGRAM": 16}, num_warps=2), + triton.Config({"BLOCKS_PER_PROGRAM": 32}, num_warps=4), + triton.Config({"BLOCKS_PER_PROGRAM": 64}, num_warps=8), +] + + +@triton.autotune(configs=_FP8_SWEEP_AUTOTUNE_CONFIGS, key=["N_BLOCKS"]) +@triton.jit +def _fp8_scale_sweep_kernel( + x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32) + candidates_ptr, # [NUM_CANDIDATES] fp32 + global_amax_ptr, # scalar fp32 + best_amax_ptr, # [N_BLOCKS] fp32 output + N_BLOCKS, + BLOCK_SIZE: tl.constexpr, + NUM_CANDIDATES: tl.constexpr, + BLOCKS_PER_PROGRAM: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCKS_PER_PROGRAM + block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM) + block_mask = block_idx < N_BLOCKS + + # Load weights for this tile and pre-compute their absolute values once. + # The squared error is sign-invariant since FP4 quant preserves sign: + # (w - w_q)^2 = (|w| - |w_q|)^2 = (|w| - q_mag * scale)^2 + # so we never need ``w`` itself again, dropping a tl.where + negation per element. + elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :] + elem_mask = block_mask[:, None] + w_abs = tl.abs(tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32)) + + global_amax = tl.load(global_amax_ptr).to(tl.float32) + + best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32) + best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32) + + # Loop over the 126 FP8 candidates (compile-time unrolled). + # Scales are guaranteed positive and finite (constructed from a positive candidate + # times nonneg global_amax), so the degenerate-scale guard from nvfp4_scalar_quant is + # unnecessary apart from the global_amax == 0 case handled below. + for k in tl.static_range(NUM_CANDIDATES): + c = tl.load(candidates_ptr + k).to(tl.float32) + scale = c * global_amax / 6.0 + # Avoid divide-by-zero when global_amax == 0; the resulting err == w_abs² is + # the same for every candidate, so any best_idx is fine. + scale_safe = tl.where(scale == 0.0, 1.0, scale) + q_mag = fp4_round_magnitude(w_abs / scale_safe) + diff = w_abs - q_mag * scale + loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM] + is_better = loss < best_loss + best_loss = tl.where(is_better, loss, best_loss) + best_idx = tl.where(is_better, k, best_idx) + + # Map each block's winning candidate index back to its amax = global_amax * c[best]. + best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32) + best_amax = global_amax * best_c + tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask) + + +def nvfp4_fp8_scale_sweep( + x: torch.Tensor, + global_amax: torch.Tensor, + block_size: int = 16, +) -> torch.Tensor: + """Find the per-block FP8 scale that minimizes NVFP4 quantization MSE. + + Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into + a single Triton kernel: every block's weight elements are loaded once, all 126 + candidates are evaluated in registers, and the running argmin is kept inline. + + Args: + x: Weight tensor on CUDA. Total element count must be divisible by + ``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``. + global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``). + block_size: NVFP4 block size (typically 16). + + Returns: + ``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``. + """ + if not x.is_cuda: + raise ValueError("nvfp4_fp8_scale_sweep requires a CUDA tensor.") + if not isinstance(block_size, int) or block_size <= 0: + raise ValueError(f"block_size must be a positive int, got {block_size!r}.") + if x.numel() % block_size != 0: + raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).") + + candidates = fp8_scale_candidates(x.device).to(dtype=torch.float32) + + n_blocks = x.numel() // block_size + x_flat = x.contiguous().view(-1) + global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1) + best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device) + + grid = lambda meta: (triton.cdiv(n_blocks, meta["BLOCKS_PER_PROGRAM"]),) + with torch.cuda.device(x.device): + _fp8_scale_sweep_kernel[grid]( + x_flat, + candidates, + global_amax_f32, + best_amax, + n_blocks, + BLOCK_SIZE=block_size, + NUM_CANDIDATES=int(candidates.numel()), + ) + return best_amax diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 1f439a7e77..79961a0b67 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -24,7 +24,7 @@ from .. import utils as quant_utils from .calibrator import _Calibrator -__all__ = ["MseCalibrator", "NVFP4MSECalibrator"] +__all__ = ["MseCalibrator", "NVFP4MSECalibrator", "TritonNVFP4MSECalibrator"] class MseCalibrator(_Calibrator): @@ -192,9 +192,106 @@ def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor: return torch.ones_like(self._initial_amax) * self._global_amax * candidates def _generate_candidates(self, device: torch.device) -> torch.Tensor: - """Generate 126 valid FP8 E4M3 scale candidates.""" + """Generate 126 valid FP8 E4M3 scale candidates. + + Kept in sync with ``fp8_scale_candidates`` in + ``modelopt.torch.kernels.quantization.gemm.nvfp4_fp8_sweep`` — the FP8 E4M3 + spec is fixed, and the parity test exercises both paths against each other. + """ uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) fp8_values = uint8_values.view(torch.float8_e4m3fn).float() valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) fp8_values = fp8_values[valid_mask] return fp8_values / 448.0 + + +class TritonNVFP4MSECalibrator(NVFP4MSECalibrator): + """Triton-fused FP8 scale sweep calibrator for NVFP4 weight MSE. + + Numerically equivalent to :class:`NVFP4MSECalibrator` but evaluates all 126 + candidates in a single fused Triton kernel — one weight read instead of 126. + + Limitation: a single ``collect()`` call is supported per ``compute_amax`` cycle. + This matches the static weight-MSE flow (``mse_calibrate``'s weight loop), where + the calibrator is collected once per weight and immediately consumed. For + activation calibration (multiple ``collect`` calls), use :class:`NVFP4MSECalibrator`. + Call :meth:`reset` to free internal state and re-enable :meth:`collect`. + """ + + def __init__( + self, + amax: torch.Tensor, + global_amax: torch.Tensor, + axis: int | tuple | list | None = None, + quant_func: Callable | None = None, + error_func: Callable | None = None, + ): + """Initialize the Triton-fused NVFP4 MSE calibrator. + + See :class:`NVFP4MSECalibrator`. ``quant_func``/``error_func`` are unused by + the kernel path but accepted for API parity. Tile shape and ``num_warps`` are + autotuned by the kernel per ``N_BLOCKS``. + """ + super().__init__( + amax=amax, + global_amax=global_amax, + axis=axis, + quant_func=quant_func, + error_func=error_func, + ) + # Stash shape metadata so collect() can keep working after reset() releases + # the (potentially large) _initial_amax buffer. + self._initial_amax_shape = tuple(amax.shape) + self._initial_amax_dtype = amax.dtype + self._n_blocks = int(amax.numel()) + self._best_amax: torch.Tensor | None = None + + @torch.no_grad() + def collect(self, x: torch.Tensor): + """Run the fused FP8 sweep kernel and store the resulting per-block amax.""" + from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep + + if self._best_amax is not None: + raise RuntimeError( + "TritonNVFP4MSECalibrator.collect() is one-shot; call reset() to " + "discard the previous result before collecting again." + ) + + x = x.detach() + # The weight quantizer reshapes its input to [n_blocks, block_size] before + # calling collect (see TensorQuantizer._process_for_blockquant). Validate + # via ValueError so the contract still holds under ``python -O``. + if x.ndim != 2: + raise ValueError( + f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}." + ) + block_size = x.shape[-1] + if block_size <= 0: + raise ValueError(f"x.shape[-1] must be positive; got {block_size}.") + n_blocks = x.shape[0] + if n_blocks != self._n_blocks: + raise ValueError( + f"initial amax.numel() ({self._n_blocks}) does not match the number " + f"of NVFP4 blocks in x ({n_blocks})." + ) + + best_amax_flat = nvfp4_fp8_scale_sweep( + x, + self._global_amax, + block_size=block_size, + ) + # Match the original shape/dtype of the initial amax so downstream + # load_calib_amax behaves identically to the reference path. + self._best_amax = best_amax_flat.reshape(self._initial_amax_shape).to( + self._initial_amax_dtype + ) + + @torch.no_grad() + def compute_amax(self, verbose: bool = False): + """Return the per-block amax computed during ``collect``.""" + return self._best_amax + + def reset(self): + """Reset the stored best amax. Subsequent ``collect`` calls are allowed.""" + self._best_amax = None + super().reset() diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 4ce0f62a75..cd86ff1c72 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -16,6 +16,7 @@ """Calibration utilities.""" import math +import os import time import warnings from collections.abc import Callable @@ -37,7 +38,7 @@ from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method -from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator +from .calib import MseCalibrator, NVFP4MSECalibrator, TritonNVFP4MSECalibrator, _Calibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( @@ -354,6 +355,11 @@ def mse_calibrate( weight_quantizers = [] seen_modules = set() + # Triton-fused FP8 sweep is on by default for NVFP4 static quant; set + # MODELOPT_NVFP4_TRITON_SWEEP=0 to fall back to the reference for debugging. + use_triton_fp8_sweep = os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" + nvfp4_calibrator_cls = TritonNVFP4MSECalibrator if use_triton_fp8_sweep else NVFP4MSECalibrator + for name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): @@ -391,8 +397,7 @@ def mse_calibrate( continue if fp8_scale_sweep and is_nvfp4_static: - # Replace calibrator with NVFP4MSECalibrator - module._calibrator = NVFP4MSECalibrator( + module._calibrator = nvfp4_calibrator_cls( amax=initial_amax, axis=module._calibrator._axis, global_amax=module.global_amax, diff --git a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py new file mode 100644 index 0000000000..14d70d007c --- /dev/null +++ b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py @@ -0,0 +1,302 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parity + speedup tests for the fused NVFP4 FP8 scale sweep Triton kernel. + +Compares :class:`TritonNVFP4MSECalibrator` against the reference +:class:`NVFP4MSECalibrator` on the same inputs and asserts the resulting per-block +amax tensors are bit-identical. Also reports a wall-clock speedup number for the +weight-MSE search step on a representative LLM-sized weight. +""" + +import time + +import pytest +import torch +from conftest import requires_triton + +from modelopt.torch.quantization.calib import NVFP4MSECalibrator, TritonNVFP4MSECalibrator +from modelopt.torch.quantization.tensor_quant import static_blockwise_fp4_fake_quant + +BLOCK_SIZE = 16 + + +def _reference_quant_func(global_amax): + """Reference NVFP4 fake-quant matching what ``mse_calibrate`` plumbs in.""" + + def quant_func(x, amax): + return static_blockwise_fp4_fake_quant(x, amax, global_amax) + + return quant_func + + +def _run_reference(x, per_block_amax, global_amax): + cal = NVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + quant_func=_reference_quant_func(global_amax), + ) + cal.collect(x) + return cal.compute_amax() + + +def _run_triton(x, per_block_amax, global_amax): + cal = TritonNVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + quant_func=_reference_quant_func(global_amax), + ) + cal.collect(x) + return cal.compute_amax() + + +@requires_triton +@pytest.mark.parametrize("seed", [0, 1, 2]) +@pytest.mark.parametrize("num_blocks", [4, 64, 1024]) +def test_parity_random_weights(seed, num_blocks): + """Triton sweep must produce the exact same per-block amax as the reference.""" + torch.manual_seed(seed) + device = "cuda" + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref = _run_reference(x, per_block_amax, global_amax) + tri = _run_triton(x, per_block_amax, global_amax) + + assert ref.shape == tri.shape + # Both pick from the same 126-element discrete candidate set, so any disagreement + # would show up as a non-zero diff (not a small float epsilon). Demand exact match. + assert torch.equal(ref, tri), ( + f"Triton sweep diverged from reference: max |diff| = " + f"{(ref - tri).abs().max().item():.3e}, " + f"differing blocks = {(ref != tri).sum().item()} / {num_blocks}" + ) + + +@requires_triton +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_parity_dtypes(dtype): + """Sweep must agree across the dtypes supported by the NVFP4 quantizer.""" + torch.manual_seed(42) + device = "cuda" + num_blocks = 256 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=dtype) + # Promote to fp32 for the per-block amax (matches what max_calibrate produces). + per_block_amax = x.float().abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref = _run_reference(x, per_block_amax, global_amax) + tri = _run_triton(x, per_block_amax, global_amax) + assert torch.equal(ref, tri) + + +@requires_triton +def test_quantized_output_matches(): + """Round-tripping x through the chosen amax should give the same fake-quant result.""" + torch.manual_seed(7) + device = "cuda" + num_blocks = 128 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref_amax = _run_reference(x, per_block_amax, global_amax) + tri_amax = _run_triton(x, per_block_amax, global_amax) + + ref_xq = static_blockwise_fp4_fake_quant(x, ref_amax, global_amax) + tri_xq = static_blockwise_fp4_fake_quant(x, tri_amax, global_amax) + assert torch.equal(ref_xq, tri_xq) + + +@requires_triton +def test_reset_allows_recollect(): + torch.manual_seed(0) + device = "cuda" + num_blocks = 32 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + cal = TritonNVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + ) + cal.collect(x) + first = cal.compute_amax().clone() + + # collect() is one-shot per cycle until reset() is called. + with pytest.raises(RuntimeError, match="one-shot"): + cal.collect(x) + + cal.reset() + # After reset, the same calibrator instance can be re-used. + cal.collect(x) + assert torch.equal(first, cal.compute_amax()) + + +@requires_triton +def test_input_validation(): + """``nvfp4_fp8_scale_sweep`` should reject malformed inputs cleanly.""" + from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep + + device = "cuda" + x = torch.randn(64, BLOCK_SIZE, device=device) + g = x.abs().amax() + + # CPU tensor → ValueError (not bare AssertionError). + with pytest.raises(ValueError, match="CUDA"): + nvfp4_fp8_scale_sweep(x.cpu(), g.cpu()) + + # block_size <= 0. + with pytest.raises(ValueError, match="block_size"): + nvfp4_fp8_scale_sweep(x, g, block_size=0) + with pytest.raises(ValueError, match="block_size"): + nvfp4_fp8_scale_sweep(x, g, block_size=-1) + + # Non-divisible numel. + with pytest.raises(ValueError, match="not divisible"): + nvfp4_fp8_scale_sweep(x, g, block_size=15) + + +@requires_triton +def test_mse_calibrate_dispatch(monkeypatch): + """``mse_calibrate(fp8_scale_sweep=True)`` must install the right calibrator class. + + Default path: ``TritonNVFP4MSECalibrator``. + With ``MODELOPT_NVFP4_TRITON_SWEEP=0``: ``NVFP4MSECalibrator`` (and not its subclass). + """ + from _test_utils.torch.quantization.models import SimpleLinear + + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.extensions import get_cuda_ext_mx + from modelopt.torch.quantization.nn import TensorQuantizer + + if get_cuda_ext_mx() is None: + pytest.skip("cuda_ext_mx is not available") + + cfg = { + "quant_cfg": [ + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + }, + "enable": True, + }, + {"quantizer_name": "*input_quantizer", "enable": False}, + ], + "algorithm": {"method": "mse", "fp8_scale_sweep": True}, + } + + def _quantize_and_get_weight_calibrators(model): + calib_data = [model.get_input().cuda() for _ in range(2)] + + def forward_loop(m): + for batch in calib_data: + m(batch) + + mtq.quantize(model, cfg, forward_loop=forward_loop) + return [ + type(m._calibrator) + for name, m in model.named_modules() + if isinstance(m, TensorQuantizer) + and name.endswith("weight_quantizer") + and getattr(m, "_calibrator", None) is not None + ] + + # Default: triton path. + monkeypatch.delenv("MODELOPT_NVFP4_TRITON_SWEEP", raising=False) + types_default = _quantize_and_get_weight_calibrators(SimpleLinear().cuda()) + assert types_default, "expected at least one weight quantizer with a calibrator" + assert all(t is TritonNVFP4MSECalibrator for t in types_default), types_default + + # Opt-out: reference path, exact class match (TritonNVFP4MSECalibrator is a subclass). + monkeypatch.setenv("MODELOPT_NVFP4_TRITON_SWEEP", "0") + types_optout = _quantize_and_get_weight_calibrators(SimpleLinear().cuda()) + assert types_optout, "expected at least one weight quantizer with a calibrator" + assert all(t is NVFP4MSECalibrator for t in types_optout), types_optout + + +def _bench(fn, warmup=2, iters=5): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(iters): + fn() + torch.cuda.synchronize() + return (time.perf_counter() - t0) / iters + + +@requires_triton +def test_speedup_report(capsys): + """Sanity-check that the Triton path is meaningfully faster on a realistic weight. + + Uses an 8192 x 4096 weight (~33M elements, ~2M NVFP4 blocks) — roughly the size + of an LLM attention/MLP projection. Reports the speedup; does not gate on a + minimum factor (kernel timing is noisy on shared CI), but does require parity + on the chosen amax. + """ + torch.manual_seed(123) + device = "cuda" + cout, cin = 8192, 4096 + x = torch.randn(cout, cin // BLOCK_SIZE, BLOCK_SIZE, device=device, dtype=torch.float32) + x = x.reshape(-1, BLOCK_SIZE) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref_amax = _run_reference(x, per_block_amax, global_amax) + tri_amax = _run_triton(x, per_block_amax, global_amax) + # Bit-equality across millions of blocks isn't guaranteed: when two adjacent FP8 + # candidates yield near-identical per-block MSE (within fp32 noise), the reference's + # CUDA fake_e4m3fy path and our Triton inline math can break ties differently. Demand + # instead that the Triton choice produces a per-block MSE within fp32 epsilon of the + # reference's choice. + n_blocks = ref_amax.numel() + n_diff = int((ref_amax != tri_amax).sum()) + if n_diff: + ref_xq = static_blockwise_fp4_fake_quant(x, ref_amax, global_amax) + tri_xq = static_blockwise_fp4_fake_quant(x, tri_amax, global_amax) + per_block_mse_ref = (x - ref_xq).pow(2).sum(dim=-1) + per_block_mse_tri = (x - tri_xq).pow(2).sum(dim=-1) + # Reference is the formal argmin, so triton's loss should be ≥ reference's. + # Allow at most 1e-5 relative gap on differing blocks (observed ~1e-7 in practice). + rel_gap = (per_block_mse_tri - per_block_mse_ref).abs() / per_block_mse_ref.clamp_min(1e-12) + worst = rel_gap.max().item() + assert worst < 1e-5, ( + f"{n_diff}/{n_blocks} blocks disagree with worst relative MSE gap {worst:.3e} " + "— exceeds tie-break tolerance" + ) + + ref_t = _bench(lambda: _run_reference(x, per_block_amax, global_amax)) + tri_t = _bench(lambda: _run_triton(x, per_block_amax, global_amax)) + speedup = ref_t / tri_t + + # Force-print regardless of pytest capture mode. + with capsys.disabled(): + n_blocks = x.numel() // BLOCK_SIZE + print( + f"\n[NVFP4 FP8 sweep] weight=({cout},{cin}) " + f"n_blocks={n_blocks} block_size={BLOCK_SIZE}\n" + f" reference NVFP4MSECalibrator: {ref_t * 1e3:8.2f} ms\n" + f" triton TritonNVFP4MSECalibrator: {tri_t * 1e3:8.2f} ms\n" + f" speedup: {speedup:.1f}x" + )