Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
9497a15
[JAX] Support recipe flags for disabling SR, RHT, and 2D quantization
jberchtold-nvidia Oct 13, 2025
b1a8736
lint
jberchtold-nvidia Oct 13, 2025
559e7e2
Fix issue with SR state being erased due to pytree handling of NVFP4Q…
jberchtold-nvidia Oct 14, 2025
01654b4
Add test for SR state preservation across VJP boundaries
jberchtold-nvidia Oct 14, 2025
42b6350
Fix sharding of SR rng state
jberchtold-nvidia Oct 14, 2025
e30e05e
lint
jberchtold-nvidia Oct 14, 2025
1d01859
update tolerances slightly now that SR is enabled
jberchtold-nvidia Oct 14, 2025
be9ca3e
Merge remote-tracking branch 'github-upstream/main' into jberchtold/n…
jberchtold-nvidia Oct 14, 2025
7e8e5d2
lint
jberchtold-nvidia Oct 14, 2025
30efd28
Use hashlib for deterministic hashes across runs for SR
jberchtold-nvidia Oct 15, 2025
3da43bd
rename uses_rht on scaled tensors to has_applied_rht
jberchtold-nvidia Oct 15, 2025
c0d1569
add assert
jberchtold-nvidia Oct 15, 2025
18fd85e
Move decision of whether to use RHT into helper.py and add dedicated …
jberchtold-nvidia Oct 15, 2025
4254f5a
Merge branch 'main' into jberchtold/nvfp4-recipe-flags
jberchtold-nvidia Oct 15, 2025
217632f
lint
jberchtold-nvidia Oct 16, 2025
d9c4efa
fix use_rht attr usage
jberchtold-nvidia Oct 16, 2025
d2bc2b3
Merge branch 'main' into jberchtold/nvfp4-recipe-flags
jberchtold-nvidia Oct 16, 2025
e36a5c0
fix pure-jax rht usage criteria
jberchtold-nvidia Oct 21, 2025
ffab48f
Merge branch 'main' into jberchtold/nvfp4-recipe-flags
jberchtold-nvidia Oct 21, 2025
6fac614
Adjust tolerances after rebase
jberchtold-nvidia Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/jax/encoder/test_multiprocessing_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def test_te_mxfp8(self):
def test_te_nvfp4(self):
"""Test Transformer Engine with NVFP4"""
result = self.exec(True, "NVFP4BlockScaling")
assert result[0] < 0.451 and result[1] > 0.79
assert result[0] < 0.451 and result[1] > 0.788

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
Expand Down Expand Up @@ -710,7 +710,7 @@ def test_te_mxfp8_shardy(self):
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True)
assert result[0] < 0.451 and result[1] > 0.79
assert result[0] < 0.451 and result[1] > 0.788


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/jax/encoder/test_single_gpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def test_te_nvfp4(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.476 and actual[1] > 0.775
assert actual[0] < 0.477 and actual[1] > 0.769


if __name__ == "__main__":
Expand Down
155 changes: 101 additions & 54 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
QuantizerFactory,
QuantizeLayout,
noop_quantizer_set,
should_use_rht,
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
Expand Down Expand Up @@ -685,21 +684,14 @@ class TestQuantize:
Purely quantization related tests that will always test on a wider set of types and shapes
"""

def _skip_for_fp4(self, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
"""Temporary hack to skip unsupported FP4 cases until we implement them"""
def _skip_unsupported_dtypes(self, q_dtype, scaling_mode):
"""Skip unsupported dtypes for given scaling mode. For example, NVFP4 only supports the float4_e2m1 dtype not float8 dtypes."""
if q_dtype not in scaling_mode.get_compatible_q_dtypes():
pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}")
return

# HACK: FIXME TODO(jberchtold)
row = reduce(operator.mul, input_shape[flatten_axis:], 1)
col = reduce(operator.mul, input_shape[:flatten_axis], 1)
will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout)
if will_use_rht and (row % 64 != 0 or col % 128 != 0):
pytest.skip("Unfused RHT is not supported currently, skipping")

def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis)
self._skip_unsupported_dtypes(q_dtype, scaling_mode)

key = jax.random.PRNGKey(0)

Expand Down Expand Up @@ -780,22 +772,8 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt
assert_dequantized_scaled_tensor(scaled_tensor, x)

def _should_use_precise_comparison(
self, in_dtype, scaling_mode, q_layout, input_shape, flatten_axis
self, in_dtype, scaling_mode, quantizer, input_shape, flatten_axis
):
# TODO(jberchtold): Remove this hack once we have a better solution to ensure bitwise identical results between TE and JAX RHT+quant implementations. Currently for certain shapes the quantized fp4 data differs by a small amount on <0.5% of the values.
RHT_SLIGHT_MISMATCH_SHAPES = [
((32, 256, 128), -1),
((64, 32, 32, 256), -1),
((8192, 2, 4096), -2),
]

if (
should_use_rht(scaling_mode, q_layout=q_layout)
and (input_shape, flatten_axis) in RHT_SLIGHT_MISMATCH_SHAPES
):
# TE fused RHT+quant and JAX RHT+quant have slight implementation differences which can lead to small numerical differences on certain shapes
return False

if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16:
# With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation
return False
Expand All @@ -805,7 +783,7 @@ def _should_use_precise_comparison(
def test_quantize_bitwise(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis)
self._skip_unsupported_dtypes(q_dtype, scaling_mode)

key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
Expand All @@ -816,28 +794,20 @@ def test_quantize_bitwise(

jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)

try:
te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
except AssertionError as e:
if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16:
error_message = e.args[0]
if "RHT requires input to be bfloat16" in error_message:
# Successfully caught the expected error, early return from the test
return
raise e
te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)

assert_bitwise_scaled_tensors(
te_output,
jax_output,
precise_comparison=self._should_use_precise_comparison(
in_dtype, scaling_mode, q_layout, input_shape, flatten_axis
in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis
),
)

def test_quantize_bitwise_jitted(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis)
self._skip_unsupported_dtypes(q_dtype, scaling_mode)

key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
Expand All @@ -851,21 +821,13 @@ def test_quantize_bitwise_jitted(

jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)

try:
te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
except AssertionError as e:
if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16:
error_message = e.args[0]
if "RHT requires input to be bfloat16" in error_message:
# Successfully caught the expected error, early return from the test
return
raise e
te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis)

assert_bitwise_scaled_tensors(
te_output,
jax_output,
precise_comparison=self._should_use_precise_comparison(
in_dtype, scaling_mode, q_layout, input_shape, flatten_axis
in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis
),
)

Expand Down Expand Up @@ -985,12 +947,6 @@ def _test_sr(

def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
"""Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other."""
# HACK: FIXME TODO(jberchtold)
row = reduce(operator.mul, input_shape[flatten_axis:], 1)
col = reduce(operator.mul, input_shape[:flatten_axis], 1)
will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout)
if will_use_rht and (row % 64 != 0 or col % 128 != 0):
pytest.skip("Unfused RHT is not supported currently, skipping")

key = jax.random.PRNGKey(0)
inputs = jax.random.uniform(key, input_shape, in_dtype)
Expand All @@ -1007,6 +963,97 @@ def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout,
assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4)


@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16])
@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn])
@pytest_parametrize_wrapper(
"scaling_mode", [s for s in supported_scaling_modes if s == ScalingMode.NVFP4_1D_SCALING]
)
class TestRandomizedHadamardTransform:

@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
)
@pytest_parametrize_wrapper("input_shape,flatten_axis", [((64, 128), -1)])
def test_rht_quantize_bitwise_jitted(
self, in_dtype, q_dtype, scaling_mode, q_layout, input_shape, flatten_axis
):
key = jax.random.PRNGKey(0)
inputs = jax.random.uniform(key, input_shape, in_dtype)

te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2,
q_dtype=q_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
use_rht=True,
)

jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3))
te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,))

jax_output = jax_impl_func_jit(inputs, quantizer=jax_quantizer, flatten_axis=flatten_axis)

te_output = te_impl_func_jit(inputs, quantizer=te_quantizer, flatten_axis=flatten_axis)

assert_bitwise_scaled_tensors(te_output, jax_output)

def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
if data_layout[0] == "T":
a = jnp.swapaxes(a, -1, -2)
if data_layout[1] == "T":
b = jnp.swapaxes(b, -1, -2)
return jnp.dot(a, b)

def _generate_gemm_input(self, m, n, k, data_layout):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
x = jax.random.uniform(
subkeys[0],
(m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=jnp.bfloat16,
) / jnp.sqrt(k)
w = jax.random.uniform(
subkeys[1],
(k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
dtype=jnp.bfloat16,
) / jnp.sqrt(n)
lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)

return (x, w, contracting_dims)

@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
# We do not test NN and TT layouts here as they do not have both inputs using RHT due to RHT only supporting the colwise layout currently
@pytest_parametrize_wrapper("data_layout", ["TN", "NT"])
@pytest_parametrize_wrapper("with_jax_gemm", [True, False])
def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, with_jax_gemm):
key = jax.random.PRNGKey(0)

lhs_scaling_mode, rhs_scaling_mode = scaling_mode, scaling_mode
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
lhs_quantizer = QuantizerFactory.create(
scaling_mode=lhs_scaling_mode,
q_dtype=jnp.float4_e2m1fn,
use_rht=True,
)
rhs_quantizer = QuantizerFactory.create(
scaling_mode=rhs_scaling_mode,
q_dtype=jnp.float4_e2m1fn,
use_rht=True,
)
with use_jax_gemm(enabled=with_jax_gemm):
primitive_out = tex.gemm(
x,
w,
contracting_dims=contracting_dims,
lhs_quantizer=lhs_quantizer,
rhs_quantizer=rhs_quantizer,
)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn)


@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
Expand Down
82 changes: 81 additions & 1 deletion tests/jax/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# See LICENSE for license information.

import unittest
from functools import partial

import flax
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn

from utils import assert_allclose
from transformer_engine.common.recipe import (
Expand All @@ -24,15 +26,51 @@
ScalingMode,
update_collections,
TensorSource,
QuantizerFactory,
QuantizeLayout,
)
from transformer_engine.jax.quantize.helper import _format2dtypes
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
from transformer_engine.jax.flax.module import TransformerEngineBase

is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)


def quantizer_check_vjp(outer_quantizer_set, assertion_func, x):
"""Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""

# Define a function with a custom VJP (vector-Jacobian product)
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def quantizer_check(inner_quantizer_set, assertion_func, x):
return quantizer_check_fwd(inner_quantizer_set, assertion_func, x)

def quantizer_check_fwd(inner_quantizer_set, assertion_func, x):
assertion_func(inner_quantizer_set.x, TensorSource.X)
assertion_func(inner_quantizer_set.kernel, TensorSource.KERNEL)
assertion_func(inner_quantizer_set.dgrad, TensorSource.DGRAD)
return x

def quantizer_check_bwd(ctx, g):
return (g,)

quantizer_check.defvjp(quantizer_check_fwd, quantizer_check_bwd)
return quantizer_check(outer_quantizer_set, assertion_func, x)


class TestModule(TransformerEngineBase):
"""A simple module to test quantizer creation and reconstruction across VJP boundaries."""

# Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None
assertion_func: callable

@nn.compact
def __call__(self, x):
quantizer_set = self.generate_quantizer_set()
return quantizer_check_vjp(quantizer_set, self.assertion_func, x)


class TestHelper(unittest.TestCase):

@unittest.skipIf(not is_fp8_supported, reason=reason)
Expand Down Expand Up @@ -89,12 +127,43 @@ def _compare_nvfp4_scaling(self, test):
for tensor_source in TensorSource:
target_scaling_mode = (
ScalingMode.NVFP4_2D_SCALING
if tensor_source == TensorSource.KERNEL
if (not test.disable_2d_quantization) and tensor_source == TensorSource.KERNEL
else ScalingMode.NVFP4_1D_SCALING
)
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode
)
self.assertEqual(
get_quantize_config().DISABLE_STOCHASTIC_ROUNDING, test.disable_stochastic_rounding
)
self.assertEqual(get_quantize_config().DISABLE_RHT, test.disable_rht)
self.assertEqual(
get_quantize_config().DISABLE_2D_QUANTIZATION, test.disable_2d_quantization
)

def _compare_nvfp4_scaling_quantizers(self, test):
"""Check that the quantizers created have the expected stochastic rounding state and the state is preserved across VJP boundaries."""

def assertion_func(quantizer, tensor_source):
if test.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD:
self.assertIsNone(quantizer.stochastic_rounding_rng_state)
else:
self.assertIsNotNone(quantizer.stochastic_rounding_rng_state)

expected_rht = (
quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
and quantizer.q_layout in {QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE}
and not test.disable_rht
)
self.assertEqual(quantizer.use_rht, expected_rht)

x = jnp.ones((), dtype=jnp.float32)
test_module = TestModule(assertion_func=assertion_func)
param_key, sr_key = jax.random.split(jax.random.PRNGKey(0))
rngs = {"params": param_key, "sr_rng": sr_key}
variables = test_module.init(rngs, x)

jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs)

@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_autocast_delayed_scaling(self):
Expand Down Expand Up @@ -171,5 +240,16 @@ def test_autocast_nvfp4_block_scaling(self):
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_nvfp4_scaling(bs)
self._compare_nvfp4_scaling_quantizers(bs)

bs = NVFP4BlockScaling(
disable_stochastic_rounding=True,
disable_rht=True,
disable_2d_quantization=True,
)
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_nvfp4_scaling(bs)
self._compare_nvfp4_scaling_quantizers(bs)
Comment thread
jberchtold-nvidia marked this conversation as resolved.

self._check_default_state()
Loading
Loading