Skip to content

[JAX] Change dtype of intermediate result aval of fused_topk_and_score_function_fwd to fp32#2752

Merged
tdophung merged 4 commits intoNVIDIA:mainfrom
tdophung:router_fixed_grads_dtype
Mar 11, 2026
Merged

[JAX] Change dtype of intermediate result aval of fused_topk_and_score_function_fwd to fp32#2752
tdophung merged 4 commits intoNVIDIA:mainfrom
tdophung:router_fixed_grads_dtype

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented Mar 10, 2026

Fixed aval for intermediate results (softmaxed/sigmoided logits from fwd) to pass as residuals to fused_topk_and_score_function_bwd to CompType, which is currently hardcoded to fp32 in transformerengine/common. This prevents incorrect reading of this buffer when logits dtype used are not the same type as fp32. This was observed when integrating router to maxtext, and while grads flowing into fused_topk_and_score_function_bwd is clean of nans, the output from this kernel were nans.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Change dtype of intermediate result to CompType

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…ass as residuals to CompType, which is currently fp32. This prevents incorrect reading of this buffer when logits dtype used are not fp32

Signed-off-by: tdophung <tdophung@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 10, 2026

Greptile Summary

This PR fixes a dtype mismatch bug in the JAX MoE router: the intermediate buffer (holding softmax/sigmoid values saved for the backward pass) was declared by JAX's abstract evaluation as the logits dtype (e.g., bfloat16), while the CUDA kernel always writes it as float32 (CompType). The mismatch caused the backward kernel to misread the buffer, producing NaN gradients at the logits level.

Key changes:

  • FusedTopkWithScoreFunctionFwdPrimitive.abstract (router.py): intermediate_aval dtype changed from i_dtypejnp.float32, with an explanatory comment.
  • FusedTopkWithScoreFunctionForwardFFI (router.cpp): Reads intermediate_buf->element_type(), validates it equals DType::kFloat32 via NVTE_CHECK, then wraps as DType::kFloat32. Previously used the logits dtype.
  • FusedTopkWithScoreFunctionBackwardFFI (router.cpp): Symmetrically adds an NVTE_CHECK for the intermediate dtype, and introduces a separate grad_dtype derived from grad_probs_buf. This also fixes a secondary pre-existing bug where grad_probs and grad_logits tensors were previously wrapped using the intermediate buffer's dtype rather than the actual grad dtype.
  • No regression test is included for the fix, which makes it harder to confirm the bug is resolved and prevent future regressions. Consider adding a test that exercises the forward+backward with bfloat16 logits and verifies that saved_scores is float32 and that backward grads are NaN-free.

Confidence Score: 4/5

  • This PR is safe to merge — the fix is minimal, well-reasoned, and correctly aligned across Python aval and C++ kernel layers.
  • The root cause is clearly identified (JAX abstract dtype != CUDA kernel's actual output dtype), the fix is applied consistently in both the Python abstract and the C++ FFI handlers, and the secondary backward bug (wrong dtype for grad tensors) is also addressed. The only concern is the absence of a regression test, which slightly lowers confidence.
  • No files require special attention beyond the missing test coverage noted in the summary.

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/router.py Fixed intermediate_aval dtype from logits dtype to jnp.float32 in FusedTopkWithScoreFunctionFwdPrimitive.abstract, with a clear comment explaining the rationale. The backward abstract is unchanged and correctly derives grad_logits dtype from grad_probs.
transformer_engine/jax/csrc/extensions/router.cpp Forward now validates intermediate dtype is float32 via NVTE_CHECK and wraps the buffer as DType::kFloat32. Backward is symmetrically updated: validates intermediate is float32, derives a separate grad_dtype from grad_probs_buf, and uses it for both grad tensors — fixing a secondary bug where grad_probs was previously wrapped using the intermediate buffer's (incorrect) dtype.

Sequence Diagram

sequenceDiagram
    participant JAX as JAX Tracing
    participant PyFwd as FwdPrimitive.abstract
    participant CppFwd as ForwardFFI (C++)
    participant PyBwd as BwdPrimitive.abstract
    participant CppBwd as BackwardFFI (C++)

    JAX->>PyFwd: logits_aval (dtype=bf16/fp16)
    PyFwd-->>JAX: probs_aval(bf16), routing_map_aval(bool), intermediate_aval(fp32) ✅ fixed
    JAX->>CppFwd: logits(bf16), intermediate buffer(fp32)
    CppFwd->>CppFwd: NVTE_CHECK intermediate==kFloat32 ✅ new guard
    CppFwd-->>JAX: probs(bf16), routing_map(bool), intermediate(fp32)
    JAX->>PyBwd: intermediate_aval(fp32), grad_probs_aval(bf16)
    PyBwd-->>JAX: grad_logits_aval(bf16)
    JAX->>CppBwd: intermediate(fp32), grad_probs(bf16)
    CppBwd->>CppBwd: NVTE_CHECK intermediate==kFloat32 ✅ new guard
    CppBwd->>CppBwd: grad_dtype = grad_probs dtype ✅ fixed (was intermediate dtype)
    CppBwd-->>JAX: grad_logits(bf16)
Loading

Comments Outside Diff (1)

  1. transformer_engine/jax/cpp_extensions/router.py, line 235-251 (link)

    Consider asserting intermediate dtype in backward abstract

    The forward abstract now explicitly declares intermediate_aval as jnp.float32, and the C++ guards it with NVTE_CHECK. However, FusedTopkWithScoreFunctionBwdPrimitive.abstract does not validate intermediate_aval.dtype at the Python level.

    Adding an assertion here would surface dtype mismatches (e.g., if the forward residual is accidentally re-cast before being passed as saved_scores) as a clear Python error rather than a C++ abort deep in the kernel dispatch:

        assert intermediate_aval.dtype == jnp.float32, (
            f"intermediate (saved_scores) must be float32 (CompType); got {intermediate_aval.dtype}. "
            "Check FusedTopkWithScoreFunctionFwdPrimitive.abstract."
        )

Last reviewed commit: 1c791d5

Comment on lines +101 to +112
// intermediate is always float32 (CompType) regardless of logits dtype.
auto intermediate_dtype = convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type());
auto grad_dtype = convert_ffi_datatype_to_te_dtype(grad_probs_buf.element_type());
auto dims = intermediate_buf.dimensions();
auto num_tokens = static_cast<int>(product(dims, 0, dims.size() - 1));
auto num_experts = static_cast<int>(dims[dims.size() - 1]);

auto flat_shape =
std::vector<size_t>{static_cast<size_t>(num_tokens), static_cast<size_t>(num_experts)};

auto intermediate_tensor = TensorWrapper(intermediate_buf.untyped_data(), flat_shape, dtype);
auto grad_probs_tensor = TensorWrapper(grad_probs_buf.untyped_data(), flat_shape, dtype);
auto grad_logits_tensor = TensorWrapper(grad_logits_buf->untyped_data(), flat_shape, dtype);
auto intermediate_tensor =
TensorWrapper(intermediate_buf.untyped_data(), flat_shape, intermediate_dtype);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistency in defensive dtype handling vs. forward

The comment on line 101 states // intermediate is always float32 (CompType) regardless of logits dtype., but unlike the forward pass (line 45) which explicitly hardcodes DType::kFloat32 for the intermediate_tensor, the backward reads the dtype dynamically from intermediate_buf.element_type(). After this fix the two are consistent in practice (because the Python aval now declares float32), but the backward's approach is less self-documenting and doesn't self-enforce the invariant.

If the Python aval were ever accidentally reverted, the forward would still correctly treat the buffer as float32, while the backward would silently pick up the wrong dtype from whatever JAX passes.

Consider making the backward match the forward's explicit approach:

  // intermediate is always float32 (CompType) regardless of logits dtype.
  auto intermediate_dtype = DType::kFloat32;
  auto grad_dtype = convert_ffi_datatype_to_te_dtype(grad_probs_buf.element_type());

@tdophung tdophung changed the title [JAX] Change dtype of intermediate result aval of fused_topk_and_score_function_fwd to CompType [JAX] Change dtype of intermediate result aval of fused_topk_and_score_function_fwd to fp32 Mar 10, 2026
@tdophung
Copy link
Copy Markdown
Collaborator Author

Since @denera mentioned that CompType will soon be configurable, instead of hard coded to fp32, I wonder if I should just export the CompType to router.cpp and set the aval dtype in cpp_extensions/router.py to that exported CompType instead of hardcoding like this to fp32.

Wdyt @jberchtold-nvidia @phu0ngng ?

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

Since @denera mentioned that CompType will soon be configurable, instead of hard coded to fp32, I wonder if I should just export the CompType to router.cpp and set the aval dtype in cpp_extensions/router.py to that exported CompType instead of hardcoding like this to fp32.

Wdyt @jberchtold-nvidia @phu0ngng ?

I have a slight preference to wait on this until the change has been made in TE common and have this PR just be the bugfix to make the intermediate buf always fp32.

If we make this change to have a generic CompType now before TE common is ready, we could get ahead on the boilerplate logic here. However, we'll still need asserts in Python to prevent users from passing a comp dtype other than fp32 since TE common as of now doesn't support it. Then once TE common supports compute dtypes other than fp32, we'll need a follow-up TE/JAX PR to update this assert to the new supported compute dtypes and add unit tests for them.

Given we'll need a follow-up PR here anyways once TE common supports other compute dtypes, I feel like it's easier to have this PR just be the bugfix and wait for the generalizing until TE common changes have merged.

Wdyt?

auto routing_map_tensor = TensorWrapper(routing_map, flat_shape, DType::kByte);
auto intermediate_tensor = TensorWrapper(intermediate, flat_shape, dtype);
// intermediate is always float32 (CompType) regardless of logits dtype.
auto intermediate_tensor = TensorWrapper(intermediate, flat_shape, DType::kFloat32);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an NVTE_CHECK to assert the dtype of intermediate_buf is float32 just in case something every breaks on the Python abstract side?

@tdophung
Copy link
Copy Markdown
Collaborator Author

Since @denera mentioned that CompType will soon be configurable, instead of hard coded to fp32, I wonder if I should just export the CompType to router.cpp and set the aval dtype in cpp_extensions/router.py to that exported CompType instead of hardcoding like this to fp32.
Wdyt @jberchtold-nvidia @phu0ngng ?

I have a slight preference to wait on this until the change has been made in TE common and have this PR just be the bugfix to make the intermediate buf always fp32.

If we make this change to have a generic CompType now before TE common is ready, we could get ahead on the boilerplate logic here. However, we'll still need asserts in Python to prevent users from passing a comp dtype other than fp32 since TE common as of now doesn't support it. Then once TE common supports compute dtypes other than fp32, we'll need a follow-up TE/JAX PR to update this assert to the new supported compute dtypes and add unit tests for them.

Given we'll need a follow-up PR here anyways once TE common supports other compute dtypes, I feel like it's easier to have this PR just be the bugfix and wait for the generalizing until TE common changes have merged.

Wdyt?

That makes sense. I did not think about the possibility that users could pass in a different dtype and mess the backend kernels up as it is not ready for anything but fp32. It will require another PR anyways. I will keep this as fp32 and add the NVTE_CHECK

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Copy Markdown
Collaborator Author

/te-ci jax

Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Comment on lines +107 to +110
NVTE_CHECK(convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type()) == DType::kFloat32,
"intermediate_output must be float32 (CompType); got dtype ",
static_cast<int>(convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type())),
". Check ROUTER_COMP_DTYPE in cpp_extensions/router.py.");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convert_ffi_datatype_to_te_dtype called twice; use a local variable

The backward's NVTE_CHECK calls convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type()) twice — once for the comparison and once to format the error integer. This is inconsistent with the forward pass, which stores the result in intermediate_dtype before the check and reuses it. Calling the conversion twice is unnecessary and makes the code harder to read.

Suggested change
NVTE_CHECK(convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type()) == DType::kFloat32,
"intermediate_output must be float32 (CompType); got dtype ",
static_cast<int>(convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type())),
". Check ROUTER_COMP_DTYPE in cpp_extensions/router.py.");
auto intermediate_dtype = convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type());
// intermediate uses CompType (set by the abstract via ROUTER_COMP_DTYPE).
NVTE_CHECK(intermediate_dtype == DType::kFloat32,
"intermediate_output must be float32 (CompType); got dtype ",
static_cast<int>(intermediate_dtype),
". Check ROUTER_COMP_DTYPE in cpp_extensions/router.py.");

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of static_cast, how about using to_string?

NVTE_CHECK(intermediate_dtype == DType::kFloat32,
"intermediate_output must be float32 (CompType); got dtype ",
static_cast<int>(intermediate_dtype),
". Check ROUTER_COMP_DTYPE in cpp_extensions/router.py.");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROUTER_COMP_DTYPE constant does not exist in router.py

Both NVTE_CHECK error messages (here in the forward and on line 110 in the backward) tell the developer to "Check ROUTER_COMP_DTYPE in cpp_extensions/router.py", but no such constant is defined anywhere in the codebase. The dtype is hardcoded directly as jnp.float32 inside FusedTopkWithScoreFunctionFwdPrimitive.abstract. A developer who hits this assertion and searches for ROUTER_COMP_DTYPE will find nothing, making the error harder to diagnose.

Consider updating both messages to point to the actual location:

Suggested change
". Check ROUTER_COMP_DTYPE in cpp_extensions/router.py.");
". Check FusedTopkWithScoreFunctionFwdPrimitive.abstract in cpp_extensions/router.py.");

And correspondingly on line 110:

". Check FusedTopkWithScoreFunctionBwdPrimitive.abstract in cpp_extensions/router.py."

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Copy Markdown
Collaborator Author

/te-ci jax

1 similar comment
@tdophung
Copy link
Copy Markdown
Collaborator Author

/te-ci jax

@tdophung tdophung merged commit 4c5b1a2 into NVIDIA:main Mar 11, 2026
9 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants