diff --git a/aten/src/ATen/cuda/CUDADataType.h b/aten/src/ATen/cuda/CUDADataType.h index 6ee6346732fa..fba4f855a29b 100644 --- a/aten/src/ATen/cuda/CUDADataType.h +++ b/aten/src/ATen/cuda/CUDADataType.h @@ -90,7 +90,7 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) case c10::ScalarType::Float8_e5m2fnuz: return HIP_R_8F_E5M2_FNUZ; #endif -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12080) +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12080) || (defined(USE_ROCM) && ROCM_VERSION >= 70000) case c10::ScalarType::Float4_e2m1fn_x2: return CUDA_R_4F_E2M1; #endif diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index fe6d1161d1ba..4c46aa736b6a 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -85,6 +85,15 @@ constexpr hipDataType HipDataTypeFor() { return static_cast(500); } +template <> +constexpr hipDataType HipDataTypeFor() { +#if ROCM_VERSION >= 70000 + return HIP_R_4F_E2M1; +#else + return static_cast(33); +#endif +} + template int GetBatchFromParams(const GemmParams* params) { return 1; diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 21e6f9f65dd7..1fc9e14189e4 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -1284,6 +1284,17 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, if (use_fast_accum) { TORCH_CHECK(mat1.scalar_type() != ScalarType::Float4_e2m1fn_x2 && mat2.scalar_type() != ScalarType::Float4_e2m1fn_x2, "`use_fast_accum` is not supported when `mat1` or `mat2` tensors have the `Float4_e2m1fn_x2` dtype."); } +#ifdef USE_ROCM + if (mat1.scalar_type() == ScalarType::Float4_e2m1fn_x2 || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2) { + TORCH_CHECK(ROCM_VERSION >= 70000, "Float4_e2m1fn_x2 is only supported for ROCm 7.0 and above"); + } + if (mat1.scalar_type() == ScalarType::Float8_e5m2 || mat2.scalar_type() == ScalarType::Float8_e5m2) { + TORCH_CHECK(ROCM_VERSION >= 70000, "Float8_e5m2 is only supported for ROCm 7.0 and above"); + } + if (mat1.scalar_type() == ScalarType::Float8_e4m3fn || mat2.scalar_type() == ScalarType::Float8_e4m3fn) { + TORCH_CHECK(ROCM_VERSION >= 70000, "Float8_e4m3fn is only supported for ROCm 7.0 and above"); + } +#endif if (bias) { TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32"); TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half, diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 8ec832e40a16..e8ff44fd4098 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -882,6 +882,8 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # largest power of 2 representable in `torch.float8_e4m3fn` F8E4M3_LARGEST_POW2 = 8 +# largest power of 2 representable in `torch.float4_e2m1fn_x2` +FP4E2M1FN_LARGEST_POW2 = 1.0 # max value of `torch.float8_e4m3fn` (448) F8E4M3_MAX_VAL = torch.finfo(torch.float8_e4m3fn).max # exponent bias of `torch.float8_e8m0fnu` @@ -890,14 +892,20 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: FP4_EBITS, FP4_MBITS = 2, 1 FP4_MAX_VAL = 6.0 -def data_to_mx_scale(x, block_size): +def data_to_mx_scale(x, block_size, recipe): # simple implementation of https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf # section 6.3, not all edge cases (such as NaN) are handled/tested + if recipe == "mxfp8": + largest_pow2 = F8E4M3_LARGEST_POW2 + elif recipe == "mxfp4": + largest_pow2 = FP4E2M1FN_LARGEST_POW2 + else: + raise ValueError(f"data_to_mx_scale(): Unsupported mx recipe: {recipe}") orig_shape = x.shape x = x.reshape(-1, block_size) max_abs = torch.amax(torch.abs(x), 1) largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs)) - scale_e8m0_unbiased = largest_p2_lt_max_abs - F8E4M3_LARGEST_POW2 + scale_e8m0_unbiased = largest_p2_lt_max_abs - largest_pow2 scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, -1 * F8E8M0_EXP_BIAS, F8E8M0_EXP_BIAS) scale_e8m0_biased = scale_e8m0_unbiased + F8E8M0_EXP_BIAS scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8) @@ -1446,10 +1454,10 @@ def test_pack_uint4(self): (127, 96, 1024), (1025, 128, 96) ], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}") - @parametrize("recipe", ["mxfp8", "nvfp4"]) - def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None: - if recipe == "nvfp4" and fast_accum: - return unittest.skip("fast_accum not supported in nvfp4 cublas gemm, skipping") + @parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"]) + def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None: + if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum: + raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping") device = "cuda" M, K, N = mkn @@ -1457,9 +1465,10 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r if not (M % 32 == 0 and K % 32 == 0 and N % 32 == 0): raise unittest.SkipTest("Matrix dimensions must be multiples of 32 on ROCm, skipping") - if recipe == "nvfp4" and K % 32 != 0: - return unittest.skip("K must be divisible by 32 for nvfp4 cublas gemm, skipping") + if (recipe == "nvfp4" or recipe == "mxfp4") and K % 32 != 0: + raise unittest.SkipTest("K must be divisible by 32 for nvfp4/mxfp4 cublas gemm, skipping") + fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn BLOCK_SIZE = 16 if recipe == "nvfp4" else 32 require_exact_match = True approx_match_sqnr_target = 22.0 @@ -1475,11 +1484,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r B = B_ref.to(torch.float8_e4m3fn) A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) - else: # nvfp4 + else: # nvfp4 # mxfp4 A = _bfloat16_to_float4_e2m1fn_x2(A_ref) B = _bfloat16_to_float4_e2m1fn_x2(B_ref) - A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) - B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) + A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype) + B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype) elif test_case_name == "a_ones_b_ones": A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16) @@ -1490,11 +1499,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r B = B_ref.to(torch.float8_e4m3fn) A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) - else: # nvfp4 + else: # nvfp4 # mxfp4 A = _bfloat16_to_float4_e2m1fn_x2(A_ref) B = _bfloat16_to_float4_e2m1fn_x2(B_ref) - A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) - B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) + A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype) + B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype) elif test_case_name == "a_ones_modified_b_ones": A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16) @@ -1506,11 +1515,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r B = B_ref.to(torch.float8_e4m3fn) A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) - else: # nvfp4 + else: # nvfp4 # mxfp4 A = _bfloat16_to_float4_e2m1fn_x2(A_ref) B = _bfloat16_to_float4_e2m1fn_x2(B_ref) - A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) - B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) + A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype) + B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype) elif test_case_name == "a_ones_b_ones_modified": A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16) @@ -1522,11 +1531,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r B = B_ref.to(torch.float8_e4m3fn) A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) - else: # nvfp4 + else: # nvfp4 # mxfp4 A = _bfloat16_to_float4_e2m1fn_x2(A_ref) B = _bfloat16_to_float4_e2m1fn_x2(B_ref) - A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) - B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) + A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype) + B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype) elif test_case_name == "a_scale_modified_b_ones": A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16) @@ -1540,11 +1549,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r A_ref[1][0:BLOCK_SIZE] = 4 A[1][0:BLOCK_SIZE] = 2 A_scale[1][0] = 2 - else: # nvfp4 + else: # nvfp4 # mxfp4 A = _bfloat16_to_float4_e2m1fn_x2(A_ref) B = _bfloat16_to_float4_e2m1fn_x2(B_ref) - A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) - B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) + A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype) + B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype) A_ref[1][0:BLOCK_SIZE] = 4 A.view(torch.uint8)[1][0:(BLOCK_SIZE // 2)] = 0b01000100 A_scale[1][0] = 2 @@ -1561,11 +1570,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r B_ref[1][0:BLOCK_SIZE] = 4 B[1][0:BLOCK_SIZE] = 2 B_scale[1][0] = 2 - else: # nvfp4 + else: # nvfp4 # mxfp4 A = _bfloat16_to_float4_e2m1fn_x2(A_ref) B = _bfloat16_to_float4_e2m1fn_x2(B_ref) - A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) - B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) + A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype) + B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype) B_ref[1][0:BLOCK_SIZE] = 4 B.view(torch.uint8)[1][0:(BLOCK_SIZE // 2)] = 0b01000100 B_scale[1][0] = 2 @@ -1585,7 +1594,7 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r B = B_ref.to(torch.float8_e4m3fn) A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) - else: # nvfp4 + else: # nvfp4 # mxfp4 # scales all-ones, element data random while being exactly representable in float4_e2m1fn_x2 # generate integers in [0, 16] and cast to bfloat16 A_ref = _floatx_unpacked_to_f32( @@ -1600,8 +1609,8 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r ).bfloat16() A = _bfloat16_to_float4_e2m1fn_x2(A_ref) B = _bfloat16_to_float4_e2m1fn_x2(B_ref) - A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) - B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) + A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype) + B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype) elif test_case_name == "data_random_scales_from_data": if not K % BLOCK_SIZE == 0: @@ -1613,17 +1622,18 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r if recipe == "mxfp8": # Calculate scales based on the inputs - A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE) - B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE) + A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE, recipe) + B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE, recipe) max_val = F8E4M3_MAX_VAL min_val = -1 * max_val A = (A_ref.reshape(-1, BLOCK_SIZE) / A_scale.reshape(M * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(M, K) A = A.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn) B = (B_ref.reshape(-1, BLOCK_SIZE) / B_scale.reshape(N * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(N, K) B = B.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn) - else: # nvfp4 - A_scale = data_to_nvfp4_scale(A_ref, BLOCK_SIZE) - B_scale = data_to_nvfp4_scale(B_ref, BLOCK_SIZE) + else: # nvfp4 # mxfp4 + scale_func = data_to_mx_scale if recipe == "mxfp4" else data_to_nvfp4_scale + A_scale = scale_func(A_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None) + B_scale = scale_func(B_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None) max_val = FP4_MAX_VAL min_val = -1 * max_val @@ -1634,13 +1644,14 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r B = B.clamp(min=min_val, max=max_val) B = _bfloat16_to_float4_e2m1fn_x2(B) - approx_match_sqnr_target = 15.8 + approx_match_sqnr_target = 12.0 if torch.version.hip else 15.8 C_ref = A_ref @ B_ref.t() # convert to swizzled format - A_scale = to_blocked(A_scale) - B_scale = to_blocked(B_scale) + if not torch.version.hip: + A_scale = to_blocked(A_scale) + B_scale = to_blocked(B_scale) C = torch._scaled_mm( A, @@ -1657,6 +1668,7 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r sqnr = compute_error(C_ref, C) assert sqnr.item() > approx_match_sqnr_target + @skipIfRocm @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM or IS_WINDOWS, mx_skip_msg) @parametrize("recipe", ["mxfp8", "nvfp4"]) def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None: @@ -1899,6 +1911,7 @@ def test_blockwise_mxfp8_compile(self) -> None: ) torch.testing.assert_close(C, C_ref, atol=0, rtol=0) + @skipIfRocm @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg) def test_blockwise_nvfp4_compile(self) -> None: