Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/CUDADataType.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type)
case c10::ScalarType::Float8_e5m2fnuz:
return HIP_R_8F_E5M2_FNUZ;
#endif
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12080) || (defined(USE_ROCM) && ROCM_VERSION >= 70000)
case c10::ScalarType::Float4_e2m1fn_x2:
return CUDA_R_4F_E2M1;
#endif
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/cuda/tunable/GemmHipblaslt.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ constexpr hipDataType HipDataTypeFor<c10::Float8_e8m0fnu>() {
return static_cast<hipDataType>(500);
}

template <>
constexpr hipDataType HipDataTypeFor<c10::Float4_e2m1fn_x2>() {
#if ROCM_VERSION >= 70000
return HIP_R_4F_E2M1;
#else
return static_cast<hipDataType>(33);
#endif
}

template <typename T>
int GetBatchFromParams(const GemmParams<T>* params) {
return 1;
Expand Down
11 changes: 11 additions & 0 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,17 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
if (use_fast_accum) {
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float4_e2m1fn_x2 && mat2.scalar_type() != ScalarType::Float4_e2m1fn_x2, "`use_fast_accum` is not supported when `mat1` or `mat2` tensors have the `Float4_e2m1fn_x2` dtype.");
}
#ifdef USE_ROCM
if (mat1.scalar_type() == ScalarType::Float4_e2m1fn_x2 || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2) {
TORCH_CHECK(ROCM_VERSION >= 70000, "Float4_e2m1fn_x2 is only supported for ROCm 7.0 and above");
}
if (mat1.scalar_type() == ScalarType::Float8_e5m2 || mat2.scalar_type() == ScalarType::Float8_e5m2) {
TORCH_CHECK(ROCM_VERSION >= 70000, "Float8_e5m2 is only supported for ROCm 7.0 and above");
}
if (mat1.scalar_type() == ScalarType::Float8_e4m3fn || mat2.scalar_type() == ScalarType::Float8_e4m3fn) {
TORCH_CHECK(ROCM_VERSION >= 70000, "Float8_e4m3fn is only supported for ROCm 7.0 and above");
}
#endif
if (bias) {
TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32");
TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half,
Expand Down
87 changes: 50 additions & 37 deletions test/test_matmul_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,8 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

# largest power of 2 representable in `torch.float8_e4m3fn`
F8E4M3_LARGEST_POW2 = 8
# largest power of 2 representable in `torch.float4_e2m1fn_x2`
FP4E2M1FN_LARGEST_POW2 = 1.0
# max value of `torch.float8_e4m3fn` (448)
F8E4M3_MAX_VAL = torch.finfo(torch.float8_e4m3fn).max
# exponent bias of `torch.float8_e8m0fnu`
Expand All @@ -890,14 +892,20 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
FP4_EBITS, FP4_MBITS = 2, 1
FP4_MAX_VAL = 6.0

def data_to_mx_scale(x, block_size):
def data_to_mx_scale(x, block_size, recipe):
# simple implementation of https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
# section 6.3, not all edge cases (such as NaN) are handled/tested
if recipe == "mxfp8":
largest_pow2 = F8E4M3_LARGEST_POW2
elif recipe == "mxfp4":
largest_pow2 = FP4E2M1FN_LARGEST_POW2
else:
raise ValueError(f"data_to_mx_scale(): Unsupported mx recipe: {recipe}")
orig_shape = x.shape
x = x.reshape(-1, block_size)
max_abs = torch.amax(torch.abs(x), 1)
largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs))
scale_e8m0_unbiased = largest_p2_lt_max_abs - F8E4M3_LARGEST_POW2
scale_e8m0_unbiased = largest_p2_lt_max_abs - largest_pow2
scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, -1 * F8E8M0_EXP_BIAS, F8E8M0_EXP_BIAS)
scale_e8m0_biased = scale_e8m0_unbiased + F8E8M0_EXP_BIAS
scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8)
Expand Down Expand Up @@ -1446,20 +1454,21 @@ def test_pack_uint4(self):
(127, 96, 1024),
(1025, 128, 96)
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
@parametrize("recipe", ["mxfp8", "nvfp4"])
def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
if recipe == "nvfp4" and fast_accum:
return unittest.skip("fast_accum not supported in nvfp4 cublas gemm, skipping")
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")

device = "cuda"
M, K, N = mkn
if torch.version.hip:
if not (M % 32 == 0 and K % 32 == 0 and N % 32 == 0):
raise unittest.SkipTest("Matrix dimensions must be multiples of 32 on ROCm, skipping")

if recipe == "nvfp4" and K % 32 != 0:
return unittest.skip("K must be divisible by 32 for nvfp4 cublas gemm, skipping")
if (recipe == "nvfp4" or recipe == "mxfp4") and K % 32 != 0:
raise unittest.SkipTest("K must be divisible by 32 for nvfp4/mxfp4 cublas gemm, skipping")

fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
BLOCK_SIZE = 16 if recipe == "nvfp4" else 32
require_exact_match = True
approx_match_sqnr_target = 22.0
Expand All @@ -1475,11 +1484,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
B = B_ref.to(torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
else: # nvfp4
else: # nvfp4 # mxfp4
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)

elif test_case_name == "a_ones_b_ones":
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
Expand All @@ -1490,11 +1499,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
B = B_ref.to(torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
else: # nvfp4
else: # nvfp4 # mxfp4
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)

elif test_case_name == "a_ones_modified_b_ones":
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
Expand All @@ -1506,11 +1515,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
B = B_ref.to(torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
else: # nvfp4
else: # nvfp4 # mxfp4
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)

elif test_case_name == "a_ones_b_ones_modified":
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
Expand All @@ -1522,11 +1531,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
B = B_ref.to(torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
else: # nvfp4
else: # nvfp4 # mxfp4
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)

elif test_case_name == "a_scale_modified_b_ones":
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
Expand All @@ -1540,11 +1549,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
A_ref[1][0:BLOCK_SIZE] = 4
A[1][0:BLOCK_SIZE] = 2
A_scale[1][0] = 2
else: # nvfp4
else: # nvfp4 # mxfp4
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
A_ref[1][0:BLOCK_SIZE] = 4
A.view(torch.uint8)[1][0:(BLOCK_SIZE // 2)] = 0b01000100
A_scale[1][0] = 2
Expand All @@ -1561,11 +1570,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
B_ref[1][0:BLOCK_SIZE] = 4
B[1][0:BLOCK_SIZE] = 2
B_scale[1][0] = 2
else: # nvfp4
else: # nvfp4 # mxfp4
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
B_ref[1][0:BLOCK_SIZE] = 4
B.view(torch.uint8)[1][0:(BLOCK_SIZE // 2)] = 0b01000100
B_scale[1][0] = 2
Expand All @@ -1585,7 +1594,7 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
B = B_ref.to(torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
else: # nvfp4
else: # nvfp4 # mxfp4
# scales all-ones, element data random while being exactly representable in float4_e2m1fn_x2
# generate integers in [0, 16] and cast to bfloat16
A_ref = _floatx_unpacked_to_f32(
Expand All @@ -1600,8 +1609,8 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
).bfloat16()
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)

elif test_case_name == "data_random_scales_from_data":
if not K % BLOCK_SIZE == 0:
Expand All @@ -1613,17 +1622,18 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r

if recipe == "mxfp8":
# Calculate scales based on the inputs
A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE)
B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE)
A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE, recipe)
B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE, recipe)
max_val = F8E4M3_MAX_VAL
min_val = -1 * max_val
A = (A_ref.reshape(-1, BLOCK_SIZE) / A_scale.reshape(M * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(M, K)
A = A.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
B = (B_ref.reshape(-1, BLOCK_SIZE) / B_scale.reshape(N * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(N, K)
B = B.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
else: # nvfp4
A_scale = data_to_nvfp4_scale(A_ref, BLOCK_SIZE)
B_scale = data_to_nvfp4_scale(B_ref, BLOCK_SIZE)
else: # nvfp4 # mxfp4
scale_func = data_to_mx_scale if recipe == "mxfp4" else data_to_nvfp4_scale
A_scale = scale_func(A_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
B_scale = scale_func(B_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
Comment on lines +1635 to +1636
Copy link

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The conditional expression recipe if recipe == "mxfp4" else None is confusing. Consider extracting this logic into a clearer variable assignment or using a more explicit approach to pass the correct arguments to each scaling function.

Suggested change
A_scale = scale_func(A_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
B_scale = scale_func(B_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
scale_recipe_arg = recipe if recipe == "mxfp4" else None
A_scale = scale_func(A_ref, BLOCK_SIZE, scale_recipe_arg)
B_scale = scale_func(B_ref, BLOCK_SIZE, scale_recipe_arg)

Copilot uses AI. Check for mistakes.
Comment on lines +1635 to +1636
Copy link

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The conditional expression recipe if recipe == "mxfp4" else None is confusing. Consider extracting this logic into a clearer variable assignment or using a more explicit approach to pass the correct arguments to each scaling function.

Suggested change
A_scale = scale_func(A_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
B_scale = scale_func(B_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
scale_recipe_arg = recipe if recipe == "mxfp4" else None
A_scale = scale_func(A_ref, BLOCK_SIZE, scale_recipe_arg)
B_scale = scale_func(B_ref, BLOCK_SIZE, scale_recipe_arg)

Copilot uses AI. Check for mistakes.
max_val = FP4_MAX_VAL
min_val = -1 * max_val

Expand All @@ -1634,13 +1644,14 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
B = B.clamp(min=min_val, max=max_val)
B = _bfloat16_to_float4_e2m1fn_x2(B)

approx_match_sqnr_target = 15.8
approx_match_sqnr_target = 12.0 if torch.version.hip else 15.8

C_ref = A_ref @ B_ref.t()

# convert to swizzled format
A_scale = to_blocked(A_scale)
B_scale = to_blocked(B_scale)
if not torch.version.hip:
A_scale = to_blocked(A_scale)
B_scale = to_blocked(B_scale)

C = torch._scaled_mm(
A,
Expand All @@ -1657,6 +1668,7 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
sqnr = compute_error(C_ref, C)
assert sqnr.item() > approx_match_sqnr_target

@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM or IS_WINDOWS, mx_skip_msg)
@parametrize("recipe", ["mxfp8", "nvfp4"])
def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None:
Expand Down Expand Up @@ -1899,6 +1911,7 @@ def test_blockwise_mxfp8_compile(self) -> None:
)
torch.testing.assert_close(C, C_ref, atol=0, rtol=0)

@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
def test_blockwise_nvfp4_compile(self) -> None:

Expand Down