[JAX] Flatten_axis for quantization and Sharding propagation fixes#1644
[JAX] Flatten_axis for quantization and Sharding propagation fixes#1644phu0ngng merged 34 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
| quantizer.q_layout = QuantizeLayout.ROWWISE_COLWISE | ||
| if flatten_axis < 0: | ||
| flatten_axis += rowwise.data.ndim | ||
| assert 0 < flatten_axis < rowwise.data.ndim, "flatten_axis is out of bounds" |
There was a problem hiding this comment.
This is 0 < flatten_axis rather than 0 <= flatten_axis as we need at least one axis to the left of the flatten axis so after flattening we have 2 axes to give to TE, right?
There was a problem hiding this comment.
Yes, TE requires the data to be 2D for MXFP8. So we can't accept axis = 0 and mistakenly flatten the data to 1D.
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci jax L1 |
| n_scale_blocks //= mid | ||
| scale_shape = (n_scale_blocks,) + scale_shape | ||
| else: | ||
| scale_shape = (n_scale_blocks,) |
There was a problem hiding this comment.
Do we support quantizing 1D tensors at all? I thought we require they're at least 2D. But if we do support 1D tensors, I agree this is the correct scale shape so okay with this change
There was a problem hiding this comment.
No, we don't.
But this function gets called for each part of the 2D-flattened tensor here https://github.com/NVIDIA/TransformerEngine/pull/1644/files#diff-0158638e30529db0bb268ae65eb085b2d22b52e6e0ff4891fe2c7ea9959eea79R209-R214
Therefore, the datashape can be 1D. Perhaps I should rename it to partial_data_shape.
There was a problem hiding this comment.
Oh I see now, that makes sense. Thanks!
| if is_2x: | ||
| if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: | ||
| colwise_x_spec = multidim_transpose(x_spec) | ||
| colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) |
There was a problem hiding this comment.
Note for @jreiffers, change in signature of this function
| out_spec = (*x_spec[:-2], None, x_spec[-2]) | ||
| scale_spec = get_padded_spec(arg_infos[1]) | ||
|
|
||
| out_spec = (*x_spec[:-2], x_spec[-1]) |
There was a problem hiding this comment.
Note for @jreiffers, how the signature of this primitive changed in this PR
…VIDIA#1644) * rename QuantizeAxis to QuantizeLayout, get_layout to get_data_layout, q_axis to q_layout * add fatten_axis option * added gated act to test encoder * sharding constraint fixes * fix padding when flattening first dim needs to be padded * update test sizes so that padding is tested * rm output sharding as it can be done in the flax module * sharding scale_inv for mxfp8 --------- Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…VIDIA#1644) * rename QuantizeAxis to QuantizeLayout, get_layout to get_data_layout, q_axis to q_layout * add fatten_axis option * added gated act to test encoder * sharding constraint fixes * fix padding when flattening first dim needs to be padded * update test sizes so that padding is tested * rm output sharding as it can be done in the flax module * sharding scale_inv for mxfp8 --------- Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com> Signed-off-by: Peter Dykas <wdykas@nvidia.com>
Description
In #1627, we enforced all the tensors to be flattenable to 2D tensor with axis=-1. This requires additional reshaping in JAX that merges dimensions, resulting in the loss of sharding information.
In this PR, we introduced
flatten_axisthat allows flattening the tensor to 2D via any axis. With this, merging axes is no longer needed; thus, the sharding information can be propagated correctly.Type of change
Checklist: