Skip to content

[JAX] Support calling MOE router kernels from JAX side#2711

Open
tdophung wants to merge 16 commits intoNVIDIA:mainfrom
tdophung:router_jax
Open

[JAX] Support calling MOE router kernels from JAX side#2711
tdophung wants to merge 16 commits intoNVIDIA:mainfrom
tdophung:router_jax

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Feb 26, 2026

Description

Current router kernels are present in common and callable from Pytorch side but not JAX. This PR support JAX router for either standalone use or later intergation to Maxtext moe layer.

Fixes # 2710

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  1. Add custom calls to router kernels: including
    a. fused_topk_with_score_function_kernel: main routing kernel that outputs a sparsed probs matrix for chosen experts + a routing map to feed in permutation. This supports 2 scoring functions: softmax and sigmoid. Also support group_topk algorithm
    b. fused_score_for_moe_aux_loss: step 1 of the side path to calculate auxiliary loss for load balancing. This step calculate the binary routing map, and the dense probs matrix for every experts
    c. fused_moe_aux_loss: step 2 of calculating auxiliary loss. This calculates the loss $$L_{\text{aux}} = C \cdot \sum_{i=1}^{E} \left(\sum_{t=1}^{N} p_{t,i}\right) \cdot f_i$$ where:
  • $p_{t,i}$ = probability that token $t$ assigns to expert $i$ (from fused_score_for_moe_aux_loss)
  • $f_i$ = number of tokens routed to expert $i$ (tokens_per_expert[i], derived from routing_map.sum(dim=0))
  • $C = \frac{E \cdot \text{coeff}}{K \cdot T^2}$ where $T$ = total_num_tokens, $K$ = topk, $E$ = num_experts
  1. Add custom partitioning for each of the above kernels when possible (sharded on the num token dimensions on the first 2 kernels, and just pure repetition on last kernel)

  2. Add tests for both single GPU and distributed case to verify sharding correctness

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 26, 2026

Greptile Summary

This PR adds JAX wrappers for the three MoE router CUDA kernels, making them accessible from the JAX side for integration with frameworks like Maxtext. It introduces XLA FFI handlers in C++, JAX primitive classes with forward/backward/batcher/partition/shardy-rule support, a high-level public API with custom VJP, and comprehensive single-GPU and distributed tests via Shardy.

Key changes and critical findings:

  • New C++ FFI layer (router.cpp): Four handlers wrapping the underlying router kernels. The FusedMoEAuxLossForwardFFI handler passes num_tokens for both total_num_tokens and num_rows without documenting the replication invariant that makes this safe. A clarifying comment would prevent future maintenance issues.

  • Primitive layer (cpp_extensions/router.py): The FusedMoEAuxLossFwdPrimitive.batcher has a correctness issue—it passes the batched probs tensor directly without reshaping, which would cause the kernel to misinterpret batch/token/expert dimensions if vmap is used. Either the batcher must properly handle batching, or it should explicitly reject vmap with a clear error message.

  • Public API (router.py): Clean custom VJP design with proper residual storage, score-function validation, and gradient routing.

  • Tests: Good single-GPU and Shardy-distributed coverage validating that sharded execution on the token dimension is correct.

Confidence Score: 2/5

  • The core functionality for Shardy-distributed routing is well-tested and sound, but the vmap batcher for aux loss contains a correctness bug that would silently produce wrong results if vmap is used.
  • The PR implements the intended Shardy use case correctly (token-dimension sharding with full replication of global reduction ops), with good test coverage. However, the FusedMoEAuxLossFwdPrimitive.batcher is broken: it does not reshape the batched probs tensor before passing it to the kernel, which would cause silent correctness failures under vmap. Although vmap may not be the intended use pattern, the presence of a batcher suggests vmap support was intended. This is a critical issue that must be fixed before merge. Additionally, the total_num_tokens == num_rows invariant in the C++ handler should be documented to prevent future bugs.
  • transformer_engine/jax/cpp_extensions/router.py (vmap batcher correctness) and transformer_engine/jax/csrc/extensions/router.cpp (replication invariant documentation)

Last reviewed commit: 6bfa933

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@tdophung tdophung requested review from jberchtold-nvidia and phu0ngng and removed request for jberchtold-nvidia February 26, 2026 19:35
@tdophung
Copy link
Collaborator Author

I created this PR without being aware of https://github.com/NVIDIA/TransformerEngine/pull/2385/changes

But after reviewing it, I see that both PR are doing very similar things, I will startt with addressing @phu0ngng comments on that PR

)

@staticmethod
def infer_sharding_from_operands(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot remove infer_sharding_from_operands and partition() function just yet, because the BasePrimitive still requires it. I'll leave them here and we can remove them when we remove it from the BasePrimitive

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to remove the partition(), just infer_sharding_from_operands().

Depends on whether #2702 or this PR merges first; we can rebase the other PR.

probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff,
)
residuals = (const_buf, tokens_per_expert, num_rows, num_cols)
return aux_loss.squeeze(), residuals
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can make the primitive to return a scalar abstract value output instead of 1D vector, then we don't need to squeeze/reshape after.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean returning the value as a scalar from the C++, or returning the value as a 1D vector from the C++, then in the JAX primitive impl, we squeeze after bind, then return scalar from here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I mean in the abstract of the primitive here https://github.com/tdophung/TransformerEngine/blob/3fdeeef253e5496ceb6ca6de6053d9f619b0c2ba/transformer_engine/jax/cpp_extensions/router.py#L425, could we simply allocate the right output shape so that we don't need to squeeze it later?

aux_loss_aval = probs_aval.update(shape=(), dtype=i_dtype)

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 28, 2026

Additional Comments (1)

transformer_engine/jax/cpp_extensions/router.py, line 27
inconsistent section header formatting - uneven equals signs

# =============================================================================

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 28, 2026

Additional Comments (2)

transformer_engine/jax/router.py, line 188
score_function is actually used on line 203

    score_function,

transformer_engine/jax/router.py, line 362
g is actually used on line 366

    g,

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 28, 2026

Additional Comments (1)

transformer_engine/jax/cpp_extensions/router.py, line 30
Section header has inconsistent formatting (space in the middle of equals signs)

# =============================================================================

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +265 to +268
total_num_tokens : int
Total token count for normalization.
num_experts : int
Number of experts.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also infer these two arguments from the input arrays?

probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff,
)
residuals = (const_buf, tokens_per_expert, num_rows, num_cols)
return aux_loss.squeeze(), residuals
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I mean in the abstract of the primitive here https://github.com/tdophung/TransformerEngine/blob/3fdeeef253e5496ceb6ca6de6053d9f619b0c2ba/transformer_engine/jax/cpp_extensions/router.py#L425, could we simply allocate the right output shape so that we don't need to squeeze it later?

aux_loss_aval = probs_aval.update(shape=(), dtype=i_dtype)

Comment on lines +330 to +333
total_num_tokens, # pylint: disable=unused-argument
num_experts, # pylint: disable=unused-argument
topk, # pylint: disable=unused-argument
coeff, # pylint: disable=unused-argument
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor suggestion: use del args instead of disabling pylint.

expert_bias,
compute_aux_scores,
)
residuals = (routing_map, intermediate_output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we rename intermediate_output to something more meaningful?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll rename it to saved_scores

Comment on lines +218 to +219
Result_Type grad_probs_buf, // [num_rows, num_cols]
int64_t num_rows, int64_t num_cols) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we infer num_rows, num_cols from grad_probs_buf?
Same for other places.

In general, I would recommend passing the function args only when needed to keep the API clean.

double scaling_factor, int64_t score_function, int64_t compute_aux_scores) {
auto dtype = convert_ffi_datatype_to_te_dtype(logits_buf.element_type());
auto dims = logits_buf.dimensions();
auto num_tokens = compute_num_tokens(dims);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tdophung and others added 11 commits March 3, 2026 11:30
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…e op

Signed-off-by: tdophung <tdophung@nvidia.com>
…n to make basePrimitive class happy

Signed-off-by: tdophung <tdophung@nvidia.com>
…topk_and_score

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (2)

transformer_engine/jax/csrc/extensions/pybind.cpp, line 94
The "te_inspect_ffi" key is registered twice in the Registrations() dictionary. Lines 84–85 contain the first registration, but lines 93–94 contain an identical second registration. The second assignment silently overwrites the first in Python dictionaries, resulting in dead code that will confuse maintainers.

Remove the first te_inspect_ffi registration (lines 84–85) and keep only the one after the router registrations:

  // Router
  dict["te_fused_topk_with_score_function_forward_ffi"] =
      EncapsulateFFI(FusedTopkWithScoreFunctionForwardHandler);
  dict["te_fused_topk_with_score_function_backward_ffi"] =
      EncapsulateFFI(FusedTopkWithScoreFunctionBackwardHandler);
  dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler);
  dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler);
  dict["te_inspect_ffi"] =
      pybind11::dict(pybind11::arg("execute") = EncapsulateFFI(InspectHandler));

transformer_engine/jax/cpp_extensions/router.py, line 538
The backward shardy rule for FusedMoEAuxLossBwdPrimitive introduces an output dimension num_tokens that does not appear in any input dimension:

"const_buf_one, num_experts, grad_one -> num_tokens num_experts"

The inputs are const_buf_one (shape (1,)), num_experts (shape (num_experts,)), and grad_one (scalar). The output dimension num_tokens has no corresponding input to trace its sharding, which violates Shardy's semantic requirements for sharding rule propagation. While the current partition() function overrides this with PartitionSpec(None, None), the rule is semantically incorrect.

Consider either (1) renaming the output dimension to match an actual input dimension, or (2) explicitly marking it as a non-sharded batch dimension using Shardy's notation for replicated axes.

Signed-off-by: tdophung <tdophung@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (3)

transformer_engine/jax/csrc/extensions/pybind.cpp, line 1754
Duplicate te_inspect_ffi registration

The router handlers were inserted immediately before the existing te_inspect_ffi entry, but the new block accidentally repeats the te_inspect_ffi registration first. This means the key "te_inspect_ffi" is written to dict twice (the second write silently overwrites the first).

The new te_inspect_ffi line (added by this PR) on the first + block should be removed — it was the original line that got accidentally duplicated when the router block was inserted.

  // Router
  dict["te_fused_topk_with_score_function_forward_ffi"] =
      EncapsulateFFI(FusedTopkWithScoreFunctionForwardHandler);
  dict["te_fused_topk_with_score_function_backward_ffi"] =
      EncapsulateFFI(FusedTopkWithScoreFunctionBackwardHandler);
  dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler);
  dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler);
  dict["te_inspect_ffi"] =
      pybind11::dict(pybind11::arg("execute") = EncapsulateFFI(InspectHandler));

transformer_engine/jax/cpp_extensions/router.py, line 1456
Wrong PartitionSpec for scalar aux_loss output

aux_loss is declared with shape () (a scalar) in abstract, but the partition function assigns it PartitionSpec(None), which describes a 1-D tensor with a single unreplicated axis. JAX will reject this spec when it validates sharding against the actual shape ().

The scalar output needs an empty PartitionSpec(), while const_buf (shape (1,)) can keep PartitionSpec(None).

        scalar_sharding = NamedSharding(mesh, PartitionSpec())
        const_buf_sharding = NamedSharding(mesh, PartitionSpec(None))
        out_shardings = [scalar_sharding, const_buf_sharding]

transformer_engine/jax/router.py, line 2086
num_groups=1 / group_topk=1 defaults conflict with the -1 disabled sentinel used in tests

The public API defaults num_groups=1 and group_topk=1 and documents 1 as meaning "no grouping". However, test_fused_router.py calls this same function with -1 as the disabled sentinel:

# test_fused_router.py
fused_topk_with_score_function(
    logits_,
    topk,
    use_pre_softmax,
    num_groups if num_groups else -1,   # None  →  -1
    group_topk if group_topk else -1,   # None  →  -1
    ...
)

If the underlying C++ kernel treats any value <= 0 as "grouping disabled" (which the use of -1 strongly implies), then a caller using the default num_groups=1, group_topk=1 will actually trigger the grouped top-k path with 1 group / top-1 per group instead of plain top-k. The defaults should be changed to -1 to match the actual sentinel, and the docstring updated accordingly:

def fused_topk_with_score_function(
    logits: jnp.ndarray,
    topk: int,
    use_pre_softmax: bool = False,
    num_groups: int = -1,
    group_topk: int = -1,
    scaling_factor: float = 1.0,
    score_function: Union[str, ScoreFunction] = ScoreFunction.SOFTMAX,
    expert_bias: Optional[jnp.ndarray] = None,
    compute_aux_scores: bool = False,
) -> Tuple[jnp.ndarray, jnp.ndarray]:

Signed-off-by: tdophung <tdophung@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (1)

transformer_engine/jax/cpp_extensions/router.py, line 204
Shardy rule dimension-name conflict for empty expert_bias

The rule names the second input's dimension num_experts, identical to the last dimension of logits. When expert_bias is None (the common case for both softmax and the un-biased sigmoid path), fused_topk_with_score_function substitutes jnp.empty((0,), dtype=logits.dtype) at line 144 in router.py, giving it shape (0,). Shardy will then attempt to unify num_experts = 0 (from the empty bias) with num_experts = E (from the logits), producing a dimension-name conflict and a runtime error.

The distributed tests (test_distributed_router.py lines 107–111) set jax_use_shardy_partitioner=True and call fused_topk_with_score_function without expert_bias, so this code path is exercised.

Use a distinct symbolic name for the bias dimension so it is never unified with the expert count from logits:

    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return (
            "num_tokens num_experts, num_bias -> num_tokens num_experts, num_tokens num_experts,"
            " num_tokens num_experts"
        )

This keeps num_bias independent of num_experts, so sizes of 0 (empty) and E (full bias) can coexist without a conflict.

Signed-off-by: tdophung <tdophung@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (4)

transformer_engine/jax/csrc/extensions/router.cpp, line 175
nvte_fused_moe_aux_loss_forward receives the local num_tokens for both the data-shape argument and the normalization-constant T in C = E·coeff/(k·T²). At line 172–175, the call passes num_tokens (derived from probs_dims[0]) for both argument positions 3 and 5. Position 5 should receive the global total token count used in the loss normalization, not the local shard size.

Because FusedMoEAuxLossFwdPrimitive.partition (cpp_extensions/router.py:429–444) propagates the input's sharding directly rather than enforcing replication on probs, a user who passes token-sharded probs will cause the kernel to use the local shard size as T in the denominator. This produces a loss scaled by (where N is the DP degree) compared to the correct value.

Fix: Either (1) pass total_num_tokens as a separate static attribute (similar to the backward pass using num_tokens as a static arg), or (2) explicitly enforce replicated sharding on probs in the partition function to make the invariant explicit:

arg_shardings = (
    NamedSharding(mesh, PartitionSpec(None, None)),  # probs replicated
    arg_infos[1].sharding
)

transformer_engine/jax/cpp_extensions/router.py, line 204
The Shardy sharding rule maps the second input (expert_bias) to dimension label num_experts. However, when no bias is provided (the default), the caller sets expert_bias = jnp.empty((0,), dtype=logits.dtype) at transformer_engine/jax/router.py:132, 144. Shardy infers num_experts = 0 from the empty bias, while logits contributes num_experts = E from the first input label. These conflicting bindings of the same label to different concrete sizes will cause dimension-constraint propagation failures and could lead to compilation errors or silent incorrect sharding when Shardy is enabled.

Fix: Use a distinct label for the bias dimension (e.g., bias_experts) when bias is independent of the expert count, or make the constraint conditional. For example:

return (
    "num_tokens num_experts, bias_experts -> num_tokens num_experts, num_tokens num_experts,"
    " num_tokens num_experts"
)

At minimum, add a test that runs this rule with expert_bias = jnp.empty((0,)) under jax_use_shardy_partitioner=True to verify no errors occur.


tests/jax/test_distributed_router.py, line 88
jax.config.update("jax_use_shardy_partitioner", True) is called unconditionally inside _impl_test at lines 88, 221, and 357 without saving and restoring the original value. JAX configuration is process-global and persists for all subsequent tests in the same process. Any tests that run after these distributed tests will have Shardy unexpectedly enabled, which can cause crashes or incorrect behavior in code that has not been validated with Shardy.

Fix: Save and restore the config value:

original = jax.config.jax_use_shardy_partitioner
jax.config.update("jax_use_shardy_partitioner", True)
try:
    # ... test body ...
finally:
    jax.config.update("jax_use_shardy_partitioner", original)

This pattern should be applied at lines 88, 221, and 357.


transformer_engine/jax/cpp_extensions/router.py, line 544
In Shardy's dimension propagation, every output dimension label must appear in at least one input for sharding to be forward-propagated. Here, num_tokens appears only in the output (grad_probs), not in any of the three inputs (const_buf_one, num_experts, grad_one). Shardy cannot infer any sharding for num_tokens from the inputs, so the dimension is effectively always unsharded. While this is consistent with the partition function using PartitionSpec(None, None), the mismatch between the rule and inputs means Shardy may raise an error or produce unexpected behavior in future JAX versions.

Fix: Use a distinct sentinel label like i, or document why the backward output is always replicated. For example:

return "const_buf_one, num_experts, grad_one -> i num_experts"

Alternatively, if output replication is intentional, consider removing the dimension label entirely or adding a comment explaining the invariant.

@tdophung
Copy link
Collaborator Author

tdophung commented Mar 4, 2026

/te-ci jax

Signed-off-by: tdophung <tdophung@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (2)

transformer_engine/jax/cpp_extensions/router.py, line 426
Vmap batcher for FusedMoEAuxLossFwdPrimitive mishandles batch dimensions

The batcher passes probs directly without reshaping when called under vmap. When vmap is applied with probs of shape [B, T, E] (batch dim 0), the underlying FFI kernel receives the 3-D tensor and interprets it as:

  • num_tokens = probs_dims[0] → B (batch size, not token count)
  • num_experts = probs_dims[1] → T (token count, not expert count)

This produces incorrect results. While the partition rule enforces full replication for the intended Shardy use case, the presence of a batcher suggests vmap support was intended. Either the batcher should properly reshape the input (or iterate over the batch), or it should raise NotImplementedError with a clear message that vmap is not supported.


transformer_engine/jax/csrc/extensions/router.cpp, line 175
Document the total_num_tokens == num_rows invariant

The call to nvte_fused_moe_aux_loss_forward passes num_tokens (derived from probs buffer shape) for both the total_num_tokens and num_rows parameters. This is only correct when probs are fully replicated across devices (local token count == global token count).

Please add a comment explaining that this invariant is enforced by the partition() function, which mandates full replication via PartitionSpec(). This will prevent future maintenance issues where someone might bypass the partition rule and introduce a correctness bug in the loss denominator calculation.

Comment on lines +297 to +311
residuals = (const_buf, tokens_per_expert, probs.shape[0])
return aux_loss, residuals


def _fused_moe_aux_loss_bwd(topk, coeff, residuals, g):
del topk, coeff
const_buf, tokens_per_expert, num_tokens = residuals
grad_aux_loss = g.reshape(1)

grad_probs = fused_moe_aux_loss_bwd(
const_buf,
tokens_per_expert,
grad_aux_loss,
num_tokens,
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phu0ngng num_tokens is the probs.shape[0] that was passed back via the residuals already.

Do you mean something else, like passing probs in residuals and then pass probs.shape[0] in place of num_tokens in the call to fused_moe_aux_loss_bwd?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double-check the code again. I think the above comment is just outdated.
Could you rename the remaining num_rows and num_cols?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants