From 2752a8cd80a492aba04de705185521c3f60118b5 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 4 Jun 2025 00:16:18 +0000 Subject: [PATCH 1/3] Do not initialize quantized weights with column-wise usage in inference mode Signed-off-by: Tim Moon --- tests/pytorch/test_sanity.py | 81 ++++++++++++++++++- transformer_engine/pytorch/module/base.py | 32 +++++--- .../pytorch/module/layernorm_mlp.py | 4 +- .../pytorch/tensor/mxfp8_tensor.py | 1 + 4 files changed, 104 insertions(+), 14 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 2ca133e77b..8086449c7f 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -10,6 +10,7 @@ import pytest import os +import transformer_engine.pytorch from transformer_engine.pytorch.fp8 import ( fp8_autocast, FP8GlobalStateManager, @@ -38,9 +39,11 @@ from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import ( - Float8Quantizer, Float8CurrentScalingQuantizer, + Float8Quantizer, + Float8Tensor, ) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.distributed import checkpoint from test_numerics import reset_rng_states, dtype_tols @@ -1338,3 +1341,79 @@ def backward(ctx, grad_output): # Assert that gradients are the same torch.testing.assert_close(grad_checkpoint, grad_standard) + + +@pytest.mark.parametrize( + "module_name", + ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"), +) +@pytest.mark.parametrize( + "quantization", + (None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"), +) +def test_inference_mode( + module_name: str, + quantization: Optional[str], +) -> None: + """Test heuristics for initializing quantized weights""" + + # Tensor dimensions + sequence_length = 32 + hidden_size = 32 + + # Skip invalid configurations + if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: + pytest.skip(reason_for_no_fp8) + if quantization == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + # Construct quantization recipe + with_quantization = quantization not in (None, "None") + quantization_recipe = None + if quantization == "fp8_delayed_scaling": + quantization_recipe = recipe.DelayedScaling() + elif quantization == "fp8_current_scaling": + quantization_recipe = recipe.Float8CurrentScaling() + elif quantization == "mxfp8": + quantization_recipe = recipe.MXFP8BlockScaling() + + # Construct module + module = None + with torch.inference_mode(): + with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe): + if module_name == "Linear": + module = Linear(hidden_size, hidden_size) + elif module_name == "LayerNormLinear": + module = LayerNormLinear(hidden_size, hidden_size) + elif module_name == "LayerNormMLP": + module = LayerNormMLP(hidden_size, hidden_size) + elif module_name == "GroupedLinear": + module = GroupedLinear(1, hidden_size, hidden_size) + elif module_name == "ops.Linear": + module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size) + + def check_weights(): + """Helper function to check that weight parameters have expected data""" + for param in module.parameters(): + if isinstance(param, Float8Tensor): + assert param._data is not None, "Missing FP8 data" + assert ( + param._transpose is None and param._transpose_invalid + ), "FP8 transpose is not expected for inference" + if isinstance(param, MXFP8Tensor): + assert param._rowwise_data is not None, "Missing row-wise MXFP8 data" + assert ( + param._columnwise_data is None + ), "Column-wise MXFP8 data is not expected for inference" + + # Check that modules have expected weights after initialization + check_weights() + + # Check that modules have expected weights after forward pass + with torch.inference_mode(): + x = torch.zeros(sequence_length, hidden_size, device="cuda") + kwargs = {} + if module_name == "GroupedLinear": + kwargs["m_splits"] = [sequence_length] + y = module(x, **kwargs) + check_weights() diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 87794cc63b..edafdabb0d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1178,18 +1178,23 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: with get_rng_state_tracker().fork(): init_fn(param) - # If primary weights are in fp8, wrap the parameter as FP8Tensor + # Wrap parameters in QuantizedTensor if needed fp8_meta_index = self.param_init_meta[name].fp8_meta_index high_precision_init_val = None if self.primary_weights_in_fp8 and fp8_meta_index is not None: + + # Keep high-precision values on CPU if needed if self.preserve_high_precision_init_val: high_precision_init_val = param.detach().cpu() + # Configure quantizer quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] - assert ( - quantizer is not None - ) # to use primary fp8 weight one needs to use FP8 autocast with specific recipe. + if quantizer is None: + raise RuntimeError("Weight quantizer has not been initialized") + quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) quantizer.internal = False + + # Quantize parameter param = quantizer(param) # Redo parameter wrap in case we broke it above @@ -1197,6 +1202,8 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. param = torch.nn.Parameter(param) + + # Keep high-precision values on CPU if needed if high_precision_init_val is not None: # - Master weights are initialized from model weights, if we use fp8 primary @@ -1240,7 +1247,7 @@ def get_weight_workspace( fsdp_group: Optional[dist_group_type] = None, workspace_dtype: Optional[torch.dtype] = None, ) -> QuantizedTensor: - """Get FP8 workspace buffer and maybe update its values + """Get workspace buffer for weights and maybe update its values The workspace buffer may be cached for future function calls. @@ -1266,13 +1273,16 @@ def get_weight_workspace( for debug quantization, this is dtype of the tensor. """ - # FP8 primary weights + # Handle case where weights are already quantized + # Note: Make sure weights have required usages, but do not + # destroy unnecessary usages since they may be used later. if isinstance(tensor, QuantizedTensor): - if update_workspace and quantizer is not None: - tensor.update_usage( - rowwise_usage=quantizer.rowwise_usage, - columnwise_usage=quantizer.columnwise_usage, - ) + update_rowwise_usage = True if quantizer.rowwise_usage else None + update_columnwise_usage = True if quantizer.columnwise_usage else None + tensor.update_usage( + rowwise_usage=update_rowwise_usage, + columnwise_usage=update_columnwise_usage, + ) return tensor # Try getting workspace from cache diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6ff2763ee1..d89eef55cb 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1759,9 +1759,9 @@ def forward( fc2_bias = self.fc2_bias if self.use_bias else None if not self.fp8: if isinstance(fc1_weight, Float8Tensor): - fc1_weight = fc1_weight.from_float8() + fc1_weight = fc1_weight.dequantize() if isinstance(fc2_weight, Float8Tensor): - fc2_weight = fc2_weight.from_float8() + fc2_weight = fc2_weight.dequantize() # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index c930cdbff5..e20927c24a 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -384,6 +384,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Quantize to FP8 assert self._quantizer is not None, "Can't quantize without a quantizer" + self._quantizer.internal = False self.data = self._quantizer.quantize(tensor) if self.requires_grad != tensor.requires_grad: self.requires_grad_(requires_grad=tensor.requires_grad) From 649b04c8a99c9fc2c432af3313030700fe0333f8 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 4 Jun 2025 00:36:17 +0000 Subject: [PATCH 2/3] Fix bug in test Signed-off-by: Tim Moon --- tests/pytorch/test_sanity.py | 3 ++- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 8086449c7f..f209643a6d 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -1415,5 +1415,6 @@ def check_weights(): kwargs = {} if module_name == "GroupedLinear": kwargs["m_splits"] = [sequence_length] - y = module(x, **kwargs) + with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe): + y = module(x, **kwargs) check_weights() diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index dc5d2aae89..04a05ebd65 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -276,7 +276,7 @@ def forward( # Configure quantizer if weight_quantizer is not None: - weight_quantizer.set_usage(rowwise=True, columnwise=True) + weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index d89eef55cb..73655a84ee 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -332,8 +332,8 @@ def forward( # which handles weight caching etc. # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch - fc1_weight_quantizer.set_usage(rowwise=True, columnwise=True) - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) fc1_weight_final = module.get_weight_workspace( tensor=fc1_weight, quantizer=fc1_weight_quantizer, From 6bf26cc657656d2639d6bea6fee2fe50265846fa Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 9 Jun 2025 16:50:01 -0700 Subject: [PATCH 3/3] Use no-grad mode instead of inference mode in tests Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_sanity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index f209643a6d..a7ff2b2a91 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -1379,7 +1379,7 @@ def test_inference_mode( # Construct module module = None - with torch.inference_mode(): + with torch.no_grad(): with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe): if module_name == "Linear": module = Linear(hidden_size, hidden_size)