From 72ef03041b7d5ad3afcd195760409c951305e731 Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Mon, 2 Jun 2025 15:44:15 -0700 Subject: [PATCH] [release/2.7] fp8 inductor tests: Add gfx120x support On gfx120x, triton supports float8_e5m2. Create f8_type_pair. For gfx942, add fnuz type, for gfx1200 add only float8_e5m2. For rest all archs use default fp8 type/ ocp. Signed-off-by: Jagadish Krishnamoorthy --- test/inductor/test_fp8.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index 64086e5071c62..658db2dc5f5e4 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -12,6 +12,7 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, + skipIfRocmArch, TEST_WITH_ROCM, ) from torch.testing._internal.inductor_utils import HAS_CUDA @@ -32,6 +33,17 @@ FP16_MAX_POS: float = torch.finfo(torch.float16).max EPS: float = 1e-12 +# fp8 data types for inductor based fp8 tests. This can be different +# than the one used in eager mode. +f8_type_pair = (torch.float8_e4m3fn, torch.float8_e5m2) +if torch.version.hip: + arch = torch.cuda.get_device_properties(0).gcnArchName + if "gfx94" in arch: + # for gfx942, use fnuz data type. + f8_type_pair = (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz) + elif "gfx120" in arch: + # for gfx1200 and gfx1201, e4m3 is not supported on triton. + f8_type_pair = (torch.float8_e5m2,) def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor: # The default behavior in PyTorch for casting to `float8_e4m3fn` @@ -180,10 +192,11 @@ def fp8_matmul_unwrapped(x): x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type) y_fp8 = compiled_fp8_matmul(x) # noqa: F841 + @skipIfRocmArch(("gfx1200","gfx1201")) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("dtype", (torch.float16, torch.bfloat16, torch.float)) @parametrize("shape", ("15,3,13", "4,2048,4096")) - @parametrize("dst_types", [(torch.float8_e4m3fn, torch.float8_e5m2)]) + @parametrize("dst_types", [f8_type_pair]) def test_valid_cast(self, dtype: torch.dtype, shape: str, dst_types: tuple): dst_types = _fix_fp8_dtype_for_rocm(dst_types, device="cuda") e4m3, e5m2 = dst_types @@ -227,7 +240,7 @@ def fp8_cast(x, dtype): @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float)) - @parametrize("dst_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) + @parametrize("dst_dtype", f8_type_pair) @parametrize("shape", ("16,16,16", "4,2048,4096")) def test_to_fp8_saturated( self, src_dtype: torch.dtype, dst_dtype: torch.dtype, shape: str @@ -249,7 +262,7 @@ def fp8_saturated(x, dtype): @unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) + @parametrize("float8_dtype", f8_type_pair) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) def test_amax_fp8_quant(self, float8_dtype: torch.dtype, shape: str): float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") @@ -274,7 +287,7 @@ def amax_fp8(x: Tensor, scale: Tensor): torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) + @parametrize("float8_dtype", f8_type_pair) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) def test_amax_along_with_fp8_quant(self, float8_dtype: torch.dtype, shape: str): float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") @@ -305,7 +318,7 @@ def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): @unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) + @parametrize("float8_dtype", f8_type_pair) @parametrize("amax_keep_dim", (True, False)) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) def test_layernorm_fp8_quant( @@ -347,7 +360,7 @@ def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): ) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) + @parametrize("float8_dtype", f8_type_pair) @parametrize("shape", ("4,2048,4096",)) @parametrize("keepdim", (False, True)) def test_layernorm_fp8_quant_benchmark(