diff --git a/examples/python/CuTeDSL/blackwell/grouped_blockscaled_gemm.py b/examples/python/CuTeDSL/blackwell/grouped_blockscaled_gemm.py index 0f2dd5fadd..4bc99433d2 100644 --- a/examples/python/CuTeDSL/blackwell/grouped_blockscaled_gemm.py +++ b/examples/python/CuTeDSL/blackwell/grouped_blockscaled_gemm.py @@ -2490,6 +2490,9 @@ def create_tensors_abc_for_all_groups( strides_abc = [] ptrs_abc = [] + # FP4 packing: 2 elements per byte, so the K storage dimension is halved + k_fct = 2 if ab_dtype == cutlass.Float4E2M1FN else 1 + # Iterate through all groups and create tensors for each group for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl): # Create tensors A, B, C @@ -2499,7 +2502,7 @@ def create_tensors_abc_for_all_groups( cute_tensor_a, ref_torch_fp32_tensor_a, stride_mk_a, - ) = create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype) + ) = create_tensor_and_stride(l, m, k // k_fct, a_major == "m", ab_dtype) ( ptr_b, @@ -2507,7 +2510,7 @@ def create_tensors_abc_for_all_groups( cute_tensor_b, ref_torch_fp32_tensor_b, stride_nk_b, - ) = create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype) + ) = create_tensor_and_stride(l, n, k // k_fct, b_major == "n", ab_dtype) ( ptr_c,