diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 2ca133e77b..7d54085f0b 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -10,6 +10,7 @@ import pytest import os +import transformer_engine.pytorch as te from transformer_engine.pytorch.fp8 import ( fp8_autocast, FP8GlobalStateManager, @@ -36,11 +37,14 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.module.base import get_workspace +import transformer_engine.pytorch.ops from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import ( - Float8Quantizer, Float8CurrentScalingQuantizer, + Float8Quantizer, + Float8Tensor, ) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.distributed import checkpoint from test_numerics import reset_rng_states, dtype_tols @@ -1338,3 +1342,69 @@ def backward(ctx, grad_output): # Assert that gradients are the same torch.testing.assert_close(grad_checkpoint, grad_standard) + + +@pytest.mark.parametrize( + "module_name", + ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"), +) +@pytest.mark.parametrize("quantization", ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8")) +@pytest.mark.parametrize("heuristic", ("performance", "inference")) +def test_quantized_weight_heuristics( + module_name: str, + quantization: Optional[str], + heuristic: str, +) -> None: + """Test heuristics for initializing quantized weights""" + + # Skip invalid configurations + if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: + pytest.skip(reason_for_no_fp8) + if quantization == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + # Construct quantization recipe + quantization_recipe = None + if quantization == "fp8_delayed_scaling": + quantization_recipe = recipe.DelayedScaling(heuristic=heuristic) + elif quantization == "fp8_current_scaling": + quantization_recipe = recipe.Float8CurrentScaling(heuristic=heuristic) + elif quantization == "mxfp8": + quantization_recipe = recipe.MXFP8BlockScaling(heuristic=heuristic) + + # Construct module + module = None + with fp8_model_init(recipe=quantization_recipe): + if module_name == "Linear": + module = Linear(32, 32) + elif module_name == "LayerNormLinear": + module = LayerNormLinear(32, 32) + elif module_name == "LayerNormMLP": + module = LayerNormMLP(32, 32) + elif module_name == "GroupedLinear": + module = GroupedLinear(1, 32, 32) + elif module_name == "ops.Linear": + module = transformer_engine.pytorch.ops.Linear(32, 32) + + # Check that weight parameters have expected data + for param in module.parameters(): + if isinstance(param, Float8Tensor): + assert param._data is not None, "Missing FP8 data" + if heuristic == "performance" and get_device_compute_capability() < (10, 0): + assert ( + param._transpose is not None and not param._transpose_invalid + ), "FP8 transpose is expected with 'performance' heuristic on Hopper" + if heuristic == "inference": + assert ( + param._transpose is None and param._transpose_invalid + ), "FP8 transpose is not expected for inference" + if isinstance(param, MXFP8Tensor): + assert param._rowwise_data is not None, "Missing row-wise MXFP8 data" + if heuristic == "inference": + assert ( + param._columnwise_data is None + ), "Column-wise MXFP8 data is not expected for inference" + else: + assert ( + param._columnwise_data is not None + ), "Column-wise MXFP8 data is expected for training" diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index fc8d73a136..25fd8d9cf6 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -61,10 +61,12 @@ class QParams: amax_epsilon: float = 0.0 +@dataclass class Recipe: - """ - Base recipe class. - """ + """Configuration for quantization scheme.""" + + # Recipe-specific heuristics (options: "performance", "inference") + heuristic: str = "performance" def mxfp8(self): """Whether the given recipe is MXFP8 block scaling.""" @@ -185,7 +187,8 @@ def __repr__(self) -> str: f"format={str(self.fp8_format).split('.')[1]}, " f"amax_history_len={self.amax_history_len}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"heuristic={self.heuristic}" ) @@ -228,7 +231,8 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"heuristic={self.heuristic}" ) @@ -269,7 +273,8 @@ def __repr__(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " - f"format={str(self.fp8_format).split('.')[1]}" + f"format={str(self.fp8_format).split('.')[1]}, " + f"heuristic={self.heuristic}" ) @@ -349,5 +354,6 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"heuristic={self.heuristic}" ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index adcc1a9258..e6a9648f22 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1152,18 +1152,30 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: with get_rng_state_tracker().fork(): init_fn(param) - # If primary weights are in fp8, wrap the parameter as FP8Tensor + # Wrap parameters in QuantizedTensor if needed fp8_meta_index = self.param_init_meta[name].fp8_meta_index high_precision_init_val = None if self.primary_weights_in_fp8 and fp8_meta_index is not None: + + # Keep high-precision values on CPU if needed if self.preserve_high_precision_init_val: high_precision_init_val = param.detach().cpu() + # Get quantizer quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] - assert ( - quantizer is not None - ) # to use primary fp8 weight one needs to use FP8 autocast with specific recipe. + if quantizer is None: + raise RuntimeError("Weight quantizer has not been initialized") quantizer.internal = False + + # Recipe-specific quantizer configuration + recipe = self.fp8_meta["recipe"] + if recipe is not None: + if recipe.heuristic == "inference": + # Weight needs column-wise usage for dgrad + # GEMM, so not needed for inference + quantizer.set_usage(rowwise=True, columnwise=False) + + # Quantize parameter param = quantizer(param) # Redo parameter wrap in case we broke it above @@ -1171,6 +1183,8 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. param = torch.nn.Parameter(param) + + # Keep high-precision values on CPU if needed if high_precision_init_val is not None: # - Master weights are initialized from model weights, if we use fp8 primary @@ -1214,7 +1228,7 @@ def get_weight_workspace( fsdp_group: Optional[dist_group_type] = None, workspace_dtype: Optional[torch.dtype] = None, ) -> QuantizedTensor: - """Get FP8 workspace buffer and maybe update its values + """Get workspace buffer for weights and maybe update its values The workspace buffer may be cached for future function calls. @@ -1238,15 +1252,19 @@ def get_weight_workspace( workspace_dtype: torch.dtype, default = None If weight workspace contains high-precision tensor - for example for debug quantization, this is dtype of the tensor. + """ - # FP8 primary weights + # Handle case where weights are already quantized + # Note: Make sure weights have required usages, but do not + # destroy unnecessary usages since they may be used later. if isinstance(tensor, QuantizedTensor): - if update_workspace and quantizer is not None: - tensor.update_usage( - rowwise_usage=quantizer.rowwise_usage, - columnwise_usage=quantizer.columnwise_usage, - ) + update_rowwise_usage = True if quantizer.rowwise_usage else None + update_columnwise_usage = True if quantizer.columnwise_usage else None + tensor.update_usage( + rowwise_usage=update_rowwise_usage, + columnwise_usage=update_columnwise_usage, + ) return tensor # Try getting workspace from cache diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 0e786ca96f..47b39d55e0 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -290,10 +290,13 @@ def reset_parameters(self) -> None: # Quantize if needed if self._with_quantized_weight: quantizer = self.get_quantizer("forward", 1) - quantizer.set_usage( - rowwise=True, - columnwise=torch.is_grad_enabled(), - ) + recipe = self._fp8_metas["forward"]["recipe"] + with_columnwise_usage = True + if recipe.heuristic == "inference": + # Weight needs column-wise usage for dgrad GEMM, so + # not needed for inference + with_columnwise_usage = False + quantizer.set_usage(rowwise=True, columnwise=with_columnwise_usage) with torch.no_grad(): weight = quantizer(weight) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index c930cdbff5..e20927c24a 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -384,6 +384,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Quantize to FP8 assert self._quantizer is not None, "Can't quantize without a quantizer" + self._quantizer.internal = False self.data = self._quantizer.quantize(tensor) if self.requires_grad != tensor.requires_grad: self.requires_grad_(requires_grad=tensor.requires_grad)