diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 977c3c2912..7e6605c9fe 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -57,13 +57,14 @@ def __call__(self, x, mask, disable_dropout=False): self_attn_mask_type="padding", enable_relative_embedding=False, enable_sequence_parallel=self.enable_seq_paral, + mlp_activations=("gelu", "linear"), ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = x.reshape(x.shape[0], -1) if self.enable_seq_paral: - # Trigger all-gather to collect a complete tensor alone seqence on each device. + # Trigger all-gather to collect a complete tensor alone sequence on each device. x = jax.lax.with_sharding_constraint( x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) ) @@ -459,7 +460,7 @@ def setUpClass(cls): def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.50 and actual[1] > 0.76 + assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -467,7 +468,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.50 and actual[1] > 0.76 + assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -475,14 +476,14 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.50 and actual[1] > 0.76 + assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_with_sp(self): """Test Transformer Engine with BF16 + SP""" self.args.enable_sp = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.50 and actual[1] > 0.76 + assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_with_sp(self): @@ -491,7 +492,7 @@ def test_te_delayed_scaling_fp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.50 and actual[1] > 0.76 + assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8_with_sp(self): @@ -500,7 +501,7 @@ def test_te_mxfp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.50 and actual[1] > 0.76 + assert actual[0] < 0.455 and actual[1] > 0.785 if __name__ == "__main__": diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 1efc7e1f3c..4dc07a2eea 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -29,7 +29,7 @@ ScaledTensor, ScalingMode, QuantizerFactory, - QuantizeAxis, + QuantizeLayout, ) from transformer_engine.jax.quantize import helper from transformer_engine.jax.activation import activation @@ -82,8 +82,9 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray): if isinstance(a, ScaledTensor1x): - if a.layout == "T": - b_transpose = jnp.transpose(b, (-1, *range(b.ndim - 1))) + if a.data_layout == "T": + flatten_axis = a.data.ndim - a.flatten_axis + b_transpose = jnp.transpose(b, (*range(flatten_axis, b.ndim), *range(flatten_axis))) assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype) else: assert_allclose(a.dequantize(), b, dtype=a.data.dtype) @@ -141,7 +142,8 @@ def primitive_func(self, inputs, activation_type, quantizer): def test_act_grad(self, shape, activation_type): key = jax.random.PRNGKey(0) x = jax.random.uniform(key, shape, jnp.float32) - x = jnp.repeat(x, len(activation_type), axis=-1) + x = jnp.expand_dims(x, axis=-2) + x = jnp.repeat(x, len(activation_type), axis=-2) value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) @@ -159,7 +161,8 @@ def test_act_grad(self, shape, activation_type): @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, output_type): x = random_inputs - x = jnp.repeat(x, len(activation_type), axis=-1) + x = jnp.expand_dims(x, axis=-2) + x = jnp.repeat(x, len(activation_type), axis=-2) self.activation_type = activation_type value_n_grad_primitive_func = jit( @@ -169,7 +172,7 @@ def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, quantizer = QuantizerFactory.create( scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, q_dtype=output_type, - q_axis=QuantizeAxis.ROWWISE, + q_layout=QuantizeLayout.ROWWISE, ) prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer) @@ -182,19 +185,22 @@ def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) - @pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) + @pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] + ) def test_act_forward_with_delayed_scaling_fp8( - self, random_inputs, activation_type, output_type, q_axis + self, random_inputs, activation_type, output_type, q_layout ): x = random_inputs - x = jnp.repeat(x, len(activation_type), axis=-1) + x = jnp.expand_dims(x, axis=-2) + x = jnp.repeat(x, len(activation_type), axis=-2) self.activation_type = activation_type te_quantizer, jax_quantizer = QuantizerFactory.create( n_quantizers=2, scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, q_dtype=output_type, - q_axis=q_axis, + q_layout=q_layout, ) te_output = tex.act_lu(x, activation_type, te_quantizer) @@ -203,19 +209,21 @@ def test_act_forward_with_delayed_scaling_fp8( assert_bitwise_scaled_tensors(te_output, jax_output) @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) - @pytest_parametrize_wrapper("shape", [(128, 128)]) + @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)]) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) - @pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) + @pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] + ) def test_act_forward_with_block_scaling_fp8( - self, random_inputs, activation_type, output_type, q_axis + self, random_inputs, activation_type, output_type, q_layout ): x = random_inputs - x = jnp.repeat(x, len(activation_type), axis=-1) + x = jnp.repeat(x, len(activation_type), axis=-2) self.activation_type = activation_type quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_axis=q_axis + scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) output = tex.act_lu(x, activation_type, quantizer) @@ -324,9 +332,11 @@ def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp @pytest.mark.skipif(not is_fp8_supported, reason=reason) # No Norm FWD E5M2 in TE backend @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) - @pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) + @pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] + ) def test_norm_grad_with_delayed_scaling_fp8( - self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_axis + self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout ): """ Test transformer_engine.jax.layernorm.layernorm @@ -335,7 +345,9 @@ def test_norm_grad_with_delayed_scaling_fp8( pytest.skip("RMSNorm and zero_centered_gamma is not supported!") quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, q_dtype=out_dtype, q_axis=q_axis + scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + q_dtype=out_dtype, + q_layout=q_layout, ) self._test_norm_grad( n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer @@ -351,7 +363,7 @@ def _test_norm_forward( inp_dtype, out_dtype, scaling_mode, - q_axis, + q_layout, ): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 3) @@ -363,7 +375,7 @@ def _test_norm_forward( gamma = jnp.asarray(gamma, inp_dtype) quantizer, ref_quantizer = QuantizerFactory.create( - n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_axis=q_axis + n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout ) if norm_type == "layernorm": beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1) @@ -391,9 +403,11 @@ def _test_norm_forward( @pytest.mark.skipif(not is_fp8_supported, reason=reason) # No Norm FWD E5M2 in TE backend @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) - @pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) + @pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] + ) def test_norm_forward_with_delayed_scaling_fp8( - self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_axis + self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout ): if norm_type == "rmsnorm" and zero_centered_gamma is True: pytest.skip("RMSNorm and zero_centered_gamma is not supported!") @@ -407,7 +421,7 @@ def test_norm_forward_with_delayed_scaling_fp8( inp_dtype=inp_dtype, out_dtype=out_dtype, scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, - q_axis=q_axis, + q_layout=q_layout, ) @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @@ -424,7 +438,7 @@ def test_norm_forward_with_block_scaling_fp8( inp_dtype=inp_dtype, out_dtype=out_dtype, scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, - q_axis=QuantizeAxis.ROWWISE_COLWISE, + q_layout=QuantizeLayout.ROWWISE_COLWISE, ) @@ -434,14 +448,14 @@ def test_norm_forward_with_block_scaling_fp8( } ALL_QUANTIZE_TEST_SHAPES = [ - (128, 128), - (4, 256, 512), + (32, 64), + (2, 64, 32), ] QUANTIZE_TEST_SHAPES = { "L0": [ - (256, 128), - (64, 16, 2, 256), + (32, 256, 128), + (64, 32, 32, 256), ], "L2": ALL_QUANTIZE_TEST_SHAPES, } @@ -457,48 +471,52 @@ def test_norm_forward_with_block_scaling_fp8( @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("input_shape", ALL_QUANTIZE_TEST_SHAPES) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) +@pytest_parametrize_wrapper("flatten_axis", [-1, -2]) @pytest_parametrize_wrapper( - "q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE] + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] ) class TestQuantize: """ Purely quantization related tests that will always test on a wider set of types and shapes """ - def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_axis): + def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): key = jax.random.PRNGKey(0) # Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling) quantizer = QuantizerFactory.create( scaling_mode=scaling_mode, q_dtype=q_dtype, - q_axis=q_axis, + q_layout=q_layout, ) + # Adding dimension to test if padding is done correctly when flatten 3D to 2D + if flatten_axis == -2: + input_shape = input_shape[:-1] + (2,) + input_shape[-1:] n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): x = jax.random.uniform(key, input_shape, in_dtype) - scaled_tensor = quantizer.quantize(x) + scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis) assert_dequantized_scaled_tensor(scaled_tensor, x) - def test_quantize_bitwise(self, in_dtype, input_shape, q_dtype, scaling_mode, q_axis): - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( - input_shape - ): - pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") + def test_quantize_bitwise( + self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis + ): key = jax.random.PRNGKey(0) + if flatten_axis == -2: + input_shape = input_shape[:-1] + (2,) + input_shape[-1:] input = jax.random.uniform(key, input_shape, in_dtype) te_quantizer, jax_quantizer = QuantizerFactory.create( - n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_axis=q_axis + n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout ) - jax_output = _jax_quantize(input, quantizer=jax_quantizer) + jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) - te_output = tex.quantize(input, quantizer=te_quantizer) - assert_bitwise_scaled_tensors(jax_output, te_output) + te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis) + assert_bitwise_scaled_tensors(te_output, jax_output) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @@ -508,9 +526,13 @@ class TestFusedQuantize: @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("input_shape", QUANTIZE_TEST_SHAPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) - @pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) - def test_quantize_dbias(self, in_dtype, input_shape, out_dtype, scaling_mode, q_axis): - transpose_axis = -1 + @pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] + ) + @pytest_parametrize_wrapper("flatten_axis", [-1, -2]) + def test_quantize_dbias( + self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis + ): if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( input_shape ): @@ -520,35 +542,37 @@ def test_quantize_dbias(self, in_dtype, input_shape, out_dtype, scaling_mode, q_ input = jax.random.uniform(key, input_shape, in_dtype) jax_quantizer, te_quantizer = QuantizerFactory.create( - n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_axis=q_axis + n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout ) - te_output, te_dbias = jit(lambda input: tex.quantize_dbias(input, quantizer=te_quantizer))( - input - ) + te_output, te_dbias = jit( + lambda input: tex.quantize_dbias( + input, quantizer=te_quantizer, flatten_axis=flatten_axis + ) + )(input) jax_output, jax_dbias = jit( lambda input: _jax_quantize_dbias( - input, - quantizer=jax_quantizer, + input, quantizer=jax_quantizer, flatten_axis=flatten_axis ) )(input) - assert_bitwise_scaled_tensors(jax_output, te_output) + assert_bitwise_scaled_tensors(te_output, jax_output) - assert_allclose(jax_dbias, te_dbias) + assert_allclose(te_dbias, jax_dbias) def _test_quantize_dact_dbias( - self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_axis + self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout ): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1) - x = jnp.repeat(x, len(activation_type), axis=-1) + x = jnp.expand_dims(x, axis=-2) + x = jnp.repeat(x, len(activation_type), axis=-2) dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1) jax_quantizer, te_quantizer = QuantizerFactory.create( - n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_axis=q_axis + n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout ) is_casted_output = te_quantizer is not None @@ -573,12 +597,12 @@ def _test_quantize_dact_dbias( )(dz, x) if is_casted_output: - assert_bitwise_scaled_tensors(jax_output, te_output) + assert_bitwise_scaled_tensors(te_output, jax_output) else: - assert_allclose(jax_output, te_output) + assert_allclose(te_output, jax_output) if is_dbias: - assert_allclose(jax_dbias, te_dbias) + assert_allclose(te_dbias, jax_dbias) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @@ -597,7 +621,7 @@ def test_quantize_dact_dbias_no_quantization( scaling_mode=ScalingMode.NVTE_NO_SCALING, activation_type=activation_type, is_dbias=is_dbias, - q_axis=QuantizeAxis.ROWWISE, + q_layout=QuantizeLayout.ROWWISE, ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -605,9 +629,11 @@ def test_quantize_dact_dbias_no_quantization( @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("is_dbias", [True, False]) - @pytest_parametrize_wrapper("q_axis", [QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE]) + @pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] + ) def test_quantize_dact_dbias_delayed_scaling( - self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_axis + self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout ): self._test_quantize_dact_dbias( in_dtype=in_dtype, @@ -616,7 +642,7 @@ def test_quantize_dact_dbias_delayed_scaling( scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, activation_type=activation_type, is_dbias=is_dbias, - q_axis=q_axis, + q_layout=q_layout, ) @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @@ -626,9 +652,11 @@ def test_quantize_dact_dbias_delayed_scaling( ) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("is_dbias", [True, False]) - @pytest_parametrize_wrapper("q_axis", [QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE]) + @pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] + ) def test_quantize_dact_dbias_mxfp8_scaling( - self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_axis + self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout ): if reduce(operator.mul, input_shape[:-1]) % 128 != 0 or input_shape[-1] % 128 != 0: # TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes. @@ -645,75 +673,75 @@ def test_quantize_dact_dbias_mxfp8_scaling( scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, activation_type=activation_type, is_dbias=is_dbias, - q_axis=q_axis, + q_layout=q_layout, ) class TestDense: - def _ref_gemm_with_jnp_dot(self, a, b, layout): - if layout[0] == "T": + def _ref_gemm_with_jnp_dot(self, a, b, data_layout): + if data_layout[0] == "T": a = jnp.swapaxes(a, -1, -2) - if layout[1] == "T": + if data_layout[1] == "T": b = jnp.swapaxes(b, -1, -2) return jnp.dot(a, b) - def _generate_gemm_input(self, m, n, k, layout): + def _generate_gemm_input(self, m, n, k, data_layout): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) x = jax.random.uniform( subkeys[0], - (m if layout[0] == "N" else k, k if layout[0] == "N" else m), + (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m), dtype=jnp.bfloat16, ) / jnp.sqrt(k) w = jax.random.uniform( subkeys[1], - (k if layout[1] == "N" else n, n if layout[1] == "N" else k), + (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k), dtype=jnp.bfloat16, ) / jnp.sqrt(n) - lhs_contracting_dim = (1,) if layout[0] == "N" else (0,) - rhs_contracting_dim = (0,) if layout[1] == "N" else (1,) + lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) + rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,) contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) return (x, w, contracting_dims) - @pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)]) - @pytest_parametrize_wrapper("layout", ["TN", "NT", "NN", "TT"]) - def test_gemm_bf16(self, m, n, k, layout): - x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout) + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) + @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) + def test_gemm_bf16(self, m, n, k, data_layout): + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) primitive_out = tex.gemm(x, w, contracting_dims) - ref_out = self._ref_gemm_with_jnp_dot(x, w, layout) + ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)]) + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - @pytest_parametrize_wrapper("layout", ["TN", "NT", "NN", "TT"]) - def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, layout): - x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout) + @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) + def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout): + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False ) primitive_out = tex.gemm( x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set ) - ref_out = self._ref_gemm_with_jnp_dot(x, w, layout) + ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=q_dtype) - @pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)]) + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) def test_dense_grad_bf16(self, m, n, k): - layout = "NN" - x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout) + data_layout = "NN" + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) def primitive_func(x, w, contracting_dims): primitive_out = dense(x, w, contracting_dims=contracting_dims) return jnp.mean(primitive_out) - def ref_func(x, w, layout): - return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, layout)) + def ref_func(x, w, data_layout): + return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, data_layout)) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1)) @@ -722,19 +750,19 @@ def ref_func(x, w, layout): primitive_out, (primitive_x_grad, primitive_w_grad) = value_n_grad_primitive_func( x, w, contracting_dims ) - ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, layout) + ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)]) + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode): - layout = "NN" - x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout) + data_layout = "NN" + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) key = jax.random.PRNGKey(1) bias = jax.random.uniform(key, n, dtype=jnp.bfloat16) @@ -745,9 +773,9 @@ def primitive_func(x, w, bias, contracting_dims, quantizer_set): ) return jnp.mean(primitive_out) - def ref_func(x, w, bias, layout): + def ref_func(x, w, bias, data_layout): return jnp.mean( - self._ref_gemm_with_jnp_dot(x, w, layout) + jnp.expand_dims(bias, axis=0) + self._ref_gemm_with_jnp_dot(x, w, data_layout) + jnp.expand_dims(bias, axis=0) ) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) @@ -763,7 +791,9 @@ def ref_func(x, w, bias, layout): value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) ) - ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(x, w, bias, layout) + ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func( + x, w, bias, data_layout + ) assert_allclose(primitive_out, ref_out, dtype=q_dtype) assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype) @@ -791,7 +821,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan class TestFusedDense: @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("m,n,k", [(512, 128, 128)]) + @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @@ -873,7 +903,7 @@ def ref_func(x, w, gamma, beta): assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("m,n,k", [(512, 128, 256)]) + @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @@ -898,13 +928,13 @@ def test_layernorm_mlp_grad( x = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16) kernel_1 = jax.random.normal( - subkeys[1], (k, len(activation_type) * n), jnp.bfloat16 + subkeys[1], (k, len(activation_type), n), jnp.bfloat16 ) / jnp.sqrt(k) kernel_2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n) gamma = jax.random.normal(subkeys[5], (k,), jnp.bfloat16) beta = None # was tested in TestNorm if use_bias: - bias_1 = jax.random.normal(subkeys[3], (len(activation_type) * n), jnp.bfloat16) + bias_1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16) bias_2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16) else: bias_1 = None @@ -1039,19 +1069,19 @@ def _generate_grouped_gemm_input(self, dtype, shape_list, layout_list): subkeys = jax.random.split(key, len(shape_list) * 2) lhs_list, rhs_list, contracting_dims_list = [], [], [] - for i, ((m, n, k), layout) in enumerate(zip(shape_list, layout_list)): + for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)): lhs = jax.random.uniform( subkeys[2 * i], - (m if layout[0] == "N" else k, k if layout[0] == "N" else m), + (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m), dtype=dtype, ) rhs = jax.random.uniform( subkeys[2 * i + 1], - (k if layout[1] == "N" else n, n if layout[1] == "N" else k), + (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k), dtype=dtype, ) - lhs_contracting_dim = (1,) if layout[0] == "N" else (0,) - rhs_contracting_dim = (0,) if layout[1] == "N" else (1,) + lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) + rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,) contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) lhs_list.append(lhs) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index efc24fe6ea..4350d5e8f3 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -45,11 +45,17 @@ SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) DTYPES = [jnp.bfloat16, jnp.float16] -INPUT_SHAPE = [[2, 64, 64]] # [batch, seqlen, hidden_in] +INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in] LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES) DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES) DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES) +KERNEL_1_AXES = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES) +KERNEL_2_AXES = (W_TP_AXES, W_FSDP_AXES) +LN_SCALE_AXES = (W_NO_SHARD_AXES,) +LN_BIAS_AXES = (W_NO_SHARD_AXES,) +BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES) +BIAS_2_AXES = (W_NO_SHARD_AXES,) INTERMEDIATE = 64 @@ -60,7 +66,6 @@ def generate_fsdp_and_tp_configs(): configs.append( [2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")] ) - if is_devices_enough(4): configs.append( [4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")] @@ -80,13 +85,13 @@ def generate_inputs(self, input_shape, activation_type, use_bias, dtype): x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype) k1 = jax.random.normal( - subkeys[1], (hidden_in, len(activation_type) * INTERMEDIATE), dtype + subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype ) / jnp.sqrt(hidden_in) k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt( INTERMEDIATE ) if use_bias: - b1 = jax.random.normal(subkeys[3], (len(activation_type) * INTERMEDIATE), dtype) + b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype) b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype) else: b1 = None @@ -111,10 +116,12 @@ def layernorm_fp8_mlp_prim_func( layernorm_input_axes = LAYERNORM_INPUT_AXES dot_1_input_axes = DOT_1_INPUT_AXES dot_2_input_axes = DOT_2_INPUT_AXES + kernel_1_axes = KERNEL_1_AXES + kernel_2_axes = KERNEL_2_AXES else: layernorm_input_axes = None - dot_1_input_axes = None - dot_2_input_axes = None + dot_1_input_axes = dot_2_input_axes = None + kernel_1_axes = kernel_2_axes = None quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2) @@ -130,6 +137,8 @@ def layernorm_fp8_mlp_prim_func( norm_input_axes=layernorm_input_axes, dot_1_input_axes=dot_1_input_axes, dot_2_input_axes=dot_2_input_axes, + kernel_1_axes=kernel_1_axes, + kernel_2_axes=kernel_2_axes, activation_type=activation_type, quantizer_sets=quantizer_sets, ) @@ -142,7 +151,7 @@ def layernorm_fp8_mlp_prim_func( @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) - def test_layernorm_fp8_mlp_primitive( + def test_layernorm_mlp_grad( self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe ): device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config @@ -168,12 +177,12 @@ def test_layernorm_fp8_mlp_primitive( devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): - k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", "tp")) + k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp")) k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp")) k1_ = jax.device_put(k1, k1_sharding) k2_ = jax.device_put(k2, k2_sharding) if use_bias: - b1_sharding = NamedSharding(mesh, PartitionSpec("tp")) + b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp")) b1_ = jax.device_put(b1, b1_sharding) else: b1_sharding = b1_ = None @@ -267,16 +276,7 @@ def _test_layernorm_mlp( transpose_batch_sequence=False, # input: [batch, seqlen, hidden] intermediate_dim=INTERMEDIATE, activations=activation_type, - scale_axes=(W_NO_SHARD_AXES,), - ln_bias_axes=(W_NO_SHARD_AXES,), - kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), - kernel_axes_2=(W_TP_AXES, W_FSDP_AXES), use_bias=use_bias, - bias_axes_1=(W_JOINED_AXES, W_TP_AXES), - bias_axes_2=(W_NO_SHARD_AXES,), - layernorm_input_axes=LAYERNORM_INPUT_AXES, - dot_1_input_axes=DOT_1_INPUT_AXES, - dot_2_input_axes=DOT_2_INPUT_AXES, ) params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) mlp_out_single, ln_out_single = ln_mlp_single.apply( @@ -295,13 +295,13 @@ def _test_layernorm_mlp( transpose_batch_sequence=False, intermediate_dim=INTERMEDIATE, activations=activation_type, - scale_axes=(W_NO_SHARD_AXES,), - ln_bias_axes=(W_NO_SHARD_AXES,), - kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), - kernel_axes_2=(W_TP_AXES, W_FSDP_AXES), + scale_axes=LN_SCALE_AXES, + ln_bias_axes=LN_BIAS_AXES, + kernel_axes_1=KERNEL_1_AXES, + kernel_axes_2=KERNEL_2_AXES, use_bias=use_bias, - bias_axes_1=(W_JOINED_AXES, W_TP_AXES), - bias_axes_2=(W_NO_SHARD_AXES,), + bias_axes_1=BIAS_1_AXES, + bias_axes_2=BIAS_2_AXES, layernorm_input_axes=LAYERNORM_INPUT_AXES, dot_1_input_axes=DOT_1_INPUT_AXES, dot_2_input_axes=DOT_2_INPUT_AXES, @@ -334,7 +334,7 @@ def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) - def test_layernorm_fp8_mlp_layer( + def test_layernorm_mlp_layer_fp8( self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe ): self._test_layernorm_mlp( diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index a2d0a6f4d9..ef6def2d03 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -91,7 +91,6 @@ def _activation_bwd_rule(activation_type, ctx, g): (x, _) = ctx assert x.dtype == g.dtype dx = tex.dact_lu(g, x, activation_type) - dx = jnp.reshape(dx, x.shape) return (dx, None) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 70227e1620..d7676781c3 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -26,12 +26,12 @@ should_apply_1x_fused_dbias_war_for_arch_l_100, NamedSharding, ) -from .quantization import _jax_quantize_dbias, _jax_dbias, quantize_dbias +from .quantization import _jax_dbias, _quantize_dbias_impl from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..quantize import ScaledTensor, ScaledTensorFactory from ..quantize import ( Quantizer, - QuantizeAxis, + QuantizeLayout, DelayedScaleQuantizer, ScalingMode, ) @@ -110,41 +110,31 @@ def abstract( """ te_act_lu_p abstract """ - del act_enum, act_len, scale_shapes + del act_enum, scale_shapes dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 - - out_shape = ( - *x_aval.shape[:-2], - 1, - x_aval.shape[-1], + assert x_aval.shape[-2] == act_len, ( + "activation input should be replicated by act_len in the -2 axis, got input shape" + f" {x_aval.shape} and act_len {act_len}" ) + + out_shape = (*x_aval.shape[:-2], x_aval.shape[-1]) # Exclude act dim out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) + updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(out_shape[:-2] + (out_shape[-1],), is_padded=not is_outer) - - if len(rowwise_scale_inv_shape) > 1: - rowwise_scale_inv_shape = ( - rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:] - ) - if len(colwise_scale_inv_shape) > 1: - colwise_scale_inv_shape = ( - colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:] - ) - + ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1) + if not is_2x: + out_shape = (1,) + colwise_scale_inv_shape = (1,) + colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) - - colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype) - colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype) - if is_2x: - colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) - colwise_scale_inv_aval = jax.core.ShapedArray( - shape=colwise_scale_inv_shape, dtype=scale_dtype - ) + colwise_scale_inv_aval = jax.core.ShapedArray( + shape=colwise_scale_inv_shape, dtype=scale_dtype + ) return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval @@ -211,15 +201,8 @@ def impl( ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(out.shape[:-2] + (out.shape[-1],), is_padded=False) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - rowwise_scale_inv_shape = ( - rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:] - ) - if is_2x: - colwise_scale_inv_shape = ( - colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:] - ) + ).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1) + # Slice out padding for MXFP8, noop for DelayedScaling scale_inv = jax.lax.slice( scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape ) @@ -227,6 +210,7 @@ def impl( colwise_scale_inv = jax.lax.slice( colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape ) + return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax @staticmethod @@ -292,11 +276,14 @@ def infer_sharding_from_operands( is_outer, ) # Unused. x_spec = get_padded_spec(arg_infos[0]) - out_spec = (*x_spec[:-2], None, x_spec[-2]) + scale_spec = get_padded_spec(arg_infos[1]) + + out_spec = (*x_spec[:-2], x_spec[-1]) out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") + if is_2x: if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: - colwise_out_spec = multidim_transpose(out_spec) + colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: colwise_out_spec = out_spec else: @@ -304,18 +291,24 @@ def infer_sharding_from_operands( colwise_out_sharding = NamedSharding( mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out" ) + + scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + scale_inv_spec = out_spec + + if is_2x: + colwise_scale_inv_spec = scale_inv_spec + scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv" + mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv" ) - amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax") - - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv" - ) - colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description( - "ActLuPrimitive.colwise_scale_inv" + amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax") + colwise_scale_inv_sharding = NamedSharding( + mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv" ) + return ( out_sharding, colwise_out_sharding, @@ -340,14 +333,14 @@ def partition( ): del result_infos, is_outer # Unused. x_spec = get_padded_spec(arg_infos[0]) - out_spec = (*x_spec[:-1], x_spec[-1]) - if act_len == 2 and x_spec[-1] is None: - # Ensure last axis is partitioned and not the gating axis - out_spec = (*x_spec[:-2], None, x_spec[-2]) + scale_spec = get_padded_spec(arg_infos[1]) + + out_spec = (*x_spec[:-2], x_spec[-1]) out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") + if is_2x: if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: - colwise_out_spec = multidim_transpose(out_spec) + colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: colwise_out_spec = out_spec else: @@ -355,21 +348,25 @@ def partition( colwise_out_sharding = NamedSharding( mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out" ) + + scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + scale_inv_spec = out_spec + + if is_2x: + colwise_scale_inv_spec = scale_inv_spec + scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv" + mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv" ) - amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax") - - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv" - ) - colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description( - "ActLuPrimitive.colwise_scale_inv" + amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax") + colwise_scale_inv_sharding = NamedSharding( + mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv" ) - arg_shardings = list(arg_i.sharding for arg_i in arg_infos) - arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec)) - arg_shardings = tuple(arg_shardings) + + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = ( out_sharding, colwise_out_sharding, @@ -413,6 +410,7 @@ def sharded_impl(x, scale): register_primitive(ActLuPrimitive) +# TODO(Jeremy): replace is_2x with q_layout class DActLuDBiasQuantizePrimitive(BasePrimitive): """ DActLu DBias Cast Transpose Primitive @@ -445,42 +443,41 @@ def abstract( te_dact_dbias_quantize_p abstract """ del act_enum, scale_shapes - dtype = dtypes.canonicalize_dtype(dz_aval.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert x_aval.dtype == dtype + dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype) + assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert x_aval.dtype == dz_dtype + assert x_aval.shape[-2] == act_len, ( + "activation input should be replicated by act_len in the -2 axis, got input shape" + f" {x_aval.shape} and act_len {act_len}" + ) assert scale_aval.dtype == jnp.float32 ir_hidden_size = dz_aval.shape[-1] - gi_hidden_size = x_aval.shape[-1] + gi_hidden_size = act_len * x_aval.shape[-1] assert act_len * ir_hidden_size == gi_hidden_size out_shape = x_aval.shape out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) + updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) - - scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) - - colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype) - - dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2) if is_2x: - # Don't transpose output for MXFP8 - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - t_shape = out_shape + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2) else: - t_shape = multidim_transpose(out_shape) - colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype) - colwise_scale_inv_aval = jax.core.ShapedArray( - shape=colwise_scale_inv_shape, dtype=scale_dtype - ) + colwise_out_shape = out_shape + else: + colwise_out_shape = (1,) + colwise_scale_inv_shape = (1,) + colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype) + scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) + colwise_scale_inv_aval = jax.core.ShapedArray( + shape=colwise_scale_inv_shape, dtype=scale_dtype + ) if is_dbias: - dbias_shape = gi_hidden_size - dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype) + dbias_shape = (act_len, ir_hidden_size) (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes( x_aval.size // gi_hidden_size, gi_hidden_size, @@ -489,9 +486,14 @@ def abstract( scaling_mode, is_2x, ) - wkspace_aval = x_aval.update( - shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) - ) + wkspace_shape = wkspace_info[0] + wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) + else: + dbias_shape = (1,) + wkspace_shape = (1,) + wkspace_dtype = jnp.float32 + dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype) + wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype) return ( out_aval, @@ -587,23 +589,16 @@ def impl( ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(x.shape, is_padded=False) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - scale_inv = jax.lax.slice( - scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape + ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2) + # Slice out padding for MXFP8, noop for DelayedScaling + scale_inv = jax.lax.slice( + scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape + ) + if is_2x: + colwise_scale_inv = jax.lax.slice( + colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape ) - if is_2x: - colwise_scale_inv = jax.lax.slice( - colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape - ) - return ( - out, - colwise_out, - scale_inv, - colwise_scale_inv, - updated_amax, - dbias, - ) # Exclude wkspace + return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias @staticmethod def batcher( @@ -670,15 +665,16 @@ def infer_sharding_from_operands( result_infos, ): del out_dtype, result_infos, act_enum - del scale_dtype, scale_shapes, is_dbias, act_len, is_outer + del scale_dtype, scale_shapes, act_len, is_outer x_spec = get_padded_spec(arg_infos[1]) + scale_spec = get_padded_spec(arg_infos[2]) out_sharding = NamedSharding( mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" ) if is_2x: if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: - colwise_x_spec = multidim_transpose(x_spec) + colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: colwise_x_spec = x_spec else: @@ -687,23 +683,32 @@ def infer_sharding_from_operands( mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out" ) - dbias_shaprding = NamedSharding( + dbias_spec = x_spec[-2:] if is_dbias else (None,) + dbias_sharding = NamedSharding( mesh, - PartitionSpec(x_spec[-1]), + PartitionSpec(*dbias_spec), desc="DActLuDBiasQuantizePrimitive.dbias", ) + + scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + scale_inv_spec = x_spec + + if is_2x: + colwise_scale_inv_spec = scale_inv_spec + scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv" + mesh, PartitionSpec(*scale_inv_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv" ) amax_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax" + mesh, PartitionSpec(*amax_spec), desc="DActLuDBiasQuantizePrimitive.amax" ) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv" - ) - colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description( - "DActLuDBiasQuantizePrimitive.colwise_scale_inv" + colwise_scale_inv_sharding = NamedSharding( + mesh, + PartitionSpec(*colwise_scale_inv_spec), + desc="DActLuDBiasQuantizePrimitive.colwise_scale_inv", ) return ( out_sharding, @@ -711,7 +716,7 @@ def infer_sharding_from_operands( scale_inv_sharding, colwise_scale_inv_sharding, amax_sharding, - dbias_shaprding, + dbias_sharding, ) @staticmethod @@ -731,10 +736,15 @@ def partition( ): del result_infos, is_outer x_spec = get_padded_spec(arg_infos[1]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec), desc="out") + scale_spec = get_padded_spec(arg_infos[2]) + + out_sharding = NamedSharding( + mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" + ) + if is_2x: if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: - colwise_x_spec = multidim_transpose(x_spec) + colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: colwise_x_spec = x_spec else: @@ -743,38 +753,39 @@ def partition( mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out" ) - dbias_shaprding = NamedSharding( + dbias_spec = x_spec[-2:] if is_dbias else (None,) + dbias_sharding = NamedSharding( mesh, - PartitionSpec(x_spec[-1]), + PartitionSpec(*dbias_spec), desc="DActLuDBiasQuantizePrimitive.dbias", ) + + scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + scale_inv_spec = x_spec + + if is_2x: + colwise_scale_inv_spec = scale_inv_spec + scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv" + mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv" ) - amax_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax" - ) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv" - ) - colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description( - "DActLuDBiasQuantizePrimitive.colwise_scale_inv" + amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax") + colwise_scale_inv_sharding = NamedSharding( + mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv" ) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - arg_shardings = ( - arg_shardings[1], - arg_shardings[1], - *arg_shardings[2:], - ) # dz and x are the same + out_shardings = ( out_sharding, colwise_out_sharding, scale_inv_sharding, colwise_scale_inv_sharding, amax_sharding, - dbias_shaprding, + dbias_sharding, ) def sharded_impl(dz, x, scale): @@ -816,14 +827,21 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S """ JAX native activation implementation """ - x = jnp.split(inputs, len(activation_type), axis=-1) + act_len = len(activation_type) + assert inputs.shape[-2] == act_len, ( + "activation input should be replicated by act_len in the -2 axis, got input shape" + f" {inputs.shape} and act_len {act_len}" + ) + + x = jnp.split(inputs, act_len, axis=-2) acts = [] for idx, act_fn in enumerate(activation_type): x_i = _convert_to_activation_function(act_fn)(x[idx]) acts.append(x_i) x = reduce(operator.mul, acts) + x = jnp.squeeze(x, axis=-2) if quantizer: - return quantizer.quantize(x) + return quantizer.quantize(x, flatten_axis=-1) return x @@ -837,6 +855,12 @@ def _jax_quantize_dact_dbias( """ JAX implementation of dact_lu and dbias with optional quantization """ + act_len = len(activation_type) + assert x.shape[-2] == act_len, ( + "activation input should be replicated by act_len in the -2 axis, got input shape" + f" {x.shape} and act_len {act_len}" + ) + _, vjp_func = jax.vjp( partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) ) @@ -844,10 +868,10 @@ def _jax_quantize_dact_dbias( dbias = None if is_dbias: - dbias = _jax_dbias(dx).astype(x.dtype) + dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2) if quantizer is not None: - dx = quantizer.quantize(dx, dq_dtype=x.dtype) + dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2) else: dx = dx.astype(x.dtype) @@ -863,6 +887,7 @@ def act_lu( Args: x: Input tensor to be processed. + Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function to apply. quantizer: Optional quantizer for FP8 quantization of the output. @@ -873,12 +898,17 @@ def act_lu( A ScaledTensor containing the quantized activated input. """ act_type_id = ActivationEnum[activation_type].value + act_len = len(activation_type) + assert x.shape[-2] == act_len, ( + "activation input should be replicated by act_len in the -2 axis, got input shape" + f" {x.shape} and act_len {act_len}" + ) if not ActLuPrimitive.enabled(): return _jax_act_lu(x, activation_type, quantizer) # TE/common does not support colwise-only quantization yet - if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: + if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: return _jax_act_lu(x, activation_type, quantizer) # TE/common does not support 2x quantization for DelayedScaling yet @@ -889,16 +919,15 @@ def act_lu( return war_output scale = jnp.empty((1,), jnp.float32) - output_shape = (*x.shape[:-1], x.shape[-1] // len(activation_type)) + output_shape = (*x.shape[:-2], x.shape[-1]) if quantizer is None: - x = x.reshape((-1, len(activation_type), x.shape[-1] // len(activation_type))) out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind( x, scale, out_dtype=x.dtype, act_enum=act_type_id, - act_len=len(activation_type), + act_len=act_len, scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, is_2x=False, scale_dtype=jnp.float32, @@ -911,7 +940,6 @@ def act_lu( if isinstance(quantizer, DelayedScaleQuantizer): scale = quantizer.scale - x = x.reshape((*x.shape[:-1], len(activation_type), x.shape[-1] // len(activation_type))) ( rowwise_casted_output, colwise_casted_output, @@ -923,25 +951,15 @@ def act_lu( scale, out_dtype=quantizer.q_dtype, act_enum=act_type_id, - act_len=len(activation_type), + act_len=act_len, scaling_mode=quantizer.scaling_mode.value, is_2x=quantizer.is_2x2x(), scale_dtype=quantizer.get_scale_dtype(), - scale_shapes=quantizer.get_scale_shapes(output_shape), + # output does not have act axis + scale_shapes=quantizer.get_scale_shapes(output_shape, flatten_axis=-1), is_outer=True, ) - rowwise_casted_output = rowwise_casted_output.reshape(output_shape) - if len(rowwise_scale_inv.shape) > 1: - rowwise_scale_inv = jnp.squeeze(rowwise_scale_inv, axis=-2) # Remove act axis - if quantizer.q_axis in (QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE): - colwise_output_shape = output_shape - if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: - colwise_output_shape = multidim_transpose(output_shape) - colwise_casted_output = colwise_casted_output.reshape(colwise_output_shape) - if len(colwise_scale_inv.shape) > 1: - colwise_scale_inv = jnp.squeeze(colwise_scale_inv, axis=-2) # Remove act axis - quantizer.update(updated_amax) return ScaledTensorFactory.create( @@ -951,8 +969,8 @@ def act_lu( colwise_scale_inv=colwise_scale_inv, scaling_mode=quantizer.scaling_mode, dq_dtype=x.dtype, - q_axis=quantizer.q_axis, - layout=quantizer.get_layout(), + q_layout=quantizer.q_layout, + data_layout=quantizer.get_data_layout(), ) @@ -968,7 +986,7 @@ def quantize_dact_dbias( Args: dz: Gradient of the output with respect to the activation output. x: Input tensor that was processed by the forward pass. - Shape: (..., ACT_DIM * K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations + Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). is_dbias: If True, compute bias gradient. Defaults to True. quantizer: Optional quantizer for FP8 quantization of the output. @@ -979,21 +997,25 @@ def quantize_dact_dbias( - The gradient of the activation with respect to the bias. """ + act_len = len(activation_type) + assert x.shape[-2] == act_len, ( + "activation input should be replicated by act_len in the -2 axis, got input shape" + f" {x.shape} and act_len {act_len}" + ) + if not DActLuDBiasQuantizePrimitive.enabled(): return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) # TE/common does not support colwise-only quantization yet - if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: + if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): - out, _ = quantize_dact_dbias( - dz=dz, x=x, activation_type=activation_type, is_dbias=False, quantizer=None - ) - return quantize_dbias(out, is_dbias=True, quantizer=quantizer) + out = dact_lu(dz, x, activation_type, quantizer=None) + return _quantize_dbias_impl(out, quantizer, is_dbias=True, flatten_axis=-2) - is_gated = len(activation_type) == 2 + is_gated = act_len == 2 # TE/common does not support DelayedScaling2x for gated-act yet if is_gated: war_output = try_apply_delayed_scaling_2x_war( @@ -1003,6 +1025,7 @@ def quantize_dact_dbias( activation_type=activation_type, is_dbias=is_dbias, quantizer=quantizer, + flatten_axis=-2, ) if war_output is not None: return war_output @@ -1025,12 +1048,12 @@ def quantize_dact_dbias( scale_shapes=((), ()), # unused is_dbias=False, act_enum=act_type_id, - act_len=len(activation_type), + act_len=act_len, is_outer=True, ) dbias = None if is_dbias: - dbias = _jax_dbias(output).astype(x.dtype) + dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) return output.astype(x.dtype), dbias if isinstance(quantizer, DelayedScaleQuantizer): @@ -1041,16 +1064,9 @@ def quantize_dact_dbias( dgated = dact_lu( dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type ) - # TODO(Jeremy): Debug - TE's quantize_dbias produced nans in this case for distributed layernorm_mlp tests - if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - out, dbias = _jax_quantize_dbias(dgated, quantizer=quantizer, dq_dtype=x.dtype) - else: - out, dbias = quantize_dbias( - dgated, - quantizer=quantizer, - is_dbias=True, - dq_dtype=x.dtype, - ) + out, dbias = _quantize_dbias_impl( + dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 + ) return out, dbias out_shape = x.shape @@ -1070,10 +1086,11 @@ def quantize_dact_dbias( scaling_mode=quantizer.scaling_mode.value, is_2x=quantizer.is_2x2x(), scale_dtype=quantizer.get_scale_dtype(), - scale_shapes=quantizer.get_scale_shapes(out_shape), + # output has act axis + scale_shapes=quantizer.get_scale_shapes(out_shape, flatten_axis=-2), is_dbias=is_dbias, act_enum=act_type_id, - act_len=len(activation_type), + act_len=act_len, is_outer=True, ) @@ -1090,8 +1107,9 @@ def quantize_dact_dbias( colwise_scale_inv=colwise_scale_inv, scaling_mode=quantizer.scaling_mode, dq_dtype=x.dtype, - q_axis=quantizer.q_axis, - layout=quantizer.get_layout(), + q_layout=quantizer.q_layout, + data_layout=quantizer.get_data_layout(), + flatten_axis=-2, # as output has act axis ) return out, dbias diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0fad75817f..736105dd75 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -6,9 +6,9 @@ from typing import Tuple, Sequence, Union, Dict, List from functools import partial, reduce import operator -from transformer_engine_jax import get_device_compute_capability import jax import jax.numpy as jnp +from transformer_engine_jax import get_device_compute_capability from .base import BasePrimitive, register_primitive @@ -183,10 +183,9 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision): # Reshape + Transpose # [..., M, K] -> [B, M, K] # [..., K, M] -> [B, M, K] - lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.layout == "N") - rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.layout == "T") + lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.data_layout == "N") + rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.data_layout == "T") - # _shape_normalization ensures contracting_dims=2 and batch_dims=0 dim_nums = (((2,), (2,)), ((0,), (0,))) out_3d = jax.lax.dot_general( lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype @@ -203,9 +202,9 @@ def _jax_gemm_delayed_scaling_fp8( ), "rhs does not have delayed tensor scaling mode" (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums - if lhs.layout == "T": + if lhs.data_layout == "T": lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract) - if rhs.layout == "T": + if rhs.data_layout == "T": rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract) lhs_dn = (lhs_contract, lhs_batch) @@ -403,19 +402,19 @@ def grouped_gemm( lhs_shape = lhs.data.shape rhs_shape = rhs.data.shape out_dtype = lhs.dq_dtype - # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal layout + # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: assert not ( lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 ), "FP8 GEMM does not support E5M2 * E5M2" ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims - if lhs.layout == "T": + if lhs.data_layout == "T": lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim - if rhs.layout == "T": + if rhs.data_layout == "T": rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) else: - # For jnp.ndarray, only consider contracting_dims, layout is always NN + # For jnp.ndarray, only consider contracting_dims, data_layout is always NN scaling_mode = ScalingMode.NVTE_NO_SCALING lhs_shape = lhs.shape rhs_shape = rhs.shape @@ -432,8 +431,8 @@ def grouped_gemm( lhs_3d = _shape_normalization(lhs, lhs_dn) rhs_3d = _shape_normalization(rhs, rhs_dn) elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: - lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.layout == "N") - rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.layout == "T") + lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") + rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: lhs_3d = _shape_normalization(lhs.data, lhs_dn) rhs_3d = _shape_normalization(rhs.data, rhs_dn) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 980ea556bb..c79eda5568 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -19,7 +19,7 @@ import transformer_engine_jax from ..sharding import get_padded_spec as te_get_padded_spec -from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeAxis +from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeLayout TEDType = transformer_engine_jax.DType @@ -107,37 +107,37 @@ def normalize_axis_boundary(axis, ndim): return axis if axis >= 0 else ndim + axis -def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis_boundary=-1): +def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1): """ te_cast_transpose_p multi-dims transpose static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be involved into transpose, -1 means all axes involve into transpose. - transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for - transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary + transpose_axis: int, Indicate how to split multi-dimensions tensors to 2D matrix for + transpose. Note, transpose_axis should be greater than static_axis_boundary examples: X in shape (dim0, dim1, dim2, dim3, dim4) - static_axis_boundary == -1, transpose_axis_boundary == 2 + static_axis_boundary == -1, transpose_axis == 2 Xt = (dim2, dim3, dim4, dim0, dim1) - static_axis_boundary == 0, transpose_axis_boundary == 2 + static_axis_boundary == 0, transpose_axis == 2 Xt = (dim0, dim2, dim3, dim4, dim1) - static_axis_boundary == 0, transpose_axis_boundary == 3 + static_axis_boundary == 0, transpose_axis == 3 Xt = (dim0, dim3, dim4, dim1. dim2) """ if static_axis_boundary < 0: static_axis_boundary = -1 # means no static axes assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose. transpose_start_idx = static_axis_boundary + 1 - transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, len(shape)) - assert transpose_start_idx < transpose_axis_boundary + transpose_axis = normalize_axis_boundary(transpose_axis, len(shape)) + assert transpose_start_idx < transpose_axis return ( *shape[:transpose_start_idx], - *shape[transpose_axis_boundary:], - *shape[transpose_start_idx:transpose_axis_boundary], + *shape[transpose_axis:], + *shape[transpose_start_idx:transpose_axis], ) @@ -195,13 +195,13 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant break return ( quantizer is not None - and quantizer.q_axis == QuantizeAxis.ROWWISE + and quantizer.q_layout == QuantizeLayout.ROWWISE and arch_l_100 and is_dbias ) -def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs): +def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, **kwargs): """ Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling. It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result. @@ -224,14 +224,19 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs): # 2x is not supported by TE kernels for delayed scaling # so revert to 1x and transpose in JAX - quantizer.q_axis = QuantizeAxis.ROWWISE + quantizer.q_layout = QuantizeLayout.ROWWISE rowwise = f(*args, **kwargs, quantizer=quantizer) other_outputs = None if isinstance(rowwise, tuple): other_outputs = rowwise[1:] rowwise = rowwise[0] - quantizer.q_axis = QuantizeAxis.ROWWISE_COLWISE - colwise_data = jnp.transpose(rowwise.data, (-1, *range(rowwise.data.ndim - 1))) + quantizer.q_layout = QuantizeLayout.ROWWISE_COLWISE + if flatten_axis < 0: + flatten_axis += rowwise.data.ndim + assert 0 < flatten_axis < rowwise.data.ndim, "flatten_axis is out of bounds" + colwise_data = jnp.transpose( + rowwise.data, (*range(flatten_axis, rowwise.data.ndim), *range(flatten_axis)) + ) output_2x = ScaledTensorFactory.create( data=rowwise.data, scale_inv=rowwise.scale_inv, @@ -239,8 +244,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs): colwise_scale_inv=rowwise.scale_inv, scaling_mode=quantizer.scaling_mode, dq_dtype=rowwise.dq_dtype, - q_axis=QuantizeAxis.ROWWISE_COLWISE, - layout=quantizer.get_layout(), + q_layout=QuantizeLayout.ROWWISE_COLWISE, + data_layout=quantizer.get_data_layout(), + flatten_axis=flatten_axis, ) if other_outputs is not None: return (output_2x,) + other_outputs diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 4a342dd4e0..74882c92db 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -30,7 +30,7 @@ from ..quantize import ScaledTensor, ScaledTensorFactory from ..quantize import ( Quantizer, - QuantizeAxis, + QuantizeLayout, DelayedScaleQuantizer, ScalingMode, ) @@ -277,14 +277,14 @@ def impl( rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x( x.shape, is_padded=False ) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - scale_inv = scale_inv.flatten()[ - : reduce(operator.mul, rowwise_scale_inv_shape) - ].reshape(rowwise_scale_inv_shape) - if is_2x: - colwise_scale_inv = colwise_scale_inv.flatten()[ - : reduce(operator.mul, colwise_scale_inv_shape) - ].reshape(colwise_scale_inv_shape) + # slice out padding for mxfp8, noop for DelayedScaling + scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape( + rowwise_scale_inv_shape + ) + if is_2x: + colwise_scale_inv = colwise_scale_inv.flatten()[ + : reduce(operator.mul, colwise_scale_inv_shape, 1) + ].reshape(colwise_scale_inv_shape) return ( out, colwise_out, @@ -816,7 +816,7 @@ def layernorm_fwd( return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) # TE/common does not support normalization with colwise only quantization yet - if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: + if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) scale = ( @@ -900,8 +900,8 @@ def layernorm_fwd( colwise_scale_inv=colwise_scale_inv, scaling_mode=quantizer.scaling_mode, dq_dtype=x.dtype, - q_axis=quantizer.q_axis, - layout=quantizer.get_layout(), + q_layout=quantizer.q_layout, + data_layout=quantizer.get_data_layout(), ) return scaled_tensor, mu, rsigma @@ -997,7 +997,7 @@ def rmsnorm_fwd( return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) # TE/common does not support normalization with colwise only quantization yet - if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: + if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) scale = ( @@ -1082,8 +1082,8 @@ def rmsnorm_fwd( colwise_scale_inv=colwise_scale_inv, scaling_mode=quantizer.scaling_mode, dq_dtype=x.dtype, - q_axis=quantizer.q_axis, - layout=quantizer.get_layout(), + q_layout=quantizer.q_layout, + data_layout=quantizer.get_data_layout(), ) return scaled_tensor, rsigma diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 551b4b4bdb..034e149c50 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -2,6 +2,8 @@ # # See LICENSE for license information. """JAX/TE custom ops for quantization""" +import operator +from functools import reduce from typing import Tuple, Optional from packaging import version @@ -24,7 +26,7 @@ ) from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory -from ..quantize import Quantizer, QuantizeAxis, DelayedScaleQuantizer, ScalingMode +from ..quantize import Quantizer, QuantizeLayout, DelayedScaleQuantizer, ScalingMode if version.parse(jax.__version__) >= version.parse("0.5.0"): from jax import ffi # pylint: disable=ungrouped-imports @@ -50,7 +52,8 @@ class DBiasQuantizePrimitive(BasePrimitive): 6, 7, 8, - ) # out_dtype, scaling_mode, q_axis, scale_dtype, scale_shapes, is_dbias, is_outer + 9, + ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, scale_shapes, is_dbias, is_outer inner_primitive = None outer_primitive = None @@ -61,7 +64,8 @@ def abstract( *, out_dtype, scaling_mode, - q_axis, + q_layout, + flatten_axis, scale_dtype, scale_shapes, is_dbias, @@ -73,49 +77,52 @@ def abstract( del scale_shapes dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + out_shape = x_aval.shape assert scale_aval is None or scale_aval.dtype == jnp.float32 - rowwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype) - - if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): - rowwise_out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) + if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + rowwise_out_shape = out_shape + else: + rowwise_out_shape = (1,) + rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) + ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis) + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) + else: + colwise_out_shape = out_shape + else: + colwise_out_shape = (1,) + colwise_scale_inv_shape = (1,) + colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype) scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) - - colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype) - colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype) - - dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): - t_shape = multidim_transpose(x_aval.shape) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - # Don't transpose output for MXFP8 - t_shape = x_aval.shape - colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype) - colwise_scale_inv_aval = jax.core.ShapedArray( - shape=colwise_scale_inv_shape, dtype=scale_dtype - ) + colwise_scale_inv_aval = jax.core.ShapedArray( + shape=colwise_scale_inv_shape, dtype=scale_dtype + ) if is_dbias: - gi_hidden_size = x_aval.shape[-1] - dbias_shape = (gi_hidden_size,) - dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype) + dbias_shape = x_aval.shape[flatten_axis:] + gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1) (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes( x_aval.size // gi_hidden_size, gi_hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype), ) - wkspace_aval = x_aval.update( - shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) - ) + wkspace_shape = wkspace_info[0] + wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) + else: + dbias_shape = (1,) + wkspace_shape = (1,) + wkspace_dtype = jnp.float32 + dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype) + wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype) return ( rowwise_out_aval, @@ -151,7 +158,8 @@ def lowering( *, out_dtype, scaling_mode, - q_axis, + q_layout, + flatten_axis, scale_dtype, scale_shapes, is_dbias, @@ -169,7 +177,8 @@ def lowering( x, scale, scaling_mode=scaling_mode, - q_axis=q_axis, + q_layout=q_layout, + flatten_axis=flatten_axis, is_dbias=is_dbias, ) @@ -179,7 +188,8 @@ def impl( scale, out_dtype, scaling_mode, - q_axis, + q_layout, + flatten_axis, scale_dtype, scale_shapes, is_dbias, @@ -203,7 +213,8 @@ def impl( scale, out_dtype=out_dtype, scaling_mode=scaling_mode, - q_axis=q_axis, + q_layout=q_layout, + flatten_axis=flatten_axis, scale_dtype=scale_dtype, scale_shapes=scale_shapes, is_dbias=is_dbias, @@ -211,16 +222,14 @@ def impl( ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(x.shape, is_padded=False) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): - scale_inv = jax.lax.slice( - scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape - ) - if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): - colwise_scale_inv = jax.lax.slice( - colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape - ) + ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis) + scale_inv = jax.lax.slice( + scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape + ) + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + colwise_scale_inv = jax.lax.slice( + colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape + ) return ( out, colwise_out, @@ -237,7 +246,8 @@ def batcher( *, out_dtype, scaling_mode, - q_axis, + q_layout, + flatten_axis, scale_dtype, scale_shapes, is_dbias, @@ -260,7 +270,8 @@ def batcher( scale, out_dtype=out_dtype, scaling_mode=scaling_mode, - q_axis=q_axis, + q_layout=q_layout, + flatten_axis=flatten_axis, scale_dtype=scale_dtype, scale_shapes=scale_shapes, is_dbias=is_dbias, @@ -272,7 +283,8 @@ def batcher( def infer_sharding_from_operands( out_dtype, scaling_mode, - q_axis, + q_layout, + flatten_axis, scale_dtype, scale_shapes, is_dbias, @@ -281,16 +293,17 @@ def infer_sharding_from_operands( arg_infos, result_infos, ): - del (out_dtype, result_infos, scale_dtype, scale_shapes, is_dbias, is_outer) # Unused. + del (out_dtype, result_infos, scale_dtype, scale_shapes, is_outer) # Unused. x_spec = get_padded_spec(arg_infos[0]) + scale_spec = get_padded_spec(arg_infos[1]) out_sharding = NamedSharding( mesh, - PartitionSpec(*x_spec[:-1], x_spec[-1]), + PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.out_sharding", ) - if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: - colwise_out_spec = multidim_transpose(x_spec) + colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec else: @@ -300,26 +313,35 @@ def infer_sharding_from_operands( PartitionSpec(*colwise_out_spec), desc="DBiasQuantizePrimitive.colwise_out_sharding", ) - scale_inv_sharding = NamedSharding( + + dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,) + dbias_sharding = NamedSharding( mesh, - PartitionSpec(*get_padded_spec(arg_infos[1])), - desc="DBiasQuantizePrimitive.scale_inv", + PartitionSpec(*dbias_spec), + desc="DBiasQuantizePrimitive.dbias_sharding", ) - amax_sharding = scale_inv_sharding.duplicate_with_new_description( - desc="DBiasQuantizePrimitive.amax_sharding" + + scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + scale_inv_spec = x_spec + + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + colwise_scale_inv_spec = scale_inv_spec + + scale_inv_sharding = NamedSharding( + mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv" ) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv" - ) - colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description( - "DBiasQuantizePrimitive.colwise_scale_inv" + amax_sharding = NamedSharding( + mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax" ) - dbias_sharding = NamedSharding( + colwise_scale_inv_sharding = NamedSharding( mesh, - PartitionSpec(x_spec[-1]), - desc="DBiasQuantizePrimitive.dbias_sharding", + PartitionSpec(*colwise_scale_inv_spec), + desc="DBiasQuantizePrimitive.colwise_scale_inv", ) + return ( out_sharding, colwise_out_sharding, @@ -333,7 +355,8 @@ def infer_sharding_from_operands( def partition( out_dtype, scaling_mode, - q_axis, + q_layout, + flatten_axis, scale_dtype, scale_shapes, is_dbias, @@ -344,14 +367,15 @@ def partition( ): del result_infos, is_outer x_spec = get_padded_spec(arg_infos[0]) + scale_spec = get_padded_spec(arg_infos[1]) out_sharding = NamedSharding( mesh, - PartitionSpec(*x_spec[:-1], x_spec[-1]), + PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.out_sharding", ) - if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: - colwise_out_spec = multidim_transpose(x_spec) + colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec else: @@ -361,26 +385,35 @@ def partition( PartitionSpec(*colwise_out_spec), desc="DBiasQuantizePrimitive.colwise_out_sharding", ) - scale_inv_sharding = NamedSharding( + + dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,) + dbias_sharding = NamedSharding( mesh, - PartitionSpec(*get_padded_spec(arg_infos[1])), - desc="DBiasQuantizePrimitive.scale_inv", + PartitionSpec(*dbias_spec), + desc="DBiasQuantizePrimitive.dbias_sharding", ) - amax_sharding = scale_inv_sharding.duplicate_with_new_description( - desc="DBiasQuantizePrimitive.amax_sharding" + + scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + scale_inv_spec = x_spec + + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + colwise_scale_inv_spec = scale_inv_spec + + scale_inv_sharding = NamedSharding( + mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv" ) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv" - ) - colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description( - "DBiasQuantizePrimitive.colwise_scale_inv" + amax_sharding = NamedSharding( + mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax" ) - dbias_sharding = NamedSharding( + colwise_scale_inv_sharding = NamedSharding( mesh, - PartitionSpec(x_spec[-1]), - desc="DBiasQuantizePrimitive.dbias_sharding", + PartitionSpec(*colwise_scale_inv_spec), + desc="DBiasQuantizePrimitive.colwise_scale_inv", ) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = ( out_sharding, @@ -404,7 +437,8 @@ def sharded_impl(x, scale): scale, out_dtype=out_dtype, scaling_mode=scaling_mode, - q_axis=q_axis, + q_layout=q_layout, + flatten_axis=flatten_axis, scale_dtype=scale_dtype, scale_shapes=scale_shapes, is_dbias=is_dbias, @@ -436,49 +470,45 @@ def sharded_impl(x, scale): register_primitive(DBiasQuantizePrimitive) -def _jax_quantize(x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None): +def _jax_quantize( + x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1 +): if quantizer is None: return x - return quantizer.quantize(x, dq_dtype=dq_dtype) + return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) -def _jax_dbias(dx: jnp.ndarray): +def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1): + assert flatten_axis < 0 + dtype = dtype or dx.dtype dbias = jnp.sum( - dx, - axis=tuple(range(dx.ndim - 1)), + dx.astype(jnp.float32), + axis=tuple(range(dx.ndim + flatten_axis)), keepdims=False, ) - dbias = dbias.ravel() # C++ function returns an 1D array for dbias - return dbias + return dbias.astype(dtype) def _jax_quantize_dbias( x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, + flatten_axis: int = -1, ): if quantizer is None: return x, None - return quantizer.quantize(x, dq_dtype=dq_dtype), _jax_dbias(x) - - -def _jax_dbias( - dx: jnp.ndarray, -): - dbias = jnp.sum( - dx.astype(jnp.float32), - axis=tuple(range(dx.ndim - 1)), - keepdims=False, + return ( + quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis), + _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis), ) - dbias = dbias.ravel() # C++ function returns an 1D array for dbias - return dbias.astype(dx.dtype) -def _quantize_impl( +def _quantize_dbias_impl( x: jnp.ndarray, quantizer: Quantizer, is_dbias: bool = False, dq_dtype: Optional[jnp.dtype] = None, + flatten_axis: int = -1, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -488,40 +518,51 @@ def _quantize_impl( quantizer is not None ), "quantizer must be provided if dq_dtype is provided" + dq_dtype = dq_dtype or x.dtype + if not DBiasQuantizePrimitive.enabled(): if is_dbias: return _jax_quantize_dbias( x, quantizer=quantizer, dq_dtype=dq_dtype, + flatten_axis=flatten_axis, ) - return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None + return ( + _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), + None, + ) # TE/common doesn't support colwise only quantization yet - if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: + if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: if is_dbias: return _jax_quantize_dbias( x, quantizer=quantizer, dq_dtype=dq_dtype, + flatten_axis=flatten_axis, ) - return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None + return ( + _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), + None, + ) scale = jnp.empty((), jnp.float32) # TE/common dbias_quantize does not support 1x on arch < 100 if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): - out, _ = _quantize_impl( + out, _ = _quantize_dbias_impl( x=x, is_dbias=False, quantizer=quantizer, dq_dtype=dq_dtype, + flatten_axis=flatten_axis, ) - dbias = _jax_dbias(x) + dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias if quantizer is None: if is_dbias: - return x, _jax_dbias(x) + return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) return x, None if isinstance(quantizer, DelayedScaleQuantizer): @@ -539,9 +580,10 @@ def _quantize_impl( scale, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - q_axis=quantizer.q_axis.value, + q_layout=quantizer.q_layout.value, + flatten_axis=flatten_axis, scale_dtype=quantizer.get_scale_dtype(), - scale_shapes=quantizer.get_scale_shapes(x.shape), + scale_shapes=quantizer.get_scale_shapes(x.shape, flatten_axis=flatten_axis), is_dbias=is_dbias, is_outer=True, ) @@ -557,18 +599,18 @@ def _quantize_impl( colwise_data=colwise_casted_output, colwise_scale_inv=colwise_scale_inv, scaling_mode=quantizer.scaling_mode, - dq_dtype=dq_dtype if dq_dtype is not None else x.dtype, - q_axis=quantizer.q_axis, - layout=quantizer.get_layout(), + dq_dtype=dq_dtype, + q_layout=quantizer.q_layout, + data_layout=quantizer.get_data_layout(), + flatten_axis=flatten_axis, ) - return out, dbias + return out, dbias.astype(dq_dtype) -# TODO(Phuong): do not expose dq_dtype to users def quantize( x: jnp.ndarray, quantizer: Quantizer, - dq_dtype: Optional[jnp.dtype] = None, + flatten_axis: int = -1, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -576,26 +618,25 @@ def quantize( x: Input tensor to be quantized. Shape: (..., K) where K is the hidden size. quantizer: Quantizer for FP8 quantization of the output. - dq_dtype: Optional dtype for dequantization. - If None, uses the same dtype as the input tensor. + flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. + Defaults to -1. Returns: A ScaledTensor containing the quantized input tensor. """ - out, _ = _quantize_impl( + out, _ = _quantize_dbias_impl( x, quantizer=quantizer, - dq_dtype=dq_dtype, + flatten_axis=flatten_axis, ) return out -# TODO(Phuong): do not expose dq_dtype to users def quantize_dbias( dz: jnp.ndarray, quantizer: Quantizer, is_dbias: bool = True, - dq_dtype: Optional[jnp.dtype] = None, + flatten_axis: int = -1, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -604,8 +645,8 @@ def quantize_dbias( Shape: (..., K) where K is the hidden size. quantizer: Quantizer for FP8 quantization of the output. is_dbias: If True, compute bias gradient. Defaults to True. - dq_dtype: Optional dtype for dequantization. - If None, uses the same dtype as the input tensor. + flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. + Defaults to -1. Returns: A tuple containing: @@ -614,9 +655,6 @@ def quantize_dbias( - The bias gradient tensor. Shape: (K,) or empty if is_dbias is False. """ - return _quantize_impl( - dz, - quantizer=quantizer, - is_dbias=is_dbias, - dq_dtype=dq_dtype, + return _quantize_dbias_impl( + dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis ) diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 861db97a26..e71597e4b3 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -11,14 +11,6 @@ #include "transformer_engine/cast.h" #include "xla/ffi/api/c_api.h" -namespace { -bool is_gated(NVTE_Activation_Type act_type) { - return act_type == NVTE_Activation_Type::GEGLU || act_type == NVTE_Activation_Type::SWIGLU || - act_type == NVTE_Activation_Type::REGLU || act_type == NVTE_Activation_Type::QGEGLU || - act_type == NVTE_Activation_Type::SREGLU; -} -} // namespace - namespace transformer_engine { namespace jax { @@ -44,38 +36,56 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto act_len = input_dims[input_dims.size() - 2]; auto scaling_mode = static_cast(scaling_mode_enum); auto is_2x = static_cast(is_2x_int); + auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis auto input_shape = std::vector{m, act_len * n}; auto output_shape = std::vector{m, n}; + auto output_trans_shape = std::vector{n, m}; auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); auto output_tensor = TensorWrapper(scaling_mode); output_tensor.set_rowwise_data(output, static_cast(out_dtype), output_shape); if (is_fp8_dtype(out_dtype)) { - output_tensor.set_rowwise_scale_inv( - scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), - std::vector{ - product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1), - scale_inv_buf->dimensions().back()}); - } - - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { - NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); - cudaMemsetAsync(amax, 0, sizeof(float), stream); - output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); + NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); + cudaMemsetAsync(amax, 0, sizeof(float), stream); + output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector{1}); + } else { + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), + std::vector{product(scale_inv_buf->dimensions(), 0, flatten_axis), + product(scale_inv_buf->dimensions(), flatten_axis, + scale_inv_buf->dimensions().size())}); + } } if (is_2x) { - output_tensor.set_columnwise_data(colwise_output, static_cast(out_dtype), output_shape); - output_tensor.set_columnwise_scale_inv( - colwise_scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), - std::vector{product(colwise_scale_inv_buf->dimensions(), 0, - colwise_scale_inv_buf->dimensions().size() - 1), - colwise_scale_inv_buf->dimensions().back()}); + auto &tmp_shape = + (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape); + + if (is_fp8_dtype(out_dtype)) { + // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling + auto &tmp_buf = + (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + output_tensor.set_columnwise_scale_inv( + tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), + std::vector{1}); + } else { + output_tensor.set_columnwise_scale_inv( + tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), + std::vector{ + product(tmp_buf->dimensions(), 0, flatten_axis), + product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())}); + } + } } switch (act_type) { @@ -162,8 +172,10 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid } if (is_2x) { - output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, - output_trans_shape); + auto &tmp_shape = scaling_mode == static_cast(NVTE_DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; + output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); // Only the pointers will be checked for scale_inv, thus the shapes do not matter if (is_fp8_dtype(out_dtype)) { @@ -190,9 +202,9 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type output_trans_buf, - Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf, - Result_Type amax_out_buf, Result_Type dbias_buf, + Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x, bool is_dbias, int64_t act_enum) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); @@ -201,11 +213,15 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto *input = input_buf.untyped_data(); auto *act_input = act_input_buf.untyped_data(); + float *scale = reinterpret_cast(scale_buf.untyped_data()); + float *amax = reinterpret_cast(amax_buf->untyped_data()); auto scaling_mode = static_cast(scaling_mode_enum); + auto act_type = static_cast(act_enum); + auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis auto *output = output_buf->untyped_data(); - auto *output_trans = output_trans_buf->untyped_data(); + auto *colwise_output = colwise_output_buf->untyped_data(); auto *dbias = dbias_buf->untyped_data(); void *workspace = workspace_buf->untyped_data(); @@ -213,17 +229,18 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto act_input_dims = act_input_buf.dimensions(); auto workspace_dims = workspace_buf->dimensions(); // m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims - // n = ir_dz_shape[-1], ir_dz_shape == input_dims - auto input_ranks = input_dims.size(); - auto act_input_ranks = act_input_dims.size(); - auto m = product(act_input_dims, 0, act_input_dims.size() - 1); - // 'n' will be 2x the size of input_dims.back() if the dactivation is dgated - auto n = act_input_dims.back(); - auto input_shape = std::vector{m, input_dims.back()}; - auto act_input_shape = std::vector{m, n}; - auto output_shape = std::vector{m, n}; - auto output_trans_shape = std::vector{m, n}; - auto dbias_shape = std::vector{n}; + // n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims + auto act_len = act_input_dims[act_input_dims.size() - 2]; + NVTE_CHECK(act_input_dims.back() == input_dims.back(), + "Shape mismatch between activation input and gradient input"); + auto m = product(act_input_dims, 0, act_input_dims.size() - 2); + auto n = input_dims.back(); + + auto input_shape = std::vector{m, n}; + auto act_input_shape = std::vector{m, n * act_len}; + auto output_shape = std::vector{m, n * act_len}; + auto output_trans_shape = std::vector{n * act_len, m}; + auto dbias_shape = std::vector{n * act_len}; std::vector workspace_shape(workspace_dims.begin(), workspace_dims.end()); auto input_tensor = TensorWrapper(input, input_shape, in_dtype); @@ -231,50 +248,56 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto output_tensor = TensorWrapper(scaling_mode); output_tensor.set_rowwise_data(output, out_dtype, output_shape); if (is_fp8_dtype(out_dtype)) { - output_tensor.set_rowwise_scale_inv( - scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), - std::vector{ - product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1), - scale_inv_buf->dimensions().back()}); - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - float *scale = reinterpret_cast(scale_buf.untyped_data()); - float *amax_out = reinterpret_cast(amax_out_buf->untyped_data()); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling"); - cudaMemsetAsync(amax_out, 0, sizeof(float), stream); + NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); + cudaMemsetAsync(amax, 0, sizeof(float), stream); output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - output_tensor.set_amax(amax_out, DType::kFloat32, std::vector{1}); + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector{1}); + } else { + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), + std::vector{product(scale_inv_buf->dimensions(), 0, flatten_axis), + product(scale_inv_buf->dimensions(), flatten_axis, + scale_inv_buf->dimensions().size())}); } } if (is_2x) { - output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); + auto &tmp_shape = + (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape); if (is_fp8_dtype(out_dtype)) { // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &colwise_scale_inv_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf; - output_tensor.set_columnwise_scale_inv( - colwise_scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), - std::vector{product(colwise_scale_inv_buf->dimensions(), 0, - colwise_scale_inv_buf->dimensions().size() - 1), - colwise_scale_inv_buf->dimensions().back()}); + auto &tmp_buf = + (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + output_tensor.set_columnwise_scale_inv( + tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), + std::vector{1}); + } else { + output_tensor.set_columnwise_scale_inv( + tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), + std::vector{ + product(tmp_buf->dimensions(), 0, flatten_axis), + product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())}); + } } } auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); - auto act_type = static_cast(act_enum); - // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead - NVTE_CHECK(!(is_gated(act_type) && is_dbias), "Unsupported DGatedActedDBias Fusion!"); - NVTE_CHECK(!(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && - is_gated(act_type)), - "TE/common does not support delayed scaling for 2x with gated activations."); + NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!"); + NVTE_CHECK( + !(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && act_len == 2), + "TE/common does not support delayed scaling for 2x with gated activations."); if (is_dbias) { switch (act_type) { diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 74909319cc..c1e008a5bc 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -44,12 +44,12 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh cudaStreamSynchronize(stream); // Notes on matrix layouts and transpose: - // Jax uses row-major layout, on entering this function, each input matrix pair: + // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major with size [m, k], // B: row-major with size [n, k], needs transpose, // on exiting this function, JAX expect: // C: row-major with size [m, n]. - // cuBLAS uses column-major layout, in this view, each input matrix pair: + // cuBLAS uses column-major data_layout, in this view, each input matrix pair: // A: column-major with size [k, m], needs transpose, // B: column-major with size [k, n]. // If we call cuBLAS GEMM for A * B, the output will be: diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index 09ccf6be86..c8526e20c0 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -34,7 +34,7 @@ inline size_t product(const std::vector &shape) { return ret; } -enum class QuantizeAxis { +enum class QuantizeLayout { ROWWISE, COLWISE, ROWWISE_COLWISE, diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index c777a02c99..ebdfe461c7 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -144,11 +144,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) .export_values(); - pybind11::enum_(m, "QuantizeAxis", - pybind11::module_local()) - .value("ROWWISE", transformer_engine::jax::QuantizeAxis::ROWWISE) - .value("COLWISE", transformer_engine::jax::QuantizeAxis::COLWISE) - .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeAxis::ROWWISE_COLWISE) + pybind11::enum_(m, "QuantizeLayout", + pybind11::module_local()) + .value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE) + .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE) + .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE) .export_values(); } diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index c8f98dd43f..b48ee8a9b9 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -42,10 +42,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Result_Type output_buf, Result_Type output_trans_buf, - Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf, - Result_Type amax_out_buf, Result_Type dbias_buf, - Result_Type workspace_buf, int64_t scaling_mode_enum, - int64_t quantize_axis_enum, bool is_dbias) { + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, + int64_t scaling_mode_enum, int64_t quantize_layout_enum, bool is_dbias, + int64_t flatten_axis) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -55,7 +55,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto *input = input_buf.untyped_data(); auto scaling_mode = static_cast(scaling_mode_enum); - auto const quantize_axis = static_cast(quantize_axis_enum); + auto const quantize_layout = static_cast(quantize_layout_enum); auto *output = output_buf->untyped_data(); auto *output_trans = output_trans_buf->untyped_data(); @@ -63,9 +63,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T void *workspace = workspace_buf->untyped_data(); auto input_dims = input_buf.dimensions(); + int64_t input_ndim = input_dims.size(); + if (flatten_axis < 0) flatten_axis += input_ndim; + NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!"); + auto workspace_dims = workspace_buf->dimensions(); - auto m = product(input_dims, 0, input_dims.size() - 1); - auto n = input_dims.back(); + auto m = product(input_dims, 0, flatten_axis); + auto n = product(input_dims, flatten_axis, input_ndim); auto input_shape = std::vector{m, n}; auto output_shape = std::vector{m, n}; auto output_trans_shape = std::vector{n, m}; @@ -75,37 +79,54 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto output_tensor = TensorWrapper(scaling_mode); - if (quantize_axis == QuantizeAxis::ROWWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) { + if (quantize_layout == QuantizeLayout::ROWWISE || + quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { output_tensor.set_rowwise_data(output, out_dtype, output_shape); - output_tensor.set_rowwise_scale_inv( - scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), - std::vector{ - product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1), - scale_inv_buf->dimensions().back()}); - } - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - float *scale = reinterpret_cast(scale_buf.untyped_data()); - float *amax_out = reinterpret_cast(amax_out_buf->untyped_data()); - NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling"); - output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - cudaMemsetAsync(amax_out, 0, sizeof(float), stream); - output_tensor.set_amax(amax_out, DType::kFloat32, std::vector{1}); + if (is_fp8_dtype(out_dtype)) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + float *scale = reinterpret_cast(scale_buf.untyped_data()); + float *amax = reinterpret_cast(amax_buf->untyped_data()); + NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); + NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); + output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); + cudaMemsetAsync(amax, 0, sizeof(float), stream); + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), + std::vector{1}); + } else { + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), + std::vector{product(scale_inv_buf->dimensions(), 0, flatten_axis), + product(scale_inv_buf->dimensions(), flatten_axis, + scale_inv_buf->dimensions().size())}); + } + } } - if (quantize_axis == QuantizeAxis::COLWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) { - output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); + if (quantize_layout == QuantizeLayout::COLWISE || + quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { + auto &tmp_shape = + (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape); // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &colwise_scale_inv_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf; - output_tensor.set_columnwise_scale_inv( - colwise_scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), - std::vector{product(colwise_scale_inv_buf->dimensions(), 0, - colwise_scale_inv_buf->dimensions().size() - 1), - colwise_scale_inv_buf->dimensions().back()}); + auto &tmp_buf = + (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; + + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + output_tensor.set_columnwise_scale_inv( + tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), + std::vector{1}); + } else { + output_tensor.set_columnwise_scale_inv( + tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), + std::vector{ + product(tmp_buf->dimensions(), 0, flatten_axis), + product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())}); + } } auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); @@ -133,8 +154,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Ret() // dbias .Ret() // wkspace .Attr("scaling_mode") - .Attr("q_axis") - .Attr("is_dbias"), + .Attr("q_layout") + .Attr("is_dbias") + .Attr("flatten_axis"), FFI_CudaGraph_Traits); Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 43336768cb..2ef8b91c86 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -15,7 +15,11 @@ import jax.numpy as jnp from . import cpp_extensions as tex -from .quantize import QuantizerSet, noop_quantizer_set +from .quantize import ( + QuantizerSet, + noop_quantizer_set, + with_sharding_constraint_by_logical_axes, +) def dense( @@ -23,6 +27,8 @@ def dense( kernel: jnp.ndarray, bias: jnp.ndarray = None, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), + input_axes: Tuple[str, ...] = None, + kernel_axes: Tuple[str, ...] = None, quantizer_set: QuantizerSet = noop_quantizer_set, ): """Perform dense layer transformation with optional quantization. @@ -48,12 +54,12 @@ def dense( bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) else: - output = _dense(x, kernel, bias, contracting_dims, quantizer_set) + output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set) return output -@partial(jax.custom_vjp, nondiff_argnums=(3,)) -def _dense(x, kernel, bias, contracting_dims, quantizer_set): +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) +def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): """Internal implementation of dense layer transformation with custom VJP. This function implements the core dense layer transformation logic with support @@ -64,32 +70,37 @@ def _dense(x, kernel, bias, contracting_dims, quantizer_set): kernel: Weight matrix bias: Optional bias tensor contracting_dims: Contracting dimensions specification + input_axes: Logical axes for sharding the activation input + kernel_axes: Logical axes for sharding the weight matrix quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: Transformed output tensor """ - output, _ = _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set) + output, _ = _dense_fwd_rule( + x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set + ) return output -def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set): +def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): """Forward pass rule for dense layer transformation. - Args: - x: Input tensor - kernel: Weight matrix - bias: Optional bias tensor - contracting_dims: Contracting dimensions specification - quantizer_set: QuantizerSet which contains quantizers for different tensor types - Returns: Tuple of (output, context) for backward pass """ x_contracting_dims, k_contracting_dims = contracting_dims - casted_x = tex.quantize(x, quantizer_set.x) - casted_kernel = tex.quantize(kernel, quantizer_set.kernel) + flatten_axis_x = -len(x_contracting_dims) + flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + + casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x) + casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) + + casted_kernel = tex.quantize( + kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel + ) + casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) # GEMM NN output = tex.gemm( @@ -97,6 +108,7 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set): casted_kernel.get_colwise_tensor(), (x_contracting_dims, k_contracting_dims), ) + use_bias = bias is not None if use_bias: bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape @@ -109,18 +121,16 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set): kernel.shape, use_bias, quantizer_set, + flatten_axis_k, ) return output, ctx -def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argument +def _dense_bwd_rule( + contracting_dims, input_axes, kernel_axes, ctx, grad +): # pylint: disable=unused-argument """Backward pass rule for dense layer transformation. - Args: - contracting_dims: Contracting dimensions specification - ctx: Context from forward pass - grad: Gradient from upstream - Returns: Tuple of gradients with respect to inputs """ @@ -133,9 +143,12 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu kernel_shape, use_bias, quantizer_set, + flatten_axis_k, ) = ctx - casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad) + casted_grad, dbias = tex.quantize_dbias( + grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad + ) # GEMM NT # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim @@ -151,6 +164,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu rowwise_casted_kernel, (g_constracting_dim, k_constracting_dim), ) + dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) # GEMM TN # x_non_contracting_dims @@ -161,6 +175,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu wgrad = tex.gemm( colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim) ) + wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) return dgrad, wgrad, dbias, quantizer_set diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index a0d1e33e38..a944848881 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -28,6 +28,7 @@ from ..sharding import with_sharding_constraint_by_logical_axes from ..cpp_extensions import is_softmax_kernel_available from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode +from ..sharding import get_non_contracting_logical_axes PRNGKey = Any Shape = Tuple[int, ...] @@ -406,6 +407,10 @@ class DenseGeneral(TransformerEngineBase): :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. axis: Union[Iterable[int], int], default = -1 An integer tuple with axes to apply the transformation on. + input_axes: Tuple[str, ...], default = None + Indicate the logical axes of sharding constraint to the input, like + (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert + sharding constraint. Optimization parameters ----------------------- @@ -429,6 +434,7 @@ class DenseGeneral(TransformerEngineBase): axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 transpose_batch_sequence: bool = False + input_axes: Tuple[str, ...] = () def __post_init__(self): if self.kernel_init is None: @@ -460,29 +466,35 @@ def __call__(self, inputs: Array) -> Array: axis = _normalize_axes(axis, inputs.ndim) kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features + + if self.kernel_axes: + assert len(kernel_shape) == len(self.kernel_axes), ( + "Expected len(kernel_shape) to match len(kernel_axes)," + f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}" + ) kernel = nn_partitioning.param_with_axes( "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes ) + if not QuantizeConfig.is_fp8_enabled(): kernel = kernel.astype(input_dtype) - kernel_compute_shape = ( - reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1), - reduce(operator.mul, features, 1), - ) - kernel = jnp.reshape(kernel, kernel_compute_shape) if self.use_bias: bias = nn_partitioning.param_with_axes( "bias", self.bias_init, features, self.dtype, axes=self.bias_axes - ) - bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype) + ).astype(input_dtype) else: bias = None quantizer_set = self.generate_quantizer_set() contract_ind = tuple(range(0, len(axis))) y = dense( - inputs, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set + inputs, + kernel, + contracting_dims=(axis, contract_ind), + input_axes=self.input_axes, + kernel_axes=self.kernel_axes, + quantizer_set=quantizer_set, ) if self.enable_low_rank_adaptation: @@ -491,20 +503,14 @@ def __call__(self, inputs: Array) -> Array: *features[:-1], self.low_rank_adaptation_dim, ) - lora_a_kernel_init_shape = ( - kernel_compute_shape[0], - *features[:-1], - self.low_rank_adaptation_dim, - ) - lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape) + lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape) lora_a_kernel = nn_partitioning.param_with_axes( "lora_a_kernel", self.kernel_init, - lora_a_kernel_init_shape, + lora_a_kernel_shape, self.dtype, axes=lora_a_kernel_axes, ) - lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) lora_a_kernel = lora_a_kernel.astype(input_dtype) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) @@ -527,7 +533,6 @@ def __call__(self, inputs: Array) -> Array: y += jnp.reshape(bias, bias_shape) assert y.dtype == input_dtype - y = y.reshape(*inputs.shape[: self.axis], *features) return y @@ -678,6 +683,7 @@ def __call__(self, inputs: Array) -> Array: The output tensors of layer normalization. If :attr:`return_layernorm_output=False`, then this would be None. """ + assert self.axis == -1, "Only support axis = =-1 at this moment" input_dtype = inputs.dtype ln_output = None @@ -692,10 +698,7 @@ def __call__(self, inputs: Array) -> Array: if self.enable_layernorm: inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) - - assert self.axis == -1 # Only support axis = =-1 at this moment features = inputs.shape[-1] - scale, ln_bias = _create_layernorm_parameters( self.layernorm_type, (features,), @@ -731,17 +734,12 @@ def __call__(self, inputs: Array) -> Array: axis = _normalize_axes(axis, y.ndim) - kernel_shape = tuple(y.shape[ax] for ax in axis) + features + kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel = nn_partitioning.param_with_axes( "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes ) if not QuantizeConfig.is_fp8_enabled(): kernel = kernel.astype(input_dtype) - kernel_compute_shape = ( - reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1), - reduce(operator.mul, features, 1), - ) - kernel = jnp.reshape(kernel, kernel_compute_shape) contract_ind = tuple(range(0, len(axis))) @@ -756,11 +754,19 @@ def __call__(self, inputs: Array) -> Array: epsilon=self.epsilon, layernorm_input_axes=self.layernorm_input_axes, dot_input_axes=self.dot_input_axes, + kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, ) else: y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) - z = dense(y, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set) + z = dense( + y, + kernel, + contracting_dims=(axis, contract_ind), + input_axes=self.dot_input_axes, + kernel_axes=self.kernel_axes, + quantizer_set=quantizer_set, + ) if self.enable_low_rank_adaptation: lora_a_kernel_shape = ( @@ -768,20 +774,14 @@ def __call__(self, inputs: Array) -> Array: *features[:-1], self.low_rank_adaptation_dim, ) - lora_a_kernel_init_shape = ( - kernel_compute_shape[0], - *features[:-1], - self.low_rank_adaptation_dim, - ) - lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape) + lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape) lora_a_kernel = nn_partitioning.param_with_axes( "lora_a_kernel", self.kernel_init, - lora_a_kernel_init_shape, + lora_a_kernel_shape, self.dtype, axes=lora_a_kernel_axes, ) - lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) lora_a_kernel = lora_a_kernel.astype(input_dtype) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) @@ -803,8 +803,7 @@ def __call__(self, inputs: Array) -> Array: if self.use_bias: bias = nn_partitioning.param_with_axes( "bias", self.bias_init, features, self.dtype, axes=self.bias_axes - ) - bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype) + ).astype(input_dtype) if bias is not None: bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape @@ -814,7 +813,7 @@ def __call__(self, inputs: Array) -> Array: z = z / self.depth_scaling assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}" - z = z.reshape(*inputs.shape[: self.axis], *features) + # z = z.reshape(*inputs.shape[: self.axis], *features) return z, ln_output # dense_output, layer_norm_output @@ -989,6 +988,8 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: The output tensors of layer normalization. If :attr:`return_layernorm_output=False`, then this would be None. """ + assert self.axis == -1, "Only support axis == -1 at this moment" + ffn1_quantizer_set = self.generate_quantizer_set("_0") ffn2_quantizer_set = self.generate_quantizer_set("_1") @@ -1027,7 +1028,6 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: ) # LayerNorm if self.enable_layernorm: - assert self.axis == -1 # Only support axis == -1 at this moment inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) features = inputs.shape[-1] @@ -1071,7 +1071,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): num_activations = len(normalized_acts) axis = _canonicalize_tuple(self.axis) axis = _normalize_axes(axis, y.ndim) - kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim) + kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim) kernel_1 = nn_partitioning.param_with_axes( "wi_kernel", kernel_1_init, @@ -1081,17 +1081,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): self.dtype, axes=self.kernel_axes_1, ) - kernel_1_compute_shape = ( - reduce(operator.mul, [y.shape[ax] for ax in axis], 1), - num_activations * self.intermediate_dim, - ) - kernel_1 = jnp.reshape(kernel_1, kernel_1_compute_shape) + if not QuantizeConfig.is_fp8_enabled(): kernel_1 = kernel_1.astype(input_dtype) - if self.kernel_axes_1 is not None: - kernel_1 = with_sharding_constraint_by_logical_axes( - kernel_1, self.kernel_axes_1[:-2] + self.kernel_axes_1[-1:] - ) + hidden_size = inputs.shape[-1] hidden_size_tuple = _canonicalize_tuple(hidden_size) kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple @@ -1102,27 +1095,20 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): self.dtype, axes=self.kernel_axes_2, ) - kernel_2_compute_shape = ( - self.intermediate_dim, - reduce(operator.mul, hidden_size_tuple, 1), - ) - kernel_2 = jnp.reshape(kernel_2, kernel_2_compute_shape) if not QuantizeConfig.is_fp8_enabled(): kernel_2 = kernel_2.astype(input_dtype) - if self.kernel_axes_2 is not None: - kernel_2 = with_sharding_constraint_by_logical_axes(kernel_2, self.kernel_axes_2) + contract_ind = tuple(range(0, len(axis))) if self.use_bias: - bias_1_shape = num_activations * self.intermediate_dim + bias_1_shape = (num_activations, self.intermediate_dim) bias_1 = nn_partitioning.param_with_axes( "wi_bias", self.bias_init, bias_1_shape, self.dtype, axes=self.bias_axes_1, - ) - bias_1 = bias_1.reshape(kernel_1_compute_shape[-1]).astype(input_dtype) + ).astype(input_dtype) bias_2_shape = (hidden_size,) bias_2 = nn_partitioning.param_with_axes( @@ -1131,8 +1117,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): bias_2_shape, self.dtype, axes=self.bias_axes_2, - ) - bias_2 = bias_2.reshape(kernel_2_compute_shape[-1]).astype(input_dtype) + ).astype(input_dtype) else: bias_1 = None bias_2 = None @@ -1141,8 +1126,6 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn2_ckpt_name = "ffn2" if use_fused_layernorm_mlp: - assert self.axis == -1 # Only support axis = =-1 at this moment - out = layernorm_mlp( y, scale, @@ -1155,6 +1138,8 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): norm_input_axes=self.layernorm_input_axes, dot_1_input_axes=self.dot_1_input_axes, dot_2_input_axes=self.dot_2_input_axes, + kernel_1_axes=self.kernel_axes_1, + kernel_2_axes=self.kernel_axes_2, ffn1_ckpt_name=ffn1_ckpt_name, ffn2_ckpt_name=ffn2_ckpt_name, activation_type=normalized_acts, @@ -1175,6 +1160,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): epsilon=self.epsilon, layernorm_input_axes=self.layernorm_input_axes, dot_input_axes=self.dot_1_input_axes, + kernel_axes=self.kernel_axes_1, quantizer_set=ffn1_quantizer_set, ) else: @@ -1183,35 +1169,31 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): y, kernel_1, contracting_dims=(axis, contract_ind), + input_axes=self.dot_1_input_axes, + kernel_axes=self.kernel_axes_1, quantizer_set=ffn1_quantizer_set, ) + dot_1_output_axes = ( + *get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis), + *get_non_contracting_logical_axes(kernel_1.ndim, self.kernel_axes_1, contract_ind), + ) + x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes) if self.enable_low_rank_adaptation: - wi_lora_a_kernel_shape = ( - kernel_1_compute_shape[0], - num_activations, - self.low_rank_adaptation_dim, - ) - wi_lora_a_kernel_init_shape = ( - kernel_1_each_shape[0], - num_activations, - self.low_rank_adaptation_dim, - ) - wi_lora_a_kernel_init_each_shape = ( - kernel_1_each_shape[0], + wi_lora_a_kernel_each_shape = ( + kernel_1_each_shape[: len(axis)], self.low_rank_adaptation_dim, ) - wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape) + wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1) wi_lora_a_kernel = nn_partitioning.param_with_axes( "wi_lora_a_kernel", kernel_1_init, num_activations, - -1, - wi_lora_a_kernel_init_each_shape, + -2, + wi_lora_a_kernel_each_shape, self.dtype, axes=wi_lora_a_kernel_axes, ) - wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape) wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype) wi_lora_b_kernel_shape = ( @@ -1232,7 +1214,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): x += _apply_low_rank_adaptation( y, axis, - num_activations * self.intermediate_dim, + (num_activations, self.intermediate_dim), wi_lora_a_kernel, wi_lora_b_kernel, self.low_rank_adaptation_alpha, @@ -1246,11 +1228,12 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): z = activation(x, normalized_acts) else: activations = [] - x = jnp.split(x, num_activations, axis=-1) + x = jnp.split(x, num_activations, axis=-2) for idx, act_fn in enumerate(normalized_acts): x_i = _convert_to_activation_function(act_fn)(x[idx]) activations.append(x_i) z = reduce(operator.mul, activations) + z = jnp.squeeze(z, axis=-2) z = z.astype(input_dtype) z = nn.Dropout( @@ -1264,7 +1247,12 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): # DenseGeneral 2 out = dense( - z, kernel_2, contracting_dims=(axis, contract_ind), quantizer_set=ffn2_quantizer_set + z, + kernel_2, + contracting_dims=(axis, contract_ind), + input_axes=self.dot_2_input_axes, + kernel_axes=self.kernel_axes_2, + quantizer_set=ffn2_quantizer_set, ) if self.enable_low_rank_adaptation: diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 3fe32401bd..727ff78c2d 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -33,10 +33,9 @@ def layernorm_dense( norm_type: str = "layernorm", zero_centered_gamma: bool = False, epsilon: float = 1e-6, - # The logic axes of sharding constraint to the layernorm input. layernorm_input_axes: Tuple[str, ...] = None, - # The logic axes of sharding constraint to the dot input. dot_input_axes: Tuple[str, ...] = None, + kernel_axes: Tuple[str, ...] = None, quantizer_set: QuantizerSet = noop_quantizer_set, ) -> jnp.ndarray: """Apply layer normalization followed by dense layer transformation. @@ -56,6 +55,7 @@ def layernorm_dense( epsilon: Small constant for numerical stability in normalization layernorm_input_axes: Logical axes for sharding the layernorm input dot_input_axes: Logical axes for sharding the matrix multiplication input + kernel_axes: Logical axes for sharding the weight matrix quantizer_set: Set of quantizers for different tensor types Returns: @@ -78,6 +78,7 @@ def layernorm_dense( epsilon, layernorm_input_axes, dot_input_axes, + kernel_axes, quantizer_set, ) return output @@ -91,6 +92,7 @@ def layernorm_dense( 7, 8, 9, + 10, ), ) def _layernorm_dense( @@ -104,6 +106,7 @@ def _layernorm_dense( epsilon: float, layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...], + kernel_axes: Tuple[str, ...], quantizer_set, ): """Internal implementation of layernorm_dense with custom VJP. @@ -139,6 +142,7 @@ def _layernorm_dense( epsilon, layernorm_input_axes, dot_input_axes, + kernel_axes, quantizer_set, ) return output @@ -155,6 +159,7 @@ def _layernorm_dense_fwd_rule( epsilon, layernorm_input_axes, dot_input_axes, + kernel_axes, quantizer_set, ): """Forward pass rule for layernorm_dense. @@ -171,7 +176,6 @@ def _layernorm_dense_fwd_rule( x_contracting_dims = (len(x.shape) - 1,) k_contracting_dims = (0,) assert x.shape[-1] == kernel.shape[0] - assert len(kernel.shape) == 2 # Otherwise need to merge dims in quantize x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) @@ -184,11 +188,12 @@ def _layernorm_dense_fwd_rule( norm_type, quantizer_set.x, ) + casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) # Kernel in (hidden_in, hidden_out...) - casted_kernel = tex.quantize(kernel, quantizer_set.kernel) - - casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) + flatten_axis = 1 - len(kernel.shape) + casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel) + casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) # NN GEMM # (batch..., hidden_in) x (hidden_in, hidden_out...) @@ -217,6 +222,7 @@ def _layernorm_dense_fwd_rule( k_contracting_dims, use_bias, quantizer_set, + flatten_axis, ) return output, ctx @@ -228,6 +234,7 @@ def _layernorm_dense_bwd_rule( epsilon, layernorm_input_axes, dot_input_axes, # pylint: disable=unused-argument + kernel_axes, ctx, grad, ): @@ -256,11 +263,12 @@ def _layernorm_dense_bwd_rule( k_contracting_dims_in_fwd, use_bias, quantizer_set, + flatten_axis, ) = ctx - grad = with_sharding_constraint_by_logical_axes(grad, dot_input_axes) - - casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad) + casted_grad, dbias = tex.quantize_dbias( + grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad + ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim g_constracting_dim = tuple( @@ -291,6 +299,8 @@ def _layernorm_dense_bwd_rule( (x_constracting_dim, g_constracting_dim), ) + wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) + dx, dgamma, dbeta = tex.normalization_bwd( dgrad, x, diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index f6caad62e3..e7e3fd2fb9 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -23,6 +23,7 @@ from . import cpp_extensions as tex from .layernorm import canonicalize_norm_type from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set +from .sharding import get_non_contracting_logical_axes def layernorm_mlp( @@ -37,6 +38,8 @@ def layernorm_mlp( norm_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None, + kernel_1_axes: Tuple[str, ...] = None, + kernel_2_axes: Tuple[str, ...] = None, ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), @@ -66,6 +69,8 @@ def layernorm_mlp( norm_input_axes: Logical axes for sharding the layernorm input dot_1_input_axes: Logical axes for sharding the first matrix multiplication dot_2_input_axes: Logical axes for sharding the second matrix multiplication + kernel_1_axes: Logical axes for sharding the first weight matrix + kernel_2_axes: Logical axes for sharding the second weight matrix ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network activation_type: Activation function(s) to apply after the first dense layer transformation @@ -109,6 +114,8 @@ def layernorm_mlp( norm_input_axes, dot_1_input_axes, dot_2_input_axes, + kernel_1_axes, + kernel_2_axes, ffn1_ckpt_name, ffn2_ckpt_name, activation_type, @@ -117,7 +124,7 @@ def layernorm_mlp( return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -132,6 +139,8 @@ def _layernorm_mlp( norm_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], + kernel_1_axes: Tuple[str, ...], + kernel_2_axes: Tuple[str, ...], ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], @@ -179,6 +188,8 @@ def _layernorm_mlp( norm_input_axes, dot_1_input_axes, dot_2_input_axes, + kernel_1_axes, + kernel_2_axes, ffn1_ckpt_name, ffn2_ckpt_name, activation_type, @@ -201,6 +212,8 @@ def _layernorm_mlp_fwd_rule( norm_input_axes, dot_1_input_axes, dot_2_input_axes, + kernel_1_axes, + kernel_2_axes, ffn1_ckpt_name, ffn2_ckpt_name, activation_type, @@ -220,20 +233,21 @@ def _layernorm_mlp_fwd_rule( Returns: Tuple of (output, context) for automatic differentiation """ + del kernel_2_axes + ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets # x should be in shape of (batch..., hidden) - # Kernel_1 should be in shape of (hidden_in, activation_len * intermediate) + # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate) # Kernel_2 should be in shape of (intermediate, hidden_in) - assert len(kernel_1.shape) == 2 + assert len(kernel_1.shape) == 3 assert len(kernel_2.shape) == 2 - assert kernel_1.shape[1] == kernel_2.shape[0] * len(activation_type) + assert kernel_1.shape[-2] == len(activation_type) x_contracting_dims = (len(x.shape) - 1,) k_contracting_dims = (0,) assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] - assert kernel_1.shape[1] == len(activation_type) * kernel_2.shape[0] use_bias_1 = bias_1 is not None use_bias_2 = bias_1 is not None @@ -249,11 +263,10 @@ def _layernorm_mlp_fwd_rule( norm_type, quantizer=ffn1_quantizer_set.x, ) - - casted_kernel_1 = tex.quantize(kernel_1, quantizer=ffn1_quantizer_set.kernel) - casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) + casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel) + # NN GEMM # (batch..., hidden_in) x (hidden_in, hidden_out) dot_1_output = tex.gemm( @@ -261,6 +274,13 @@ def _layernorm_mlp_fwd_rule( casted_kernel_1.get_colwise_tensor(), (x_contracting_dims, k_contracting_dims), ) + + dot_1_output_axes = ( + *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims), + *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims), + ) + dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) + if use_bias_1: bias_1_shape = bias_1.shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape @@ -283,6 +303,12 @@ def _layernorm_mlp_fwd_rule( (x_contracting_dims, k_contracting_dims), ) + dot_2_output_axes = ( + *get_non_contracting_logical_axes(x.ndim, dot_2_input_axes, x_contracting_dims), + *get_non_contracting_logical_axes(kernel_2.ndim, None, k_contracting_dims), + ) + dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_2_output_axes) + if use_bias_2: bias_2_shape = bias_2.shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape @@ -320,8 +346,10 @@ def _layernorm_mlp_bwd_rule( norm_input_axes, dot_1_input_axes, dot_2_input_axes, - ffn1_ckpt_name, # pylint: disable=unused-argument - ffn2_ckpt_name, # pylint: disable=unused-argument + kernel_1_axes, + kernel_2_axes, + ffn1_ckpt_name, + ffn2_ckpt_name, activation_type, ctx, grad, @@ -339,6 +367,7 @@ def _layernorm_mlp_bwd_rule( Returns: Tuple of gradients for all input parameters """ + del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name ( x, mu, @@ -369,11 +398,11 @@ def _layernorm_mlp_bwd_rule( ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim - g_constracting_dim_2 = tuple( + g_contracting_dims_2 = tuple( range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim) ) # k_non_contracting_dims - k_constracting_dim_2 = tuple( + k_contracting_dims_2 = tuple( dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd ) @@ -382,12 +411,12 @@ def _layernorm_mlp_bwd_rule( dgrad_2 = tex.gemm( casted_grad.get_rowwise_tensor(), rowwise_casted_kernel_2, - (g_constracting_dim_2, k_constracting_dim_2), + (g_contracting_dims_2, k_contracting_dims_2), ) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) - x_constracting_dim = g_constracting_dim = tuple( + x_contracting_dims = g_contracting_dims = tuple( range(0, len(x.shape) - len(x_contracting_dims_in_fwd)) ) @@ -396,8 +425,9 @@ def _layernorm_mlp_bwd_rule( wgrad_2 = tex.gemm( colwise_casted_act_out, casted_grad.get_colwise_tensor(), - (x_constracting_dim, g_constracting_dim), + (x_contracting_dims, g_contracting_dims), ) + wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) casted_dact_out, dbias_1 = tex.quantize_dact_dbias( dgrad_2, @@ -408,11 +438,12 @@ def _layernorm_mlp_bwd_rule( ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim - g_constracting_dim_1 = tuple( - range(dgrad_2.ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dgrad_2.ndim) + dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim + g_contracting_dims_1 = tuple( + range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim) ) # k_non_contracting_dims - k_constracting_dim_1 = tuple( + k_contracting_dims_1 = tuple( dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd ) @@ -420,19 +451,21 @@ def _layernorm_mlp_bwd_rule( dgrad_1 = tex.gemm( casted_dact_out.get_rowwise_tensor(), rowwise_casted_kernel_1, - (g_constracting_dim_1, k_constracting_dim_1), + (g_contracting_dims_1, k_contracting_dims_1), ) - dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, norm_input_axes) + dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) # TN GEMM # (hidden, batch...) x (hidden, batch...) wgrad_1 = tex.gemm( colwise_casted_ln_out, casted_dact_out.get_colwise_tensor(), - (x_constracting_dim, g_constracting_dim), + (x_contracting_dims, g_contracting_dims), ) + wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) + dx, dgamma, dbeta = tex.normalization_bwd( dgrad_1, x, diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index cdbe764ab2..b1e9ba03b4 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -57,18 +57,27 @@ def _dq_func_block_scaling(scaled_tensor): data = scaled_tensor.data.astype(jnp.float32) data_shape = data.shape scale = scaled_tensor.scale_inv.view(jnp.uint8).astype(jnp.float32) + flatten_axis = scaled_tensor.flatten_axis + flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis + assert ( + 0 < flatten_axis < len(data_shape) + ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" scale_shape = scaled_tensor.scaling_mode.get_scale_shape( - scaled_tensor.data.shape, scaled_tensor.is_colwise, is_padded=False + data_shape, scaled_tensor.is_colwise, is_padded=False, flatten_axis=flatten_axis ) scale = jax.lax.slice(scale, [0] * len(scale_shape), scale_shape) # slice out the padding + data = data.reshape( - *data_shape[:-2], - scale_shape[-2], - int(data_shape[-2] / scale_shape[-2]), + *data_shape[: flatten_axis - 1], + scale_shape[flatten_axis - 1], + int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]), + *data_shape[flatten_axis:-1], scale_shape[-1], int(data_shape[-1] / scale_shape[-1]), ) - scale = jnp.expand_dims(scale, axis=(-1, -3)) + + # E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers. + scale = jnp.expand_dims(scale, axis=(flatten_axis + 2 - 2, -1)) # E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers. return jnp.asarray(data * jnp.power(2, scale - 127), scaled_tensor.dq_dtype).reshape( data_shape diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 629e3f5bc2..bd7045453b 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -14,7 +14,7 @@ import jax import jax.numpy as jnp from jax.tree_util import register_pytree_node_class -from transformer_engine_jax import QuantizeAxis +from transformer_engine_jax import QuantizeLayout from .scaling_modes import ScalingMode from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory @@ -24,7 +24,7 @@ ) __all__ = [ - "QuantizeAxis", + "QuantizeLayout", "Quantizer", "QuantizerSet", "DelayedScaleQuantizer", @@ -45,12 +45,12 @@ class Quantizer(ABC): Attributes: q_dtype: The data type for quantized values scaling_mode: The scaling mode to use for quantization - q_axis: The quantization axis (row-wise, column-wise, or both) + q_layout: The quantization axis (row-wise, column-wise, or both) """ q_dtype: jnp.dtype scaling_mode: ScalingMode - q_axis: QuantizeAxis + q_layout: QuantizeLayout def tree_flatten(self): """Flatten the quantizer for JAX tree operations. @@ -59,7 +59,7 @@ def tree_flatten(self): Tuple of (children, aux_data) for tree operations """ children = () - aux_data = (self.q_dtype, self.scaling_mode, self.q_axis) + aux_data = (self.q_dtype, self.scaling_mode, self.q_layout) return (children, aux_data) @classmethod @@ -85,30 +85,31 @@ def is_2x2x(self) -> bool: Returns: True if using both row-wise and column-wise quantization """ - return self.q_axis == QuantizeAxis.ROWWISE_COLWISE + return self.q_layout == QuantizeLayout.ROWWISE_COLWISE @abstractmethod - def get_layout(self) -> str: - """Get the data layout. + def get_data_layout(self) -> str: + """Get the data data_layout. Returns: - Data layout in string format + Data data_layout in string format """ @abstractmethod - def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: + def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x: """Core quantization function to be implemented by subclasses. Args: x: Input tensor to quantize is_colwise: Whether to use column-wise quantization dq_dtype: Data type for dequantized values, default is x.dtype + flatten_axis: The quantization axis for the tensor Returns: A ScaledTensor1x containing the quantized data """ - def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None): + def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1): """Quantize a tensor using the internal _quantize_func(). Args: @@ -116,21 +117,26 @@ def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None): is_rowwise: Whether to use row-wise quantization is_colwise: Whether to use column-wise quantization dq_dtype: Data type for dequantized values + flatten_axis: The quantization axis for the tensor Returns: A ScaledTensor1x or ScaledTensor2x containing the quantized data """ if (is_rowwise and is_colwise) or self.is_2x2x(): - rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype) - colwise_tensor = self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype) + rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) + colwise_tensor = self._quantize_func( + x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis + ) return ScaledTensor2x(rowwise_tensor, colwise_tensor) if is_colwise: - return self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype) + return self._quantize_func( + x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis + ) - return self._quantize_func(x, dq_dtype=dq_dtype) + return self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) - def get_scale_shapes(self, data_shape, is_padded=True): + def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1): """Get shapes for scale tensors. Args: @@ -140,7 +146,7 @@ def get_scale_shapes(self, data_shape, is_padded=True): Returns: Tuple of (rowwise_scale_shape, colwise_scale_shape) """ - return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded) + return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded, flatten_axis) def get_scale_dtype(self): """Get the data type for scale tensors. @@ -161,13 +167,13 @@ class DelayedScaleQuantizer(Quantizer): Attributes: scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING - q_axis: Quantization axis (default: ROWWISE_COLWISE) + q_layout: Quantization axis (default: ROWWISE_COLWISE) scale: Current scaling factor amax_history: History of maximum absolute values """ scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING - q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE + q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) amax_history: jnp.ndarray = field( @@ -181,35 +187,37 @@ def tree_flatten(self): Tuple of (children, aux_data) for tree operations """ children = (self.scale, self.amax_history) - aux_data = (self.q_dtype, self.scaling_mode, self.q_axis) + aux_data = (self.q_dtype, self.scaling_mode, self.q_layout) return (children, aux_data) - def get_layout(self) -> str: - """Get the data layout string. + def get_data_layout(self) -> str: + """Get the data data_layout string. Returns: - Data layout in string format + Data data_layout in string format Raises: ValueError: If quantization axis is invalid """ - layout = "NT" - if self.q_axis == QuantizeAxis.ROWWISE_COLWISE: - return layout - if self.q_axis == QuantizeAxis.ROWWISE: - return layout[0] - if self.q_axis == QuantizeAxis.COLWISE: - return layout[1] - raise ValueError(f"Invalid q_axis: {self.q_axis}") - - def _quantize_func(self, x: jnp.ndarray, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: + data_layout = "NT" + if self.q_layout == QuantizeLayout.ROWWISE_COLWISE: + return data_layout + if self.q_layout == QuantizeLayout.ROWWISE: + return data_layout[0] + if self.q_layout == QuantizeLayout.COLWISE: + return data_layout[1] + raise ValueError(f"Invalid q_layout: {self.q_layout}") + + def _quantize_func( + self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1 + ) -> ScaledTensor1x: """Quantize function helper for delayed scaling FP8. Args: x: Input tensor to quantize is_colwise: Whether to use column-wise quantization dq_dtype: Data type for dequantized values - + flatten_axis: The quantization axis for the tensor Returns: A ScaledTensor1x containing the quantized data """ @@ -232,9 +240,12 @@ def _quantize_func(self, x: jnp.ndarray, is_colwise=False, dq_dtype=None) -> Sca scale_inv=scale_inv, scaling_mode=self.scaling_mode, dq_dtype=dq_dtype, + flatten_axis=flatten_axis, ) - def quantize(self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None): + def quantize( + self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None, flatten_axis=-1 + ): """Quantize a tensor using the internal _quantize_func(). Args: @@ -242,32 +253,40 @@ def quantize(self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype is_rowwise: Whether to use row-wise quantization is_colwise: Whether to use column-wise quantization dq_dtype: Data type for dequantized values + flatten_axis: The quantization axis for the tensor Returns: A ScaledTensor1x or ScaledTensor2x containing the quantized data """ dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + if flatten_axis < 0: + flatten_axis += x.ndim + assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!" + is_rowwise = ( is_rowwise if is_rowwise is not None - else (self.q_axis == QuantizeAxis.ROWWISE or self.is_2x2x()) + else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x()) ) is_colwise = ( is_colwise if is_colwise is not None - else (self.q_axis == QuantizeAxis.COLWISE or self.is_2x2x()) + else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x()) ) - rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype) + rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) colwise_tensor = None if is_colwise: colwise_tensor = ScaledTensorFactory.create_1x( - data=jnp.transpose(rowwise_tensor.data, (-1, *range(rowwise_tensor.data.ndim - 1))), + data=jnp.transpose( + rowwise_tensor.data, (*range(flatten_axis, x.ndim), *range(flatten_axis)) + ), scale_inv=rowwise_tensor.scale_inv, scaling_mode=self.scaling_mode, dq_dtype=dq_dtype, is_colwise=True, - layout="T", + data_layout="T", + flatten_axis=flatten_axis, ) if is_colwise and is_rowwise: return ScaledTensor2x(rowwise_tensor, colwise_tensor) @@ -353,46 +372,56 @@ class BlockScaleQuantizer(Quantizer): Attributes: scaling_mode: Set to NVTE_MXFP8_1D_SCALING - q_axis: Quantization axis (default: ROWWISE_COLWISE) + q_layout: Quantization axis (default: ROWWISE_COLWISE) """ scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING - q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE + q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE - def get_layout(self) -> str: - """Get the data layout string. + def get_data_layout(self) -> str: + """Get the data data_layout string. Returns: - Data layout in string format + Data data_layout in string format """ if self.is_2x2x(): return "NN" return "N" - def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: + def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x: """Quantize function helper for block scaling FP8. Args: x: Input tensor to quantize is_colwise: Whether to use column-wise quantization dq_dtype: Data type for dequantized values + flatten_axis: The quantization axis for the tensor Returns: A ScaledTensor1x containing the quantized data """ # TODO(Phuong): use quantize_func from JAX + if flatten_axis < 0: + flatten_axis = x.ndim + flatten_axis + assert ( + 0 <= flatten_axis < x.ndim + ), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}" + dq_dtype = dq_dtype if dq_dtype is not None else x.dtype x_shape = x.shape - scale_shape = self.scaling_mode.get_scale_shape(x_shape, is_colwise, is_padded=False) + scale_shape = self.scaling_mode.get_scale_shape( + x_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis + ) scale_dtype = self.scaling_mode.get_scale_dtype() x = x.reshape( - *x_shape[:-2], - scale_shape[-2], - int(x_shape[-2] / scale_shape[-2]), + *x_shape[: flatten_axis - 1], + scale_shape[flatten_axis - 1], + int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]), + *x_shape[flatten_axis:-1], scale_shape[-1], int(x_shape[-1] / scale_shape[-1]), ) - amax = jnp.max(jnp.abs(x), axis=(-3, -1), keepdims=True) + amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True) MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32) scales = amax.astype(jnp.float32) / MAX @@ -409,6 +438,7 @@ def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: self.scaling_mode, is_colwise=is_colwise, dq_dtype=dq_dtype, + flatten_axis=flatten_axis, ) def _cast_to_e8m0_with_rounding_up(self, scales): @@ -509,7 +539,7 @@ def create( n_quantizers: int = 1, scaling_mode: ScalingMode = None, q_dtype: jnp.dtype = None, - q_axis: QuantizeAxis = None, + q_layout: QuantizeLayout = None, **kwargs, ) -> Quantizer: """Create one or more quantizers with specified parameters. @@ -518,7 +548,8 @@ def create( n_quantizers: Number of quantizers to create scaling_mode: Scaling mode to use q_dtype: Quantization data type - q_axis: Quantization axis + q_layout: Quantization axis + flatten_axis: The quantization axis for the tensor **kwargs: Additional arguments for quantizer initialization Returns: @@ -534,7 +565,7 @@ def create( quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode) quantizers.append( quantizer_type( - q_dtype=q_dtype, scaling_mode=scaling_mode, q_axis=q_axis, **kwargs + q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs ) ) return quantizers[0] if len(quantizers) == 1 else tuple(quantizers) @@ -554,11 +585,11 @@ def _create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs) -> Quanti A QuantizerSet instance """ if is_2x2x: - q_axis_x = q_axis_kernel = q_axis_dgrad = QuantizeAxis.ROWWISE_COLWISE + q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE else: - q_axis_x = QuantizeAxis.ROWWISE - q_axis_kernel = QuantizeAxis.COLWISE - q_axis_dgrad = None + q_layout_x = QuantizeLayout.ROWWISE + q_layout_kernel = QuantizeLayout.COLWISE + q_layout_dgrad = None if "quantize_meta_set" in kwargs: quantize_meta_set = kwargs.get("quantize_meta_set") @@ -577,9 +608,11 @@ def _create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs) -> Quanti else: args_x = args_kernel = args_grad = {} - q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_x, **args_x) - q_kernel = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_kernel, **args_kernel) - q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_axis_dgrad, **args_grad) + q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, **args_x) + q_kernel = QuantizerFactory.create( + 1, scaling_mode, fwd_dtype, q_layout_kernel, **args_kernel + ) + q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_layout_dgrad, **args_grad) return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad) @staticmethod diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 805c034334..a9c93a3553 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -40,7 +40,11 @@ def get_scale_dtype(self) -> jnp.dtype: @abstractmethod def get_scale_shape( - self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True + self, + data_shape: Tuple[int, ...], + is_colwise: bool = False, + is_padded: bool = True, + flatten_axis: int = -1, ) -> Tuple[int, ...]: """Get the shape for scale tensors. @@ -48,7 +52,7 @@ def get_scale_shape( data_shape: The shape of the tensor being quantized is_colwise: Whether the scaling is column-wise is_padded: Whether to return padded shape - + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. Returns: The shape for scale tensors """ @@ -69,7 +73,11 @@ def get_scale_dtype(self) -> jnp.dtype: return jnp.float32 def get_scale_shape( - self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True + self, + data_shape: Tuple[int, ...], + is_colwise: bool = False, + is_padded: bool = True, + flatten_axis: int = -1, ) -> Tuple[int, ...]: """Get the shape for scale tensors in delayed scaling. @@ -77,6 +85,7 @@ def get_scale_shape( data_shape: The shape of the tensor being scaled is_colwise: Whether the scaling is column-wise is_padded: Whether to return padded shape + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. Returns: The shape for scale tensors - (1,) @@ -113,8 +122,35 @@ def get_scale_dtype(self) -> jnp.dtype: """ return jnp.float8_e8m0fnu + def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_dim): + """Remove excess padding from the scale shape and return the shape with respect to the original data shape.""" + if len(data_shape) > 1: + # handle last dim + assert data_shape[-1] % scale_block_dim == 0 + last = data_shape[-1] // scale_block_dim + scale_shape = (last,) + assert n_scale_blocks % last == 0 + n_scale_blocks //= last + # handle middle dim, exclude first and last + for mid in reversed(data_shape[1:-1]): + scale_shape = (mid,) + scale_shape + assert n_scale_blocks % mid == 0 + n_scale_blocks //= mid + scale_shape = (n_scale_blocks,) + scale_shape + else: + scale_shape = (n_scale_blocks,) + + assert len(scale_shape) == len( + data_shape + ), f"scale_shape {scale_shape}, data_shape {data_shape}" + return scale_shape + def get_scale_shape( - self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True + self, + data_shape: Tuple[int, ...], + is_colwise: bool = False, + is_padded: bool = True, + flatten_axis: int = -1, ) -> Tuple[int, ...]: """Get the shape for scale tensors in block scaling. @@ -122,6 +158,7 @@ def get_scale_shape( data_shape: The shape of the tensor being quantized is_colwise: Whether the scaling is column-wise is_padded: Whether to return padded shape + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. Returns: The shape for scale tensors @@ -135,35 +172,48 @@ def get_scale_shape( block_x, block_y = self._block_dims alignment_x, alignment_y = block_alignment - seq_axis = len(data_shape) - 2 - + if flatten_axis < 0: + flatten_axis = len(data_shape) + flatten_axis assert ( - data_shape[seq_axis] % block_x == 0 - ), f"Input data of shape {data_shape} should be padded by {block_x} in axes={seq_axis}" + 0 < flatten_axis < len(data_shape) + ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" + + assert data_shape[flatten_axis - 1] % block_x == 0, ( + f"Data shape {data_shape} should be divisible by block_x {block_x} in axis" + f" {flatten_axis - 1}" + ) assert ( data_shape[-1] % block_y == 0 - ), f"Input data of shape {data_shape} should be padded by {block_y} in axis -1" + ), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1" - # NOTE: this overpads if dim > 2 and dims before seq_axis are greater than 1 - n_block_seq = data_shape[seq_axis] // block_x - n_block_y = data_shape[-1] // block_y + flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1) + flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1) - n_flat_first_dim = reduce(operator.mul, data_shape[:seq_axis], 1) * n_block_seq + assert flattened_first_dim % block_x == 0, ( + f"Flattened first dim - mutiplication of axes={tuple(range(0, flatten_axis))} of shape" + f" {data_shape} - should be divisible by block_x {block_x}" + ) + assert flattened_last_dim % block_y == 0, ( + "Flattened last dim - mutiplication of" + f" axes={tuple(range(flatten_axis, len(data_shape)))} of shape {data_shape} - should be" + f" divisible by block_y {block_y}" + ) - # Padding - n_flat_first_dim = ((n_flat_first_dim + alignment_x - 1) // alignment_x) * alignment_x - n_block_y = ((n_block_y + alignment_y - 1) // alignment_y) * alignment_y + n_block_x = int(flattened_first_dim / block_x) + n_block_y = int(flattened_last_dim / block_y) - out_shape = () - for i in range(seq_axis): - d = data_shape[i] - out_shape += (d,) - assert n_flat_first_dim % d == 0 - n_flat_first_dim //= d + # padding + n_block_x = int(((n_block_x + alignment_x - 1) // alignment_x) * alignment_x) + n_block_y = int(((n_block_y + alignment_y - 1) // alignment_y) * alignment_y) - out_shape += (n_flat_first_dim, n_block_y) + first_dim_scale_shape = self._apply_scale_shape_correction( + data_shape[:flatten_axis], n_block_x, block_x + ) + last_dim_scale_shape = self._apply_scale_shape_correction( + data_shape[flatten_axis:], n_block_y, block_y + ) - return out_shape + return (*first_dim_scale_shape, *last_dim_scale_shape) # (Phuong: Map the NVTEScalingMode value to the ScalingMode @@ -208,34 +258,40 @@ def get_scale_dtype(self): """ return self._get_impl().get_scale_dtype() - def get_scale_shape_2x(self, data_shape, is_padded=True) -> Tuple[Tuple[int]]: + def get_scale_shape_2x(self, data_shape, is_padded=True, flatten_axis=-1) -> Tuple[Tuple[int]]: """Get shapes for both row-wise and column-wise scaling. Args: data_shape: Shape of the data tensor is_padded: Whether to use padded shapes + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. Returns: Tuple of (rowwise_scale_shape, colwise_scale_shape) """ rowwise_scale_shape = self.get_scale_shape( - data_shape, is_colwise=False, is_padded=is_padded + data_shape, is_colwise=False, is_padded=is_padded, flatten_axis=flatten_axis + ) + colwise_scale_shape = self.get_scale_shape( + data_shape, is_colwise=True, is_padded=is_padded, flatten_axis=flatten_axis ) - colwise_scale_shape = self.get_scale_shape(data_shape, is_colwise=True, is_padded=is_padded) return (rowwise_scale_shape, colwise_scale_shape) - def get_scale_shape(self, data_shape, is_colwise, is_padded=True) -> Tuple[int]: + def get_scale_shape( + self, data_shape, is_colwise, is_padded=True, flatten_axis=-1 + ) -> Tuple[int]: """Get the shape for scale tensors in this mode. Args: data_shape: Shape of the data tensor is_colwise: Whether to use column-wise scaling is_padded: Whether to use padded shapes + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. Returns: The shape for scale tensors """ - return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded) + return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis) def __eq__(self, other): """Compare this scaling mode with another. diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 8c01dd9af0..c34a235d94 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -15,7 +15,7 @@ import jax.numpy as jnp from jax.tree_util import register_pytree_node_class -from transformer_engine_jax import QuantizeAxis +from transformer_engine_jax import QuantizeLayout from .scaling_modes import ScalingMode from .dequantizer import Dequantizer @@ -84,6 +84,17 @@ def get_colwise_tensor(self): ValueError: If called on a tensor that doesn't support column-wise access """ + @abstractmethod + def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): + """Applies sharding constraints to a tensor based on logical axis names. + + Args: + logical_axis_names: Tuple of logical axis names for sharding + + Returns: + The tensor with applied sharding constraints + """ + @register_pytree_node_class @dataclass @@ -100,7 +111,8 @@ class ScaledTensor1x(ScaledTensor): dq_dtype: The data type for dequantized values _dq_func: The dequantization function is_colwise: Whether the tensor uses column-wise quantization - layout: The layout specification for the tensor + data_layout: The data_layout specification for the tensor + flatten_axis: The quantization axis for the tensor """ data: jnp.ndarray @@ -109,7 +121,8 @@ class ScaledTensor1x(ScaledTensor): dq_dtype: jnp.dtype _dq_func: Callable is_colwise: bool - layout: str + data_layout: str + flatten_axis: int = -1 def __post_init__(self): """Validates and adjusts the scale_inv shape after initialization. @@ -117,11 +130,22 @@ def __post_init__(self): Ensures the scale_inv shape matches the expected shape based on the scaling mode and quantization direction. Pads the scale_inv if necessary. """ + flatten_axis = ( + len(self.data.shape) + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis + ) + assert ( + 0 < flatten_axis < len(self.data.shape) + ), f"flatten_axis {flatten_axis} is out of bounds for shape {self.data.shape}" + + if self.data_layout == "T": + flatten_axis = self.data.ndim - flatten_axis + self.flatten_axis = flatten_axis + expected_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=True + self.data.shape, self.is_colwise, is_padded=True, flatten_axis=flatten_axis ) expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=False + self.data.shape, self.is_colwise, is_padded=False, flatten_axis=flatten_axis ) if self.scale_inv.shape != expected_scale_shape: assert self.scale_inv.shape == expected_unpadded_scale_shape, ( @@ -144,7 +168,14 @@ def tree_flatten(self): A tuple containing (children, aux_data) for tree operations """ children = (self.data, self.scale_inv) - aux_data = (self.scaling_mode, self.dq_dtype, self._dq_func, self.is_colwise, self.layout) + aux_data = ( + self.scaling_mode, + self.dq_dtype, + self._dq_func, + self.is_colwise, + self.data_layout, + self.flatten_axis, + ) return (children, aux_data) def dequantize(self): @@ -183,6 +214,46 @@ def get_colwise_tensor(self): raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!") + def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): + """Applies sharding constraints to a tensor based on logical axis names. + + Args: + logical_axis_names: Tuple of logical axis names for sharding + + Returns: + The tensor with applied sharding constraints + """ + if not logical_axis_names: + return self + + # axis_names were given for N layout, so needs to be transpose for T layout + if self.data_layout == "T": + assert self.flatten_axis > 0 + flatten_axis = -self.flatten_axis + axis_names = (*logical_axis_names[flatten_axis:], *logical_axis_names[:flatten_axis]) + else: + axis_names = logical_axis_names + + data = with_sharding_constraint_by_logical_axes(self.data, axis_names) + + if self.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + # TODO(Phuong): Handle padding !? + scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names) + else: + scale_inv = self.scale_inv + + # TODO(Phuong): constaint padded scale_inv? + return ScaledTensor1x( + data=data, + scale_inv=scale_inv, + scaling_mode=self.scaling_mode, + dq_dtype=self.dq_dtype, + _dq_func=self._dq_func, + is_colwise=self.is_colwise, + data_layout=self.data_layout, + flatten_axis=self.flatten_axis, + ) + @register_pytree_node_class @dataclass @@ -233,6 +304,27 @@ def get_colwise_tensor(self): """ return self.colwise_tensor + def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): + """Applies sharding constraints to a tensor based on logical axis names. + + Args: + logical_axis_names: Tuple of logical axis names for sharding + + Returns: + The tensor with applied sharding constraints + """ + if not logical_axis_names: + return self + + rowwise_tensor = self.rowwise_tensor.apply_sharding_constraint_by_logical_axes( + logical_axis_names + ) + colwise_tensor = self.colwise_tensor.apply_sharding_constraint_by_logical_axes( + logical_axis_names + ) + + return ScaledTensor2x(rowwise_tensor, colwise_tensor) + @dataclass class ScaledTensorFactory: @@ -244,7 +336,13 @@ class ScaledTensorFactory: @staticmethod def create_1x( - data, scale_inv, scaling_mode, dq_dtype=jnp.bfloat16, is_colwise=False, layout="N" + data, + scale_inv, + scaling_mode, + dq_dtype=jnp.bfloat16, + is_colwise=False, + data_layout="N", + flatten_axis=-1, ): """Creates a single-scale quantized tensor. @@ -254,13 +352,16 @@ def create_1x( scaling_mode: The scaling mode for quantization dq_dtype: The data type for dequantized values (default: bfloat16) is_colwise: Whether to use column-wise quantization (default: False) - layout: The layout specification (default: "N") + data_layout: The data_layout specification (default: "N") + flatten_axis: The quantization axis for the tensor Returns: A ScaledTensor1x instance """ dq_func = Dequantizer.funcs.get(scaling_mode) - return ScaledTensor1x(data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, layout) + return ScaledTensor1x( + data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, data_layout, flatten_axis + ) @staticmethod def create_2x( @@ -270,7 +371,8 @@ def create_2x( colwise_scale_inv, scaling_mode, dq_dtype=jnp.bfloat16, - layout="NN", + data_layout="NN", + flatten_axis=-1, ): """Creates a double-scale quantized tensor. @@ -281,7 +383,8 @@ def create_2x( colwise_scale_inv: The column-wise inverse scaling factors scaling_mode: The scaling mode for quantization dq_dtype: The data type for dequantized values (default: bfloat16) - layout: The layout specification (default: "NN") + data_layout: The data_layout specification (default: "NN") + flatten_axis: The quantization axis for the tensor Returns: A ScaledTensor2x instance @@ -294,7 +397,8 @@ def create_2x( dq_dtype, dq_func, is_colwise=False, - layout=layout[0], + data_layout=data_layout[0], + flatten_axis=flatten_axis, ) colwise_tensor = ScaledTensor1x( colwise_data, @@ -303,7 +407,8 @@ def create_2x( dq_dtype, dq_func, is_colwise=True, - layout=layout[1], + data_layout=data_layout[1], + flatten_axis=flatten_axis, ) return ScaledTensor2x(rowwise_tensor, colwise_tensor) @@ -315,8 +420,9 @@ def create( colwise_scale_inv: jnp.ndarray, scaling_mode: ScalingMode, dq_dtype: jnp.dtype = jnp.bfloat16, - layout: str = "NN", - q_axis: QuantizeAxis = QuantizeAxis.ROWWISE, + data_layout: str = "NN", + q_layout: QuantizeLayout = QuantizeLayout.ROWWISE, + flatten_axis: int = -1, ): """Creates a scaled tensor based on the quantization axis. @@ -327,13 +433,13 @@ def create( colwise_scale_inv: The column-wise inverse scaling factors scaling_mode: The scaling mode for quantization dq_dtype: The data type for dequantized values (default: bfloat16) - layout: The layout specification (default: "NN") - q_axis: The quantization axis (default: ROWWISE) + data_layout: The data_layout specification (default: "NN") + q_layout: The quantization axis (default: ROWWISE) Returns: - Either a ScaledTensor1x or ScaledTensor2x instance depending on q_axis + Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout """ - if q_axis == QuantizeAxis.ROWWISE_COLWISE: + if q_layout == QuantizeLayout.ROWWISE_COLWISE: return ScaledTensorFactory.create_2x( data, scale_inv, @@ -341,12 +447,19 @@ def create( colwise_scale_inv, scaling_mode, dq_dtype, - layout=layout, + data_layout=data_layout, + flatten_axis=flatten_axis, ) - is_colwise = q_axis == QuantizeAxis.COLWISE + is_colwise = q_layout == QuantizeLayout.COLWISE return ScaledTensorFactory.create_1x( - data, scale_inv, scaling_mode, dq_dtype, is_colwise=is_colwise, layout=layout[0] + data, + scale_inv, + scaling_mode, + dq_dtype, + is_colwise=is_colwise, + data_layout=data_layout[0], + flatten_axis=flatten_axis, ) @@ -360,24 +473,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, . Returns: The tensor with applied sharding constraints """ - if isinstance(x, ScaledTensor1x): - return ScaledTensor1x( - data=with_sharding_constraint_by_logical_axes(x.data, logical_axis_names), - scale_inv=x.scale_inv, - scaling_mode=x.scaling_mode, - dq_dtype=x.dq_dtype, - _dq_func=x._dq_func, - is_colwise=x.is_colwise, - layout=x.layout, - ) - if isinstance(x, ScaledTensor2x): - return ScaledTensor2x( - rowwise_tensor=with_sharding_constraint_by_logical_axes( - x.rowwise_tensor, logical_axis_names - ), - colwise_tensor=with_sharding_constraint_by_logical_axes( - x.colwise_tensor, logical_axis_names - ), - ) + if isinstance(x, ScaledTensor): + return x.apply_sharding_constraint_by_logical_axes(logical_axis_names) return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 8e7ce93986..df3f38cbd1 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -89,7 +89,11 @@ def generate_pspec(logical_axis_names): Convert logical axes to PartitionSpec """ rules = get_sharding_map_logic_axis_to_mesh_axis() - mesh_axis_names = [rules[name] for name in logical_axis_names] + # mesh_axis_names = [rules[name] for name in logical_axis_names] + mesh_axis_names = [] + for name in logical_axis_names: + axis_name = rules[name] if name in rules else None + mesh_axis_names.append(axis_name) pspec = jax.sharding.PartitionSpec(*mesh_axis_names) return pspec @@ -112,7 +116,7 @@ def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: t """ A wrapper function to jax.lax.with_sharding_constraint to accept logical axes. """ - if logical_axis_names is None: + if not logical_axis_names: return x assert len(x.shape) == len(logical_axis_names) @@ -315,3 +319,25 @@ class ShardingType(Enum): TP_ROW = (MajorShardingType.TP, "tp_row") DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col") DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row") + + +def get_non_contracting_logical_axes(ndim, logical_axes, contracting_dims): + """Get logical axes for non-contracting dimensions. + + Args: + ndim: Number of dimensions in the tensor. + logical_axes: Tuple of logical axes for each dimension. + contracting_dims: Set of dimensions that are being contracted. + + Returns: + Tuple of logical axes for non-contracting dimensions. + """ + if not logical_axes: + logical_axes = (None,) * ndim + elif len(logical_axes) < ndim: + logical_axes = logical_axes + (None,) * (ndim - len(logical_axes)) + assert len(logical_axes) == ndim + + non_contracting_dims = [i for i in range(ndim) if i not in contracting_dims] + non_contracting_logical_axes = tuple(logical_axes[i] for i in non_contracting_dims) + return non_contracting_logical_axes