Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 88 additions & 97 deletions tests/pytorch/distributed/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,28 @@
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex

# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, make_recipe


# Check what quantization schemes are supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.append("fp8")
quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
quantization_list.append("mxfp8")

Expand Down Expand Up @@ -63,11 +72,12 @@ def reset_rng(seed: int = 1234) -> None:
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_fp8: bool = False,
test_is_quantized: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
Expand All @@ -76,78 +86,55 @@ def make_reference_and_test_tensors(
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.

If a quantization scheme is provided, the tensor values are
quantized so that they are representable.

"""

# Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)

# Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
if quantization is None:
if test_is_quantized:
raise ValueError("Quantization scheme not provided")
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device),
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Orthogonal but why does this device arg exist here?

Copy link
Copy Markdown
Collaborator Author

@timmoon10 timmoon10 Jun 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly for completeness. We have a ref_device option since I prefer computing the reference impl on CPU, which helps catch CUDA-related bugs.

)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
test = test.dequantize()

# Make sure reference and test tensors match each other
ref.copy_(test)

ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test


def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
"""Estimated numerical error for a datatype

Based on tolerances for torch.testing.assert_close.

"""

# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
if dtype == tex.DType.kFloat8E4M3:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == tex.DType.kFloat8E5M2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
}[dtype]

# PyTorch dtypes
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
raise ValueError(f"Unsupported dtype ({dtype})")


def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")


def _test_all_reduce(
*,
local_size: int = 17,
local_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
quantization: Optional[str] = None,
) -> None:

# Distributed process group
Expand All @@ -156,22 +143,25 @@ def _test_all_reduce(
world_size = torch.distributed.get_world_size(process_group)

# Tensor dimensions
in_shape = [world_size, local_size]
out_shape = [local_size]
in_shape = [world_size, local_size, local_size]
out_shape = [local_size, local_size]

# Random data
reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)

# Plain PyTorch implementation
Expand Down Expand Up @@ -199,10 +189,10 @@ def _test_all_reduce(

def _test_all_gather(
*,
local_size: int = 13,
local_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
quantization: Optional[str] = None,
) -> None:

# Distributed process group
Expand All @@ -211,26 +201,29 @@ def _test_all_gather(
world_size = torch.distributed.get_world_size(process_group)

# Tensor dimensions
in_shape = [world_size, local_size]
out_shape = [world_size, world_size * local_size]
in_shape = [world_size, local_size, local_size]
out_shape = [world_size, world_size * local_size, local_size]

# Random data
reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)

# Plain PyTorch implementation
y_ref = x_ref.tile((world_size, 1)).reshape(out_shape)
y_ref = x_ref.tile((world_size, 1, 1)).reshape(out_shape)
y_ref.backward(dy_ref)

# Convert to distributed tensors
Expand All @@ -257,10 +250,10 @@ def _test_all_gather(

def _test_reduce_scatter(
*,
local_size: int = 11,
local_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
quantization: Optional[str] = None,
) -> None:

# Distributed process group
Expand All @@ -269,22 +262,25 @@ def _test_reduce_scatter(
world_size = torch.distributed.get_world_size(process_group)

# Tensor dimensions
in_shape = [world_size, world_size * local_size]
out_shape = [world_size, local_size]
in_shape = [world_size, world_size * local_size, local_size]
out_shape = [world_size, local_size, local_size]

# Random data
reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)

# Plain PyTorch implementation
Expand Down Expand Up @@ -324,7 +320,11 @@ def _test_basic_linear(
tensor_parallel_mode: str = "column",
sequence_parallel: bool = False,
) -> None:

# Skip invalid configurations
quantized_compute = quantization is not None
if not quantized_compute and quantized_weight:
return

# Distributed process group
process_group = world_group()
Expand All @@ -348,30 +348,23 @@ def _test_basic_linear(
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()

# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
Expand Down Expand Up @@ -468,7 +461,11 @@ def _test_linear(
tensor_parallel_mode: str = "column",
sequence_parallel: bool = False,
) -> None:

# Skip invalid configurations
quantized_compute = quantization is not None
if not quantized_compute and quantized_weight:
return

# Distributed process group
process_group = world_group()
Expand All @@ -492,21 +489,16 @@ def _test_linear(
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None
if bias:
if tensor_parallel_mode == "row":
Expand All @@ -520,13 +512,11 @@ def _test_linear(
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()

# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
Expand Down Expand Up @@ -773,9 +763,10 @@ def run_parallel_tests() -> None:
if rank == 0:
print(f"Running _test_all_reduce")
_test_all_reduce()
if rank == 0:
print(f"Running _test_all_gather")
_test_all_gather()
for quantization in quantization_list:
if rank == 0:
print(f"Running _test_all_gather with quantization={quantization}")
_test_all_gather(quantization=quantization)
if rank == 0:
print(f"Running _test_reduce_scatter")
_test_reduce_scatter()
Expand Down
Loading