From 9a19dce29c9f231fcdc1370095c502b2cee3a93d Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 6 Jun 2025 01:39:00 +0000 Subject: [PATCH 1/6] Add FP8 current scaling to te.Sequential tests Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 243 ++++++++++-------- tests/pytorch/utils.py | 22 ++ .../pytorch/ops/basic/basic_linear.py | 28 +- 3 files changed, 179 insertions(+), 114 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index a0c0ee5faa..2caeab1a1c 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -7,6 +7,8 @@ from collections.abc import Iterable import io import math +import pathlib +import sys from typing import Optional import pytest @@ -24,10 +26,16 @@ ForwardLinearBiasAdd, ) from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8CurrentScalingQuantizer, Float8Quantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex +# Import utility functions +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent)) +from utils import dtype_tols, make_recipe + # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -40,6 +48,13 @@ # Supported devices _devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")] +# Supported quantization recipes +_quantization_schemes: list[Optional[str]] = [None] +if fp8_available: + _quantization_schemes.extend(("fp8_delayed_scaling", "fp8_current_scaling")) +if mxfp8_available: + _quantization_schemes.append("mxfp8") + def maybe_skip_quantization( quantization: Optional[str], @@ -53,7 +68,10 @@ def maybe_skip_quantization( return # Check if quantization scheme is supported - if quantization == "fp8" and not fp8_available: + if ( + quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") + and not fp8_available + ): pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) @@ -61,7 +79,7 @@ def maybe_skip_quantization( if dims is not None: if not isinstance(dims, Iterable): dims = (dims,) - if quantization == "fp8": + if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"): if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: pytest.skip("FP8 GEMMs require dims that are divisible by 16") elif quantization == "mxfp8": @@ -73,39 +91,6 @@ def maybe_skip_quantization( pytest.skip("Quantization is only supported on CUDA devices") -def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: - """Estimated numerical error for a datatype - - Based on tolerances for torch.testing.assert_close. - - """ - - # Transformer Engine dtypes - if isinstance(dtype, tex.DType): - if dtype == tex.DType.kFloat8E4M3: - return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 - if dtype == tex.DType.kFloat8E5M2: - return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 - dtype = { - tex.DType.kByte: torch.uint8, - tex.DType.kInt32: torch.int32, - tex.DType.kFloat32: torch.float32, - tex.DType.kFloat16: torch.half, - tex.DType.kBFloat16: torch.bfloat16, - }[dtype] - - # PyTorch dtypes - if dtype == torch.float16: - return dict(rtol=1e-3, atol=1e-5) - if dtype == torch.bfloat16: - return dict(rtol=1.6e-2, atol=1e-5) - if dtype == torch.float32: - return dict(rtol=1.3e-6, atol=1e-5) - if dtype == torch.float64: - return dict(rtol=1e-7, atol=1e-7) - raise ValueError(f"Unsupported dtype ({dtype})") - - @torch.no_grad() def make_reference_and_test_tensors( shape: int | Iterable[int], @@ -113,7 +98,7 @@ def make_reference_and_test_tensors( ref_device: torch.device = "cpu", test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", - test_is_fp8: bool = False, + test_quantization: Optional[str] = None, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -123,38 +108,40 @@ def make_reference_and_test_tensors( in Transformer Engine operations. """ + + # Random reference tensor ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) + + # Construct test tensor from reference tensor test = ref.to(device=test_device, dtype=test_dtype) - if test_is_fp8: + if test_quantization is None: + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + elif test_quantization in ("fp8", "fp8_delayed_scaling"): quantizer = Float8Quantizer( scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), fp8_dtype=tex.DType.kFloat8E4M3, ) test = quantizer(test) - elif test.data_ptr() == ref.data_ptr(): - test = test.clone() + elif test_quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, device=test_device, + ) + test = quantizer(test) + elif test_quantization == "mxfp8": + test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + else: + raise ValueError(f"Unsupported quantization scheme ({test_quantization})") + + # Make sure reference and test tensors match each other ref.copy_(test) + ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) return ref, test -def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: - """Make recipe for quantization scheme""" - if name is None: - return None - if name == "fp8": - return transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - if name == "mxfp8": - return transformer_engine.common.recipe.MXFP8BlockScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - raise ValueError(f"Unsupported quantization scheme ({name})") - - class TestSequential: """Tests for sequential container""" @@ -364,7 +351,7 @@ def test_fp8_scale_update( @pytest.mark.parametrize("init_dtype", _dtypes) @pytest.mark.parametrize("final_dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_schemes) def test_dtype_cast( self, *, @@ -390,7 +377,7 @@ def test_dtype_cast( (size, size), test_dtype=dtype, test_device=device, - test_is_fp8=with_quantization, + test_quantization=quantization, ) # Construct operation @@ -412,7 +399,7 @@ def test_dtype_cast( assert isinstance(op.weight, QuantizedTensor) == with_quantization assert op.weight.dtype == final_dtype w_test = op.weight.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(w_test, w_ref, rtol=0, atol=0) + torch.testing.assert_close(w_test, w_ref, **dtype_tols(dtype)) # Check forward and backward pass x = torch.zeros( @@ -429,7 +416,7 @@ def test_dtype_cast( @pytest.mark.parametrize("model_dtype", _dtypes) @pytest.mark.parametrize("autocast_dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_schemes) def test_pyt_autocast( self, *, @@ -513,7 +500,7 @@ def test_identity( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_quantization="fp8_current_scaling" if fp8 else None, ) dy_ref, dy_test = make_reference_and_test_tensors( in_shape, @@ -579,7 +566,7 @@ def test_reshape( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_quantization="fp8_current_scaling" if fp8 else None, ) x_test = x_test.contiguous(memory_format=memory_format) x_test = x_test.detach().requires_grad_() @@ -643,7 +630,7 @@ def test_bias( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_quantization="fp8_current_scaling" if fp8 else None, ) b_ref, b_test = make_reference_and_test_tensors( size, @@ -678,7 +665,7 @@ def test_bias( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols) - @pytest.mark.parametrize("quantization", ("fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_schemes) @pytest.mark.parametrize("cast_forward", (False, True)) @pytest.mark.parametrize("cast_backward", (False, True)) def test_quantize( @@ -694,6 +681,7 @@ def test_quantize( """Quantize""" # Skip invalid configurations + with_quantization = quantization is not None maybe_skip_quantization(quantization) # Random data @@ -701,18 +689,21 @@ def test_quantize( in_shape, test_dtype=dtype, test_device=device, - requires_grad=False, - test_is_fp8=True, + test_quantization=quantization, + requires_grad=True, ) - x_test = x_test.dequantize().requires_grad_() + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, + test_quantization=quantization, requires_grad=False, - test_is_fp8=True, ) - dy_test = dy_test.dequantize() + if isinstance(dy_test, QuantizedTensor): + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = x_ref @@ -721,13 +712,14 @@ def test_quantize( # Implementation with fusible operation op = te_ops.Quantize(forward=cast_forward, backward=cast_backward) recipe = make_recipe(quantization) - with te.fp8_autocast(fp8_recipe=recipe): + with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) # Check tensor types - assert isinstance(y_test, QuantizedTensor) == cast_forward - assert isinstance(x_test.grad, QuantizedTensor) == cast_backward + if with_quantization: + assert isinstance(y_test, QuantizedTensor) == cast_forward + assert isinstance(x_test.grad, QuantizedTensor) == cast_backward # Check values tols = dict(rtol=0, atol=0) @@ -762,10 +754,23 @@ def _test_basic_linear( # Skip invalid configurations maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=out_shape) - if quantization == "fp8" and quantized_output and not quantized_compute: - pytest.skip("FP8 output is only supported with FP8 GEMMs") - if quantization == "fp8" and quantized_grad_input and not quantized_compute: - pytest.skip("FP8 grad input is only supported with FP8 GEMMs") + quantization_needed = any(( + quantized_compute, + quantized_input, + quantized_weight, + quantized_output, + quantized_grad_output, + quantized_grad_input, + )) + if quantization is None and quantization_needed: + pytest.skip("Quantization scheme is not specified") + if quantization is not None and not quantization_needed: + pytest.skip("Quantization scheme is not used") + if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"): + if quantized_output and not quantized_compute: + pytest.skip("FP8 output is only supported with FP8 GEMMs") + if quantized_grad_input and not quantized_compute: + pytest.skip("FP8 grad input is only supported with FP8 GEMMs") if quantization == "mxfp8" and quantized_output: pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs") if quantization == "mxfp8" and quantized_grad_input: @@ -776,25 +781,25 @@ def _test_basic_linear( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_input), + test_quantization=quantization, ) - if isinstance(x_test, QuantizedTensor): + if isinstance(x_test, QuantizedTensor) and not quantized_input: with torch.no_grad(): x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), + test_quantization=quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_grad_output), + test_quantization=quantization, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): + if isinstance(dy_test, QuantizedTensor) and not quantized_grad_output: dy_test = dy_test.dequantize() # Plain PyTorch implementation @@ -858,7 +863,7 @@ def _test_basic_linear( @pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5))) @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_schemes) @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) def test_basic_linear( self, @@ -880,7 +885,7 @@ def test_basic_linear( ) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - @pytest.mark.parametrize("quantization", ("fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_schemes) @pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_input", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True)) @@ -899,6 +904,8 @@ def test_basic_linear_quantized( quantized_grad_input: bool, ) -> None: """GEMM with FP8 inputs and outputs""" + if quantization is None: + pytest.skip("Skipping case without quantization") self._test_basic_linear( dtype=torch.bfloat16, quantization=quantization, @@ -911,7 +918,8 @@ def test_basic_linear_quantized( ) @pytest.mark.parametrize("bias", (False, True)) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("input_requires_grad", (False, True)) @pytest.mark.parametrize("weight_requires_grad", (False, True)) @@ -924,6 +932,7 @@ def test_linear( dtype: torch.dtype = torch.float32, device: torch.device = "cuda", quantization: Optional[str], + quantized_compute: bool, quantized_weight: bool, input_requires_grad: bool, weight_requires_grad: bool, @@ -936,16 +945,19 @@ def test_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - quantized_compute = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not specified") + if quantization is not None and not (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not used") # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, + test_quantization=quantization, ) with torch.no_grad(): if isinstance(x_test, QuantizedTensor): @@ -955,7 +967,7 @@ def test_linear( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), + test_quantization=quantization, ) b_ref, b_test = None, None if bias: @@ -968,8 +980,11 @@ def test_linear( out_shape, test_dtype=dtype, test_device=device, + test_quantization=quantization, requires_grad=False, ) + if isinstance(dy_test, QuantizedTensor): + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref) @@ -1022,7 +1037,7 @@ def test_linear( @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("zero_centered_gamma", (False, True)) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_schemes) def test_layer_norm( self, *, @@ -1192,7 +1207,7 @@ def test_layer_norm_autocast( @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("zero_centered_gamma", (False, True)) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_schemes) def test_rmsnorm( self, *, @@ -1301,13 +1316,13 @@ def test_add_in_place( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_quantization="fp8_current_scaling" if fp8 else None, ) x2_ref, x2_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_quantization="fp8_current_scaling" if fp8 else None, ) dy_ref, dy_test = make_reference_and_test_tensors( in_shape, @@ -1366,7 +1381,7 @@ def test_make_extra_output( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_quantization="fp8_current_scaling" if fp8 else None, ) dy1_ref, dy1_test = make_reference_and_test_tensors( in_shape, @@ -1403,7 +1418,7 @@ def test_make_extra_output( @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_schemes) @pytest.mark.parametrize("cache_quantized_input", (False, True)) def test_activation( self, @@ -1426,26 +1441,24 @@ def test_activation( quantized_compute = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device) if cache_quantized_input: - maybe_skip_quantization("fp8", device=device) + maybe_skip_quantization("fp8_current_scaling", device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, + test_quantization="fp8_current_scaling" if cache_quantized_input else None, ) + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, requires_grad=False, ) - if quantized_compute: - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref: torch.Tensor @@ -1488,8 +1501,6 @@ def test_activation( tols = dtype_tols(dtype) if quantized_compute or cache_quantized_input: tols = dtype_tols(tex.DType.kFloat8E4M3) - if activation == "relu" and not cache_quantized_input: - tols = {"atol": 0, "rtol": 0} # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1498,7 +1509,7 @@ def test_activation( torch.testing.assert_close(dx_test, x_ref.grad, **tols) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_schemes) @pytest.mark.parametrize("quantize_forward", (False, True)) @pytest.mark.parametrize("quantize_backward", (False, True)) def test_swiglu( @@ -1576,7 +1587,7 @@ def setup_class(cls) -> None: @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5))) @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_schemes) @pytest.mark.parametrize("quantized_weight", (False, True)) def test_forward_linear_bias_activation( self, @@ -1610,16 +1621,16 @@ def test_forward_linear_bias_activation( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, + test_quantization=quantization, ) - if quantized_compute: + if isinstance(x_test, QuantizedTensor): with torch.no_grad(): x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), + test_quantization=quantization, ) b_ref, b_test = None, None if bias: @@ -1632,8 +1643,11 @@ def test_forward_linear_bias_activation( out_shape, test_dtype=dtype, test_device=device, + test_quantization=quantization, requires_grad=False, ) + if isinstance(dy_test, QuantizedTensor): + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref) @@ -1686,7 +1700,7 @@ def test_forward_linear_bias_activation( @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_schemes) def test_forward_linear_bias_add( self, *, @@ -1717,7 +1731,7 @@ def test_forward_linear_bias_add( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, + test_quantization=quantization, ) if isinstance(x1_test, QuantizedTensor): with torch.no_grad(): @@ -1726,7 +1740,7 @@ def test_forward_linear_bias_add( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), + test_quantization=quantization, ) b_ref, b_test = None, None if bias: @@ -1744,6 +1758,7 @@ def test_forward_linear_bias_add( out_shape, test_dtype=dtype, test_device=device, + test_quantization=quantization, requires_grad=False, ) @@ -1800,7 +1815,7 @@ def test_forward_linear_bias_add( torch.testing.assert_close(db_test, b_ref.grad, **tols) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_schemes) def test_backward_linear_add( self, *, @@ -1830,7 +1845,7 @@ def test_backward_linear_add( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, + test_quantization=quantization, ) if isinstance(x_test, QuantizedTensor): with torch.no_grad(): @@ -1839,18 +1854,20 @@ def test_backward_linear_add( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), + test_quantization=quantization, ) dy1_ref, dy1_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, + test_quantization=quantization, requires_grad=False, ) dy2_ref, dy2_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, + test_quantization=quantization, requires_grad=False, ) @@ -1912,7 +1929,7 @@ def setup_class(cls) -> None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_schemes) @pytest.mark.parametrize("quantized_weight", (False, True)) def test_linear( self, diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 450c24da33..f4a8ce69c6 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -7,6 +7,7 @@ import torch import transformer_engine +import transformer_engine.common.recipe import transformer_engine.pytorch as te import transformer_engine_torch as tex @@ -83,3 +84,24 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: if dtype == torch.float8_e5m2: return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 raise ValueError(f"Unsupported dtype ({dtype})") + + +def make_recipe(name: Optional[str]) -> Optional[Recipe]: + """Make recipe for quantization scheme""" + if name is None: + return None + if name in ("fp8", "fp8_delayed_scaling"): + return transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "fp8_current_scaling": + return transformer_engine.common.recipe.Float8CurrentScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "mxfp8": + return transformer_engine.common.recipe.MXFP8BlockScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "fp8_block_scaling": + return transformer_engine.common.recipe.Float8BlockScaling() + raise ValueError(f"Unsupported quantization scheme ({name})") diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 0e786ca96f..008025496d 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -22,7 +22,7 @@ from ...fp8 import FP8GlobalStateManager from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...tensor import Quantizer, QuantizedTensor -from ...tensor.float8_tensor import Float8Quantizer +from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase @@ -324,6 +324,32 @@ def pre_forward(self, *args, **kwargs) -> None: weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + # Recipe-specific configuration + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + if any( + not isinstance(q, Float8CurrentScalingQuantizer) + for q in (input_quantizer, weight_quantizer, grad_output_quantizer) + ): + raise RuntimeError( + "FP8 current-scaling recipe is enabled, " + f"but input quantizer is {input_quantizer.__class__.__name__}, " + f"weight quantizer is {weight_quantizer.__class__.__name__}, " + f"grad output quantizer is {grad_output_quantizer.__class__.__name__}" + ) + input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + if self.sequence_parallel and self.tensor_parallel_mode == "column": + input_quantizer.with_amax_reduction = True + input_quantizer.amax_reduction_group = self.tensor_parallel_group + if self.sequence_parallel and self.tensor_parallel_mode == "row": + grad_output_quantizer.with_amax_reduction = True + grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + # Make sure weight tensor has correct quantizer # Note: Quantizer might have changed if quantization # recipe changed From d5d65de4bc34732cb5d5b4ea8eafeda711f7bb6c Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 6 Jun 2025 02:15:53 +0000 Subject: [PATCH 2/6] Helper function for test/ref tensors does not produce quantized tensor by default Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 181 ++++++++++++++---------------- 1 file changed, 85 insertions(+), 96 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 2caeab1a1c..12bd00651b 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -62,6 +62,7 @@ def maybe_skip_quantization( dims: Optional[Iterable[int] | int] = None, device: Optional[torch.device | str] = None, ) -> None: + """Skip test case if a quantization scheme is not supported""" # Don't skip if there is no quantization if quantization is None: @@ -94,11 +95,12 @@ def maybe_skip_quantization( @torch.no_grad() def make_reference_and_test_tensors( shape: int | Iterable[int], + quantization: Optional[str] = None, ref_dtype: torch.dtype = torch.float64, ref_device: torch.device = "cpu", test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", - test_quantization: Optional[str] = None, + test_is_quantized: bool = False, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -107,6 +109,9 @@ def make_reference_and_test_tensors( operations in high precision. The test tensor is intended for use in Transformer Engine operations. + If a quantization scheme is provided, the tensor values are + quantized so that they are representable. + """ # Random reference tensor @@ -114,25 +119,29 @@ def make_reference_and_test_tensors( # Construct test tensor from reference tensor test = ref.to(device=test_device, dtype=test_dtype) - if test_quantization is None: + if quantization is None: + if test_is_quantized: + raise ValueError("Quantization scheme not provided") if test.data_ptr() == ref.data_ptr(): test = test.clone() - elif test_quantization in ("fp8", "fp8_delayed_scaling"): + elif quantization in ("fp8", "fp8_delayed_scaling"): quantizer = Float8Quantizer( scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), fp8_dtype=tex.DType.kFloat8E4M3, ) test = quantizer(test) - elif test_quantization == "fp8_current_scaling": + elif quantization == "fp8_current_scaling": quantizer = Float8CurrentScalingQuantizer( fp8_dtype=tex.DType.kFloat8E4M3, device=test_device, ) test = quantizer(test) - elif test_quantization == "mxfp8": + elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) else: - raise ValueError(f"Unsupported quantization scheme ({test_quantization})") + raise ValueError(f"Unsupported quantization scheme ({quantization})") + if isinstance(test, QuantizedTensor) and not test_is_quantized: + test = test.dequantize() # Make sure reference and test tensors match each other ref.copy_(test) @@ -364,8 +373,8 @@ def test_dtype_cast( """Check dtype cast functions""" # Skip invalid configurations - maybe_skip_quantization(quantization, device=device) with_quantization = quantization is not None + maybe_skip_quantization(quantization, device=device) # Random data dtype = torch.float32 @@ -375,9 +384,9 @@ def test_dtype_cast( dtype = torch.bfloat16 w_ref, w_test = make_reference_and_test_tensors( (size, size), + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, ) # Construct operation @@ -432,7 +441,7 @@ def test_pyt_autocast( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization) + maybe_skip_quantization(quantization, device=device) # Construct operation recipe = make_recipe(quantization) @@ -479,33 +488,34 @@ def setup_class(cls) -> None: @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", _quantization_schemes) def test_identity( self, *, in_shape: Iterable[int] = (1,), dtype: torch.dtype, device: torch.device, - fp8: bool, + quantization: Optional[str], ) -> None: # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + with_quantization = quantization is not None + maybe_skip_quantization(quantization, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization="fp8_current_scaling" if fp8 else None, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -541,7 +551,7 @@ def test_identity( ), ) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8_current_scaling")) def test_reshape( self, *, @@ -549,31 +559,32 @@ def test_reshape( dtype: torch.dtype, device: torch.device = "cuda", memory_format: torch.memory_format = torch.contiguous_format, - fp8: bool, + quantization: Optional[str], ) -> None: in_shape, out_shape = shapes # Skip invalid configurations if memory_format == torch.channels_last and len(in_shape) != 4: pytest.skip("torch.channels_last only supports 4D tensors") - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + maybe_skip_quantization(quantization, device=device) + with_quantization = quantization is not None # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization="fp8_current_scaling" if fp8 else None, + test_is_quantized=with_quantization, ) x_test = x_test.contiguous(memory_format=memory_format) x_test = x_test.detach().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( x_ref.reshape(out_shape).size(), + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -605,7 +616,7 @@ def test_reshape( @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (2, 3, 4, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", _devices) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", _quantization_schemes) def test_bias( self, *, @@ -613,24 +624,23 @@ def test_bias( in_shape: Iterable[int], dtype: torch.dtype, device: torch.device, - fp8: bool, + quantization: Optional[str], ) -> None: # Make input and bias shapes consistent in_shape = list(in_shape)[:-1] + [size] # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + with_quantization = quantization is not None + maybe_skip_quantization(quantization, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization="fp8_current_scaling" if fp8 else None, + test_is_quantized=with_quantization, ) b_ref, b_test = make_reference_and_test_tensors( size, @@ -639,8 +649,10 @@ def test_bias( ) dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -682,28 +694,23 @@ def test_quantize( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization) + maybe_skip_quantization(quantization, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, requires_grad=True, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = x_ref @@ -779,28 +786,25 @@ def _test_basic_linear( # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, + test_is_quantized=quantized_input, ) - if isinstance(x_test, QuantizedTensor) and not quantized_input: - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, + test_is_quantized=quantized_grad_output, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor) and not quantized_grad_output: - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -955,19 +959,15 @@ def test_linear( # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, ) - with torch.no_grad(): - if isinstance(x_test, QuantizedTensor): - x_test = x_test.dequantize() - x_test.requires_grad_(requires_grad=input_requires_grad) w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, ) b_ref, b_test = None, None if bias: @@ -978,13 +978,11 @@ def test_linear( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref) @@ -1290,14 +1288,14 @@ def test_rmsnorm( @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", _quantization_schemes) def test_add_in_place( self, *, in_shape: Iterable[int] = (1,), dtype: torch.dtype, device: torch.device, - fp8: bool, + quantization: Optional[str], ) -> None: """Add two tensors @@ -1306,28 +1304,30 @@ def test_add_in_place( """ # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + with_quantization = quantization is not None + maybe_skip_quantization(quantization, device=device) # Random data x1_ref, x1_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization="fp8_current_scaling" if fp8 else None, + test_is_quantized=with_quantization, ) x2_ref, x2_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization="fp8_current_scaling" if fp8 else None, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -1344,7 +1344,7 @@ def test_add_in_place( # Check results tols = dtype_tols(dtype) - if fp8: + if with_quantization: tols = dtype_tols(x1_test._fp8_dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") @@ -1355,14 +1355,14 @@ def test_add_in_place( @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", _quantization_schemes) def test_make_extra_output( self, *, in_shape: Iterable[int] = (1,), dtype: torch.dtype, device: torch.device, - fp8: bool, + quantization: Optional[str], ) -> None: """Output tensor twice @@ -1371,28 +1371,31 @@ def test_make_extra_output( """ # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + with_quantization = quantization is not None + maybe_skip_quantization(quantization, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization="fp8_current_scaling" if fp8 else None, + test_is_quantized=with_quantization, ) dy1_ref, dy1_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) dy2_ref, dy2_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -1446,13 +1449,10 @@ def test_activation( # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization="fp8_current_scaling" if cache_quantized_input else None, test_dtype=dtype, test_device=device, - test_quantization="fp8_current_scaling" if cache_quantized_input else None, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, @@ -1619,18 +1619,15 @@ def test_forward_linear_bias_activation( # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, ) b_ref, b_test = None, None if bias: @@ -1641,13 +1638,11 @@ def test_forward_linear_bias_activation( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref) @@ -1729,18 +1724,15 @@ def test_forward_linear_bias_add( # Random data x1_ref, x1_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, ) - if isinstance(x1_test, QuantizedTensor): - with torch.no_grad(): - x1_test = x1_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, ) b_ref, b_test = None, None if bias: @@ -1756,9 +1748,9 @@ def test_forward_linear_bias_add( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, requires_grad=False, ) @@ -1843,31 +1835,28 @@ def test_backward_linear_add( # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, ) dy1_ref, dy1_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, requires_grad=False, ) dy2_ref, dy2_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_quantization=quantization, requires_grad=False, ) From 3c06f31e6547a8e6969ab2b71b669f568d028b62 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 6 Jun 2025 21:26:03 +0000 Subject: [PATCH 3/6] Add FP8 current scaling to distributed te.Sequential tests Signed-off-by: Tim Moon --- tests/pytorch/distributed/test_fusible_ops.py | 160 ++++++++---------- .../test_fusible_ops_with_userbuffers.py | 17 +- tests/pytorch/test_fusible_ops.py | 42 ++--- transformer_engine/pytorch/distributed.py | 2 +- 4 files changed, 96 insertions(+), 125 deletions(-) diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 472d20c508..fc7abb96c8 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -22,19 +22,24 @@ import transformer_engine.pytorch as te from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex +# Import utility functions +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import dtype_tols, make_recipe + # Check what quantization schemes are supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() quantization_list: list[Optional[str]] = [None] if fp8_available: - quantization_list.append("fp8") + quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: quantization_list.append("mxfp8") @@ -63,11 +68,12 @@ def reset_rng(seed: int = 1234) -> None: @torch.no_grad() def make_reference_and_test_tensors( shape: int | Iterable[int], + quantization: Optional[str] = None, ref_dtype: torch.dtype = torch.float64, ref_device: torch.device = "cpu", test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", - test_is_fp8: bool = False, + test_is_quantized: bool = False, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -76,78 +82,54 @@ def make_reference_and_test_tensors( operations in high precision. The test tensor is intended for use in Transformer Engine operations. + If a quantization scheme is provided, the tensor values are + quantized so that they are representable. + """ + + # Random reference tensor ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) + + # Construct test tensor from reference tensor test = ref.to(device=test_device, dtype=test_dtype) - if test_is_fp8: + if quantization is None: + if test_is_quantized: + raise ValueError("Quantization scheme not provided") + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + elif quantization in ("fp8", "fp8_delayed_scaling"): quantizer = Float8Quantizer( - scale=torch.ones(1, dtype=torch.float32, device=test_device), + scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), fp8_dtype=tex.DType.kFloat8E4M3, ) test = quantizer(test) - elif test.data_ptr() == ref.data_ptr(): - test = test.clone() + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, device=test_device, + ) + test = quantizer(test) + elif quantization == "mxfp8": + test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + else: + raise ValueError(f"Unsupported quantization scheme ({quantization})") + if isinstance(test, QuantizedTensor) and not test_is_quantized: + test = test.dequantize() + + # Make sure reference and test tensors match each other ref.copy_(test) + ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) return ref, test -def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: - """Estimated numerical error for a datatype - - Based on tolerances for torch.testing.assert_close. - - """ - - # Transformer Engine dtypes - if isinstance(dtype, tex.DType): - if dtype == tex.DType.kFloat8E4M3: - return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 - if dtype == tex.DType.kFloat8E5M2: - return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 - dtype = { - tex.DType.kByte: torch.uint8, - tex.DType.kInt32: torch.int32, - tex.DType.kFloat32: torch.float32, - tex.DType.kFloat16: torch.half, - tex.DType.kBFloat16: torch.bfloat16, - }[dtype] - - # PyTorch dtypes - if dtype == torch.float16: - return dict(rtol=1e-3, atol=1e-5) - if dtype == torch.bfloat16: - return dict(rtol=1.6e-2, atol=1e-5) - if dtype == torch.float32: - return dict(rtol=1.3e-6, atol=1e-5) - if dtype == torch.float64: - return dict(rtol=1e-7, atol=1e-7) - raise ValueError(f"Unsupported dtype ({dtype})") - - -def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: - """Make recipe for quantization scheme""" - if name is None: - return None - if name == "fp8": - return transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - if name == "mxfp8": - return transformer_engine.common.recipe.MXFP8BlockScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - raise ValueError(f"Unsupported quantization scheme ({name})") - - def _test_all_reduce( *, local_size: int = 17, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8: bool = False, + quantization: Optional[str] = None, ) -> None: # Distributed process group @@ -161,17 +143,20 @@ def _test_all_reduce( # Random data reset_rng() + with_quantization = quantization is not None x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) # Plain PyTorch implementation @@ -202,7 +187,7 @@ def _test_all_gather( local_size: int = 13, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8: bool = False, + quantization: Optional[str] = None, ) -> None: # Distributed process group @@ -216,17 +201,20 @@ def _test_all_gather( # Random data reset_rng() + with_quantization = quantization is not None x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) # Plain PyTorch implementation @@ -260,7 +248,7 @@ def _test_reduce_scatter( local_size: int = 11, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8: bool = False, + quantization: Optional[str] = None, ) -> None: # Distributed process group @@ -274,17 +262,20 @@ def _test_reduce_scatter( # Random data reset_rng() + with_quantization = quantization is not None x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) # Plain PyTorch implementation @@ -324,7 +315,11 @@ def _test_basic_linear( tensor_parallel_mode: str = "column", sequence_parallel: bool = False, ) -> None: + + # Skip invalid configurations quantized_compute = quantization is not None + if not quantized_compute and quantized_weight: + return # Distributed process group process_group = world_group() @@ -348,30 +343,23 @@ def _test_basic_linear( reset_rng() x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) - if isinstance(w_test, QuantizedTensor): - w_test = w_test.dequantize() dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -468,7 +456,11 @@ def _test_linear( tensor_parallel_mode: str = "column", sequence_parallel: bool = False, ) -> None: + + # Skip invalid configurations quantized_compute = quantization is not None + if not quantized_compute and quantized_weight: + return # Distributed process group process_group = world_group() @@ -492,21 +484,16 @@ def _test_linear( reset_rng() x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) - if isinstance(w_test, QuantizedTensor): - w_test = w_test.dequantize() b_ref, b_test = None, None if bias: if tensor_parallel_mode == "row": @@ -520,13 +507,11 @@ def _test_linear( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -773,9 +758,10 @@ def run_parallel_tests() -> None: if rank == 0: print(f"Running _test_all_reduce") _test_all_reduce() - if rank == 0: - print(f"Running _test_all_gather") - _test_all_gather() + for quantization in quantization_list: + if rank == 0: + print(f"Running _test_all_gather with quantization={quantization}") + _test_all_gather(quantization=quantization) if rank == 0: print(f"Running _test_reduce_scatter") _test_reduce_scatter() diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 42070ea0f4..4862440110 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -33,7 +33,7 @@ # Import utility functions _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) -from utils import dtype_tols, str_to_dtype +from utils import dtype_tols, make_recipe, str_to_dtype # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -157,21 +157,6 @@ def make_reference_and_test_tensors( return ref, test -def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: - """Make recipe for quantization scheme""" - if name is None: - return None - if name == "fp8": - return transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - if name == "mxfp8": - return transformer_engine.common.recipe.MXFP8BlockScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - raise ValueError(f"Unsupported quantization scheme ({name})") - - def _test_linear( *, model_config: ModelConfig, diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 12bd00651b..1ad7755f88 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -49,11 +49,11 @@ _devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")] # Supported quantization recipes -_quantization_schemes: list[Optional[str]] = [None] +_quantization_list: list[Optional[str]] = [None] if fp8_available: - _quantization_schemes.extend(("fp8_delayed_scaling", "fp8_current_scaling")) + _quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: - _quantization_schemes.append("mxfp8") + _quantization_list.append("mxfp8") def maybe_skip_quantization( @@ -360,7 +360,7 @@ def test_fp8_scale_update( @pytest.mark.parametrize("init_dtype", _dtypes) @pytest.mark.parametrize("final_dtype", _dtypes) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) def test_dtype_cast( self, *, @@ -425,7 +425,7 @@ def test_dtype_cast( @pytest.mark.parametrize("model_dtype", _dtypes) @pytest.mark.parametrize("autocast_dtype", _dtypes) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) def test_pyt_autocast( self, *, @@ -488,7 +488,7 @@ def setup_class(cls) -> None: @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) def test_identity( self, *, @@ -616,7 +616,7 @@ def test_reshape( @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (2, 3, 4, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", _devices) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) def test_bias( self, *, @@ -677,7 +677,7 @@ def test_bias( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("cast_forward", (False, True)) @pytest.mark.parametrize("cast_backward", (False, True)) def test_quantize( @@ -867,7 +867,7 @@ def _test_basic_linear( @pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5))) @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) def test_basic_linear( self, @@ -889,7 +889,7 @@ def test_basic_linear( ) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_input", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True)) @@ -922,7 +922,7 @@ def test_basic_linear_quantized( ) @pytest.mark.parametrize("bias", (False, True)) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("input_requires_grad", (False, True)) @@ -1035,7 +1035,7 @@ def test_linear( @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("zero_centered_gamma", (False, True)) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) def test_layer_norm( self, *, @@ -1205,7 +1205,7 @@ def test_layer_norm_autocast( @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("zero_centered_gamma", (False, True)) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) def test_rmsnorm( self, *, @@ -1288,7 +1288,7 @@ def test_rmsnorm( @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) def test_add_in_place( self, *, @@ -1355,7 +1355,7 @@ def test_add_in_place( @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) def test_make_extra_output( self, *, @@ -1421,7 +1421,7 @@ def test_make_extra_output( @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("cache_quantized_input", (False, True)) def test_activation( self, @@ -1509,7 +1509,7 @@ def test_activation( torch.testing.assert_close(dx_test, x_ref.grad, **tols) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantize_forward", (False, True)) @pytest.mark.parametrize("quantize_backward", (False, True)) def test_swiglu( @@ -1587,7 +1587,7 @@ def setup_class(cls) -> None: @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5))) @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_weight", (False, True)) def test_forward_linear_bias_activation( self, @@ -1695,7 +1695,7 @@ def test_forward_linear_bias_activation( @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) def test_forward_linear_bias_add( self, *, @@ -1807,7 +1807,7 @@ def test_forward_linear_bias_add( torch.testing.assert_close(db_test, b_ref.grad, **tols) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) def test_backward_linear_add( self, *, @@ -1918,7 +1918,7 @@ def setup_class(cls) -> None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) - @pytest.mark.parametrize("quantization", _quantization_schemes) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_weight", (False, True)) def test_linear( self, diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 55246a3d1d..868fc3a27a 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -947,7 +947,7 @@ def _all_gather_fp8( out = quantizer.make_empty(out_shape, dtype=dtype, device=device) elif isinstance(inp, Float8Tensor): out = inp.make_like(inp, shape=out_shape) - out._data = torch.empty_like( + out._data = torch.empty( out_shape, dtype=torch.uint8, device=inp.device, From f5ac3ad64eee25e2e8a3a99cef308100aecd5b41 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 6 Jun 2025 21:42:09 +0000 Subject: [PATCH 4/6] Add FP8 current scaling to Userbuffers te.Sequential tests Signed-off-by: Tim Moon --- .../test_fusible_ops_with_userbuffers.py | 52 +++++++++++-------- .../pytorch/ops/basic/basic_linear.py | 5 +- .../ops/fused/userbuffers_forward_linear.py | 12 +++-- 3 files changed, 42 insertions(+), 27 deletions(-) diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 4862440110..6e04a1e363 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -26,7 +26,7 @@ UserbuffersBackwardLinear, UserbuffersForwardLinear, ) -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.utils import is_bf16_compatible @@ -40,7 +40,7 @@ mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() quantization_list: list[Optional[str]] = [None] if fp8_available: - quantization_list.append("fp8") + quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: quantization_list.append("mxfp8") @@ -118,11 +118,12 @@ def reset_rng(seed: int = 1234) -> None: @torch.no_grad() def make_reference_and_test_tensors( shape: int | Iterable[int], + quantization: Optional[str] = None, ref_dtype: torch.dtype = torch.float64, ref_device: torch.device = "cpu", test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", - test_is_fp8: bool = False, + test_is_quantized: bool = False, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -131,27 +132,43 @@ def make_reference_and_test_tensors( operations in high precision. The test tensor is intended for use in Transformer Engine operations. + If a quantization scheme is provided, the tensor values are + quantized so that they are representable. + """ - # Random data + # Random reference tensor ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) - # Make copy of tensor + # Construct test tensor from reference tensor test = ref.to(device=test_device, dtype=test_dtype) - if test_is_fp8: + if quantization is None: + if test_is_quantized: + raise ValueError("Quantization scheme not provided") + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + elif quantization in ("fp8", "fp8_delayed_scaling"): quantizer = Float8Quantizer( - scale=torch.ones(1, dtype=torch.float32, device=test_device), + scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), fp8_dtype=tex.DType.kFloat8E4M3, ) test = quantizer(test) - elif test.data_ptr() == ref.data_ptr(): - test = test.clone() + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, device=test_device, + ) + test = quantizer(test) + elif quantization == "mxfp8": + test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + else: + raise ValueError(f"Unsupported quantization scheme ({quantization})") + if isinstance(test, QuantizedTensor) and not test_is_quantized: + test = test.dequantize() - # Make sure reference and test tensors represent exact same values + # Make sure reference and test tensors match each other ref.copy_(test) - # Return reference and test tensors ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) return ref, test @@ -186,21 +203,16 @@ def _test_linear( reset_rng() x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(w_test, QuantizedTensor): - w_test = w_test.dequantize() b_ref, b_test = None, None if bias: if tensor_parallel_mode == "row": @@ -214,13 +226,11 @@ def _test_linear( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 008025496d..bbad21200c 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -353,8 +353,9 @@ def pre_forward(self, *args, **kwargs) -> None: # Make sure weight tensor has correct quantizer # Note: Quantizer might have changed if quantization # recipe changed - if isinstance(weight_quantizer, Float8Quantizer) and isinstance( - weight, Float8TensorBase + if ( + isinstance(weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) + and isinstance(weight, Float8TensorBase) ): weight._quantizer = weight_quantizer diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index c35d029403..0078f7ae65 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -21,7 +21,7 @@ _2X_ACC_FPROP, ) from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer -from ...tensor.float8_tensor import Float8Quantizer +from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...utils import canonicalize_device, canonicalize_dtype from ..basic import BasicLinear, Bias, ReduceScatter @@ -208,7 +208,9 @@ def _functional_forward( if input_quantizer is not None: if not isinstance(x_local, QuantizedTensorBase): input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) - if isinstance(input_quantizer, Float8Quantizer): + if isinstance( + input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ): input_quantizer.set_usage(columnwise=False) x_local = input_quantizer(x_local) input_quantizer.set_usage(rowwise=True, columnwise=False) @@ -327,8 +329,10 @@ def fuser_forward( grad_input_quantizer = None if with_quantized_compute: recipe = FP8GlobalStateManager.get_fp8_recipe() - if not recipe.delayed() and not recipe.mxfp8(): - raise RuntimeError("Userbuffers is only supported with FP8 delayed scaling recipe") + if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())): + raise RuntimeError( + f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})" + ) input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) grad_output_quantizer = linear_op.get_quantizer("backward", 0) From cbb9d7b736b7335b6734c3a123f4099d7e285e71 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Jun 2025 21:50:04 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/distributed/test_fusible_ops.py | 8 +++-- .../test_fusible_ops_with_userbuffers.py | 8 +++-- tests/pytorch/test_fusible_ops.py | 32 +++++++++++-------- .../pytorch/ops/basic/basic_linear.py | 7 ++-- 4 files changed, 33 insertions(+), 22 deletions(-) diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index fc7abb96c8..6eab47488f 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -22,7 +22,10 @@ import transformer_engine.pytorch as te from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.utils import is_bf16_compatible @@ -106,7 +109,8 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "fp8_current_scaling": quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, device=test_device, + fp8_dtype=tex.DType.kFloat8E4M3, + device=test_device, ) test = quantizer(test) elif quantization == "mxfp8": diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 6e04a1e363..883e7f4645 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -26,7 +26,10 @@ UserbuffersBackwardLinear, UserbuffersForwardLinear, ) -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.utils import is_bf16_compatible @@ -156,7 +159,8 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "fp8_current_scaling": quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, device=test_device, + fp8_dtype=tex.DType.kFloat8E4M3, + device=test_device, ) test = quantizer(test) elif quantization == "mxfp8": diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 1ad7755f88..b0cb288376 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -26,7 +26,11 @@ ForwardLinearBiasAdd, ) from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8CurrentScalingQuantizer, Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Tensor, + Float8CurrentScalingQuantizer, + Float8Quantizer, +) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex @@ -69,10 +73,7 @@ def maybe_skip_quantization( return # Check if quantization scheme is supported - if ( - quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") - and not fp8_available - ): + if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) @@ -133,7 +134,8 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "fp8_current_scaling": quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, device=test_device, + fp8_dtype=tex.DType.kFloat8E4M3, + device=test_device, ) test = quantizer(test) elif quantization == "mxfp8": @@ -761,14 +763,16 @@ def _test_basic_linear( # Skip invalid configurations maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=out_shape) - quantization_needed = any(( - quantized_compute, - quantized_input, - quantized_weight, - quantized_output, - quantized_grad_output, - quantized_grad_input, - )) + quantization_needed = any( + ( + quantized_compute, + quantized_input, + quantized_weight, + quantized_output, + quantized_grad_output, + quantized_grad_input, + ) + ) if quantization is None and quantization_needed: pytest.skip("Quantization scheme is not specified") if quantization is not None and not quantization_needed: diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index bbad21200c..4bd94b3ad2 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -353,10 +353,9 @@ def pre_forward(self, *args, **kwargs) -> None: # Make sure weight tensor has correct quantizer # Note: Quantizer might have changed if quantization # recipe changed - if ( - isinstance(weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) - and isinstance(weight, Float8TensorBase) - ): + if isinstance( + weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ) and isinstance(weight, Float8TensorBase): weight._quantizer = weight_quantizer @staticmethod From 5a477ece6c07d3463be10b1b5084b5b44084369c Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 12 Jun 2025 21:26:41 +0000 Subject: [PATCH 6/6] Debug MXFP8 tests Signed-off-by: Tim Moon --- tests/pytorch/distributed/test_fusible_ops.py | 21 ++++++------- .../test_fusible_ops_with_userbuffers.py | 1 + tests/pytorch/test_fusible_ops.py | 30 +++++++++++-------- 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 6eab47488f..6f025817df 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -26,6 +26,7 @@ Float8Quantizer, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.utils import is_bf16_compatible @@ -130,7 +131,7 @@ def make_reference_and_test_tensors( def _test_all_reduce( *, - local_size: int = 17, + local_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", quantization: Optional[str] = None, @@ -142,8 +143,8 @@ def _test_all_reduce( world_size = torch.distributed.get_world_size(process_group) # Tensor dimensions - in_shape = [world_size, local_size] - out_shape = [local_size] + in_shape = [world_size, local_size, local_size] + out_shape = [local_size, local_size] # Random data reset_rng() @@ -188,7 +189,7 @@ def _test_all_reduce( def _test_all_gather( *, - local_size: int = 13, + local_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", quantization: Optional[str] = None, @@ -200,8 +201,8 @@ def _test_all_gather( world_size = torch.distributed.get_world_size(process_group) # Tensor dimensions - in_shape = [world_size, local_size] - out_shape = [world_size, world_size * local_size] + in_shape = [world_size, local_size, local_size] + out_shape = [world_size, world_size * local_size, local_size] # Random data reset_rng() @@ -222,7 +223,7 @@ def _test_all_gather( ) # Plain PyTorch implementation - y_ref = x_ref.tile((world_size, 1)).reshape(out_shape) + y_ref = x_ref.tile((world_size, 1, 1)).reshape(out_shape) y_ref.backward(dy_ref) # Convert to distributed tensors @@ -249,7 +250,7 @@ def _test_all_gather( def _test_reduce_scatter( *, - local_size: int = 11, + local_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", quantization: Optional[str] = None, @@ -261,8 +262,8 @@ def _test_reduce_scatter( world_size = torch.distributed.get_world_size(process_group) # Tensor dimensions - in_shape = [world_size, world_size * local_size] - out_shape = [world_size, local_size] + in_shape = [world_size, world_size * local_size, local_size] + out_shape = [world_size, local_size, local_size] # Random data reset_rng() diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 883e7f4645..68083a0e03 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -30,6 +30,7 @@ Float8Quantizer, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.utils import is_bf16_compatible diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index b0cb288376..7c825ac0f2 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -31,7 +31,7 @@ Float8CurrentScalingQuantizer, Float8Quantizer, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex @@ -375,8 +375,9 @@ def test_dtype_cast( """Check dtype cast functions""" # Skip invalid configurations + in_shape = (size, size) with_quantization = quantization is not None - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data dtype = torch.float32 @@ -414,7 +415,7 @@ def test_dtype_cast( # Check forward and backward pass x = torch.zeros( - (size, size), + in_shape, dtype=init_dtype, device=device, requires_grad=True, @@ -442,8 +443,9 @@ def test_pyt_autocast( device = torch.device(device) # Skip invalid configurations + in_shape = (size, size) quantized_compute = quantization is not None - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Construct operation recipe = make_recipe(quantization) @@ -452,7 +454,7 @@ def test_pyt_autocast( # Check forward and backward pass x = torch.zeros( - (size, size), + in_shape, dtype=model_dtype, device=device, requires_grad=True, @@ -494,7 +496,7 @@ def setup_class(cls) -> None: def test_identity( self, *, - in_shape: Iterable[int] = (1,), + in_shape: Iterable[int] = (32, 32), dtype: torch.dtype, device: torch.device, quantization: Optional[str], @@ -502,7 +504,7 @@ def test_identity( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -615,7 +617,7 @@ def test_reshape( torch.testing.assert_close(dx_test, x_ref.grad, **tols) @pytest.mark.parametrize("size", (1, 7, 32)) - @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (2, 3, 4, -1))) + @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (4, 3, 8, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", _devices) @pytest.mark.parametrize("quantization", _quantization_list) @@ -634,7 +636,7 @@ def test_bias( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -697,6 +699,8 @@ def test_quantize( # Skip invalid configurations with_quantization = quantization is not None maybe_skip_quantization(quantization, device=device) + if quantization == "mxfp8": + maybe_skip_quantization(quantization, dims=in_shape) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1296,7 +1300,7 @@ def test_rmsnorm( def test_add_in_place( self, *, - in_shape: Iterable[int] = (1,), + in_shape: Iterable[int] = (32, 32), dtype: torch.dtype, device: torch.device, quantization: Optional[str], @@ -1309,7 +1313,7 @@ def test_add_in_place( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x1_ref, x1_test = make_reference_and_test_tensors( @@ -1363,7 +1367,7 @@ def test_add_in_place( def test_make_extra_output( self, *, - in_shape: Iterable[int] = (1,), + in_shape: Iterable[int] = (32, 32), dtype: torch.dtype, device: torch.device, quantization: Optional[str], @@ -1376,7 +1380,7 @@ def test_make_extra_output( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors(