Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 81 additions & 1 deletion tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
import os

import transformer_engine.pytorch
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Copy Markdown
Collaborator Author

@timmoon10 timmoon10 Jun 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it convenient to be able to access a class without explicitly doing from ... import ...:

module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)

It's just a matter of style though. Within the package we explicitly list the imports to order to guarantee only relative imports, but this isn't relevant for tests since we always do absolute imports. Also, Google's style guide recommends against it.

from transformer_engine.pytorch.fp8 import (
fp8_autocast,
FP8GlobalStateManager,
Expand Down Expand Up @@ -38,9 +39,11 @@
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
Float8Quantizer,
Float8Tensor,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint
from test_numerics import reset_rng_states, dtype_tols
Expand Down Expand Up @@ -1338,3 +1341,80 @@ def backward(ctx, grad_output):

# Assert that gradients are the same
torch.testing.assert_close(grad_checkpoint, grad_standard)


@pytest.mark.parametrize(
"module_name",
("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
"quantization",
(None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
module_name: str,
quantization: Optional[str],
) -> None:
"""Test heuristics for initializing quantized weights"""

# Tensor dimensions
sequence_length = 32
hidden_size = 32

# Skip invalid configurations
if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

# Construct quantization recipe
with_quantization = quantization not in (None, "None")
quantization_recipe = None
if quantization == "fp8_delayed_scaling":
quantization_recipe = recipe.DelayedScaling()
elif quantization == "fp8_current_scaling":
quantization_recipe = recipe.Float8CurrentScaling()
elif quantization == "mxfp8":
quantization_recipe = recipe.MXFP8BlockScaling()

# Construct module
module = None
with torch.no_grad():
with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe):
if module_name == "Linear":
module = Linear(hidden_size, hidden_size)
elif module_name == "LayerNormLinear":
module = LayerNormLinear(hidden_size, hidden_size)
elif module_name == "LayerNormMLP":
module = LayerNormMLP(hidden_size, hidden_size)
elif module_name == "GroupedLinear":
module = GroupedLinear(1, hidden_size, hidden_size)
elif module_name == "ops.Linear":
module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)

def check_weights():
"""Helper function to check that weight parameters have expected data"""
for param in module.parameters():
if isinstance(param, Float8Tensor):
assert param._data is not None, "Missing FP8 data"
assert (
param._transpose is None and param._transpose_invalid
), "FP8 transpose is not expected for inference"
if isinstance(param, MXFP8Tensor):
assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
assert (
param._columnwise_data is None
), "Column-wise MXFP8 data is not expected for inference"

# Check that modules have expected weights after initialization
check_weights()

# Check that modules have expected weights after forward pass
with torch.inference_mode():
x = torch.zeros(sequence_length, hidden_size, device="cuda")
kwargs = {}
if module_name == "GroupedLinear":
kwargs["m_splits"] = [sequence_length]
with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe):
y = module(x, **kwargs)
check_weights()
32 changes: 21 additions & 11 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,25 +1178,32 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
with get_rng_state_tracker().fork():
init_fn(param)

# If primary weights are in fp8, wrap the parameter as FP8Tensor
# Wrap parameters in QuantizedTensor if needed
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
high_precision_init_val = None
if self.primary_weights_in_fp8 and fp8_meta_index is not None:

# Keep high-precision values on CPU if needed
if self.preserve_high_precision_init_val:
high_precision_init_val = param.detach().cpu()

# Configure quantizer
quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
assert (
quantizer is not None
) # to use primary fp8 weight one needs to use FP8 autocast with specific recipe.
if quantizer is None:
raise RuntimeError("Weight quantizer has not been initialized")
quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
quantizer.internal = False

# Quantize parameter
param = quantizer(param)

# Redo parameter wrap in case we broke it above
# NOTE: Currently this can only be broken when primary weights are in Fp8 but
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety.
param = torch.nn.Parameter(param)

# Keep high-precision values on CPU if needed
if high_precision_init_val is not None:

# - Master weights are initialized from model weights, if we use fp8 primary
Expand Down Expand Up @@ -1240,7 +1247,7 @@ def get_weight_workspace(
fsdp_group: Optional[dist_group_type] = None,
workspace_dtype: Optional[torch.dtype] = None,
) -> QuantizedTensor:
"""Get FP8 workspace buffer and maybe update its values
"""Get workspace buffer for weights and maybe update its values

The workspace buffer may be cached for future function calls.

Expand All @@ -1266,13 +1273,16 @@ def get_weight_workspace(
for debug quantization, this is dtype of the tensor.
"""

# FP8 primary weights
# Handle case where weights are already quantized
# Note: Make sure weights have required usages, but do not
# destroy unnecessary usages since they may be used later.
if isinstance(tensor, QuantizedTensor):
if update_workspace and quantizer is not None:
tensor.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
)
update_rowwise_usage = True if quantizer.rowwise_usage else None
update_columnwise_usage = True if quantizer.columnwise_usage else None
tensor.update_usage(
rowwise_usage=update_rowwise_usage,
columnwise_usage=update_columnwise_usage,
)
return tensor

# Try getting workspace from cache
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def forward(

# Configure quantizer
if weight_quantizer is not None:
weight_quantizer.set_usage(rowwise=True, columnwise=True)
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)

# Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,8 @@ def forward(
# which handles weight caching etc.
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc1_weight_final = module.get_weight_workspace(
tensor=fc1_weight,
quantizer=fc1_weight_quantizer,
Expand Down Expand Up @@ -1754,9 +1754,9 @@ def forward(
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.from_float8()
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.from_float8()
fc2_weight = fc2_weight.dequantize()

# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def _set_data(self, tensor: torch.Tensor) -> None:

# Quantize to FP8
assert self._quantizer is not None, "Can't quantize without a quantizer"
self._quantizer.internal = False
self.data = self._quantizer.quantize(tensor)
if self.requires_grad != tensor.requires_grad:
self.requires_grad_(requires_grad=tensor.requires_grad)
Expand Down