From 86e15026c491bcf8b651fdf21b8866e96f45142d Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Mon, 28 Apr 2025 11:02:06 +0000 Subject: [PATCH 01/18] Check tensor-recipe compatibility Signed-off-by: Evgeny Tsykunov --- tests/pytorch/test_recipe.py | 36 ++++++++++++++- .../common/gemm/cublaslt_gemm.cu | 3 +- transformer_engine/pytorch/module/base.py | 44 ++++++++++++++++++- 3 files changed, 80 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 6d127aa741..9fc7193eea 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -13,14 +13,19 @@ from transformer_engine.pytorch.fp8 import ( FP8GlobalStateManager, _amax_and_scale_update, - get_default_fp8_recipe, + fp8_model_init, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer import transformer_engine.pytorch.ops as te_ops +from transformer_engine.pytorch import Linear +from transformer_engine.pytorch.distributed import fp8_autocast +from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling import transformer_engine_torch as tex # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available() # FP8 per tensor delayed scaling @@ -367,3 +372,32 @@ def setup_fp8_meta(): ) torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale) + + @pytest.mark.parametrize( + "model_init_recipe", + [ + pytest.param( + MXFP8BlockScaling(), + marks=pytest.mark.skipif( + not mxfp8_available, + reason=reason_for_no_mxfp8 + ) + ), + pytest.param( + Float8BlockScaling(), + marks=pytest.mark.skipif( + not fp8_block_scaling_available, + reason=reason_for_no_fp8_block_scaling + ) + ) + ] + ) + def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_recipe): + with fp8_model_init(enabled=True, recipe=model_init_recipe): + linear = Linear(32, 32).cuda() + + x = torch.randn(32, 32, device="cuda") + with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling()): + with pytest.raises(RuntimeError) as excinfo: + _ = linear(x) + assert "Tensor type mismatch" in str(excinfo.value) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 8db26183bd..dbc3b296cc 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -87,7 +87,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla A.scaling_mode == B.scaling_mode || (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) || (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D), - "Inputs A and B to GEMM need to have compatible scaling modes!"); + "Inputs A and B to GEMM need to have compatible scaling modes, but got A.scaling_mode = " + + to_string(A.scaling_mode) + ", B.scaling_mode = " + to_string(B.scaling_mode)); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); GemmParam ret; diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 17848a36bf..6fc1f03175 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -36,7 +36,9 @@ ) from ..constants import dist_group_type from ..tensor import QuantizedTensor, Quantizer -from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer +from ..tensor.float8_tensor import Float8Tensor +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer, Float8BlockwiseQTensor +from ..tensor.mxfp8_tensor import MXFP8Tensor from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase @@ -799,6 +801,7 @@ def prepare_forward( self.set_activation_dtype(inp) self.init_fp8_metadata(num_gemms=num_gemms) + self._check_weight_tensor_recipe_correspondence() if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed(): assert self.fp8_meta["recipe"].reduce_amax, ( @@ -1179,6 +1182,45 @@ def _validate_name(self): ) self.name = f"Layer_{TEDebugState.get_layer_count()}" + def _check_weight_tensor_recipe_correspondence(self) -> None: + """ + Verify that the weight tensor types match their corresponding recipe type. + This is invoked in the forward(). + + This establishes a 1:1 correspondence between recipe types and tensor types: + - DelayedScaling → Float8Tensor + - Float8CurrentScaling → Float8Tensor + - MXFP8BlockScaling → MXFP8Tensor + - Float8BlockScaling → Float8BlockTensor + + Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()), + but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()). + """ + if not self.fp8 and not self.fp8_calibration: + return + if not hasattr(self, "weight_names") or not self.weight_names: + return + + recipe = self.fp8_meta["recipe"] + expected_tensor_class = None + if recipe.delayed() or recipe.float8_current_scaling(): + expected_tensor_class = Float8Tensor + elif recipe.mxfp8(): + expected_tensor_class = MXFP8Tensor + elif recipe.float8_block_scaling(): + expected_tensor_class = Float8BlockwiseQTensor + else: + raise RuntimeError(f"Unsupported recipe type: {recipe.__class__.__name__}") + + weight_tensors = [getattr(self, name) for name in self.weight_names] + for i, tensor in enumerate(weight_tensors): + if isinstance(tensor, QuantizedTensor) and not isinstance(tensor, expected_tensor_class): + raise RuntimeError( + f"Tensor type mismatch for '{self.weight_names[i]}': expected {expected_tensor_class.__name__} for " + f"recipe {recipe.__class__.__name__}, got {tensor.__class__.__name__}. " + f"Please check the recipes assigned during fp8_model_init() and fp8_autocast() calls." + ) + def _turn_off_unsupported_features_in_debug(self): if ( getattr(self, "ub_bulk_wgrad", False) From 98799f73aa6c2249feb48402177098d6de16d702 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Tue, 29 Apr 2025 08:11:33 +0000 Subject: [PATCH 02/18] Tensor class in recipe, checking for *Base Signed-off-by: Evgeny Tsykunov --- transformer_engine/common/recipe/__init__.py | 23 +++++++++++++++----- transformer_engine/pytorch/module/base.py | 20 +++++++---------- transformer_engine/pytorch/utils.py | 4 +++- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 80857e565c..20ad7aaf42 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -7,9 +7,14 @@ import warnings import os from enum import Enum -from typing import Literal, Optional, Union, Callable, NamedTuple +from typing import Literal, Optional, Type, Union, Callable, NamedTuple from pydantic.dataclasses import dataclass +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor +from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor + class _FormatHelper(NamedTuple): """ @@ -174,6 +179,7 @@ def scaling_factor_compute(amax: Tensor, reduce_amax: bool = True fp8_dpa: bool = False fp8_mha: bool = False + expected_tensor_class: Type[QuantizedTensor] = Float8Tensor def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -184,7 +190,8 @@ def __repr__(self) -> str: f"format={str(self.fp8_format).split('.')[1]}, " f"amax_history_len={self.amax_history_len}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}," + f"expected_tensor_class={self.expected_tensor_class.__name__}" ) @@ -239,6 +246,7 @@ class Float8CurrentScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False + expected_tensor_class: Type[QuantizedTensor] = Float8Tensor def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -253,7 +261,8 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"expected_tensor_class={self.expected_tensor_class.__name__}" ) @@ -286,12 +295,14 @@ class MXFP8BlockScaling(Recipe): fp8_format: Format = Format.E4M3 fp8_dpa: bool = False fp8_mha: bool = False + expected_tensor_class: Type[QuantizedTensor] = MXFP8Tensor def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." def __repr__(self) -> str: - return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," + return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," \ + f"expected_tensor_class={self.expected_tensor_class.__name__}" @dataclass() @@ -355,6 +366,7 @@ class Float8BlockScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False + expected_tensor_class: Type[QuantizedTensor] = Float8BlockwiseQTensor def __post_init__(self) -> None: assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" @@ -386,5 +398,6 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"expected_tensor_class={self.expected_tensor_class.__name__}" ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6fc1f03175..855cb4fa56 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -58,6 +58,12 @@ _NUM_MAX_UB_STREAMS = 3 _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None layers_atomic_ring_exchange = [] +_QUANTIZED_WEIGHT_TENSOR_TYPES = ( + QuantizedTensor, + Float8TensorBase, + MXFP8TensorBase, + Float8BlockwiseQTensorBase, +) def get_cublas_workspace_size_bytes() -> None: @@ -1202,21 +1208,11 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: return recipe = self.fp8_meta["recipe"] - expected_tensor_class = None - if recipe.delayed() or recipe.float8_current_scaling(): - expected_tensor_class = Float8Tensor - elif recipe.mxfp8(): - expected_tensor_class = MXFP8Tensor - elif recipe.float8_block_scaling(): - expected_tensor_class = Float8BlockwiseQTensor - else: - raise RuntimeError(f"Unsupported recipe type: {recipe.__class__.__name__}") - weight_tensors = [getattr(self, name) for name in self.weight_names] for i, tensor in enumerate(weight_tensors): - if isinstance(tensor, QuantizedTensor) and not isinstance(tensor, expected_tensor_class): + if isinstance(tensor, _QUANTIZED_WEIGHT_TENSOR_TYPES) and not isinstance(tensor, recipe.expected_tensor_class): raise RuntimeError( - f"Tensor type mismatch for '{self.weight_names[i]}': expected {expected_tensor_class.__name__} for " + f"Tensor type mismatch for '{self.weight_names[i]}': expected {recipe.expected_tensor_class.__name__} for " f"recipe {recipe.__class__.__name__}, got {tensor.__class__.__name__}. " f"Please check the recipes assigned during fp8_model_init() and fp8_autocast() calls." ) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index aa93961111..acdad75146 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -10,7 +10,6 @@ from typing import Any, Callable, List, Optional, Tuple import torch -import transformer_engine.pytorch.cpp_extensions as ext from ..debug.pytorch.debug_quantization import DebugQuantizedTensor from .tensor.quantized_tensor import QuantizedTensor @@ -262,6 +261,9 @@ def is_non_tn_fp8_gemm_supported() -> bool: @functools.lru_cache(maxsize=None) def get_cudnn_version() -> Tuple[int, int, int]: """Runtime cuDNN version (major, minor, patch)""" + # Import locally to avoid circular dependencies (cpp_extensions imports utils). + import transformer_engine.pytorch.cpp_extensions as ext # pylint: disable=import-outside-toplevel + encoded_version = ext.get_cudnn_version() major_version_magnitude = 1000 if encoded_version < 90000 else 10000 major, encoded_version = divmod(encoded_version, major_version_magnitude) From 67bb47fa31db42b455a4553526abe0d65e7e6b9b Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Tue, 29 Apr 2025 09:05:36 +0000 Subject: [PATCH 03/18] Extend recipe __repr__ with recipe_type Signed-off-by: Evgeny Tsykunov --- transformer_engine/common/recipe/__init__.py | 13 ++++++++++--- transformer_engine/pytorch/module/base.py | 4 +--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 20ad7aaf42..2f1aa1048f 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -186,11 +186,12 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return ( + f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " f"amax_history_len={self.amax_history_len}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}," + f"fp8_mha={self.fp8_mha}, " f"expected_tensor_class={self.expected_tensor_class.__name__}" ) @@ -253,6 +254,7 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return ( + f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " @@ -301,8 +303,12 @@ def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." def __repr__(self) -> str: - return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," \ - f"expected_tensor_class={self.expected_tensor_class.__name__}" + return ( + f"recipe_type={self.__class__.__name__}, " + f"margin={self.margin}, " + f"format={str(self.fp8_format).split('.')[1]}, " + f"expected_tensor_class={self.expected_tensor_class.__name__}" + ) @dataclass() @@ -387,6 +393,7 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return ( + f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 855cb4fa56..49ad2681cd 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -36,9 +36,7 @@ ) from ..constants import dist_group_type from ..tensor import QuantizedTensor, Quantizer -from ..tensor.float8_tensor import Float8Tensor -from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer, Float8BlockwiseQTensor -from ..tensor.mxfp8_tensor import MXFP8Tensor +from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase From e64188b86927c500420e34af4a433db7acbe44f3 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Tue, 29 Apr 2025 09:54:06 +0000 Subject: [PATCH 04/18] Warn about recipe change Signed-off-by: Evgeny Tsykunov --- transformer_engine/pytorch/module/base.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 49ad2681cd..1293b81198 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -743,6 +743,8 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: # assume FP8 execution. def init_fp8_metadata(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" + _original_recipe = self.fp8_meta.get("recipe", None) + self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() self.fp8 = FP8GlobalStateManager.is_fp8_enabled() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() @@ -781,6 +783,14 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + _current_recipe = self.fp8_meta["recipe"] + if _original_recipe is not None and _original_recipe.__class__ != _current_recipe.__class__: + warnings.warn( + f"Recipe type changed from {_original_recipe.__class__.__name__} " + f"to {_current_recipe.__class__.__name__}. " + f"This may affect model behavior." + ) + @contextmanager def prepare_forward( self, From eac9fe1e798c414bebed6e117bc5a1f0e00a8884 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 29 Apr 2025 10:02:31 +0000 Subject: [PATCH 05/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_recipe.py | 18 ++++++++---------- .../common/gemm/cublaslt_gemm.cu | 2 +- transformer_engine/pytorch/module/base.py | 13 ++++++++----- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 9fc7193eea..edcaee276d 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -25,7 +25,9 @@ # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) # FP8 per tensor delayed scaling @@ -378,19 +380,15 @@ def setup_fp8_meta(): [ pytest.param( MXFP8BlockScaling(), - marks=pytest.mark.skipif( - not mxfp8_available, - reason=reason_for_no_mxfp8 - ) + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), ), pytest.param( Float8BlockScaling(), marks=pytest.mark.skipif( - not fp8_block_scaling_available, - reason=reason_for_no_fp8_block_scaling - ) - ) - ] + not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling + ), + ), + ], ) def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_recipe): with fp8_model_init(enabled=True, recipe=model_init_recipe): diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index dbc3b296cc..64688e2077 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -88,7 +88,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) || (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D), "Inputs A and B to GEMM need to have compatible scaling modes, but got A.scaling_mode = " + - to_string(A.scaling_mode) + ", B.scaling_mode = " + to_string(B.scaling_mode)); + to_string(A.scaling_mode) + ", B.scaling_mode = " + to_string(B.scaling_mode)); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); GemmParam ret; diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 1293b81198..8eb2f67e46 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -788,7 +788,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: warnings.warn( f"Recipe type changed from {_original_recipe.__class__.__name__} " f"to {_current_recipe.__class__.__name__}. " - f"This may affect model behavior." + "This may affect model behavior." ) @contextmanager @@ -1218,11 +1218,14 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: recipe = self.fp8_meta["recipe"] weight_tensors = [getattr(self, name) for name in self.weight_names] for i, tensor in enumerate(weight_tensors): - if isinstance(tensor, _QUANTIZED_WEIGHT_TENSOR_TYPES) and not isinstance(tensor, recipe.expected_tensor_class): + if isinstance(tensor, _QUANTIZED_WEIGHT_TENSOR_TYPES) and not isinstance( + tensor, recipe.expected_tensor_class + ): raise RuntimeError( - f"Tensor type mismatch for '{self.weight_names[i]}': expected {recipe.expected_tensor_class.__name__} for " - f"recipe {recipe.__class__.__name__}, got {tensor.__class__.__name__}. " - f"Please check the recipes assigned during fp8_model_init() and fp8_autocast() calls." + f"Tensor type mismatch for '{self.weight_names[i]}': expected" + f" {recipe.expected_tensor_class.__name__} for recipe" + f" {recipe.__class__.__name__}, got {tensor.__class__.__name__}. Please check" + " the recipes assigned during fp8_model_init() and fp8_autocast() calls." ) def _turn_off_unsupported_features_in_debug(self): From 674a3c9853b81380510d897a02718a0510b50720 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Tue, 29 Apr 2025 12:22:18 +0000 Subject: [PATCH 06/18] Enable dynamic recipe change: clear fp8 workspace Signed-off-by: Evgeny Tsykunov --- transformer_engine/pytorch/module/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 8eb2f67e46..b325a3fb1a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -790,6 +790,8 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: f"to {_current_recipe.__class__.__name__}. " "This may affect model behavior." ) + # Clear cached workspaces as they were created with the old recipe/quantizer type + self._fp8_workspaces.clear() @contextmanager def prepare_forward( From 5023a32a0f704f63785fdfe6dff7966a2cfe7c9d Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Tue, 29 Apr 2025 13:13:57 +0000 Subject: [PATCH 07/18] TE 1.x checkpoint compatibility Signed-off-by: Evgeny Tsykunov --- transformer_engine/pytorch/module/base.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b325a3fb1a..d0b638f6c7 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -40,7 +40,7 @@ from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase -from ...common.recipe import Recipe +from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor @@ -670,6 +670,14 @@ def set_extra_state(self, state: torch.Tensor) -> None: if state is None: return + # TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing + if "recipe" not in state: + # TE 1.x only supported delayed scaling, which was the default recipe + state["recipe"] = DelayedScaling() + # TE 1.x also saved scale_inv, which is not needed with Recipe object + state.pop("scale_inv_fwd", None) + state.pop("scale_inv_bwd", None) + # Load extra items self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta["recipe"] = state["recipe"] From 243b6f461cd894f1167fe8dd728bb9644bf4c740 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Tue, 29 Apr 2025 13:36:38 +0000 Subject: [PATCH 08/18] Disable warning for recipe wrappers Signed-off-by: Evgeny Tsykunov --- transformer_engine/pytorch/module/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d0b638f6c7..3df228b4ff 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -792,7 +792,10 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() _current_recipe = self.fp8_meta["recipe"] - if _original_recipe is not None and _original_recipe.__class__ != _current_recipe.__class__: + if _original_recipe is not None and not ( + issubclass(_current_recipe.__class__, _original_recipe.__class__) + or issubclass(_original_recipe.__class__, _current_recipe.__class__) + ): warnings.warn( f"Recipe type changed from {_original_recipe.__class__.__name__} " f"to {_current_recipe.__class__.__name__}. " From 400e20607d3496251db9cc7b02afffb7fc66b983 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Wed, 30 Apr 2025 09:23:56 +0000 Subject: [PATCH 09/18] Test recipe change Signed-off-by: Evgeny Tsykunov --- tests/pytorch/test_recipe.py | 69 ++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index edcaee276d..18c1abbc67 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -6,9 +6,12 @@ import pytest import torch +import warnings import transformer_engine.common.recipe import transformer_engine.pytorch as te +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import ( FP8GlobalStateManager, @@ -399,3 +402,69 @@ def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_reci with pytest.raises(RuntimeError) as excinfo: _ = linear(x) assert "Tensor type mismatch" in str(excinfo.value) + + @pytest.mark.parametrize( + "target_recipe_class, expected_quantizer_type, available_flag, reason", + [ + pytest.param( + MXFP8BlockScaling, + MXFP8Quantizer, + mxfp8_available, + reason_for_no_mxfp8, + id="DelayedScaling->MXFP8BlockScaling", + ), + pytest.param( + Float8BlockScaling, + Float8BlockQuantizer, + fp8_block_scaling_available, + reason_for_no_fp8_block_scaling, + id="DelayedScaling->Float8BlockScaling", + ), + ], + ) + def test_dynamic_recipe_update(self, target_recipe_class, expected_quantizer_type, available_flag, reason): + if not available_flag: + pytest.skip(reason) + + in_features = 32 + out_features = 32 + batch_size = 32 + linear = Linear(in_features, out_features).cuda() + initial_recipe = DelayedScaling() + + # Run initial iterations with DelayedScaling + for _ in range(3): + x = torch.randn(batch_size, in_features, device="cuda") + with fp8_autocast(enabled=True, fp8_recipe=initial_recipe): + y = linear(x) + loss = y.mean() + loss.backward() + + for quantizer in linear.quantizers["scaling_fwd"]: + assert isinstance(quantizer, Float8Quantizer) + + # Change recipe + target_recipe = target_recipe_class() + + # Run subsequent iterations with the target recipe + for i in range(3): + x = torch.randn(batch_size, in_features, device="cuda") + if i == 0: + # Expect a warning on the first iteration with the new recipe + with pytest.warns(UserWarning, match="Recipe type changed"): + with fp8_autocast(enabled=True, fp8_recipe=target_recipe): + y = linear(x) + for quantizer in linear.quantizers["scaling_fwd"]: + assert isinstance(quantizer, expected_quantizer_type) + else: + # No warning expected on subsequent iterations + with warnings.catch_warnings(): + warnings.simplefilter("error") # Raise error if unexpected warning occurs + with fp8_autocast(enabled=True, fp8_recipe=target_recipe): + y = linear(x) + loss = y.mean() + loss.backward() + + # Final check + for quantizer in linear.quantizers["scaling_fwd"]: + assert isinstance(quantizer, expected_quantizer_type) From 01a48da2e16a25123286ee88c7677f6ae2e0295c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 30 Apr 2025 09:24:55 +0000 Subject: [PATCH 10/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_recipe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 18c1abbc67..4f46691546 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -422,7 +422,9 @@ def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_reci ), ], ) - def test_dynamic_recipe_update(self, target_recipe_class, expected_quantizer_type, available_flag, reason): + def test_dynamic_recipe_update( + self, target_recipe_class, expected_quantizer_type, available_flag, reason + ): if not available_flag: pytest.skip(reason) @@ -459,7 +461,7 @@ def test_dynamic_recipe_update(self, target_recipe_class, expected_quantizer_typ else: # No warning expected on subsequent iterations with warnings.catch_warnings(): - warnings.simplefilter("error") # Raise error if unexpected warning occurs + warnings.simplefilter("error") # Raise error if unexpected warning occurs with fp8_autocast(enabled=True, fp8_recipe=target_recipe): y = linear(x) loss = y.mean() From 7c2f5eb40f88d15e62baa76e14e1d6d29fb41f0c Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Fri, 9 May 2025 13:03:13 +0000 Subject: [PATCH 11/18] Use QuantizedTensorBase Signed-off-by: Evgeny Tsykunov --- transformer_engine/pytorch/module/base.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index feb5ea8dc4..2c1fdc5a9a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -60,12 +60,6 @@ _NUM_MAX_UB_STREAMS = 3 _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None layers_atomic_ring_exchange = [] -_QUANTIZED_WEIGHT_TENSOR_TYPES = ( - QuantizedTensor, - Float8TensorBase, - MXFP8TensorBase, - Float8BlockwiseQTensorBase, -) def get_cublas_workspace_size_bytes() -> None: @@ -1393,7 +1387,7 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: recipe = self.fp8_meta["recipe"] weight_tensors = [getattr(self, name) for name in self.weight_names] for i, tensor in enumerate(weight_tensors): - if isinstance(tensor, _QUANTIZED_WEIGHT_TENSOR_TYPES) and not isinstance( + if isinstance(tensor, QuantizedTensorBase) and not isinstance( tensor, recipe.expected_tensor_class ): raise RuntimeError( From 5cdda54a6ac3b39003175a47613b3029e82c6df9 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Fri, 9 May 2025 13:23:18 +0000 Subject: [PATCH 12/18] Fix circular import Signed-off-by: Evgeny Tsykunov --- transformer_engine/common/recipe/__init__.py | 57 +++++++++++++++----- transformer_engine/pytorch/module/base.py | 4 +- 2 files changed, 46 insertions(+), 15 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 2f1aa1048f..39a8017b8b 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -10,11 +10,6 @@ from typing import Literal, Optional, Type, Union, Callable, NamedTuple from pydantic.dataclasses import dataclass -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor - class _FormatHelper(NamedTuple): """ @@ -179,7 +174,16 @@ def scaling_factor_compute(amax: Tensor, reduce_amax: bool = True fp8_dpa: bool = False fp8_mha: bool = False - expected_tensor_class: Type[QuantizedTensor] = Float8Tensor + + @staticmethod + def get_expected_tensor_class(): + # TODO(ksivamani): Find better design for this, adding here to avoid circular import. + # It should be a class attribute, but that will cause circular import. + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor + from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor + + expected_tensor_class: Type[QuantizedTensor] = Float8Tensor + return expected_tensor_class def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -192,7 +196,7 @@ def __repr__(self) -> str: f"amax_history_len={self.amax_history_len}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"expected_tensor_class={self.expected_tensor_class.__name__}" + f"expected_tensor_class={self.get_expected_tensor_class().__name__}" ) @@ -247,7 +251,16 @@ class Float8CurrentScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False - expected_tensor_class: Type[QuantizedTensor] = Float8Tensor + + @staticmethod + def get_expected_tensor_class(): + # TODO(ksivamani): Find better design for this, adding here to avoid circular import. + # It should be a class attribute, but that will cause circular import. + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor + from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor + + expected_tensor_class: Type[QuantizedTensor] = Float8Tensor + return expected_tensor_class def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -264,7 +277,7 @@ def __repr__(self) -> str: f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"expected_tensor_class={self.expected_tensor_class.__name__}" + f"expected_tensor_class={self.get_expected_tensor_class().__name__}" ) @@ -297,7 +310,16 @@ class MXFP8BlockScaling(Recipe): fp8_format: Format = Format.E4M3 fp8_dpa: bool = False fp8_mha: bool = False - expected_tensor_class: Type[QuantizedTensor] = MXFP8Tensor + + @staticmethod + def get_expected_tensor_class(): + # TODO(ksivamani): Find better design for this, adding here to avoid circular import. + # It should be a class attribute, but that will cause circular import. + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor + from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor + + expected_tensor_class: Type[QuantizedTensor] = MXFP8Tensor + return expected_tensor_class def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -307,7 +329,7 @@ def __repr__(self) -> str: f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " - f"expected_tensor_class={self.expected_tensor_class.__name__}" + f"expected_tensor_class={self.get_expected_tensor_class().__name__}" ) @@ -372,7 +394,16 @@ class Float8BlockScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False - expected_tensor_class: Type[QuantizedTensor] = Float8BlockwiseQTensor + + @staticmethod + def get_expected_tensor_class(): + # TODO(ksivamani): Find better design for this, adding here to avoid circular import. + # It should be a class attribute, but that will cause circular import. + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor + from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor + + expected_tensor_class: Type[QuantizedTensor] = Float8BlockwiseQTensor + return expected_tensor_class def __post_init__(self) -> None: assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" @@ -406,5 +437,5 @@ def __repr__(self) -> str: f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"expected_tensor_class={self.expected_tensor_class.__name__}" + f"expected_tensor_class={self.get_expected_tensor_class().__name__}" ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 2c1fdc5a9a..46d4694593 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1388,11 +1388,11 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: weight_tensors = [getattr(self, name) for name in self.weight_names] for i, tensor in enumerate(weight_tensors): if isinstance(tensor, QuantizedTensorBase) and not isinstance( - tensor, recipe.expected_tensor_class + tensor, recipe.get_expected_tensor_class() ): raise RuntimeError( f"Tensor type mismatch for '{self.weight_names[i]}': expected" - f" {recipe.expected_tensor_class.__name__} for recipe" + f" {recipe.get_expected_tensor_class().__name__} for recipe" f" {recipe.__class__.__name__}, got {tensor.__class__.__name__}. Please check" " the recipes assigned during fp8_model_init() and fp8_autocast() calls." ) From 463ad23b4085cad72fdde64b348f084fdcb34535 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Fri, 9 May 2025 13:36:30 +0000 Subject: [PATCH 13/18] Revert previous circular import fix Signed-off-by: Evgeny Tsykunov --- transformer_engine/common/recipe/__init__.py | 57 +++++--------------- transformer_engine/pytorch/module/base.py | 4 +- transformer_engine/pytorch/utils.py | 1 - 3 files changed, 15 insertions(+), 47 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 39a8017b8b..2f1aa1048f 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -10,6 +10,11 @@ from typing import Literal, Optional, Type, Union, Callable, NamedTuple from pydantic.dataclasses import dataclass +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor +from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor + class _FormatHelper(NamedTuple): """ @@ -174,16 +179,7 @@ def scaling_factor_compute(amax: Tensor, reduce_amax: bool = True fp8_dpa: bool = False fp8_mha: bool = False - - @staticmethod - def get_expected_tensor_class(): - # TODO(ksivamani): Find better design for this, adding here to avoid circular import. - # It should be a class attribute, but that will cause circular import. - from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor - from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor - - expected_tensor_class: Type[QuantizedTensor] = Float8Tensor - return expected_tensor_class + expected_tensor_class: Type[QuantizedTensor] = Float8Tensor def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -196,7 +192,7 @@ def __repr__(self) -> str: f"amax_history_len={self.amax_history_len}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"expected_tensor_class={self.get_expected_tensor_class().__name__}" + f"expected_tensor_class={self.expected_tensor_class.__name__}" ) @@ -251,16 +247,7 @@ class Float8CurrentScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False - - @staticmethod - def get_expected_tensor_class(): - # TODO(ksivamani): Find better design for this, adding here to avoid circular import. - # It should be a class attribute, but that will cause circular import. - from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor - from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor - - expected_tensor_class: Type[QuantizedTensor] = Float8Tensor - return expected_tensor_class + expected_tensor_class: Type[QuantizedTensor] = Float8Tensor def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -277,7 +264,7 @@ def __repr__(self) -> str: f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"expected_tensor_class={self.get_expected_tensor_class().__name__}" + f"expected_tensor_class={self.expected_tensor_class.__name__}" ) @@ -310,16 +297,7 @@ class MXFP8BlockScaling(Recipe): fp8_format: Format = Format.E4M3 fp8_dpa: bool = False fp8_mha: bool = False - - @staticmethod - def get_expected_tensor_class(): - # TODO(ksivamani): Find better design for this, adding here to avoid circular import. - # It should be a class attribute, but that will cause circular import. - from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor - from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor - - expected_tensor_class: Type[QuantizedTensor] = MXFP8Tensor - return expected_tensor_class + expected_tensor_class: Type[QuantizedTensor] = MXFP8Tensor def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -329,7 +307,7 @@ def __repr__(self) -> str: f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " - f"expected_tensor_class={self.get_expected_tensor_class().__name__}" + f"expected_tensor_class={self.expected_tensor_class.__name__}" ) @@ -394,16 +372,7 @@ class Float8BlockScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False - - @staticmethod - def get_expected_tensor_class(): - # TODO(ksivamani): Find better design for this, adding here to avoid circular import. - # It should be a class attribute, but that will cause circular import. - from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor - from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor - - expected_tensor_class: Type[QuantizedTensor] = Float8BlockwiseQTensor - return expected_tensor_class + expected_tensor_class: Type[QuantizedTensor] = Float8BlockwiseQTensor def __post_init__(self) -> None: assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" @@ -437,5 +406,5 @@ def __repr__(self) -> str: f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"expected_tensor_class={self.get_expected_tensor_class().__name__}" + f"expected_tensor_class={self.expected_tensor_class.__name__}" ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 46d4694593..2c1fdc5a9a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1388,11 +1388,11 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: weight_tensors = [getattr(self, name) for name in self.weight_names] for i, tensor in enumerate(weight_tensors): if isinstance(tensor, QuantizedTensorBase) and not isinstance( - tensor, recipe.get_expected_tensor_class() + tensor, recipe.expected_tensor_class ): raise RuntimeError( f"Tensor type mismatch for '{self.weight_names[i]}': expected" - f" {recipe.get_expected_tensor_class().__name__} for recipe" + f" {recipe.expected_tensor_class.__name__} for recipe" f" {recipe.__class__.__name__}, got {tensor.__class__.__name__}. Please check" " the recipes assigned during fp8_model_init() and fp8_autocast() calls." ) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 8f5162d5a0..742d76dc12 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -11,7 +11,6 @@ import numpy as np import torch -import transformer_engine.pytorch.cpp_extensions as ext from . import torch_version from ..debug.pytorch.debug_quantization import DebugQuantizedTensor From d6ea98167c39da0185aa89d2a98b080a1a26dfdb Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Tue, 13 May 2025 08:09:14 +0000 Subject: [PATCH 14/18] Fix pytorch imports in common Signed-off-by: Evgeny Tsykunov --- transformer_engine/common/recipe/__init__.py | 45 ++++++++++++++------ 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 2f1aa1048f..e59ad2baec 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -10,11 +10,6 @@ from typing import Literal, Optional, Type, Union, Callable, NamedTuple from pydantic.dataclasses import dataclass -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor - class _FormatHelper(NamedTuple): """ @@ -179,12 +174,18 @@ def scaling_factor_compute(amax: Tensor, reduce_amax: bool = True fp8_dpa: bool = False fp8_mha: bool = False - expected_tensor_class: Type[QuantizedTensor] = Float8Tensor + expected_tensor_class: Optional[Type] = None def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + try: + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor + self.expected_tensor_class = Float8Tensor + except ImportError: + pass def __repr__(self) -> str: + expected_tensor_class_name = self.expected_tensor_class.__name__ if self.expected_tensor_class else "None" return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " @@ -192,7 +193,7 @@ def __repr__(self) -> str: f"amax_history_len={self.amax_history_len}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"expected_tensor_class={self.expected_tensor_class.__name__}" + f"expected_tensor_class={expected_tensor_class_name}" ) @@ -247,12 +248,18 @@ class Float8CurrentScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False - expected_tensor_class: Type[QuantizedTensor] = Float8Tensor + expected_tensor_class: Optional[Type] = None def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + try: + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor + self.expected_tensor_class = Float8Tensor + except ImportError: + pass def __repr__(self) -> str: + expected_tensor_class_name = self.expected_tensor_class.__name__ if self.expected_tensor_class else "None" return ( f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " @@ -264,7 +271,7 @@ def __repr__(self) -> str: f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"expected_tensor_class={self.expected_tensor_class.__name__}" + f"expected_tensor_class={expected_tensor_class_name}" ) @@ -297,17 +304,23 @@ class MXFP8BlockScaling(Recipe): fp8_format: Format = Format.E4M3 fp8_dpa: bool = False fp8_mha: bool = False - expected_tensor_class: Type[QuantizedTensor] = MXFP8Tensor + expected_tensor_class: Optional[Type] = None def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + try: + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor + self.expected_tensor_class = MXFP8Tensor + except ImportError: + pass def __repr__(self) -> str: + expected_tensor_class_name = self.expected_tensor_class.__name__ if self.expected_tensor_class else "None" return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " - f"expected_tensor_class={self.expected_tensor_class.__name__}" + f"expected_tensor_class={expected_tensor_class_name}" ) @@ -372,7 +385,7 @@ class Float8BlockScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False - expected_tensor_class: Type[QuantizedTensor] = Float8BlockwiseQTensor + expected_tensor_class: Optional[Type] = None def __post_init__(self) -> None: assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" @@ -390,8 +403,14 @@ def __post_init__(self) -> None: assert self.fp8_gemm_fprop.use_split_accumulator, "Split accumulator required for fprop." assert self.fp8_gemm_dgrad.use_split_accumulator, "Split accumulator required for dgrad." assert self.fp8_gemm_wgrad.use_split_accumulator, "Split accumulator required for wgrad." + try: + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor + self.expected_tensor_class = Float8BlockwiseQTensor + except ImportError: + pass def __repr__(self) -> str: + expected_tensor_class_name = self.expected_tensor_class.__name__ if self.expected_tensor_class else "None" return ( f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " @@ -406,5 +425,5 @@ def __repr__(self) -> str: f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"expected_tensor_class={self.expected_tensor_class.__name__}" + f"expected_tensor_class={expected_tensor_class_name}" ) From c5c27a6f69feb0fd2c0c943825e125c31ece0664 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 May 2025 08:14:57 +0000 Subject: [PATCH 15/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/recipe/__init__.py | 24 ++++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index e59ad2baec..0b4611641b 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -180,12 +180,15 @@ def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." try: from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor + self.expected_tensor_class = Float8Tensor except ImportError: pass def __repr__(self) -> str: - expected_tensor_class_name = self.expected_tensor_class.__name__ if self.expected_tensor_class else "None" + expected_tensor_class_name = ( + self.expected_tensor_class.__name__ if self.expected_tensor_class else "None" + ) return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " @@ -254,12 +257,15 @@ def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." try: from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor + self.expected_tensor_class = Float8Tensor except ImportError: pass def __repr__(self) -> str: - expected_tensor_class_name = self.expected_tensor_class.__name__ if self.expected_tensor_class else "None" + expected_tensor_class_name = ( + self.expected_tensor_class.__name__ if self.expected_tensor_class else "None" + ) return ( f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " @@ -310,12 +316,15 @@ def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." try: from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor + self.expected_tensor_class = MXFP8Tensor except ImportError: pass def __repr__(self) -> str: - expected_tensor_class_name = self.expected_tensor_class.__name__ if self.expected_tensor_class else "None" + expected_tensor_class_name = ( + self.expected_tensor_class.__name__ if self.expected_tensor_class else "None" + ) return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " @@ -404,13 +413,18 @@ def __post_init__(self) -> None: assert self.fp8_gemm_dgrad.use_split_accumulator, "Split accumulator required for dgrad." assert self.fp8_gemm_wgrad.use_split_accumulator, "Split accumulator required for wgrad." try: - from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockwiseQTensor, + ) + self.expected_tensor_class = Float8BlockwiseQTensor except ImportError: pass def __repr__(self) -> str: - expected_tensor_class_name = self.expected_tensor_class.__name__ if self.expected_tensor_class else "None" + expected_tensor_class_name = ( + self.expected_tensor_class.__name__ if self.expected_tensor_class else "None" + ) return ( f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " From a2910d7a34a668b9fd42ceb5bbac90080a635083 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Mon, 19 May 2025 17:33:56 +0000 Subject: [PATCH 16/18] Let quantizer know about the recipe Signed-off-by: Evgeny Tsykunov --- tests/pytorch/test_recipe.py | 2 +- transformer_engine/common/recipe/__init__.py | 56 ++----------------- .../debug/pytorch/debug_quantization.py | 6 +- transformer_engine/pytorch/module/base.py | 22 +++++--- .../pytorch/tensor/float8_blockwise_tensor.py | 6 +- .../pytorch/tensor/float8_tensor.py | 9 ++- .../pytorch/tensor/mxfp8_tensor.py | 6 +- .../pytorch/tensor/quantized_tensor.py | 5 ++ transformer_engine/pytorch/utils.py | 4 +- 9 files changed, 48 insertions(+), 68 deletions(-) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 4f46691546..912dc67bfc 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -401,7 +401,7 @@ def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_reci with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling()): with pytest.raises(RuntimeError) as excinfo: _ = linear(x) - assert "Tensor type mismatch" in str(excinfo.value) + assert "Recipe mismatch for " in str(excinfo.value) @pytest.mark.parametrize( "target_recipe_class, expected_quantizer_type, available_flag, reason", diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 0b4611641b..f1ecb33272 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -7,7 +7,7 @@ import warnings import os from enum import Enum -from typing import Literal, Optional, Type, Union, Callable, NamedTuple +from typing import Literal, Optional, Union, Callable, NamedTuple from pydantic.dataclasses import dataclass @@ -174,29 +174,18 @@ def scaling_factor_compute(amax: Tensor, reduce_amax: bool = True fp8_dpa: bool = False fp8_mha: bool = False - expected_tensor_class: Optional[Type] = None def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - try: - from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor - - self.expected_tensor_class = Float8Tensor - except ImportError: - pass def __repr__(self) -> str: - expected_tensor_class_name = ( - self.expected_tensor_class.__name__ if self.expected_tensor_class else "None" - ) return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " f"amax_history_len={self.amax_history_len}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}, " - f"expected_tensor_class={expected_tensor_class_name}" + f"fp8_mha={self.fp8_mha}" ) @@ -251,21 +240,11 @@ class Float8CurrentScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False - expected_tensor_class: Optional[Type] = None def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - try: - from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor - - self.expected_tensor_class = Float8Tensor - except ImportError: - pass def __repr__(self) -> str: - expected_tensor_class_name = ( - self.expected_tensor_class.__name__ if self.expected_tensor_class else "None" - ) return ( f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " @@ -276,8 +255,7 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}, " - f"expected_tensor_class={expected_tensor_class_name}" + f"fp8_mha={self.fp8_mha}" ) @@ -310,26 +288,15 @@ class MXFP8BlockScaling(Recipe): fp8_format: Format = Format.E4M3 fp8_dpa: bool = False fp8_mha: bool = False - expected_tensor_class: Optional[Type] = None def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - try: - from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor - - self.expected_tensor_class = MXFP8Tensor - except ImportError: - pass def __repr__(self) -> str: - expected_tensor_class_name = ( - self.expected_tensor_class.__name__ if self.expected_tensor_class else "None" - ) return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " - f"format={str(self.fp8_format).split('.')[1]}, " - f"expected_tensor_class={expected_tensor_class_name}" + f"format={str(self.fp8_format).split('.')[1]}" ) @@ -394,7 +361,6 @@ class Float8BlockScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False - expected_tensor_class: Optional[Type] = None def __post_init__(self) -> None: assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" @@ -412,19 +378,8 @@ def __post_init__(self) -> None: assert self.fp8_gemm_fprop.use_split_accumulator, "Split accumulator required for fprop." assert self.fp8_gemm_dgrad.use_split_accumulator, "Split accumulator required for dgrad." assert self.fp8_gemm_wgrad.use_split_accumulator, "Split accumulator required for wgrad." - try: - from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( - Float8BlockwiseQTensor, - ) - - self.expected_tensor_class = Float8BlockwiseQTensor - except ImportError: - pass def __repr__(self) -> str: - expected_tensor_class_name = ( - self.expected_tensor_class.__name__ if self.expected_tensor_class else "None" - ) return ( f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " @@ -438,6 +393,5 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}, " - f"expected_tensor_class={expected_tensor_class_name}" + f"fp8_mha={self.fp8_mha}" ) diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 4a7a156a0a..118a9dbe1f 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -14,7 +14,7 @@ import transformer_engine_torch as tex - +from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch.tensor.quantized_tensor import ( QuantizedTensor, Quantizer, @@ -455,6 +455,10 @@ def any_feature_enabled(self) -> bool: return True return False + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + """Probably not needed for debug quantizer""" + return None + class DebugQuantizedTensor: """ diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 2c1fdc5a9a..edad35a680 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1387,15 +1387,19 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: recipe = self.fp8_meta["recipe"] weight_tensors = [getattr(self, name) for name in self.weight_names] for i, tensor in enumerate(weight_tensors): - if isinstance(tensor, QuantizedTensorBase) and not isinstance( - tensor, recipe.expected_tensor_class - ): - raise RuntimeError( - f"Tensor type mismatch for '{self.weight_names[i]}': expected" - f" {recipe.expected_tensor_class.__name__} for recipe" - f" {recipe.__class__.__name__}, got {tensor.__class__.__name__}. Please check" - " the recipes assigned during fp8_model_init() and fp8_autocast() calls." - ) + if isinstance(tensor, QuantizedTensorBase): + quantizer = tensor._get_quantizer() + if quantizer is None: + continue + compatible_recipe_class = quantizer._get_compatible_recipe() + if compatible_recipe_class is None: + continue + if not isinstance(recipe, compatible_recipe_class): + raise RuntimeError( + f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe " + f"{compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}. Please check the recipes " + f"assigned during fp8_model_init() and fp8_autocast() calls." + ) def _turn_off_unsupported_features_in_debug(self): if ( diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ce4137c660..2d1231b765 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -4,12 +4,13 @@ """Tensor class with FP8 data quantized with NxN tiles""" from __future__ import annotations -from typing import Optional, Tuple, Iterable +from typing import Optional, Tuple, Iterable, Union import math import torch import transformer_engine_torch as tex +from transformer_engine.common.recipe import Float8BlockScaling, Recipe from transformer_engine_torch import DType as TE_DType from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc @@ -229,6 +230,9 @@ def calibrate(self, tensor: torch.Tensor) -> None: # where state from an estimator influences distribution parameters. pass + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return Float8BlockScaling + class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 9c8fb6a1a2..f572c4378b 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -4,12 +4,13 @@ """Tensor class with FP8 data""" from __future__ import annotations -from typing import Optional, Tuple, Iterable +from typing import Optional, Tuple, Iterable, Union import warnings import torch import transformer_engine_torch as tex +from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe from transformer_engine_torch import DType as TE_DType from ..utils import canonicalize_process_group, devices_match from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func @@ -166,6 +167,9 @@ def create_tensor_from_data( quantizer=self, ) + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return DelayedScaling + class Float8CurrentScalingQuantizer(Quantizer): """Builder class for FP8 tensors with per-tensor current scaling @@ -328,6 +332,9 @@ def _canonicalized_amax_reduction_group(self) -> dist_group_type: """Get process group for amax reduction""" return canonicalize_process_group(self.amax_reduction_group) + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return Float8CurrentScaling + class Float8Tensor(Float8TensorBase, QuantizedTensor): """Experimental tensor class with FP8 data diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5b3532b301..a7336435ff 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -6,11 +6,12 @@ from __future__ import annotations from collections.abc import Iterable import math -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch import transformer_engine_torch as tex +from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe from transformer_engine_torch import DType as TE_DType from ..constants import MXFP8_BLOCK_SCALING_SIZE from ..utils import devices_match, round_up_to_nearest_multiple @@ -135,6 +136,9 @@ def calibrate(self, tensor: torch.Tensor) -> None: # TODO(ksivamani): No calibration needed for mxfp8? pass + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return MXFP8BlockScaling + class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): """Experimental tensor class with FP8 data diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 155113738b..a3cbe02f16 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -13,6 +13,7 @@ from torch.utils._pytree import tree_map import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe class QuantizedTensorBase: @@ -238,6 +239,10 @@ def copy(self) -> Quantizer: """Create shallow copy""" return copy.copy(self) + @abc.abstractmethod + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + """Returns recipe class that is compatible with this quantizer""" + class _QuantizeFunc(torch.autograd.Function): """Cast to FP8 from other dtype""" diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 742d76dc12..3abebdf1e4 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -11,6 +11,7 @@ import numpy as np import torch +import transformer_engine.pytorch.cpp_extensions as ext from . import torch_version from ..debug.pytorch.debug_quantization import DebugQuantizedTensor @@ -450,9 +451,6 @@ def is_non_tn_fp8_gemm_supported() -> bool: @functools.lru_cache(maxsize=None) def get_cudnn_version() -> Tuple[int, int, int]: """Runtime cuDNN version (major, minor, patch)""" - # Import locally to avoid circular dependencies (cpp_extensions imports utils). - import transformer_engine.pytorch.cpp_extensions as ext # pylint: disable=import-outside-toplevel - encoded_version = ext.get_cudnn_version() major_version_magnitude = 1000 if encoded_version < 90000 else 10000 major, encoded_version = divmod(encoded_version, major_version_magnitude) From 7d7841a2a9d02462c2a02c8743c0520dd2bbd519 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 May 2025 17:35:04 +0000 Subject: [PATCH 17/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index edad35a680..83414939ad 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1396,9 +1396,10 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: continue if not isinstance(recipe, compatible_recipe_class): raise RuntimeError( - f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe " - f"{compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}. Please check the recipes " - f"assigned during fp8_model_init() and fp8_autocast() calls." + f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe" + f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}." + " Please check the recipes assigned during fp8_model_init() and" + " fp8_autocast() calls." ) def _turn_off_unsupported_features_in_debug(self): From e3576062294d4e0b8a935f9dd5dcbbc51e046373 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Mon, 19 May 2025 17:51:21 +0000 Subject: [PATCH 18/18] Fix imports Signed-off-by: Evgeny Tsykunov --- transformer_engine/pytorch/tensor/float8_blockwise_tensor.py | 2 +- transformer_engine/pytorch/tensor/float8_tensor.py | 2 +- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 2d1231b765..4ab04da83f 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -9,9 +9,9 @@ import math import torch import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import Float8BlockScaling, Recipe -from transformer_engine_torch import DType as TE_DType from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from ..utils import devices_match, round_up_to_nearest_multiple diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index f572c4378b..1c3e575473 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -9,9 +9,9 @@ import torch import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe -from transformer_engine_torch import DType as TE_DType from ..utils import canonicalize_process_group, devices_match from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index a7336435ff..c930cdbff5 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -10,9 +10,9 @@ import torch import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe -from transformer_engine_torch import DType as TE_DType from ..constants import MXFP8_BLOCK_SCALING_SIZE from ..utils import devices_match, round_up_to_nearest_multiple