Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions test/inductor/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skipIfRocmArch,
TEST_WITH_ROCM,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
Expand All @@ -32,6 +33,17 @@
FP16_MAX_POS: float = torch.finfo(torch.float16).max
EPS: float = 1e-12

# fp8 data types for inductor based fp8 tests. This can be different
# than the one used in eager mode.
f8_type_pair = (torch.float8_e4m3fn, torch.float8_e5m2)
if torch.version.hip:
arch = torch.cuda.get_device_properties(0).gcnArchName
if "gfx94" in arch:
# for gfx942, use fnuz data type.
f8_type_pair = (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)
elif "gfx120" in arch:
# for gfx1200 and gfx1201, e4m3 is not supported on triton.
f8_type_pair = (torch.float8_e5m2,)

def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor:
# The default behavior in PyTorch for casting to `float8_e4m3fn`
Expand Down Expand Up @@ -180,10 +192,11 @@ def fp8_matmul_unwrapped(x):
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type)
y_fp8 = compiled_fp8_matmul(x) # noqa: F841

@skipIfRocmArch(("gfx1200","gfx1201"))
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
@parametrize("shape", ("15,3,13", "4,2048,4096"))
@parametrize("dst_types", [(torch.float8_e4m3fn, torch.float8_e5m2)])
@parametrize("dst_types", [f8_type_pair])
def test_valid_cast(self, dtype: torch.dtype, shape: str, dst_types: tuple):
dst_types = _fix_fp8_dtype_for_rocm(dst_types, device="cuda")
e4m3, e5m2 = dst_types
Expand Down Expand Up @@ -227,7 +240,7 @@ def fp8_cast(x, dtype):

@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float))
@parametrize("dst_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("dst_dtype", f8_type_pair)
@parametrize("shape", ("16,16,16", "4,2048,4096"))
def test_to_fp8_saturated(
self, src_dtype: torch.dtype, dst_dtype: torch.dtype, shape: str
Expand All @@ -249,7 +262,7 @@ def fp8_saturated(x, dtype):

@unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("float8_dtype", f8_type_pair)
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
def test_amax_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda")
Expand All @@ -274,7 +287,7 @@ def amax_fp8(x: Tensor, scale: Tensor):
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2)

@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("float8_dtype", f8_type_pair)
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
def test_amax_along_with_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda")
Expand Down Expand Up @@ -305,7 +318,7 @@ def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):

@unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("float8_dtype", f8_type_pair)
@parametrize("amax_keep_dim", (True, False))
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
def test_layernorm_fp8_quant(
Expand Down Expand Up @@ -347,7 +360,7 @@ def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
)

@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("float8_dtype", f8_type_pair)
@parametrize("shape", ("4,2048,4096",))
@parametrize("keepdim", (False, True))
def test_layernorm_fp8_quant_benchmark(
Expand Down