Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
import os

import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import (
fp8_autocast,
FP8GlobalStateManager,
Expand All @@ -36,11 +37,14 @@
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import general_gemm
from transformer_engine.pytorch.module.base import get_workspace
import transformer_engine.pytorch.ops
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
Float8Quantizer,
Float8Tensor,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint
from test_numerics import reset_rng_states, dtype_tols
Expand Down Expand Up @@ -1338,3 +1342,69 @@ def backward(ctx, grad_output):

# Assert that gradients are the same
torch.testing.assert_close(grad_checkpoint, grad_standard)


@pytest.mark.parametrize(
"module_name",
("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize("quantization", ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"))
@pytest.mark.parametrize("heuristic", ("performance", "inference"))
def test_quantized_weight_heuristics(
module_name: str,
quantization: Optional[str],
heuristic: str,
) -> None:
"""Test heuristics for initializing quantized weights"""

# Skip invalid configurations
if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

# Construct quantization recipe
quantization_recipe = None
if quantization == "fp8_delayed_scaling":
quantization_recipe = recipe.DelayedScaling(heuristic=heuristic)
elif quantization == "fp8_current_scaling":
quantization_recipe = recipe.Float8CurrentScaling(heuristic=heuristic)
elif quantization == "mxfp8":
quantization_recipe = recipe.MXFP8BlockScaling(heuristic=heuristic)

# Construct module
module = None
with fp8_model_init(recipe=quantization_recipe):
if module_name == "Linear":
module = Linear(32, 32)
elif module_name == "LayerNormLinear":
module = LayerNormLinear(32, 32)
elif module_name == "LayerNormMLP":
module = LayerNormMLP(32, 32)
elif module_name == "GroupedLinear":
module = GroupedLinear(1, 32, 32)
elif module_name == "ops.Linear":
module = transformer_engine.pytorch.ops.Linear(32, 32)

# Check that weight parameters have expected data
for param in module.parameters():
if isinstance(param, Float8Tensor):
assert param._data is not None, "Missing FP8 data"
if heuristic == "performance" and get_device_compute_capability() < (10, 0):
assert (
param._transpose is not None and not param._transpose_invalid
), "FP8 transpose is expected with 'performance' heuristic on Hopper"
if heuristic == "inference":
assert (
param._transpose is None and param._transpose_invalid
), "FP8 transpose is not expected for inference"
if isinstance(param, MXFP8Tensor):
assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
if heuristic == "inference":
assert (
param._columnwise_data is None
), "Column-wise MXFP8 data is not expected for inference"
else:
assert (
param._columnwise_data is not None
), "Column-wise MXFP8 data is expected for training"
20 changes: 13 additions & 7 deletions transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,12 @@ class QParams:
amax_epsilon: float = 0.0


@dataclass
class Recipe:
"""
Base recipe class.
"""
"""Configuration for quantization scheme."""

# Recipe-specific heuristics (options: "performance", "inference")
heuristic: str = "performance"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is not the best - wouldn't you want performance during inference?

Copy link
Copy Markdown
Collaborator Author

@timmoon10 timmoon10 May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily if you're memory constrained.

Perhaps a naming scheme like "training_performance", "inference_performance", "training_memory", "inference_memory" would be more precise?


def mxfp8(self):
"""Whether the given recipe is MXFP8 block scaling."""
Expand Down Expand Up @@ -185,7 +187,8 @@ def __repr__(self) -> str:
f"format={str(self.fp8_format).split('.')[1]}, "
f"amax_history_len={self.amax_history_len}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
f"fp8_mha={self.fp8_mha}, "
f"heuristic={self.heuristic}"
)


Expand Down Expand Up @@ -228,7 +231,8 @@ def __repr__(self) -> str:
f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, "
f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
f"fp8_mha={self.fp8_mha}, "
f"heuristic={self.heuristic}"
)


Expand Down Expand Up @@ -269,7 +273,8 @@ def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}"
f"format={str(self.fp8_format).split('.')[1]}, "
f"heuristic={self.heuristic}"
)


Expand Down Expand Up @@ -349,5 +354,6 @@ def __repr__(self) -> str:
f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, "
f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
f"fp8_mha={self.fp8_mha}, "
f"heuristic={self.heuristic}"
)
40 changes: 29 additions & 11 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,25 +1152,39 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
with get_rng_state_tracker().fork():
init_fn(param)

# If primary weights are in fp8, wrap the parameter as FP8Tensor
# Wrap parameters in QuantizedTensor if needed
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
high_precision_init_val = None
if self.primary_weights_in_fp8 and fp8_meta_index is not None:

# Keep high-precision values on CPU if needed
if self.preserve_high_precision_init_val:
high_precision_init_val = param.detach().cpu()

# Get quantizer
quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
assert (
quantizer is not None
) # to use primary fp8 weight one needs to use FP8 autocast with specific recipe.
if quantizer is None:
raise RuntimeError("Weight quantizer has not been initialized")
quantizer.internal = False

# Recipe-specific quantizer configuration
recipe = self.fp8_meta["recipe"]
if recipe is not None:
if recipe.heuristic == "inference":
# Weight needs column-wise usage for dgrad
# GEMM, so not needed for inference
quantizer.set_usage(rowwise=True, columnwise=False)

# Quantize parameter
param = quantizer(param)

# Redo parameter wrap in case we broke it above
# NOTE: Currently this can only be broken when primary weights are in Fp8 but
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety.
param = torch.nn.Parameter(param)

# Keep high-precision values on CPU if needed
if high_precision_init_val is not None:

# - Master weights are initialized from model weights, if we use fp8 primary
Expand Down Expand Up @@ -1214,7 +1228,7 @@ def get_weight_workspace(
fsdp_group: Optional[dist_group_type] = None,
workspace_dtype: Optional[torch.dtype] = None,
) -> QuantizedTensor:
"""Get FP8 workspace buffer and maybe update its values
"""Get workspace buffer for weights and maybe update its values

The workspace buffer may be cached for future function calls.

Expand All @@ -1238,15 +1252,19 @@ def get_weight_workspace(
workspace_dtype: torch.dtype, default = None
If weight workspace contains high-precision tensor - for example
for debug quantization, this is dtype of the tensor.

"""

# FP8 primary weights
# Handle case where weights are already quantized
# Note: Make sure weights have required usages, but do not
# destroy unnecessary usages since they may be used later.
if isinstance(tensor, QuantizedTensor):
if update_workspace and quantizer is not None:
tensor.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
)
update_rowwise_usage = True if quantizer.rowwise_usage else None
update_columnwise_usage = True if quantizer.columnwise_usage else None
tensor.update_usage(
rowwise_usage=update_rowwise_usage,
columnwise_usage=update_columnwise_usage,
)
Comment on lines +1262 to +1267
Copy link
Copy Markdown
Collaborator Author

@timmoon10 timmoon10 May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Destroying unnecessary usages was causing problems when alternating between training steps (column-wise data needed) and validation steps (column-wise data not needed). See #1832 (comment).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH this issue is just because of optimizer not doing the right job with quantizing. If we made it so it uses the quantizer then we would not need this part at all.

Copy link
Copy Markdown
Collaborator Author

@timmoon10 timmoon10 May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The layers will configure the quantizer to avoid unnecessary allocations:

# Configure quantizer
if weight_quantizer is not None:
columnwise_usage = is_grad_enabled and inp.requires_grad
if not columnwise_usage:
columnwise_usage = (
is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
)
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)

This is what we want when allocating new buffers, but is overly aggressive when dealing with an existing QuantizedTensor. We could remove this logic from get_weight_workspace, but I don't like how it would ignore the configuration within the quantizer.

return tensor

# Try getting workspace from cache
Expand Down
11 changes: 7 additions & 4 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,13 @@ def reset_parameters(self) -> None:
# Quantize if needed
if self._with_quantized_weight:
quantizer = self.get_quantizer("forward", 1)
quantizer.set_usage(
rowwise=True,
columnwise=torch.is_grad_enabled(),
)
recipe = self._fp8_metas["forward"]["recipe"]
with_columnwise_usage = True
if recipe.heuristic == "inference":
# Weight needs column-wise usage for dgrad GEMM, so
# not needed for inference
with_columnwise_usage = False
quantizer.set_usage(rowwise=True, columnwise=with_columnwise_usage)
with torch.no_grad():
weight = quantizer(weight)

Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def _set_data(self, tensor: torch.Tensor) -> None:

# Quantize to FP8
assert self._quantizer is not None, "Can't quantize without a quantizer"
self._quantizer.internal = False
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me think that internal should maybe be an option to tex.quantize rather than the member of quantizer.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have mixed opinions.

  • The layers know which tensors can be internal tensors and which must be PyTorch tensors. internal seems like a usage hint just like whether it needs row-wise/column-wise data.
  • We override internal multiple times, enough to make it feel redundant. These are usually in special cases outside a layer's normal operation (when primary weights are quantized, when setting tensor.data).

Maybe tex.quantize should have an option to force internal=False, but otherwise respect the quantizer's config? This seems a little overcomplicated though.

self.data = self._quantizer.quantize(tensor)
if self.requires_grad != tensor.requires_grad:
self.requires_grad_(requires_grad=tensor.requires_grad)
Expand Down