[JAX] NVFP4 recipe with option to enable/disable SR, RHT, and 2D quantization#2270
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci L1 jax |
030ebff to
b1a8736
Compare
|
/te-ci L1 jax |
…uantizer Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
b6ce86b to
559e7e2
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
080d081 to
42b6350
Compare
|
/te-ci L1 jax |
5cce979 to
d1d179d
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
d1d179d to
1d01859
Compare
…vfp4-recipe-flags Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
b89b260 to
be9ca3e
Compare
|
/te-ci L1 jax |
| is_colwise: Whether the tensor uses column-wise quantization | ||
| data_layout: The data_layout specification for the tensor | ||
| flatten_axis: The quantization axis for the tensor | ||
| uses_rht: Whether the tensor uses the Randomized Hadamard Transform (RHT) |
There was a problem hiding this comment.
Nitpick: I suggest naming it applied_rht or with_rht as RHT is already applied to the data when the tensor is created.
| is_colwise=is_colwise, | ||
| data_layout=data_layout, | ||
| flatten_axis=flatten_axis, | ||
| uses_rht=False, |
There was a problem hiding this comment.
Could you add a TODO here that we will need to update this when we work on making GroupedGEMM + NVFP4 work?
| group_sizes: Array containing the size of each group (default: None) | ||
| original_shape: The original shape of the tensor before grouping (default: None) | ||
| group_axis: The axis along which grouping is performed (default: 0) | ||
| rowwise_uses_rht: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) |
There was a problem hiding this comment.
We don't support rowwise + RHT at all, right? If we leave this here, I think we should probably add an assertion.
There was a problem hiding this comment.
Added an assert here
| uses_rht=use_rht, | ||
| ) | ||
|
|
||
| def should_use_rht(self, q_layout=None): |
There was a problem hiding this comment.
Same here, the quantizer probably should not have the right to decide whether to apply RHT.
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci L1 jax |
…RHT tests Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
98e78f5 to
18fd85e
Compare
|
/te-ci L1 jax |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci L1 jax |
phu0ngng
left a comment
There was a problem hiding this comment.
I like the new design very much!
Left a small comment but please feel free to merge it.
| def has_rht_applied(q: AbstractBaseTensor) -> bool: | ||
| return isinstance(q, ScaledTensor1x) and q.has_rht_applied | ||
|
|
||
| assert uses_rht(lhs_q) == uses_rht(rhs_q), ( | ||
| "With NVFP4_1D_SCALING, if one operand is colwise quantized, the other must be colwise" | ||
| " quantized as well. This is to ensure the RHT is applied to both and will cancel out in" | ||
| " the GEMM." | ||
| assert has_rht_applied(lhs_q) == has_rht_applied(rhs_q), ( | ||
| "With NVFP4_1D_SCALING, if one operand is quantized with RHT, the other must be quantized" | ||
| " with RHT as well. This is to ensure the RHT is applied to both and will cancel out in the" | ||
| " GEMM." |
There was a problem hiding this comment.
Could we simplify this block to
assert lhs_q.has_rht_applied == rhs_q.has_rht_applied, msg
?
There was a problem hiding this comment.
The lhs_q and rhs_q are not guaranteed to be ScaledTensor1x. If we are running a non-quantized GEMM, they could be NoScaleTensor which doesn't have has_rht_applied.
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci L1 jax |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci L1 jax |
…tization (#2270) * [JAX] Support recipe flags for disabling SR, RHT, and 2D quantization Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * lint Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Fix issue with SR state being erased due to pytree handling of NVFP4Quantizer Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Add test for SR state preservation across VJP boundaries Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Fix sharding of SR rng state Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * lint Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * update tolerances slightly now that SR is enabled Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * lint Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Use hashlib for deterministic hashes across runs for SR Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * rename uses_rht on scaled tensors to has_applied_rht Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * add assert Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Move decision of whether to use RHT into helper.py and add dedicated RHT tests Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * lint Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * fix use_rht attr usage Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * fix pure-jax rht usage criteria Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Adjust tolerances after rebase Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Description
Support recipe flags for disabling SR, RHT, and 2D quantization for the NVFP4 recipe in TE/JAX.
Type of change
Changes
Support
disable_stochastic_roundingby not returning any stochastic rounding state in the QuantizeMeta returned from NVFP4 quantize config ifdisable_stochastic_roundingis setSupport
disable_rhtby moving theshould_use_rhtcheck into the quantizer itself and check the quantize config for disabling RHT. A newuses_rhtfield has been added to theScaledTensor1xto further consolidate this logic and avoid duplicate downstream re-checking of RHT, such as in the gemm.Support
disable_2d_quantizationby updating the NVFP4 quantize config to always return 1D quantization ifdisable_2d_quantizationis setFix bug where SR rng state on the
NVFP4Quantizerwas being lost across VJP boundaries due to missing pytree flatten and unflatten overrides resulting in not applying SR. Specializedtree_flattenandtree_unflattenmethods have now been implemented onNVFP4Quantizerto fix this issue.Add checks to
test_helper.pyto test that NVFP4 quantizer's SR state is preserved across a VJP boundary. Confirmed the test fails before this fix to pytree flattening and after the fix it passes.Fix SR state sharding assertion when the input data is replicated along all or some axes.
Fix usage of Python's default
hashfunction, which is not deterministic across runs, in SR state creation. Now we usehashlibfor consistent hashes and SR rng state with the same top-level seed.Checklist: