diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 6d127aa741..912dc67bfc 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -6,21 +6,31 @@ import pytest import torch +import warnings import transformer_engine.common.recipe import transformer_engine.pytorch as te +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import ( FP8GlobalStateManager, _amax_and_scale_update, - get_default_fp8_recipe, + fp8_model_init, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer import transformer_engine.pytorch.ops as te_ops +from transformer_engine.pytorch import Linear +from transformer_engine.pytorch.distributed import fp8_autocast +from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling import transformer_engine_torch as tex # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) # FP8 per tensor delayed scaling @@ -367,3 +377,96 @@ def setup_fp8_meta(): ) torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale) + + @pytest.mark.parametrize( + "model_init_recipe", + [ + pytest.param( + MXFP8BlockScaling(), + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + ), + pytest.param( + Float8BlockScaling(), + marks=pytest.mark.skipif( + not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling + ), + ), + ], + ) + def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_recipe): + with fp8_model_init(enabled=True, recipe=model_init_recipe): + linear = Linear(32, 32).cuda() + + x = torch.randn(32, 32, device="cuda") + with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling()): + with pytest.raises(RuntimeError) as excinfo: + _ = linear(x) + assert "Recipe mismatch for " in str(excinfo.value) + + @pytest.mark.parametrize( + "target_recipe_class, expected_quantizer_type, available_flag, reason", + [ + pytest.param( + MXFP8BlockScaling, + MXFP8Quantizer, + mxfp8_available, + reason_for_no_mxfp8, + id="DelayedScaling->MXFP8BlockScaling", + ), + pytest.param( + Float8BlockScaling, + Float8BlockQuantizer, + fp8_block_scaling_available, + reason_for_no_fp8_block_scaling, + id="DelayedScaling->Float8BlockScaling", + ), + ], + ) + def test_dynamic_recipe_update( + self, target_recipe_class, expected_quantizer_type, available_flag, reason + ): + if not available_flag: + pytest.skip(reason) + + in_features = 32 + out_features = 32 + batch_size = 32 + linear = Linear(in_features, out_features).cuda() + initial_recipe = DelayedScaling() + + # Run initial iterations with DelayedScaling + for _ in range(3): + x = torch.randn(batch_size, in_features, device="cuda") + with fp8_autocast(enabled=True, fp8_recipe=initial_recipe): + y = linear(x) + loss = y.mean() + loss.backward() + + for quantizer in linear.quantizers["scaling_fwd"]: + assert isinstance(quantizer, Float8Quantizer) + + # Change recipe + target_recipe = target_recipe_class() + + # Run subsequent iterations with the target recipe + for i in range(3): + x = torch.randn(batch_size, in_features, device="cuda") + if i == 0: + # Expect a warning on the first iteration with the new recipe + with pytest.warns(UserWarning, match="Recipe type changed"): + with fp8_autocast(enabled=True, fp8_recipe=target_recipe): + y = linear(x) + for quantizer in linear.quantizers["scaling_fwd"]: + assert isinstance(quantizer, expected_quantizer_type) + else: + # No warning expected on subsequent iterations + with warnings.catch_warnings(): + warnings.simplefilter("error") # Raise error if unexpected warning occurs + with fp8_autocast(enabled=True, fp8_recipe=target_recipe): + y = linear(x) + loss = y.mean() + loss.backward() + + # Final check + for quantizer in linear.quantizers["scaling_fwd"]: + assert isinstance(quantizer, expected_quantizer_type) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 8db26183bd..64688e2077 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -87,7 +87,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla A.scaling_mode == B.scaling_mode || (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) || (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D), - "Inputs A and B to GEMM need to have compatible scaling modes!"); + "Inputs A and B to GEMM need to have compatible scaling modes, but got A.scaling_mode = " + + to_string(A.scaling_mode) + ", B.scaling_mode = " + to_string(B.scaling_mode)); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); GemmParam ret; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 80857e565c..f1ecb33272 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -180,6 +180,7 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return ( + f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " f"amax_history_len={self.amax_history_len}, " @@ -245,6 +246,7 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return ( + f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " @@ -291,7 +293,11 @@ def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." def __repr__(self) -> str: - return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," + return ( + f"recipe_type={self.__class__.__name__}, " + f"margin={self.margin}, " + f"format={str(self.fp8_format).split('.')[1]}" + ) @dataclass() @@ -375,6 +381,7 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return ( + f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index b725d3ab37..4d61757e1d 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -14,7 +14,7 @@ import transformer_engine_torch as tex - +from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch.tensor.quantized_tensor import ( QuantizedTensor, Quantizer, @@ -459,6 +459,10 @@ def any_feature_enabled(self) -> bool: return True return False + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + """Probably not needed for debug quantizer""" + return None + class DebugQuantizedTensor(QuantizedTensorBase): """ diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index bb3bf68887..61bf49bf84 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -44,7 +44,7 @@ from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..utils import torch_get_autocast_gpu_dtype from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase -from ...common.recipe import Recipe +from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor @@ -811,6 +811,14 @@ def set_extra_state(self, state: Optional[torch.Tensor]) -> None: if state is None: return + # TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing + if "recipe" not in state: + # TE 1.x only supported delayed scaling, which was the default recipe + state["recipe"] = DelayedScaling() + # TE 1.x also saved scale_inv, which is not needed with Recipe object + state.pop("scale_inv_fwd", None) + state.pop("scale_inv_bwd", None) + # Load extra items self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta["recipe"] = state["recipe"] @@ -884,6 +892,8 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: # assume FP8 execution. def init_fp8_metadata(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" + _original_recipe = self.fp8_meta.get("recipe", None) + self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() self.fp8 = FP8GlobalStateManager.is_fp8_enabled() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() @@ -922,6 +932,19 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + _current_recipe = self.fp8_meta["recipe"] + if _original_recipe is not None and not ( + issubclass(_current_recipe.__class__, _original_recipe.__class__) + or issubclass(_original_recipe.__class__, _current_recipe.__class__) + ): + warnings.warn( + f"Recipe type changed from {_original_recipe.__class__.__name__} " + f"to {_current_recipe.__class__.__name__}. " + "This may affect model behavior." + ) + # Clear cached workspaces as they were created with the old recipe/quantizer type + self._fp8_workspaces.clear() + @contextmanager def prepare_forward( self, @@ -946,6 +969,7 @@ def prepare_forward( self.set_activation_dtype(inp) self.init_fp8_metadata(num_gemms=num_gemms) + self._check_weight_tensor_recipe_correspondence() if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed(): assert self.fp8_meta["recipe"].reduce_amax, ( @@ -1346,6 +1370,43 @@ def _validate_name(self): ) self.name = f"Layer_{TEDebugState.get_layer_count()}" + def _check_weight_tensor_recipe_correspondence(self) -> None: + """ + Verify that the weight tensor types match their corresponding recipe type. + This is invoked in the forward(). + + This establishes a 1:1 correspondence between recipe types and tensor types: + - DelayedScaling → Float8Tensor + - Float8CurrentScaling → Float8Tensor + - MXFP8BlockScaling → MXFP8Tensor + - Float8BlockScaling → Float8BlockTensor + + Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()), + but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()). + """ + if not self.fp8 and not self.fp8_calibration: + return + if not hasattr(self, "weight_names") or not self.weight_names: + return + + recipe = self.fp8_meta["recipe"] + weight_tensors = [getattr(self, name) for name in self.weight_names] + for i, tensor in enumerate(weight_tensors): + if isinstance(tensor, QuantizedTensorBase): + quantizer = tensor._get_quantizer() + if quantizer is None: + continue + compatible_recipe_class = quantizer._get_compatible_recipe() + if compatible_recipe_class is None: + continue + if not isinstance(recipe, compatible_recipe_class): + raise RuntimeError( + f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe" + f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}." + " Please check the recipes assigned during fp8_model_init() and" + " fp8_autocast() calls." + ) + def _turn_off_unsupported_features_in_debug(self): if ( getattr(self, "ub_bulk_wgrad", False) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ce4137c660..4ab04da83f 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -4,13 +4,14 @@ """Tensor class with FP8 data quantized with NxN tiles""" from __future__ import annotations -from typing import Optional, Tuple, Iterable +from typing import Optional, Tuple, Iterable, Union import math import torch import transformer_engine_torch as tex - from transformer_engine_torch import DType as TE_DType + +from transformer_engine.common.recipe import Float8BlockScaling, Recipe from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from ..utils import devices_match, round_up_to_nearest_multiple @@ -229,6 +230,9 @@ def calibrate(self, tensor: torch.Tensor) -> None: # where state from an estimator influences distribution parameters. pass + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return Float8BlockScaling + class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 9c8fb6a1a2..1c3e575473 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -4,13 +4,14 @@ """Tensor class with FP8 data""" from __future__ import annotations -from typing import Optional, Tuple, Iterable +from typing import Optional, Tuple, Iterable, Union import warnings import torch import transformer_engine_torch as tex - from transformer_engine_torch import DType as TE_DType + +from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe from ..utils import canonicalize_process_group, devices_match from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc @@ -166,6 +167,9 @@ def create_tensor_from_data( quantizer=self, ) + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return DelayedScaling + class Float8CurrentScalingQuantizer(Quantizer): """Builder class for FP8 tensors with per-tensor current scaling @@ -328,6 +332,9 @@ def _canonicalized_amax_reduction_group(self) -> dist_group_type: """Get process group for amax reduction""" return canonicalize_process_group(self.amax_reduction_group) + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return Float8CurrentScaling + class Float8Tensor(Float8TensorBase, QuantizedTensor): """Experimental tensor class with FP8 data diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5b3532b301..c930cdbff5 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -6,12 +6,13 @@ from __future__ import annotations from collections.abc import Iterable import math -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch import transformer_engine_torch as tex - from transformer_engine_torch import DType as TE_DType + +from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe from ..constants import MXFP8_BLOCK_SCALING_SIZE from ..utils import devices_match, round_up_to_nearest_multiple @@ -135,6 +136,9 @@ def calibrate(self, tensor: torch.Tensor) -> None: # TODO(ksivamani): No calibration needed for mxfp8? pass + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return MXFP8BlockScaling + class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): """Experimental tensor class with FP8 data diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 155113738b..a3cbe02f16 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -13,6 +13,7 @@ from torch.utils._pytree import tree_map import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe class QuantizedTensorBase: @@ -238,6 +239,10 @@ def copy(self) -> Quantizer: """Create shallow copy""" return copy.copy(self) + @abc.abstractmethod + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + """Returns recipe class that is compatible with this quantizer""" + class _QuantizeFunc(torch.autograd.Function): """Cast to FP8 from other dtype"""