Skip to content

[JAX] NVFP4 recipe with option to enable/disable SR, RHT, and 2D quantization#2270

Merged
jberchtold-nvidia merged 20 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/nvfp4-recipe-flags
Oct 22, 2025
Merged

[JAX] NVFP4 recipe with option to enable/disable SR, RHT, and 2D quantization#2270
jberchtold-nvidia merged 20 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/nvfp4-recipe-flags

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Oct 13, 2025

Description

Support recipe flags for disabling SR, RHT, and 2D quantization for the NVFP4 recipe in TE/JAX.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Support disable_stochastic_rounding by not returning any stochastic rounding state in the QuantizeMeta returned from NVFP4 quantize config if disable_stochastic_rounding is set

  • Support disable_rht by moving the should_use_rht check into the quantizer itself and check the quantize config for disabling RHT. A new uses_rht field has been added to the ScaledTensor1x to further consolidate this logic and avoid duplicate downstream re-checking of RHT, such as in the gemm.

  • Support disable_2d_quantization by updating the NVFP4 quantize config to always return 1D quantization if disable_2d_quantization is set

  • Fix bug where SR rng state on the NVFP4Quantizer was being lost across VJP boundaries due to missing pytree flatten and unflatten overrides resulting in not applying SR. Specialized tree_flatten and tree_unflatten methods have now been implemented on NVFP4Quantizer to fix this issue.

  • Add checks to test_helper.py to test that NVFP4 quantizer's SR state is preserved across a VJP boundary. Confirmed the test fails before this fix to pytree flattening and after the fix it passes.

  • Fix SR state sharding assertion when the input data is replicated along all or some axes.

  • Fix usage of Python's default hash function, which is not deterministic across runs, in SR state creation. Now we use hashlib for consistent hashes and SR rng state with the same top-level seed.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 jax

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/nvfp4-recipe-flags branch from 030ebff to b1a8736 Compare October 13, 2025 23:54
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 jax

…uantizer

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/nvfp4-recipe-flags branch from b6ce86b to 559e7e2 Compare October 14, 2025 15:57
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@phu0ngng phu0ngng changed the title [JAX] Support nvfp4 recipe flags for disabling SR, RHT, and 2D quantization [JAX] NVFP4 recipe with option to enable/disable SR, RHT, and 2D quantization Oct 14, 2025
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/nvfp4-recipe-flags branch from 080d081 to 42b6350 Compare October 14, 2025 17:52
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 jax

@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/nvfp4-recipe-flags branch from 5cce979 to d1d179d Compare October 14, 2025 22:36
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/nvfp4-recipe-flags branch from d1d179d to 1d01859 Compare October 14, 2025 22:36
…vfp4-recipe-flags

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/nvfp4-recipe-flags branch from b89b260 to be9ca3e Compare October 14, 2025 22:41
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 jax

is_colwise: Whether the tensor uses column-wise quantization
data_layout: The data_layout specification for the tensor
flatten_axis: The quantization axis for the tensor
uses_rht: Whether the tensor uses the Randomized Hadamard Transform (RHT)
Copy link
Copy Markdown
Collaborator

@phu0ngng phu0ngng Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: I suggest naming it applied_rht or with_rht as RHT is already applied to the data when the tensor is created.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

is_colwise=is_colwise,
data_layout=data_layout,
flatten_axis=flatten_axis,
uses_rht=False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a TODO here that we will need to update this when we work on making GroupedGEMM + NVFP4 work?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

group_sizes: Array containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
rowwise_uses_rht: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't support rowwise + RHT at all, right? If we leave this here, I think we should probably add an assertion.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an assert here

Comment thread transformer_engine/jax/quantize/quantizer.py
Comment thread transformer_engine/jax/quantize/quantizer.py
uses_rht=use_rht,
)

def should_use_rht(self, q_layout=None):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, the quantizer probably should not have the right to decide whether to apply RHT.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread tests/jax/test_helper.py
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 jax

…RHT tests

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/nvfp4-recipe-flags branch from 98e78f5 to 18fd85e Compare October 15, 2025 23:56
jberchtold-nvidia and others added 2 commits October 15, 2025 16:58
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 jax

jberchtold-nvidia and others added 2 commits October 16, 2025 14:56
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 jax

phu0ngng
phu0ngng previously approved these changes Oct 20, 2025
Copy link
Copy Markdown
Collaborator

@phu0ngng phu0ngng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the new design very much!
Left a small comment but please feel free to merge it.

Comment on lines +171 to +177
def has_rht_applied(q: AbstractBaseTensor) -> bool:
return isinstance(q, ScaledTensor1x) and q.has_rht_applied

assert uses_rht(lhs_q) == uses_rht(rhs_q), (
"With NVFP4_1D_SCALING, if one operand is colwise quantized, the other must be colwise"
" quantized as well. This is to ensure the RHT is applied to both and will cancel out in"
" the GEMM."
assert has_rht_applied(lhs_q) == has_rht_applied(rhs_q), (
"With NVFP4_1D_SCALING, if one operand is quantized with RHT, the other must be quantized"
" with RHT as well. This is to ensure the RHT is applied to both and will cancel out in the"
" GEMM."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we simplify this block to

assert lhs_q.has_rht_applied == rhs_q.has_rht_applied, msg

?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lhs_q and rhs_q are not guaranteed to be ScaledTensor1x. If we are running a non-quantized GEMM, they could be NoScaleTensor which doesn't have has_rht_applied.

@phu0ngng phu0ngng added the 2.9.0 label Oct 20, 2025
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 jax

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 jax

@jberchtold-nvidia jberchtold-nvidia merged commit 818b30c into NVIDIA:main Oct 22, 2025
24 of 25 checks passed
@jberchtold-nvidia jberchtold-nvidia deleted the jberchtold/nvfp4-recipe-flags branch October 22, 2025 15:51
KshitijLakhani pushed a commit that referenced this pull request Oct 24, 2025
…tization (#2270)

* [JAX] Support recipe flags for disabling SR, RHT, and 2D quantization

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* lint

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* Fix issue with SR state being erased due to pytree handling of NVFP4Quantizer

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* Add test for SR state preservation across VJP boundaries

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* Fix sharding of SR rng state

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* lint

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* update tolerances slightly now that SR is enabled

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* lint

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* Use hashlib for deterministic hashes across runs for SR

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* rename uses_rht on scaled tensors to has_applied_rht

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* add assert

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* Move decision of whether to use RHT into helper.py and add dedicated RHT tests

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* lint

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* fix use_rht attr usage

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* fix pure-jax rht usage criteria

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* Adjust tolerances after rebase

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

---------

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants