Skip to content

[JAX] Flatten_axis for quantization and Sharding propagation fixes#1644

Merged
phu0ngng merged 34 commits intoNVIDIA:mainfrom
phu0ngng:quantize_axis
Apr 4, 2025
Merged

[JAX] Flatten_axis for quantization and Sharding propagation fixes#1644
phu0ngng merged 34 commits intoNVIDIA:mainfrom
phu0ngng:quantize_axis

Conversation

@phu0ngng
Copy link
Collaborator

@phu0ngng phu0ngng commented Apr 3, 2025

Description

In #1627, we enforced all the tensors to be flattenable to 2D tensor with axis=-1. This requires additional reshaping in JAX that merges dimensions, resulting in the loss of sharding information.

In this PR, we introduced flatten_axis that allows flattening the tensor to 2D via any axis. With this, merging axes is no longer needed; thus, the sharding information can be propagated correctly.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

phu0ngng added 20 commits April 1, 2025 05:16
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
phu0ngng and others added 2 commits April 3, 2025 15:34
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
quantizer.q_layout = QuantizeLayout.ROWWISE_COLWISE
if flatten_axis < 0:
flatten_axis += rowwise.data.ndim
assert 0 < flatten_axis < rowwise.data.ndim, "flatten_axis is out of bounds"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is 0 < flatten_axis rather than 0 <= flatten_axis as we need at least one axis to the left of the flatten axis so after flattening we have 2 axes to give to TE, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, TE requires the data to be 2D for MXFP8. So we can't accept axis = 0 and mistakenly flatten the data to 1D.

phu0ngng added 6 commits April 4, 2025 06:32
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
phu0ngng and others added 3 commits April 4, 2025 07:38
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng phu0ngng marked this pull request as ready for review April 4, 2025 14:49
phu0ngng and others added 3 commits April 4, 2025 08:01
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng
Copy link
Collaborator Author

phu0ngng commented Apr 4, 2025

/te-ci jax L1

n_scale_blocks //= mid
scale_shape = (n_scale_blocks,) + scale_shape
else:
scale_shape = (n_scale_blocks,)
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we support quantizing 1D tensors at all? I thought we require they're at least 2D. But if we do support 1D tensors, I agree this is the correct scale shape so okay with this change

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we don't.

But this function gets called for each part of the 2D-flattened tensor here https://github.com/NVIDIA/TransformerEngine/pull/1644/files#diff-0158638e30529db0bb268ae65eb085b2d22b52e6e0ff4891fe2c7ea9959eea79R209-R214
Therefore, the datashape can be 1D. Perhaps I should rename it to partial_data_shape.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see now, that makes sense. Thanks!

@phu0ngng phu0ngng merged commit ff884e2 into NVIDIA:main Apr 4, 2025
22 checks passed
@phu0ngng phu0ngng deleted the quantize_axis branch April 4, 2025 18:47
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec)
colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for @jreiffers, change in signature of this function

out_spec = (*x_spec[:-2], None, x_spec[-2])
scale_spec = get_padded_spec(arg_infos[1])

out_spec = (*x_spec[:-2], x_spec[-1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for @jreiffers, how the signature of this primitive changed in this PR

lhb8125 pushed a commit to lhb8125/TransformerEngine that referenced this pull request Apr 8, 2025
…VIDIA#1644)

* rename QuantizeAxis to QuantizeLayout, get_layout to get_data_layout, q_axis to q_layout

* add fatten_axis option

* added gated act to test encoder

* sharding constraint fixes

* fix padding when flattening first dim needs to be padded

* update test sizes so that padding is tested

* rm output sharding as it can be done in the flax module

* sharding scale_inv for mxfp8

---------

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
wdykas pushed a commit to wdykas/TransformerEngine that referenced this pull request Apr 14, 2025
…VIDIA#1644)

* rename QuantizeAxis to QuantizeLayout, get_layout to get_data_layout, q_axis to q_layout

* add fatten_axis option

* added gated act to test encoder

* sharding constraint fixes

* fix padding when flattening first dim needs to be padded

* update test sizes so that padding is tested

* rm output sharding as it can be done in the flax module

* sharding scale_inv for mxfp8

---------

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Peter Dykas <wdykas@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants