diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 472d20c508..6f025817df 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -22,19 +22,28 @@ import transformer_engine.pytorch as te from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex +# Import utility functions +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import dtype_tols, make_recipe + # Check what quantization schemes are supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() quantization_list: list[Optional[str]] = [None] if fp8_available: - quantization_list.append("fp8") + quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: quantization_list.append("mxfp8") @@ -63,11 +72,12 @@ def reset_rng(seed: int = 1234) -> None: @torch.no_grad() def make_reference_and_test_tensors( shape: int | Iterable[int], + quantization: Optional[str] = None, ref_dtype: torch.dtype = torch.float64, ref_device: torch.device = "cpu", test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", - test_is_fp8: bool = False, + test_is_quantized: bool = False, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -76,78 +86,55 @@ def make_reference_and_test_tensors( operations in high precision. The test tensor is intended for use in Transformer Engine operations. + If a quantization scheme is provided, the tensor values are + quantized so that they are representable. + """ + + # Random reference tensor ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) + + # Construct test tensor from reference tensor test = ref.to(device=test_device, dtype=test_dtype) - if test_is_fp8: + if quantization is None: + if test_is_quantized: + raise ValueError("Quantization scheme not provided") + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + elif quantization in ("fp8", "fp8_delayed_scaling"): quantizer = Float8Quantizer( - scale=torch.ones(1, dtype=torch.float32, device=test_device), + scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), fp8_dtype=tex.DType.kFloat8E4M3, ) test = quantizer(test) - elif test.data_ptr() == ref.data_ptr(): - test = test.clone() + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=test_device, + ) + test = quantizer(test) + elif quantization == "mxfp8": + test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + else: + raise ValueError(f"Unsupported quantization scheme ({quantization})") + if isinstance(test, QuantizedTensor) and not test_is_quantized: + test = test.dequantize() + + # Make sure reference and test tensors match each other ref.copy_(test) + ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) return ref, test -def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: - """Estimated numerical error for a datatype - - Based on tolerances for torch.testing.assert_close. - - """ - - # Transformer Engine dtypes - if isinstance(dtype, tex.DType): - if dtype == tex.DType.kFloat8E4M3: - return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 - if dtype == tex.DType.kFloat8E5M2: - return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 - dtype = { - tex.DType.kByte: torch.uint8, - tex.DType.kInt32: torch.int32, - tex.DType.kFloat32: torch.float32, - tex.DType.kFloat16: torch.half, - tex.DType.kBFloat16: torch.bfloat16, - }[dtype] - - # PyTorch dtypes - if dtype == torch.float16: - return dict(rtol=1e-3, atol=1e-5) - if dtype == torch.bfloat16: - return dict(rtol=1.6e-2, atol=1e-5) - if dtype == torch.float32: - return dict(rtol=1.3e-6, atol=1e-5) - if dtype == torch.float64: - return dict(rtol=1e-7, atol=1e-7) - raise ValueError(f"Unsupported dtype ({dtype})") - - -def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: - """Make recipe for quantization scheme""" - if name is None: - return None - if name == "fp8": - return transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - if name == "mxfp8": - return transformer_engine.common.recipe.MXFP8BlockScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - raise ValueError(f"Unsupported quantization scheme ({name})") - - def _test_all_reduce( *, - local_size: int = 17, + local_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8: bool = False, + quantization: Optional[str] = None, ) -> None: # Distributed process group @@ -156,22 +143,25 @@ def _test_all_reduce( world_size = torch.distributed.get_world_size(process_group) # Tensor dimensions - in_shape = [world_size, local_size] - out_shape = [local_size] + in_shape = [world_size, local_size, local_size] + out_shape = [local_size, local_size] # Random data reset_rng() + with_quantization = quantization is not None x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) # Plain PyTorch implementation @@ -199,10 +189,10 @@ def _test_all_reduce( def _test_all_gather( *, - local_size: int = 13, + local_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8: bool = False, + quantization: Optional[str] = None, ) -> None: # Distributed process group @@ -211,26 +201,29 @@ def _test_all_gather( world_size = torch.distributed.get_world_size(process_group) # Tensor dimensions - in_shape = [world_size, local_size] - out_shape = [world_size, world_size * local_size] + in_shape = [world_size, local_size, local_size] + out_shape = [world_size, world_size * local_size, local_size] # Random data reset_rng() + with_quantization = quantization is not None x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) # Plain PyTorch implementation - y_ref = x_ref.tile((world_size, 1)).reshape(out_shape) + y_ref = x_ref.tile((world_size, 1, 1)).reshape(out_shape) y_ref.backward(dy_ref) # Convert to distributed tensors @@ -257,10 +250,10 @@ def _test_all_gather( def _test_reduce_scatter( *, - local_size: int = 11, + local_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8: bool = False, + quantization: Optional[str] = None, ) -> None: # Distributed process group @@ -269,22 +262,25 @@ def _test_reduce_scatter( world_size = torch.distributed.get_world_size(process_group) # Tensor dimensions - in_shape = [world_size, world_size * local_size] - out_shape = [world_size, local_size] + in_shape = [world_size, world_size * local_size, local_size] + out_shape = [world_size, local_size, local_size] # Random data reset_rng() + with_quantization = quantization is not None x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) # Plain PyTorch implementation @@ -324,7 +320,11 @@ def _test_basic_linear( tensor_parallel_mode: str = "column", sequence_parallel: bool = False, ) -> None: + + # Skip invalid configurations quantized_compute = quantization is not None + if not quantized_compute and quantized_weight: + return # Distributed process group process_group = world_group() @@ -348,30 +348,23 @@ def _test_basic_linear( reset_rng() x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) - if isinstance(w_test, QuantizedTensor): - w_test = w_test.dequantize() dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -468,7 +461,11 @@ def _test_linear( tensor_parallel_mode: str = "column", sequence_parallel: bool = False, ) -> None: + + # Skip invalid configurations quantized_compute = quantization is not None + if not quantized_compute and quantized_weight: + return # Distributed process group process_group = world_group() @@ -492,21 +489,16 @@ def _test_linear( reset_rng() x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) - if isinstance(w_test, QuantizedTensor): - w_test = w_test.dequantize() b_ref, b_test = None, None if bias: if tensor_parallel_mode == "row": @@ -520,13 +512,11 @@ def _test_linear( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -773,9 +763,10 @@ def run_parallel_tests() -> None: if rank == 0: print(f"Running _test_all_reduce") _test_all_reduce() - if rank == 0: - print(f"Running _test_all_gather") - _test_all_gather() + for quantization in quantization_list: + if rank == 0: + print(f"Running _test_all_gather with quantization={quantization}") + _test_all_gather(quantization=quantization) if rank == 0: print(f"Running _test_reduce_scatter") _test_reduce_scatter() diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 42070ea0f4..68083a0e03 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -26,21 +26,25 @@ UserbuffersBackwardLinear, UserbuffersForwardLinear, ) -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.utils import is_bf16_compatible # Import utility functions _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) -from utils import dtype_tols, str_to_dtype +from utils import dtype_tols, make_recipe, str_to_dtype # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() quantization_list: list[Optional[str]] = [None] if fp8_available: - quantization_list.append("fp8") + quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: quantization_list.append("mxfp8") @@ -118,11 +122,12 @@ def reset_rng(seed: int = 1234) -> None: @torch.no_grad() def make_reference_and_test_tensors( shape: int | Iterable[int], + quantization: Optional[str] = None, ref_dtype: torch.dtype = torch.float64, ref_device: torch.device = "cpu", test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", - test_is_fp8: bool = False, + test_is_quantized: bool = False, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -131,47 +136,49 @@ def make_reference_and_test_tensors( operations in high precision. The test tensor is intended for use in Transformer Engine operations. + If a quantization scheme is provided, the tensor values are + quantized so that they are representable. + """ - # Random data + # Random reference tensor ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) - # Make copy of tensor + # Construct test tensor from reference tensor test = ref.to(device=test_device, dtype=test_dtype) - if test_is_fp8: + if quantization is None: + if test_is_quantized: + raise ValueError("Quantization scheme not provided") + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + elif quantization in ("fp8", "fp8_delayed_scaling"): quantizer = Float8Quantizer( - scale=torch.ones(1, dtype=torch.float32, device=test_device), + scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), fp8_dtype=tex.DType.kFloat8E4M3, ) test = quantizer(test) - elif test.data_ptr() == ref.data_ptr(): - test = test.clone() + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=test_device, + ) + test = quantizer(test) + elif quantization == "mxfp8": + test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + else: + raise ValueError(f"Unsupported quantization scheme ({quantization})") + if isinstance(test, QuantizedTensor) and not test_is_quantized: + test = test.dequantize() - # Make sure reference and test tensors represent exact same values + # Make sure reference and test tensors match each other ref.copy_(test) - # Return reference and test tensors ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) return ref, test -def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: - """Make recipe for quantization scheme""" - if name is None: - return None - if name == "fp8": - return transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - if name == "mxfp8": - return transformer_engine.common.recipe.MXFP8BlockScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - raise ValueError(f"Unsupported quantization scheme ({name})") - - def _test_linear( *, model_config: ModelConfig, @@ -201,21 +208,16 @@ def _test_linear( reset_rng() x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(w_test, QuantizedTensor): - w_test = w_test.dequantize() b_ref, b_test = None, None if bias: if tensor_parallel_mode == "row": @@ -229,13 +231,11 @@ def _test_linear( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index b1706db612..f78fa581b5 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -7,6 +7,8 @@ from collections.abc import Iterable import io import math +import pathlib +import sys from typing import Optional import pytest @@ -24,10 +26,20 @@ ForwardLinearBiasAdd, ) from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Tensor, + Float8CurrentScalingQuantizer, + Float8Quantizer, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex +# Import utility functions +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent)) +from utils import dtype_tols, make_recipe + # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -40,6 +52,13 @@ # Supported devices _devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")] +# Supported quantization recipes +_quantization_list: list[Optional[str]] = [None] +if fp8_available: + _quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) +if mxfp8_available: + _quantization_list.append("mxfp8") + def maybe_skip_quantization( quantization: Optional[str], @@ -47,13 +66,14 @@ def maybe_skip_quantization( dims: Optional[Iterable[int] | int] = None, device: Optional[torch.device | str] = None, ) -> None: + """Skip test case if a quantization scheme is not supported""" # Don't skip if there is no quantization if quantization is None: return # Check if quantization scheme is supported - if quantization == "fp8" and not fp8_available: + if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) @@ -61,7 +81,7 @@ def maybe_skip_quantization( if dims is not None: if not isinstance(dims, Iterable): dims = (dims,) - if quantization == "fp8": + if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"): if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: pytest.skip("FP8 GEMMs require dims that are divisible by 16") elif quantization == "mxfp8": @@ -73,47 +93,15 @@ def maybe_skip_quantization( pytest.skip("Quantization is only supported on CUDA devices") -def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: - """Estimated numerical error for a datatype - - Based on tolerances for torch.testing.assert_close. - - """ - - # Transformer Engine dtypes - if isinstance(dtype, tex.DType): - if dtype == tex.DType.kFloat8E4M3: - return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 - if dtype == tex.DType.kFloat8E5M2: - return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 - dtype = { - tex.DType.kByte: torch.uint8, - tex.DType.kInt32: torch.int32, - tex.DType.kFloat32: torch.float32, - tex.DType.kFloat16: torch.half, - tex.DType.kBFloat16: torch.bfloat16, - }[dtype] - - # PyTorch dtypes - if dtype == torch.float16: - return dict(rtol=1e-3, atol=1e-5) - if dtype == torch.bfloat16: - return dict(rtol=1.6e-2, atol=1e-5) - if dtype == torch.float32: - return dict(rtol=1.3e-6, atol=1e-5) - if dtype == torch.float64: - return dict(rtol=1e-7, atol=1e-7) - raise ValueError(f"Unsupported dtype ({dtype})") - - @torch.no_grad() def make_reference_and_test_tensors( shape: int | Iterable[int], + quantization: Optional[str] = None, ref_dtype: torch.dtype = torch.float64, ref_device: torch.device = "cpu", test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", - test_is_fp8: bool = False, + test_is_quantized: bool = False, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -122,39 +110,49 @@ def make_reference_and_test_tensors( operations in high precision. The test tensor is intended for use in Transformer Engine operations. + If a quantization scheme is provided, the tensor values are + quantized so that they are representable. + """ + + # Random reference tensor ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) + + # Construct test tensor from reference tensor test = ref.to(device=test_device, dtype=test_dtype) - if test_is_fp8: + if quantization is None: + if test_is_quantized: + raise ValueError("Quantization scheme not provided") + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + elif quantization in ("fp8", "fp8_delayed_scaling"): quantizer = Float8Quantizer( scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), fp8_dtype=tex.DType.kFloat8E4M3, ) test = quantizer(test) - elif test.data_ptr() == ref.data_ptr(): - test = test.clone() + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=test_device, + ) + test = quantizer(test) + elif quantization == "mxfp8": + test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + else: + raise ValueError(f"Unsupported quantization scheme ({quantization})") + if isinstance(test, QuantizedTensor) and not test_is_quantized: + test = test.dequantize() + + # Make sure reference and test tensors match each other ref.copy_(test) + ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) return ref, test -def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: - """Make recipe for quantization scheme""" - if name is None: - return None - if name == "fp8": - return transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - if name == "mxfp8": - return transformer_engine.common.recipe.MXFP8BlockScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - raise ValueError(f"Unsupported quantization scheme ({name})") - - class TestSequential: """Tests for sequential container""" @@ -364,7 +362,7 @@ def test_fp8_scale_update( @pytest.mark.parametrize("init_dtype", _dtypes) @pytest.mark.parametrize("final_dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_dtype_cast( self, *, @@ -377,8 +375,9 @@ def test_dtype_cast( """Check dtype cast functions""" # Skip invalid configurations - maybe_skip_quantization(quantization, device=device) + in_shape = (size, size) with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data dtype = torch.float32 @@ -388,9 +387,9 @@ def test_dtype_cast( dtype = torch.bfloat16 w_ref, w_test = make_reference_and_test_tensors( (size, size), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=with_quantization, ) # Construct operation @@ -412,11 +411,11 @@ def test_dtype_cast( assert isinstance(op.weight, QuantizedTensor) == with_quantization assert op.weight.dtype == final_dtype w_test = op.weight.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(w_test, w_ref, rtol=0, atol=0) + torch.testing.assert_close(w_test, w_ref, **dtype_tols(dtype)) # Check forward and backward pass x = torch.zeros( - (size, size), + in_shape, dtype=init_dtype, device=device, requires_grad=True, @@ -429,7 +428,7 @@ def test_dtype_cast( @pytest.mark.parametrize("model_dtype", _dtypes) @pytest.mark.parametrize("autocast_dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_pyt_autocast( self, *, @@ -444,8 +443,9 @@ def test_pyt_autocast( device = torch.device(device) # Skip invalid configurations + in_shape = (size, size) quantized_compute = quantization is not None - maybe_skip_quantization(quantization) + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Construct operation recipe = make_recipe(quantization) @@ -454,7 +454,7 @@ def test_pyt_autocast( # Check forward and backward pass x = torch.zeros( - (size, size), + in_shape, dtype=model_dtype, device=device, requires_grad=True, @@ -492,33 +492,34 @@ def setup_class(cls) -> None: @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", _quantization_list) def test_identity( self, *, - in_shape: Iterable[int] = (1,), + in_shape: Iterable[int] = (32, 32), dtype: torch.dtype, device: torch.device, - fp8: bool, + quantization: Optional[str], ) -> None: # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -554,7 +555,7 @@ def test_identity( ), ) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8_current_scaling")) def test_reshape( self, *, @@ -562,31 +563,32 @@ def test_reshape( dtype: torch.dtype, device: torch.device = "cuda", memory_format: torch.memory_format = torch.contiguous_format, - fp8: bool, + quantization: Optional[str], ) -> None: in_shape, out_shape = shapes # Skip invalid configurations if memory_format == torch.channels_last and len(in_shape) != 4: pytest.skip("torch.channels_last only supports 4D tensors") - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + maybe_skip_quantization(quantization, device=device) + with_quantization = quantization is not None # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) x_test = x_test.contiguous(memory_format=memory_format) x_test = x_test.detach().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( x_ref.reshape(out_shape).size(), + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -615,10 +617,10 @@ def test_reshape( torch.testing.assert_close(dx_test, x_ref.grad, **tols) @pytest.mark.parametrize("size", (1, 7, 32)) - @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (2, 3, 4, -1))) + @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (4, 3, 8, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", _devices) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", _quantization_list) def test_bias( self, *, @@ -626,24 +628,23 @@ def test_bias( in_shape: Iterable[int], dtype: torch.dtype, device: torch.device, - fp8: bool, + quantization: Optional[str], ) -> None: # Make input and bias shapes consistent in_shape = list(in_shape)[:-1] + [size] # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) b_ref, b_test = make_reference_and_test_tensors( size, @@ -652,8 +653,10 @@ def test_bias( ) dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -678,7 +681,7 @@ def test_bias( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols) - @pytest.mark.parametrize("quantization", ("fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("cast_forward", (False, True)) @pytest.mark.parametrize("cast_backward", (False, True)) def test_quantize( @@ -694,25 +697,26 @@ def test_quantize( """Quantize""" # Skip invalid configurations - maybe_skip_quantization(quantization) + with_quantization = quantization is not None + maybe_skip_quantization(quantization, device=device) + if quantization == "mxfp8": + maybe_skip_quantization(quantization, dims=in_shape) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - requires_grad=False, - test_is_fp8=True, + requires_grad=True, ) - x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, - test_is_fp8=True, ) - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = x_ref @@ -721,13 +725,14 @@ def test_quantize( # Implementation with fusible operation op = te_ops.Quantize(forward=cast_forward, backward=cast_backward) recipe = make_recipe(quantization) - with te.fp8_autocast(fp8_recipe=recipe): + with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) # Check tensor types - assert isinstance(y_test, QuantizedTensor) == cast_forward - assert isinstance(x_test.grad, QuantizedTensor) == cast_backward + if with_quantization: + assert isinstance(y_test, QuantizedTensor) == cast_forward + assert isinstance(x_test.grad, QuantizedTensor) == cast_backward # Check values tols = dict(rtol=0, atol=0) @@ -762,10 +767,25 @@ def _test_basic_linear( # Skip invalid configurations maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=out_shape) - if quantization == "fp8" and quantized_output and not quantized_compute: - pytest.skip("FP8 output is only supported with FP8 GEMMs") - if quantization == "fp8" and quantized_grad_input and not quantized_compute: - pytest.skip("FP8 grad input is only supported with FP8 GEMMs") + quantization_needed = any( + ( + quantized_compute, + quantized_input, + quantized_weight, + quantized_output, + quantized_grad_output, + quantized_grad_input, + ) + ) + if quantization is None and quantization_needed: + pytest.skip("Quantization scheme is not specified") + if quantization is not None and not quantization_needed: + pytest.skip("Quantization scheme is not used") + if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"): + if quantized_output and not quantized_compute: + pytest.skip("FP8 output is only supported with FP8 GEMMs") + if quantized_grad_input and not quantized_compute: + pytest.skip("FP8 grad input is only supported with FP8 GEMMs") if quantization == "mxfp8" and quantized_output: pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs") if quantization == "mxfp8" and quantized_grad_input: @@ -774,28 +794,25 @@ def _test_basic_linear( # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_input), + test_is_quantized=quantized_input, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_grad_output), + test_is_quantized=quantized_grad_output, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -858,7 +875,7 @@ def _test_basic_linear( @pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5))) @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) def test_basic_linear( self, @@ -880,7 +897,7 @@ def test_basic_linear( ) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - @pytest.mark.parametrize("quantization", ("fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_input", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True)) @@ -899,6 +916,8 @@ def test_basic_linear_quantized( quantized_grad_input: bool, ) -> None: """GEMM with FP8 inputs and outputs""" + if quantization is None: + pytest.skip("Skipping case without quantization") self._test_basic_linear( dtype=torch.bfloat16, quantization=quantization, @@ -911,7 +930,8 @@ def test_basic_linear_quantized( ) @pytest.mark.parametrize("bias", (False, True)) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("input_requires_grad", (False, True)) @pytest.mark.parametrize("weight_requires_grad", (False, True)) @@ -924,6 +944,7 @@ def test_linear( dtype: torch.dtype = torch.float32, device: torch.device = "cuda", quantization: Optional[str], + quantized_compute: bool, quantized_weight: bool, input_requires_grad: bool, weight_requires_grad: bool, @@ -936,26 +957,25 @@ def test_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - quantized_compute = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not specified") + if quantization is not None and not (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not used") # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - with torch.no_grad(): - if isinstance(x_test, QuantizedTensor): - x_test = x_test.dequantize() - x_test.requires_grad_(requires_grad=input_requires_grad) w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -966,6 +986,7 @@ def test_linear( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, @@ -1022,7 +1043,7 @@ def test_linear( @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("zero_centered_gamma", (False, True)) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_layer_norm( self, *, @@ -1192,7 +1213,7 @@ def test_layer_norm_autocast( @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("zero_centered_gamma", (False, True)) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_rmsnorm( self, *, @@ -1327,14 +1348,14 @@ def test_l2normalization( @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", _quantization_list) def test_add_in_place( self, *, - in_shape: Iterable[int] = (1,), + in_shape: Iterable[int] = (32, 32), dtype: torch.dtype, device: torch.device, - fp8: bool, + quantization: Optional[str], ) -> None: """Add two tensors @@ -1343,28 +1364,30 @@ def test_add_in_place( """ # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x1_ref, x1_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) x2_ref, x2_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -1381,7 +1404,7 @@ def test_add_in_place( # Check results tols = dtype_tols(dtype) - if fp8: + if with_quantization: tols = dtype_tols(x1_test._fp8_dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") @@ -1392,14 +1415,14 @@ def test_add_in_place( @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", _quantization_list) def test_make_extra_output( self, *, - in_shape: Iterable[int] = (1,), + in_shape: Iterable[int] = (32, 32), dtype: torch.dtype, device: torch.device, - fp8: bool, + quantization: Optional[str], ) -> None: """Output tensor twice @@ -1408,28 +1431,31 @@ def test_make_extra_output( """ # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy1_ref, dy1_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) dy2_ref, dy2_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -1455,7 +1481,7 @@ def test_make_extra_output( @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("cache_quantized_input", (False, True)) def test_activation( self, @@ -1478,26 +1504,21 @@ def test_activation( quantized_compute = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device) if cache_quantized_input: - maybe_skip_quantization("fp8", device=device) + maybe_skip_quantization("fp8_current_scaling", device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization="fp8_current_scaling" if cache_quantized_input else None, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, requires_grad=False, ) - if quantized_compute: - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref: torch.Tensor @@ -1540,8 +1561,6 @@ def test_activation( tols = dtype_tols(dtype) if quantized_compute or cache_quantized_input: tols = dtype_tols(tex.DType.kFloat8E4M3) - if activation == "relu" and not cache_quantized_input: - tols = {"atol": 0, "rtol": 0} # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1550,7 +1569,7 @@ def test_activation( torch.testing.assert_close(dx_test, x_ref.grad, **tols) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantize_forward", (False, True)) @pytest.mark.parametrize("quantize_backward", (False, True)) def test_swiglu( @@ -1628,7 +1647,7 @@ def setup_class(cls) -> None: @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5))) @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_weight", (False, True)) def test_forward_linear_bias_activation( self, @@ -1660,18 +1679,15 @@ def test_forward_linear_bias_activation( # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if quantized_compute: - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -1682,6 +1698,7 @@ def test_forward_linear_bias_activation( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, @@ -1738,7 +1755,7 @@ def test_forward_linear_bias_activation( @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_forward_linear_bias_add( self, *, @@ -1767,18 +1784,15 @@ def test_forward_linear_bias_add( # Random data x1_ref, x1_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x1_test, QuantizedTensor): - with torch.no_grad(): - x1_test = x1_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -1794,6 +1808,7 @@ def test_forward_linear_bias_add( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, @@ -1852,7 +1867,7 @@ def test_forward_linear_bias_add( torch.testing.assert_close(db_test, b_ref.grad, **tols) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_backward_linear_add( self, *, @@ -1880,27 +1895,26 @@ def test_backward_linear_add( # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) dy1_ref, dy1_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, ) dy2_ref, dy2_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, @@ -1964,7 +1978,7 @@ def setup_class(cls) -> None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_weight", (False, True)) def test_linear( self, diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 450c24da33..f4a8ce69c6 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -7,6 +7,7 @@ import torch import transformer_engine +import transformer_engine.common.recipe import transformer_engine.pytorch as te import transformer_engine_torch as tex @@ -83,3 +84,24 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: if dtype == torch.float8_e5m2: return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 raise ValueError(f"Unsupported dtype ({dtype})") + + +def make_recipe(name: Optional[str]) -> Optional[Recipe]: + """Make recipe for quantization scheme""" + if name is None: + return None + if name in ("fp8", "fp8_delayed_scaling"): + return transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "fp8_current_scaling": + return transformer_engine.common.recipe.Float8CurrentScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "mxfp8": + return transformer_engine.common.recipe.MXFP8BlockScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "fp8_block_scaling": + return transformer_engine.common.recipe.Float8BlockScaling() + raise ValueError(f"Unsupported quantization scheme ({name})") diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 55246a3d1d..868fc3a27a 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -947,7 +947,7 @@ def _all_gather_fp8( out = quantizer.make_empty(out_shape, dtype=dtype, device=device) elif isinstance(inp, Float8Tensor): out = inp.make_like(inp, shape=out_shape) - out._data = torch.empty_like( + out._data = torch.empty( out_shape, dtype=torch.uint8, device=inp.device, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 0e786ca96f..4bd94b3ad2 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -22,7 +22,7 @@ from ...fp8 import FP8GlobalStateManager from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...tensor import Quantizer, QuantizedTensor -from ...tensor.float8_tensor import Float8Quantizer +from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase @@ -324,12 +324,38 @@ def pre_forward(self, *args, **kwargs) -> None: weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + # Recipe-specific configuration + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + if any( + not isinstance(q, Float8CurrentScalingQuantizer) + for q in (input_quantizer, weight_quantizer, grad_output_quantizer) + ): + raise RuntimeError( + "FP8 current-scaling recipe is enabled, " + f"but input quantizer is {input_quantizer.__class__.__name__}, " + f"weight quantizer is {weight_quantizer.__class__.__name__}, " + f"grad output quantizer is {grad_output_quantizer.__class__.__name__}" + ) + input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + if self.sequence_parallel and self.tensor_parallel_mode == "column": + input_quantizer.with_amax_reduction = True + input_quantizer.amax_reduction_group = self.tensor_parallel_group + if self.sequence_parallel and self.tensor_parallel_mode == "row": + grad_output_quantizer.with_amax_reduction = True + grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + # Make sure weight tensor has correct quantizer # Note: Quantizer might have changed if quantization # recipe changed - if isinstance(weight_quantizer, Float8Quantizer) and isinstance( - weight, Float8TensorBase - ): + if isinstance( + weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ) and isinstance(weight, Float8TensorBase): weight._quantizer = weight_quantizer @staticmethod diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index c35d029403..0078f7ae65 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -21,7 +21,7 @@ _2X_ACC_FPROP, ) from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer -from ...tensor.float8_tensor import Float8Quantizer +from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...utils import canonicalize_device, canonicalize_dtype from ..basic import BasicLinear, Bias, ReduceScatter @@ -208,7 +208,9 @@ def _functional_forward( if input_quantizer is not None: if not isinstance(x_local, QuantizedTensorBase): input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) - if isinstance(input_quantizer, Float8Quantizer): + if isinstance( + input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ): input_quantizer.set_usage(columnwise=False) x_local = input_quantizer(x_local) input_quantizer.set_usage(rowwise=True, columnwise=False) @@ -327,8 +329,10 @@ def fuser_forward( grad_input_quantizer = None if with_quantized_compute: recipe = FP8GlobalStateManager.get_fp8_recipe() - if not recipe.delayed() and not recipe.mxfp8(): - raise RuntimeError("Userbuffers is only supported with FP8 delayed scaling recipe") + if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())): + raise RuntimeError( + f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})" + ) input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) grad_output_quantizer = linear_op.get_quantizer("backward", 0)