Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 104 additions & 1 deletion tests/pytorch/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,31 @@

import pytest
import torch
import warnings

import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
_amax_and_scale_update,
get_default_fp8_recipe,
fp8_model_init,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch import Linear
from transformer_engine.pytorch.distributed import fp8_autocast
from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling
import transformer_engine_torch as tex

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)


# FP8 per tensor delayed scaling
Expand Down Expand Up @@ -367,3 +377,96 @@ def setup_fp8_meta():
)

torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale)

@pytest.mark.parametrize(
"model_init_recipe",
[
pytest.param(
MXFP8BlockScaling(),
marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8),
),
pytest.param(
Float8BlockScaling(),
marks=pytest.mark.skipif(
not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling
),
),
],
)
def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_recipe):
with fp8_model_init(enabled=True, recipe=model_init_recipe):
linear = Linear(32, 32).cuda()

x = torch.randn(32, 32, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling()):
with pytest.raises(RuntimeError) as excinfo:
_ = linear(x)
assert "Recipe mismatch for " in str(excinfo.value)

@pytest.mark.parametrize(
"target_recipe_class, expected_quantizer_type, available_flag, reason",
[
pytest.param(
MXFP8BlockScaling,
MXFP8Quantizer,
mxfp8_available,
reason_for_no_mxfp8,
id="DelayedScaling->MXFP8BlockScaling",
),
pytest.param(
Float8BlockScaling,
Float8BlockQuantizer,
fp8_block_scaling_available,
reason_for_no_fp8_block_scaling,
id="DelayedScaling->Float8BlockScaling",
),
],
)
def test_dynamic_recipe_update(
self, target_recipe_class, expected_quantizer_type, available_flag, reason
):
if not available_flag:
pytest.skip(reason)

in_features = 32
out_features = 32
batch_size = 32
linear = Linear(in_features, out_features).cuda()
initial_recipe = DelayedScaling()

# Run initial iterations with DelayedScaling
for _ in range(3):
x = torch.randn(batch_size, in_features, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=initial_recipe):
y = linear(x)
loss = y.mean()
loss.backward()

for quantizer in linear.quantizers["scaling_fwd"]:
assert isinstance(quantizer, Float8Quantizer)

# Change recipe
target_recipe = target_recipe_class()

# Run subsequent iterations with the target recipe
for i in range(3):
x = torch.randn(batch_size, in_features, device="cuda")
if i == 0:
# Expect a warning on the first iteration with the new recipe
with pytest.warns(UserWarning, match="Recipe type changed"):
with fp8_autocast(enabled=True, fp8_recipe=target_recipe):
y = linear(x)
for quantizer in linear.quantizers["scaling_fwd"]:
assert isinstance(quantizer, expected_quantizer_type)
else:
# No warning expected on subsequent iterations
with warnings.catch_warnings():
warnings.simplefilter("error") # Raise error if unexpected warning occurs
with fp8_autocast(enabled=True, fp8_recipe=target_recipe):
y = linear(x)
loss = y.mean()
loss.backward()

# Final check
for quantizer in linear.quantizers["scaling_fwd"]:
assert isinstance(quantizer, expected_quantizer_type)
3 changes: 2 additions & 1 deletion transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
A.scaling_mode == B.scaling_mode ||
(A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) ||
(A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D),
"Inputs A and B to GEMM need to have compatible scaling modes!");
"Inputs A and B to GEMM need to have compatible scaling modes, but got A.scaling_mode = " +
to_string(A.scaling_mode) + ", B.scaling_mode = " + to_string(B.scaling_mode));
NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
GemmParam ret;
Expand Down
9 changes: 8 additions & 1 deletion transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def __post_init__(self) -> None:

def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"amax_history_len={self.amax_history_len}, "
Expand Down Expand Up @@ -245,6 +246,7 @@ def __post_init__(self) -> None:

def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, "
f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, "
Expand Down Expand Up @@ -291,7 +293,11 @@ def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."

def __repr__(self) -> str:
return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]},"
return (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}"
)


@dataclass()
Expand Down Expand Up @@ -375,6 +381,7 @@ def __post_init__(self) -> None:

def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, "
f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, "
Expand Down
6 changes: 5 additions & 1 deletion transformer_engine/debug/pytorch/debug_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import transformer_engine_torch as tex


from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
Expand Down Expand Up @@ -459,6 +459,10 @@ def any_feature_enabled(self) -> bool:
return True
return False

def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Probably not needed for debug quantizer"""
return None


class DebugQuantizedTensor(QuantizedTensorBase):
"""
Expand Down
63 changes: 62 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..utils import torch_get_autocast_gpu_dtype
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import Recipe
from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor

Expand Down Expand Up @@ -811,6 +811,14 @@ def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
if state is None:
return

# TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing
if "recipe" not in state:
# TE 1.x only supported delayed scaling, which was the default recipe
state["recipe"] = DelayedScaling()
# TE 1.x also saved scale_inv, which is not needed with Recipe object
state.pop("scale_inv_fwd", None)
state.pop("scale_inv_bwd", None)

# Load extra items
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"] = state["recipe"]
Expand Down Expand Up @@ -884,6 +892,8 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
# assume FP8 execution.
def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
_original_recipe = self.fp8_meta.get("recipe", None)

self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
Expand Down Expand Up @@ -922,6 +932,19 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:

self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()

_current_recipe = self.fp8_meta["recipe"]
if _original_recipe is not None and not (
issubclass(_current_recipe.__class__, _original_recipe.__class__)
or issubclass(_original_recipe.__class__, _current_recipe.__class__)
):
warnings.warn(
f"Recipe type changed from {_original_recipe.__class__.__name__} "
f"to {_current_recipe.__class__.__name__}. "
"This may affect model behavior."
)
# Clear cached workspaces as they were created with the old recipe/quantizer type
self._fp8_workspaces.clear()

@contextmanager
def prepare_forward(
self,
Expand All @@ -946,6 +969,7 @@ def prepare_forward(

self.set_activation_dtype(inp)
self.init_fp8_metadata(num_gemms=num_gemms)
self._check_weight_tensor_recipe_correspondence()

if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed():
assert self.fp8_meta["recipe"].reduce_amax, (
Expand Down Expand Up @@ -1346,6 +1370,43 @@ def _validate_name(self):
)
self.name = f"Layer_{TEDebugState.get_layer_count()}"

def _check_weight_tensor_recipe_correspondence(self) -> None:
"""
Verify that the weight tensor types match their corresponding recipe type.
This is invoked in the forward().

This establishes a 1:1 correspondence between recipe types and tensor types:
- DelayedScaling → Float8Tensor
- Float8CurrentScaling → Float8Tensor
- MXFP8BlockScaling → MXFP8Tensor
- Float8BlockScaling → Float8BlockTensor

Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()),
but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()).
"""
if not self.fp8 and not self.fp8_calibration:
return
if not hasattr(self, "weight_names") or not self.weight_names:
return

recipe = self.fp8_meta["recipe"]
weight_tensors = [getattr(self, name) for name in self.weight_names]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@timmoon10 I vaguealy recall that getattr from nn.Module was slow and that you created a faster function for it at some point, do you remember the details?

for i, tensor in enumerate(weight_tensors):
if isinstance(tensor, QuantizedTensorBase):
quantizer = tensor._get_quantizer()
if quantizer is None:
continue
compatible_recipe_class = quantizer._get_compatible_recipe()
if compatible_recipe_class is None:
continue
if not isinstance(recipe, compatible_recipe_class):
raise RuntimeError(
f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe"
f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}."
" Please check the recipes assigned during fp8_model_init() and"
" fp8_autocast() calls."
)

def _turn_off_unsupported_features_in_debug(self):
if (
getattr(self, "ub_bulk_wgrad", False)
Expand Down
8 changes: 6 additions & 2 deletions transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

"""Tensor class with FP8 data quantized with NxN tiles"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable
from typing import Optional, Tuple, Iterable, Union

import math
import torch
import transformer_engine_torch as tex

from transformer_engine_torch import DType as TE_DType

from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..utils import devices_match, round_up_to_nearest_multiple
Expand Down Expand Up @@ -229,6 +230,9 @@ def calibrate(self, tensor: torch.Tensor) -> None:
# where state from an estimator influences distribution parameters.
pass

def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return Float8BlockScaling


class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
"""Tensor class with FP8 data quantized via NxN blocks or 1xN blocks.
Expand Down
11 changes: 9 additions & 2 deletions transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

"""Tensor class with FP8 data"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable
from typing import Optional, Tuple, Iterable, Union
import warnings

import torch
import transformer_engine_torch as tex

from transformer_engine_torch import DType as TE_DType

from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe
from ..utils import canonicalize_process_group, devices_match
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
Expand Down Expand Up @@ -166,6 +167,9 @@ def create_tensor_from_data(
quantizer=self,
)

def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return DelayedScaling


class Float8CurrentScalingQuantizer(Quantizer):
"""Builder class for FP8 tensors with per-tensor current scaling
Expand Down Expand Up @@ -328,6 +332,9 @@ def _canonicalized_amax_reduction_group(self) -> dist_group_type:
"""Get process group for amax reduction"""
return canonicalize_process_group(self.amax_reduction_group)

def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return Float8CurrentScaling


class Float8Tensor(Float8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data
Expand Down
8 changes: 6 additions & 2 deletions transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from __future__ import annotations
from collections.abc import Iterable
import math
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
import transformer_engine_torch as tex

from transformer_engine_torch import DType as TE_DType

from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe
from ..constants import MXFP8_BLOCK_SCALING_SIZE
from ..utils import devices_match, round_up_to_nearest_multiple

Expand Down Expand Up @@ -135,6 +136,9 @@ def calibrate(self, tensor: torch.Tensor) -> None:
# TODO(ksivamani): No calibration needed for mxfp8?
pass

def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return MXFP8BlockScaling


class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data
Expand Down
Loading