diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 25a463aeaa..9ff0c11757 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -40,10 +40,11 @@ ScalingMode, QuantizerFactory, QuantizeLayout, + noop_quantizer_set, ) from transformer_engine.jax.quantize import helper from transformer_engine.jax.activation import activation -from transformer_engine.jax.dense import dense +from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense GEMM_CASES = [ @@ -1204,24 +1205,6 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) -# This function is modified from transformer_engine/jax/cpp_extensions/gemm.py::_jax_gemm() -def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer): - ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims - lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1 - rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1 - lhs_q = lhs_quantizer.quantize( - lhs, - is_rowwise=lhs_is_rowwise, - is_colwise=not lhs_is_rowwise, - ) - rhs_q = rhs_quantizer.quantize( - rhs, - is_rowwise=rhs_is_rowwise, - is_colwise=not rhs_is_rowwise, - ) - return lhs_q, rhs_q - - # E5M2 * E5M2 is not supported fwd_bwd_dtypes = [ [jnp.float8_e4m3fn, jnp.float8_e4m3fn], @@ -1229,219 +1212,194 @@ def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer [jnp.float8_e5m2, jnp.float8_e4m3fn], ] -""" -@pytest_parametrize_wrapper( - "shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]] -) +GROUPED_DENSE_INPUT_SHAPES = [ + # (n_groups, m, n, k), the actual m will be multiplied by 32 + (5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4 + (8, 64, 32, 128), + (8, 64, 128, 256), +] + + +@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES) class TestGroupedDense: - def _ref_grouped_gemm_with_jnp_dot(self, lhs_list, rhs_list, contracting_dims_list): - ref_out_list = [] - for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list): - dim_nums = (contracting_dims, ((), ())) - ref_out_list.append(jax.lax.dot_general(lhs, rhs, dim_nums)) - return ref_out_list - - def _generate_grouped_gemm_input(self, dtype, shape_list, layout_list): + def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): + lhs_contract_dim, _ = contracting_dims + assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3 + if bias is None: + bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype) + else: + assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2]) + remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop() + lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis) + rhs = jnp.split(rhs, rhs.shape[0], axis=0) + bias = jnp.split(bias, bias.shape[0], axis=0) + ref_out = [] + dim_num = (contracting_dims, ((), ())) + for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias): + out_i = jax.lax.dot_general(lhs_i, rhs_i, dim_num) + jnp.expand_dims(bias_i, axis=0) + ref_out.append(jnp.squeeze(out_i)) + return ref_out + + def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False): key = jax.random.PRNGKey(0) - subkeys = jax.random.split(key, len(shape_list) * 2) - - lhs_list, rhs_list, contracting_dims_list = [], [], [] - for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)): - lhs = jax.random.uniform( - subkeys[2 * i], - (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m), - dtype=dtype, - ) - rhs = jax.random.uniform( - subkeys[2 * i + 1], - (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k), - dtype=dtype, - ) - lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) - rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,) - contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) + subkeys = jax.random.split(key, 4) + n_groups, m, n, k = input_shape + + group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) + group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) + group_sizes = jnp.diff(group_sizes) + assert group_sizes.sum() == m + + # *32 to make sure that input shape works for MXFP8 + group_sizes = group_sizes * 32 + m = m * 32 - lhs_list.append(lhs) - rhs_list.append(rhs) - contracting_dims_list.append(contracting_dims) + lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m) + rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k) + bias_shape = (n_groups, n) - return lhs_list, rhs_list, contracting_dims_list + lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype) + rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype) + bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None + + lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) + rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,) + contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) + + return lhs, rhs, group_sizes, contracting_dims, bias + + def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype): + assert out.dtype == ref_list[0].dtype + out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + for i in range(len(ref_list)): + assert_allclose(out_list[i], ref_list[i], dtype=dtype) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) - @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]]) - def test_grouped_gemm_fp16(self, dtype, shape_list, layout_list): - lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input( - dtype, shape_list, layout_list + @pytest_parametrize_wrapper("layout", ["NN"]) + def test_grouped_gemm_fp16(self, dtype, input_shape, layout): + lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( + dtype, input_shape, layout ) - ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list) - primitive_out = tex.grouped_gemm(lhs_list, rhs_list, contracting_dims_list) - for i in range(len(shape_list)): - assert_allclose(primitive_out[i], ref_out[i], dtype=dtype) + ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) + prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims) + self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]]) - def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list, layout_list): + @pytest_parametrize_wrapper("layout", ["NN"]) + def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + pytest.skip("MXFP8 is not supported in grouped_gemm yet") + fwd_dtype, bwd_dtype = fwd_bwd_dtype quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=False + scaling_mode=scaling_mode, + fwd_dtype=fwd_dtype, + bwd_dtype=bwd_dtype, + is_2x2x=False, + n_groups=input_shape[0], ) + # quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype + # We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype + quantizer_set.kernel.q_dtype = bwd_dtype + for quantizer in quantizer_set.kernel.quantizers: + quantizer.q_dtype = bwd_dtype + out_dtype = jnp.bfloat16 - lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input( - out_dtype, shape_list, layout_list + lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( + out_dtype, input_shape, layout + ) + ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) + prim_out = tex.grouped_gemm( + lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set ) - q_lhs_list = [] - q_rhs_list = [] - for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list): - # quantizer_set.x and quantizer_set.kernel have the same q_dtype, we want to - # test the case where lhs and rhs have different q_dtypes - q_lhs, q_rhs = _quantize_gemm_pair( - lhs, rhs, contracting_dims, quantizer_set.x, quantizer_set.dgrad - ) - q_lhs_list.append(q_lhs) - q_rhs_list.append(q_rhs) - - ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list) - primitive_out = tex.grouped_gemm(q_lhs_list, q_rhs_list, contracting_dims_list) allclose_dtype = jnp.float8_e4m3fn - if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2: + if jnp.float8_e5m2 in fwd_bwd_dtype: allclose_dtype = jnp.float8_e5m2 - for i in range(len(shape_list)): - assert_allclose(primitive_out[i], ref_out[i], dtype=allclose_dtype) - @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) - def test_grouped_dense_grad_fp16(self, dtype, shape_list): - group_size = len(shape_list) - layout_list = ["NN" for _ in range(group_size)] + self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype) - x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input( - dtype, shape_list, layout_list + def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims): + out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims) + # Note: we use jnp.sum instead of jnp.mean to make the gradient larger + # and prevent them from being clamp to zero + out_sum_list = [jnp.sum(out) for out in out_list] + return jnp.sum(jnp.asarray(out_sum_list)) + + def _primitive_sum_grouped_dense( + self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set + ): + out = grouped_dense( + x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set ) - bias_list = [] - key = jax.random.PRNGKey(1) - for shape in shape_list: - n = shape[1] - bias = jax.random.uniform(key, n, dtype=dtype) - bias_list.append(bias) - - def ref_func(x_list, kernel_list, bias_list, contracting_dims_list): - out_list = [] - for i in range(len(x_list)): - out_list.append( - dense( - x_list[i], - kernel_list[i], - bias_list[i], - contracting_dims=contracting_dims_list[i], - ) - ) - # Note: we use jnp.sum instead of jnp.mean to make the gradient larger - # and prevent them from being clamp to zero - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) + return jnp.sum(jnp.asarray(out)) - def primitive_func(x_list, kernel_list, bias_list, contracting_dims_list): - out_list = grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list) - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) + @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) + def test_grouped_dense_grad_fp16(self, dtype, input_shape): + x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( + dtype, + input_shape, + with_bias=True, + ) - value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) - value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) + value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) + value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) - ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func( - x_list, kernel_list, bias_list, contracting_dims_list + ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( + x, kernel, bias, group_sizes, contracting_dims ) - primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = ( - value_n_grad_primitive_func(x_list, kernel_list, bias_list, contracting_dims_list) + prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func( + x, kernel, bias, group_sizes, contracting_dims ) - assert_allclose(primitive_out_mean, ref_out_mean, dtype=dtype) - for i in range(group_size): - assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=dtype) - assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=dtype) - assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=dtype) + assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype) + assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype) + assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype) + assert_allclose(prim_dbias, ref_dbias, dtype=dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) + @pytest.mark.parametrize( + "fwd_bwd_dtype", + [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)], + ) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list): - group_size = len(shape_list) - layout_list = ["NN" for _ in range(group_size)] - fwd_dtype, bwd_dtype = fwd_bwd_dtype - if fwd_dtype == jnp.float8_e5m2: - pytest.skip("We never use E5M2 for fwd_dtype in training") - - # Question: should we use different quantizers for different groups? - ref_quantizer_set_list = [] - quantizer_set_list = [] - for _ in range(group_size): - ref_quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True - ) - ref_quantizer_set_list.append(ref_quantizer_set) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True - ) - quantizer_set_list.append(quantizer_set) + def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + pytest.skip("MXFP8 is not supported in grouped_dense yet") - out_dtype = jnp.bfloat16 - x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input( - out_dtype, shape_list, layout_list + fwd_dtype, bwd_dtype = fwd_bwd_dtype + dtype = jnp.bfloat16 + x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( + dtype, + input_shape, + with_bias=True, ) - bias_list = [] - key = jax.random.PRNGKey(1) - for shape in shape_list: - n = shape[1] - bias = jax.random.uniform(key, n, dtype=out_dtype) - bias_list.append(bias) - - def ref_func(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list): - out_list = [] - for i in range(len(x_list)): - out_list.append( - dense( - x_list[i], - kernel_list[i], - bias_list[i], - contracting_dims=contracting_dims_list[i], - quantizer_set=quantizer_set_list[i], - ) - ) - # Note: we use jnp.sum instead of jnp.mean to make the gradient larger - # and prevent them from being clamp to zero - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) - - def primitive_func( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list - ): - out_list = grouped_dense( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list - ) - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) - - value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) - value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) - ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func( - x_list, kernel_list, bias_list, contracting_dims_list, ref_quantizer_set_list + quantizer_set = QuantizerFactory.create_set( + scaling_mode=scaling_mode, + fwd_dtype=fwd_dtype, + bwd_dtype=bwd_dtype, + is_2x2x=True, + n_groups=group_sizes.size, ) - primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = ( - value_n_grad_primitive_func( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list - ) + value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) + value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) + + ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( + x, + kernel, + bias, + group_sizes, + contracting_dims, + ) + prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func( + x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set ) - allclose_dtype = jnp.float8_e4m3fn - if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2: - allclose_dtype = jnp.float8_e5m2 - assert_allclose(primitive_out_mean, ref_out_mean, dtype=allclose_dtype) - for i in range(group_size): - assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype) - assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype) - assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype) -""" + assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype) + assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) + assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) + assert_allclose(prim_dbias, ref_dbias, dtype=dtype) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 4080ae1668..fa8785dcc7 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -525,6 +525,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const auto B_alignment = _getAlignment(reinterpret_cast(param.B)); const auto C_alignment = _getAlignment(reinterpret_cast(C)); const auto D_alignment = _getAlignment(reinterpret_cast(D)); + const auto workspace_alignment = _getAlignment(reinterpret_cast(workspace)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -533,6 +534,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment))); + NVTE_CHECK(workspace_alignment % 256 == 0, + "cuBLAS workspace pointer must be aligned to 256 bytes, got ", workspace_alignment); const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index c38a04f85a..cc02ec3404 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -6,22 +6,28 @@ from typing import Tuple, Sequence, Union, Dict from functools import partial, reduce import operator +import math import jax import jax.numpy as jnp from transformer_engine_jax import get_device_compute_capability from .base import BasePrimitive, register_primitive +from .quantization import grouped_quantize from ..quantize import ( ScaledTensor, + GroupedScaledTensor1x, ScalingMode, Quantizer, + GroupedQuantizer, QuantizeConfig, + QuantizerSet, + QuantizeLayout, noop_quantizer_set, ) -__all__ = ["gemm"] +__all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"] num_cublas_streams = 4 @@ -34,6 +40,11 @@ def get_cublas_workspace_size_bytes() -> None: return 4_194_304 +def is_gemm_with_all_layouts_supported() -> False: + """Return True if using blackwell, False otherwise.""" + return get_device_compute_capability(0) >= 100 + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM @@ -41,73 +52,139 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = () + impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15) inner_primitive = None outer_primitive = None @staticmethod - def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias): + def abstract( + lhs_data_aval, + lhs_scale_inv_aval, + rhs_data_aval, + rhs_scale_inv_aval, + bias_aval, + group_sizes_aval, + group_offset_aval, + *, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + ): """ + Grouped GEMM operation. + Args: - *args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias: - args[ 0 : num_gemms] are the lhs tensors, - args[ num_gemms : 2*num_gemms] are the rhs tensors, - args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors, - args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors, - args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True. - num_gemms: Number of GEMM operations to perform. - scaling_mode: Scaling mode for the GEMM operations. - out_dtype: Data type of the output tensors. - has_bias: Boolean indicating if bias tensors are provided. + lhs_data: Left-hand side input matrix data, 1D flattened array + lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array + rhs_data: Right-hand side input matrix data, 1D flattened array + rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array + bias: Bias matrix of shape (G, N) + group_sizes: 1D array containing the sizes of each group + group_offset: 1D array containing offsets for each group (not yet implemented) + M: Number of rows in the output matrix + N: Number of columns in the output matrix + K: Number of columns in the left-hand side matrix + lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed + rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed + scaling_mode: Scaling mode for the GEMM operations + out_dtype: Data type of the output tensors + has_bias: Boolean indicating if bias tensors are provided + is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation + where both lhs and rhs are 2D matrices and output is (G, M, N) Returns: - A tuple of ShapedArray objects of size num_gemms+1: - ret[0 : num_gemms]: GEMM output tensors, - ret[num_gemms]:workspace tensor. + A jnp.ndarray containing the result of the grouped GEMM operation """ - del scaling_mode - expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms - assert ( - len(args) == expected_num_args - ), f"Expected {expected_num_args} input arguments, but got {len(args)}" - A_list = args[0:num_gemms] - B_list = args[num_gemms : 2 * num_gemms] - # A and B have shapes [1, m, k] and [1, n, k] - out_list_aval = tuple( - jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype) - for A, B in zip(A_list, B_list) - ) + del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval + del K, lhs_is_trans, rhs_is_trans, scaling_mode, has_bias + # TODO(Phuong): move some shape checks from Cpp to here workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams + workspace_size += lhs_scale_inv_aval.size + rhs_scale_inv_aval.size workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) - return (*out_list_aval, workspace_aval) + out_shape = (M, N) + if is_grouped_dense_wgrad: + out_shape = (group_sizes_aval.size, M, N) + out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + return (out_aval, workspace_aval) @staticmethod def outer_abstract(*args, **kwargs): (out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs) - return out_aval + return (out_aval,) @staticmethod - def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias): + def lowering( + ctx, + *args, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + ): del out_dtype return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( ctx, *args, - num_gemms=num_gemms, - scaling_mode=int(scaling_mode), + M=M, + N=N, + K=K, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode.value, has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) @staticmethod - def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias): + def impl( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + group_offset, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + ): assert GroupedGemmPrimitive.inner_primitive is not None - out = GroupedGemmPrimitive.inner_primitive.bind( - *args, - num_gemms=num_gemms, - scaling_mode=scaling_mode.value, + (out, _) = GroupedGemmPrimitive.inner_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + group_offset, + M=M, + N=N, + K=K, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode, out_dtype=out_dtype, has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) - return out[:-1] # out is [out_list, wkspace], only return out_list + return (out,) register_primitive(GroupedGemmPrimitive) @@ -285,7 +362,7 @@ def gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), - quantizer_set: Dict["str", Quantizer] = noop_quantizer_set, + quantizer_set: QuantizerSet = noop_quantizer_set, ) -> jnp.ndarray: """General matrix multiplication with optional quantization. @@ -310,130 +387,190 @@ def gemm( return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set) -""" -def swizzled_scale(scales): - # Swizzle the scale tensor for FP8 GEMM - assert scales.ndim == 2 - rows, cols = scales.shape - scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4) - scales = jnp.transpose(scales, (0, 3, 2, 1, 4)) - scales = scales.reshape(rows, cols) - return scales +def grouped_gemm( + lhs: Union[jnp.ndarray, GroupedScaledTensor1x], + rhs: Union[jnp.ndarray, GroupedScaledTensor1x], + group_sizes: jnp.ndarray, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), + bias: jnp.ndarray = None, + precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, + preferred_element_type: jnp.dtype = None, + group_offset: jnp.array = None, + quantizer_set: QuantizerSet = noop_quantizer_set, +) -> jnp.ndarray: + """ + Grouped GEMM operation. + + Args: + lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x + rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x + group_sizes: 1D array containing the sizes of each group + contracting_dims: Tuple of two sequences representing the contracting dimensions + bias: Bias tensor of shape (G, N) + precision: JAX precision for the GEMM operation + preferred_element_type: Preferred data type for the output tensor + group_offset: 1D array containing offsets for each group (not yet implemented) + quantizer_set: Set of quantizers for FP8 quantization of the input and output + Returns: + A jnp.ndarray containing the result of the grouped GEMM operation -def grouped_gemm( - lhs_list: List[Union[jnp.ndarray, ScaledTensor]], - rhs_list: List[Union[jnp.ndarray, ScaledTensor]], - contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]], - bias_list: List[jnp.ndarray] = None, -) -> List[jnp.ndarray]: - # Grouped GEMM for multiple pairs of tensors. - assert ( - len(lhs_list) == len(rhs_list) == len(contracting_dims_list) - ), "lhs_list, rhs_list, contracting_dims_list must have the same length" - - num_gemms = len(lhs_list) - lhs_list_ = [] - rhs_list_ = [] - lhs_sinv_list_ = [] - rhs_sinv_list_ = [] - bias_list_ = [] - for i in range(num_gemms): - lhs = lhs_list[i] - rhs = rhs_list[i] - contracting_dims = contracting_dims_list[i] - dim_nums = (contracting_dims, ((), ())) - if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): - scaling_mode = lhs.scaling_mode - lhs_shape = lhs.data.shape - rhs_shape = rhs.data.shape - out_dtype = lhs.dq_dtype - # For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout - if lhs.scaling_mode.is_tensor_scaling(): - assert not ( - lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 - ), "FP8 GEMM does not support E5M2 * E5M2" - ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims - if lhs.data_layout == "T": - lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim - if rhs.data_layout == "T": - rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim - dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) + Note: + Tested shapes: + lhs: [M, K] or [K, N] + rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K] + """ + # TODO(Phuong): implement the group_offset + group_offset = group_offset or jnp.zeros((1,), jnp.int32) + + # TODO(Phuong): implement the precision + del precision + + if isinstance(lhs, jnp.ndarray): + assert isinstance(rhs, jnp.ndarray) + out_dtype = lhs.dtype + lhs_shape = lhs.shape + rhs_shape = rhs.shape + lhs_data = lhs + rhs_data = rhs + lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32) + scaling_mode = ScalingMode.NO_SCALING + elif isinstance(lhs, GroupedScaledTensor1x): + assert isinstance(rhs, GroupedScaledTensor1x) + out_dtype = lhs.dq_dtype + lhs_shape = lhs.original_shape + rhs_shape = rhs.original_shape + lhs_data = lhs.data + rhs_data = rhs.data + lhs_scale_inv = lhs.scale_inv + rhs_scale_inv = rhs.scale_inv + assert lhs.scaling_mode == rhs.scaling_mode + scaling_mode = lhs.scaling_mode + else: + raise TypeError("Unsupported lhs type object!") + + out_dtype = preferred_element_type or out_dtype + + lhs_contract_dim, rhs_contract_dim = contracting_dims + + lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1 + lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1) + + # rhs_shape [G, K, N] + rhs_is_trans = rhs_contract_dim[0] != 1 + rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) + + is_grouped_dense_wgrad = False + if len(rhs_shape) == 2: + rhs_is_trans = rhs_contract_dim[0] != 0 + is_grouped_dense_wgrad = True + + # TODO(Hua): thses are for fp16 dense wgrad, any better way to handle this? + if ( + is_grouped_dense_wgrad + and not isinstance(lhs, ScaledTensor) + and not isinstance(rhs, ScaledTensor) + ): + lhs_is_trans = True + rhs_is_trans = False + lhs_flatten_axis = 1 + rhs_flatten_axis = 1 + + if ( + not isinstance(lhs, ScaledTensor) + and not isinstance(rhs, ScaledTensor) + and quantizer_set != noop_quantizer_set + ): + assert isinstance(quantizer_set.x, GroupedQuantizer) + assert type(quantizer_set.x) is type(quantizer_set.kernel) + scaling_mode = quantizer_set.x.scaling_mode + if ( + # TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later + # scaling_mode.is_tensor_scaling() + # and is_gemm_with_all_layouts_supported() + scaling_mode.is_1d_block_scaling() + ): + lhs_is_rowwise = rhs_is_rowwise = True else: - # For jnp.ndarray, only consider contracting_dims, data_layout is always NN - scaling_mode = ScalingMode.NO_SCALING - lhs_shape = lhs.shape - rhs_shape = rhs.shape - out_dtype = lhs.dtype - - (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums - lhs_dn = (lhs_contract, lhs_batch) - rhs_dn = (rhs_contract, rhs_batch) - - lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) - rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) - - # Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy - if scaling_mode == ScalingMode.NO_SCALING: - lhs_3d = _shape_normalization(lhs, lhs_dn) - rhs_3d = _shape_normalization(rhs, rhs_dn) - elif scaling_mode.is_tensor_scaling(): - lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") - rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") - elif scaling_mode == ScalingMode.MXFP8_1D_SCALING: - lhs_3d = _shape_normalization(lhs.data, lhs_dn) - rhs_3d = _shape_normalization(rhs.data, rhs_dn) - lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) - rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn) - # swizzled_scale requires a matrix - lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze()) - rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze()) + lhs_is_rowwise = not lhs_is_trans + rhs_is_rowwise = lhs_is_trans + quantizer_set.x.q_layout = ( + QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE + ) + quantizer_set.kernel.q_layout = ( + QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE + ) + lhs_q = grouped_quantize(lhs, quantizer_set.x, group_sizes, lhs_flatten_axis) + rhs_q = grouped_quantize( + rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis + ) + lhs_data = lhs_q.data + rhs_data = rhs_q.data + lhs_scale_inv = lhs_q.scale_inv + rhs_scale_inv = rhs_q.scale_inv + + assert not ( + lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2 + ), "FP8 GEMM does not support E5M2 * E5M2" + + # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs + # thus additional transpose is required + # TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later + if scaling_mode.is_tensor_scaling(): # and not is_gemm_with_all_layouts_supported(): + lhs_is_trans = False + rhs_is_trans = True + if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): + lhs_layout_is_T = lhs.data_layout == "T" + rhs_layout_is_T = rhs.data_layout == "T" else: - raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") - - # Note: already_transposed doesn't matter for the output shape - # x.shape = [B, D1, D2] - # contracting_dims = (2, ) --> output.shape = [1, B * D1, D2] - # contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1] - # x.shape = [D1, D2] - # contracting_dims = (1, ) --> output.shape = [1, D1, D2] - # contracting_dims = (0, ) --> output.shape = [1, D2, D1] - bm = lhs_remain_shape[0] - bn = rhs_remain_shape[0] - kl = lhs_3d.shape[-1] - kr = rhs_3d.shape[-1] - assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}" - if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0): - print("grouped_gemm input pair {i} has invalid problem shape for lowering: ") - print(f"m = {bm}, n = {bn}, k = {kl}; ") - print("cuBLAS requires the problem shapes being multiples of 16") - assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0) - - lhs_list_.append(lhs_3d) - rhs_list_.append(rhs_3d) - if scaling_mode == ScalingMode.NO_SCALING: - lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) - rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) - if scaling_mode.is_tensor_scaling(): - lhs_sinv_list_.append(lhs.scale_inv) - rhs_sinv_list_.append(rhs.scale_inv) - if scaling_mode == ScalingMode.MXFP8_1D_SCALING: - lhs_sinv_list_.append(lhs_scale_inv) - rhs_sinv_list_.append(rhs_scale_inv) - if bias_list is not None: - bias_list_.append(bias_list[i]) - - out_list = GroupedGemmPrimitive.outer_primitive.bind( - *lhs_list_, - *rhs_list_, - *lhs_sinv_list_, - *rhs_sinv_list_, - *bias_list_, - num_gemms=num_gemms, - scaling_mode=scaling_mode, + lhs_layout_is_T = lhs_q.data_layout == "T" + rhs_layout_is_T = rhs_q.data_layout == "T" + lhs_ndim = len(lhs_shape) + rhs_ndim = len(rhs_shape) + if lhs_layout_is_T: + lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) + if rhs_layout_is_T: + rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) + lhs_data = _shape_normalization(lhs_data, (lhs_contract_dim, ()), not lhs_layout_is_T) + rhs_data = _shape_normalization(rhs_data, (rhs_contract_dim, ()), rhs_layout_is_T) + + # Calling GroupedGEMM Custom Call + K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) + K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim) + assert K_lhs == K_rhs + M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim)) + N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G + + if is_grouped_dense_wgrad: + N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)) + else: + assert group_sizes.size == rhs_shape[0] + + assert group_offset.size == 1 + + has_bias = bias is not None + assert not has_bias or bias.shape == (group_sizes.size, N) + bias = jnp.empty((), jnp.float32) if bias is None else bias + + # TODO(Phuong): support MXFP8_1D_SCALING + assert scaling_mode != ScalingMode.MXFP8_1D_SCALING, "MXFP8_1D_SCALING is not yet supported" + + (out,) = GroupedGemmPrimitive.outer_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + group_offset, + M=M, + N=N, + K=K_lhs, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode.value, out_dtype=out_dtype, - has_bias=1 if bias_list is not None else 0, + has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) - - return out_list -""" + return out diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 7ed0db0298..07d8f81df0 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -47,7 +47,7 @@ from jax.extend import ffi # pylint: disable=ungrouped-imports -__all__ = ["quantize", "quantize_dbias", "grouped_quantize"] +__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] class BaseDBiasQuantizePrimitive(BasePrimitive): @@ -1032,3 +1032,24 @@ def grouped_quantize( group_axis=group_axis, ) return out + + +def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray: + """ + Compute the grouped bias gradient. + + Args: + grad: jnp.ndarray of shape (M, N) + group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M + + Returns: + dbias: jnp.ndarray of shape (num_groups, N) + """ + assert grad.ndim == 2, "Input grad must be a 2D tensor." + assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor." + + segment_ids = jnp.repeat(jnp.arange(group_sizes.shape[0]), group_sizes) + grad_fp32 = grad.astype(jnp.float32) + dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0]) + dbias = dbias_fp32.astype(grad.dtype) + return dbias diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 0825bd2f73..d9d519fa00 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -13,43 +13,127 @@ #include "transformer_engine/multi_stream.h" #include "xla/ffi/api/c_api.h" +#define MXFP8_BLOCK_SIZE 32 + namespace transformer_engine { namespace jax { -Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, - Variadic_Result_Type output_list, int64_t num_gemms, - JAXX_Scaling_Mode scaling_mode, int64_t has_bias) { +Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, + Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, + Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, + Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, + bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, + bool is_grouped_dense_wgrad) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: - // A: row-major with size [m, k], - // B: row-major with size [n, k], needs transpose, + // A: row-major [m, k] for N - [k, m] for T + // B: row-major [k, n] for N - [n, k] for T // on exiting this function, JAX expect: // C: row-major with size [m, n]. // cuBLAS uses column-major data_layout, in this view, each input matrix pair: - // A: column-major with size [k, m], needs transpose, - // B: column-major with size [k, n]. + // A: column-major with size [k, m] for T - [m, k] for N + // B: column-major with size [n, k] for T - [k, n] for N + // // If we call cuBLAS GEMM for A * B, the output will be: // C: column-major with size [m, n] --> row-major with size [n, m]. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. - if (num_gemms <= 0) { - return ffi_with_cuda_error_check(); + int num_streams = nvte_get_num_compute_streams(); + + // Inputs + auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); + auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); + auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv.untyped_data()); + auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv.untyped_data()); + auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type()); + auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type()); + auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type()); + auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type()); + auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; + auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); + + NVTE_CHECK(group_sizes.dimensions().size() == 1); + size_t num_gemms = group_sizes.dimensions()[0]; + + // Outputs + auto out_ptr = reinterpret_cast(output->untyped_data()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); + auto workspace_total_size = product(workspace->dimensions()); + + auto lhs_sinv_size = product(lhs_sinv.dimensions()); + auto rhs_sinv_size = product(rhs_sinv.dimensions()); + auto workspace_size = (workspace_total_size - lhs_sinv_size - rhs_sinv_size) / num_streams; + auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams; + auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; + + size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); + size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); + size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); + size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); + size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); + size_t out_dtype_bytes = te_dtype_bytes(out_dtype); + + NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); + NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, + "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); + + size_t expected_lhs_size = m * k; + size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); + size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); + size_t actual_lhs_size = product(lhs_data.dimensions()); + size_t actual_rhs_size = product(rhs_data.dimensions()); + size_t actual_out_size = product(output->dimensions()); + NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", + expected_lhs_size, ", got ", actual_lhs_size); + if (!is_grouped_dense_wgrad) { + NVTE_CHECK(expected_rhs_size == actual_rhs_size, + "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, + " = ", expected_rhs_size, ", got ", actual_rhs_size); + NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m, + " * ", n, " = ", expected_out_size, ", got ", actual_out_size); + } else { + NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k, + " * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size); + NVTE_CHECK(expected_out_size == actual_out_size, + "Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n, + " = ", expected_out_size, ", got ", actual_out_size); } - size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms; - size_t expected_output_size = num_gemms + 1; - size_t actual_input_size = input_list.size(); - size_t actual_output_size = output_list.size(); - NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu", - expected_input_size, actual_input_size); - NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu", - expected_output_size, actual_output_size); - - bool trans_lhs = true; - bool trans_rhs = false; + + size_t dim_list_bytes = sizeof(int32_t) * num_gemms; + std::vector dim_list_host(num_gemms); + auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + if (!is_grouped_dense_wgrad) { + NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, + ", got sum(group_sizes)=", sum_group_sizes); + } else { + NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, + ", got sum(group_sizes)=", sum_group_sizes); + } + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); bool grad = false; bool accumulate = false; bool use_split_accumulator = false; + auto bias_shape = std::vector{has_bias ? n : 0}; + const int arch = cuda::sm_arch(); + + // It is weird that TE/Common GEMM only use colwise for MXFP8 + const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); + const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; + const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; + const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; + + if (arch < 100 && is_fp8_gemm) { + NVTE_CHECK(!lhs_is_trans && rhs_is_trans, + "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", + "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); + } // These lists are to keep the TensorWrapper objects alive std::vector lhs_wrapper_list; @@ -67,96 +151,83 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, std::vector out_list; std::vector workspace_list; - int lhs_list_offset = 0; - int rhs_list_offset = num_gemms; - int lhs_sinv_list_offset = 2 * num_gemms; - int rhs_sinv_list_offset = 3 * num_gemms; - int bias_list_offset = 4 * num_gemms; - int out_list_offset = 0; - for (int i = 0; i < num_gemms; i++) { - Buffer_Type lhs_i = input_list.get(lhs_list_offset + i).value(); - Buffer_Type rhs_i = input_list.get(rhs_list_offset + i).value(); - Buffer_Type lhs_sinv_i = input_list.get(lhs_sinv_list_offset + i).value(); - Buffer_Type rhs_sinv_i = input_list.get(rhs_sinv_list_offset + i).value(); - Result_Type out_i = output_list.get(out_list_offset + i).value(); - - DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type()); - DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type()); - DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type()); - - void *lhs_ptr = lhs_i.untyped_data(); - void *rhs_ptr = rhs_i.untyped_data(); - void *lhs_sinv_ptr = lhs_sinv_i.untyped_data(); - void *rhs_sinv_ptr = rhs_sinv_i.untyped_data(); - void *out_ptr = out_i->untyped_data(); - - // Placeholder for bias since it can be empty - DType bias_dtype = DType::kFloat32; - void *bias_ptr = nullptr; - - auto lhs_shape_ = lhs_i.dimensions(); - auto rhs_shape_ = rhs_i.dimensions(); - - // lhs and rhs has shape [1, m, k] and [1, n, k] - size_t m = lhs_shape_[1]; - size_t n = rhs_shape_[1]; - size_t k = lhs_shape_[2]; - - auto lhs_shape = std::vector{m, k}; - auto rhs_shape = std::vector{n, k}; - auto out_shape = std::vector{n, m}; - auto lhs_sinv_shape = std::vector{1, 1}; - auto rhs_sinv_shape = std::vector{1, 1}; - - if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || - scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || - scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) { - float *amax_dptr = nullptr; - float *scale_dptr = nullptr; - auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr, - reinterpret_cast(lhs_sinv_ptr)); - auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr, - reinterpret_cast(rhs_sinv_ptr)); - lhs_wrapper_list.push_back(std::move(lhs_i_)); - rhs_wrapper_list.push_back(std::move(rhs_i_)); - } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { - // Note: the scale_inv array should have been swizzled in Python before lowering - auto lhs_sinv_shape_ = lhs_sinv_i.dimensions(); - auto rhs_sinv_shape_ = rhs_sinv_i.dimensions(); - for (int i = 0; i < 2; i++) { - lhs_sinv_shape[i] = lhs_sinv_shape_[i]; - rhs_sinv_shape[i] = rhs_sinv_shape_[i]; - } - - NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode); - TensorWrapper lhs_i_(nvte_scaling_mode); - TensorWrapper rhs_i_(nvte_scaling_mode); - lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape); - rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape); - lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape); - rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape); - - lhs_wrapper_list.push_back(std::move(lhs_i_)); - rhs_wrapper_list.push_back(std::move(rhs_i_)); - } else { - NVTE_ERROR("Unsupported scaling mode: ", static_cast(scaling_mode)); + for (size_t i = 0; i < num_gemms; i++) { + // Matrix data shapes + size_t m_i = dim_list_host[i]; + auto lhs_shape = std::vector{m_i, k}; + auto rhs_shape = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; + auto out_shape = std::vector{m_i, n}; + if (is_grouped_dense_wgrad) { + size_t k_i = dim_list_host[i]; + lhs_shape[0] = lhs_is_trans ? k_i : m; + lhs_shape[1] = lhs_is_trans ? m : k_i; + rhs_shape[0] = rhs_is_trans ? n : k_i; + rhs_shape[1] = rhs_is_trans ? k_i : n; + out_shape[0] = m; + out_shape[1] = n; } - auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype); - void *pre_gelu_ptr = nullptr; - auto bias_shape = std::vector{0}; - auto pre_gelu_shape = std::vector{0}; - if (has_bias) { - auto bias_i_get = input_list.get(bias_list_offset + i); - Buffer_Type bias_i = bias_i_get.value(); - bias_ptr = bias_i.untyped_data(); - bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type()); - bias_shape[0] = n; + // Set matrix data pointers + auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto out_i = TensorWrapper(static_cast(out_ptr), out_shape, out_dtype); + void *lhs_vptr = static_cast(lhs_ptr); + void *rhs_vptr = static_cast(rhs_ptr); + if (rhs_use_colwise) // MatA to enter cuBLAS + rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape); + else + rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape); + if (lhs_use_colwise) // MatB to enter cuBLAS + lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape); + else + lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape); + + // Scale_inv shapes + auto lhs_sinv_size = std::vector{1}; + auto rhs_sinv_size = std::vector{1}; + if (is_mxfp8_scaling) { + NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", + MXFP8_BLOCK_SIZE, k); + size_t scale_k = k / MXFP8_BLOCK_SIZE; + lhs_sinv_size[0] = m_i * scale_k; + rhs_sinv_size[0] = n * scale_k; + // Need to add swizzle here } + + // Set scale_inv pointers + void *rhs_sinv_vptr = static_cast(rhs_sinv_ptr); + void *lhs_sinv_vptr = static_cast(lhs_sinv_ptr); + if (is_fp8_gemm) { + if (rhs_use_colwise) // MatA to enter cuBLAS + rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); + else + rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); + if (lhs_use_colwise) // MatB to enter cuBLAS + lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); + else + lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); + } else { + NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, + "Unsupported scaling mode: ", static_cast(scaling_mode)); + } + auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); - auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype); + auto pre_gelu_i = TensorWrapper(nullptr, std::vector{0}, out_dtype); + + // Update pointer for the next GEMM pair + lhs_ptr += lhs_shape[0] * lhs_shape[1] * lhs_dtype_bytes; + rhs_ptr += rhs_shape[0] * rhs_shape[1] * rhs_dtype_bytes; + out_ptr += out_shape[0] * out_shape[1] * out_dtype_bytes; + if (is_fp8_gemm) { + lhs_sinv_ptr += lhs_sinv_size[0] * lhs_sinv_dtype_bytes; + rhs_sinv_ptr += rhs_sinv_size[0] * rhs_sinv_dtype_bytes; + } + if (has_bias) bias_ptr += n * bias_dtype_bytes; - out_wrapper_list.push_back(std::move(out_i_)); + // Move objects to the lists to keep them alive + lhs_wrapper_list.push_back(std::move(lhs_i)); + rhs_wrapper_list.push_back(std::move(rhs_i)); + out_wrapper_list.push_back(std::move(out_i)); bias_wrapper_list.push_back(std::move(bias_i)); pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); @@ -167,11 +238,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, out_list.push_back(out_wrapper_list.back().data()); } - auto workspace_get = output_list.get(num_gemms); - Result_Type workspace = workspace_get.value(); - uint8_t *workspace_ptr = reinterpret_cast(workspace->untyped_data()); - auto num_streams = nvte_get_num_compute_streams(); - size_t workspace_size = workspace->dimensions()[0] / num_streams; auto workspace_shape = std::vector{workspace_size}; for (int i = 0; i < num_streams; i++) { auto workspace_i = @@ -182,7 +248,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, } nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - pre_gelu_list.data(), num_gemms, trans_lhs, trans_rhs, grad, + pre_gelu_list.data(), num_gemms, rhs_is_trans, lhs_is_trans, grad, workspace_list.data(), accumulate, use_split_accumulator, num_math_sm, stream); @@ -192,11 +258,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream - .RemainingArgs() // input list - .RemainingRets() // output list - .Attr("num_gemms") + .Arg() // lhs_data + .Arg() // lhs_sinv + .Arg() // rhs_data + .Arg() // rhs_sinv + .Arg() // bias + .Arg() // group_sizes + .Arg() // group_offset + .Ret() // output + .Ret() // workspace + .Attr("M") + .Attr("N") + .Attr("K") + .Attr("lhs_is_trans") + .Attr("rhs_is_trans") .Attr("scaling_mode") - .Attr("has_bias"), + .Attr("has_bias") + .Attr("is_grouped_dense_wgrad"), FFI_CudaGraph_Traits); } // namespace jax diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 55d60e4189..bba101c722 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -153,28 +153,28 @@ def _dense_bwd_rule( # GEMM NT # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim - g_constracting_dim = tuple( + g_contracting_dim = tuple( range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) ) # k_non_contracting_dims - k_constracting_dim = tuple( + k_contracting_dim = tuple( dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims ) dgrad = tex.gemm( casted_grad.get_rowwise_tensor(), rowwise_casted_kernel, - (g_constracting_dim, k_constracting_dim), + (g_contracting_dim, k_contracting_dim), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) # GEMM TN # x_non_contracting_dims - g_constracting_dim = x_constracting_dim = tuple( + g_contracting_dim = x_contracting_dim = tuple( range(0, len(x_shape) - len(fwd_x_contracting_dims)) ) wgrad = tex.gemm( - colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim) + colwise_casted_x, casted_grad.get_colwise_tensor(), (x_contracting_dim, g_contracting_dim) ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) @@ -184,135 +184,240 @@ def _dense_bwd_rule( _dense.defvjp(_dense_fwd_rule, _dense_bwd_rule) -""" def grouped_dense( - x_list, - kernel_list, - bias_list, - contracting_dims_list, - quantizer_set_list=None, + x: jnp.ndarray, + kernel: jnp.ndarray, + group_sizes: jnp.ndarray, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)), + bias: jnp.ndarray = None, + precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, + preferred_element_type: jnp.dtype = None, + group_offset: jnp.array = None, + quantizer_set: QuantizerSet = noop_quantizer_set, ): - # Perform grouped_dense layer transformation with optional quantization. + """ + Perform grouped dense (linear) layer transformation with optional quantization. - output_list = _grouped_dense( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list + Args: + x: Input tensor of shape (M, K) + kernel: Weight matrix of shape (G, K, N) + group_sizes: 1D array of shape (G,) specifying the size of each group + contracting_dims: Tuple of sequences specifying which dimensions to contract + (currently only supports ((1,), (1,))) + bias: Bias tensor of shape (G, N) + precision: JAX precision for the GEMM operation + preferred_element_type: Preferred data type for the output tensor + group_offset: 1D array containing offsets for each group (not yet implemented) + quantizer_set: Set of quantizers for FP8 quantization of the input and output + + Returns: + A jnp.ndarray containing the result of the grouped linear operation + """ + output = _grouped_dense( + x, + kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, + quantizer_set, ) - return output_list + return output -@partial(jax.custom_vjp, nondiff_argnums=(3,)) -def _grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list): - output_list, _ = _grouped_dense_fwd_rule( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list +@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7)) +def _grouped_dense( + x, + kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, + quantizer_set, +): + output, _ = _grouped_dense_fwd_rule( + x, + kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, + quantizer_set, ) - return output_list + return output def _grouped_dense_fwd_rule( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list + x, + kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, + quantizer_set, ): - use_bias = bias_list is not None - output_list = [] - x_rowwise_list = [] - x_colwise_list = [] - kernel_colwise_list = [] - kernel_rowwise_list = [] - x_shape_list = [] - kernel_shape_list = [] - if quantizer_set_list is None: - x_rowwise_list = x_list - x_colwise_list = x_list - kernel_colwise_list = kernel_list - kernel_rowwise_list = kernel_list - x_shape_list = [x.shape for x in x_list] - kernel_shape_list = [kernel.shape for kernel in kernel_list] + use_bias = bias is not None + is_noop_quantizer_set = quantizer_set == noop_quantizer_set + + if is_noop_quantizer_set: + grouped_gemm_x = x + grouped_gemm_kernel = kernel + ctx_x = x + ctx_kernel = kernel + flatten_axis_k = None else: - for i in range(len(x_list)): # pylint: disable=consider-using-enumerate - q_x = tex.quantize(x_list[i], quantizer_set_list[i].x) - q_kernel = tex.quantize(kernel_list[i], quantizer_set_list[i].kernel) - x_rowwise_list.append(q_x.get_rowwise_tensor()) - x_colwise_list.append(q_x.get_colwise_tensor()) - kernel_colwise_list.append(q_kernel.get_colwise_tensor()) - kernel_rowwise_list.append(q_kernel.get_rowwise_tensor()) - x_shape_list.append(x_rowwise_list[-1].data.shape) - kernel_shape_list.append(kernel_rowwise_list[-1].data.shape) - - output_list = tex.grouped_gemm( - x_rowwise_list, kernel_colwise_list, contracting_dims_list, bias_list + x_contracting_dims, k_contracting_dims = contracting_dims + flatten_axis_x = -len(x_contracting_dims) + flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis + + assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)" + assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)" + # Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose + # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? + assert x_contracting_dims == (1,) and k_contracting_dims == (1,), ( + "grouped_dense for FP8 can only handle x_contracting_dims=(1,) " + "and k_contracting_dims=(1,) for now, " + f"got {x_contracting_dims=} and {k_contracting_dims=}" + ) + k_contracting_dims = (0,) + + casted_x = tex.grouped_quantize( + x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x + ) + casted_kernel = tex.grouped_quantize( + kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k + ) + contracting_dims = (x_contracting_dims, k_contracting_dims) + + # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have + # rowwise_casted_x.original_shape == (M, K) + # colwise_casted_kernel.original_shape == (G, N, K) + grouped_gemm_x = casted_x.get_rowwise_tensor() + grouped_gemm_kernel = casted_kernel.get_colwise_tensor() + # TODO(Hua): Shall we give warning/error if not quantizer_set.x.is_2x2x()? + ctx_x = casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None + ctx_kernel = casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None + + output = tex.grouped_gemm( + grouped_gemm_x, + grouped_gemm_kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, ) ctx = ( - x_colwise_list, - kernel_rowwise_list, - x_shape_list, - kernel_shape_list, + group_sizes, + ctx_x, + ctx_kernel, + x.shape, + kernel.shape, use_bias, - quantizer_set_list, + is_noop_quantizer_set, + quantizer_set, + flatten_axis_k, ) - return output_list, ctx + return output, ctx -def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list): +def _grouped_dense_bwd_rule( + contracting_dims, precision, preferred_element_type, group_offset, ctx, grad +): + fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims + ( - colwise_x_list, - rowwise_kernel_list, - x_shape_list, - kernel_shape_list, + group_sizes, + ctx_x, + ctx_kernel, + x_shape, + kernel_shape, use_bias, - quantizer_set_list, + is_noop_quantizer_set, + quantizer_set, + flatten_axis_k, ) = ctx - group_size = len(grad_list) - dbias_list = [] - grad_rowwise_list = [] - grad_colwise_list = [] - dgrad_contracting_dims_list = [] - wgrad_contracting_dims_list = [] - for i in range(group_size): - grad = grad_list[i] - x_shape = x_shape_list[i] - kernel_shape = kernel_shape_list[i] - fwd_contracting_dims = contracting_dims_list[i] - - if quantizer_set_list is None: - casted_grad = grad - dbias = tex.quantization._jax_dbias(grad) - grad_rowwise_list.append(grad) - grad_colwise_list.append(grad) - else: - quantizer_set = quantizer_set_list[i] - casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad - ) - grad_rowwise_list.append(casted_grad.get_rowwise_tensor()) - grad_colwise_list.append(casted_grad.get_colwise_tensor()) - dbias_list.append(dbias) - - # GEMM NT - fwd_x_contracting_dims, fwd_k_contracting_dims = fwd_contracting_dims + if is_noop_quantizer_set: + # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?) + # g_contracting_dim = (1, ) + # k_contracting_dim = (2, ) g_contracting_dim = tuple( - range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) + range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) ) k_contracting_dim = tuple( - dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims + dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims ) dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) - dgrad_contracting_dims_list.append(dgrad_contracting_dims) + dgrad_grad = grad + dgrad_kernel_T = ctx_kernel - # GEMM TN + # g_contracting_dim = (0, ) + # x_contracting_dim = (0, ) g_contracting_dim = x_contracting_dim = tuple( range(0, len(x_shape) - len(fwd_x_contracting_dims)) ) wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) - wgrad_contracting_dims_list.append(wgrad_contracting_dims) + wgrad_x_T = ctx_x + wgrad_grad = grad + else: + casted_grad = tex.grouped_quantize( + grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k + ) + + # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use + # g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the + # extra transpose for FP8 in grouped_gemm + # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? + g_contracting_dim = (1,) + k_contracting_dim = (2,) + dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) + dgrad_grad = casted_grad.get_rowwise_tensor() + dgrad_kernel_T = ctx_kernel + + # We need to use g_contracting_dim = (0,) and x_contracting_dim = (1,) to make it work + # after the extra transpose for FP8 in grouped_gemm + # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? + g_contracting_dim = (0,) + x_contracting_dim = (1,) + wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) + wgrad_x_T = ctx_x + wgrad_grad = casted_grad.get_colwise_tensor() + + dgrad = tex.grouped_gemm( + dgrad_grad, + dgrad_kernel_T, + group_sizes, + dgrad_contracting_dims, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset, + ) - dgrad_list = tex.grouped_gemm( - grad_rowwise_list, rowwise_kernel_list, dgrad_contracting_dims_list + wgrad = tex.grouped_gemm( + wgrad_x_T, + wgrad_grad, + group_sizes, + wgrad_contracting_dims, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset, ) - wgrad_list = tex.grouped_gemm(colwise_x_list, grad_colwise_list, wgrad_contracting_dims_list) - return dgrad_list, wgrad_list, dbias_list, quantizer_set_list + group_sizes_grad = None + dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None + + return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set _grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule) -""" diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 45ec4fd1fa..06a2562fb1 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -127,14 +127,16 @@ def _dequantize_func(data, scale_inv, dq_dtype, scaling_mode, is_colwise, flatte def dequantize(scaled_tensor): """Dequantize a tensor using block scaling. - This function dequantizes a tensor that was quantized using block scaling - by applying the inverse scaling factor to each block of data. - Args: - scaled_tensor: The quantized tensor to dequantize + data: The quantized tensor data + scale_inv: The inverse scaling factors + dq_dtype: The data type for dequantized values + scaling_mode: The scaling mode used for quantization + is_colwise: Whether the scaling is column-wise + flatten_axis: The axis along which the tensor could be flattened to 2D Returns: - The dequantized tensor in the specified data type + The dequantized tensor """ return BlockScaleDequantizer._dequantize_func( scaled_tensor.data,