[JAX] Support calling MOE router kernels from JAX side#2711
[JAX] Support calling MOE router kernels from JAX side#2711tdophung wants to merge 16 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR adds JAX wrappers for the three MoE router CUDA kernels, making them accessible from the JAX side for integration with frameworks like Maxtext. It introduces XLA FFI handlers in C++, JAX primitive classes with forward/backward/batcher/partition/shardy-rule support, a high-level public API with custom VJP, and comprehensive single-GPU and distributed tests via Shardy. Key changes and critical findings:
Confidence Score: 2/5
Last reviewed commit: 6bfa933 |
|
I created this PR without being aware of https://github.com/NVIDIA/TransformerEngine/pull/2385/changes But after reviewing it, I see that both PR are doing very similar things, I will startt with addressing @phu0ngng comments on that PR |
| ) | ||
|
|
||
| @staticmethod | ||
| def infer_sharding_from_operands( |
There was a problem hiding this comment.
I cannot remove infer_sharding_from_operands and partition() function just yet, because the BasePrimitive still requires it. I'll leave them here and we can remove them when we remove it from the BasePrimitive
There was a problem hiding this comment.
We don't want to remove the partition(), just infer_sharding_from_operands().
Depends on whether #2702 or this PR merges first; we can rebase the other PR.
transformer_engine/jax/router.py
Outdated
| probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff, | ||
| ) | ||
| residuals = (const_buf, tokens_per_expert, num_rows, num_cols) | ||
| return aux_loss.squeeze(), residuals |
There was a problem hiding this comment.
I wonder if we can make the primitive to return a scalar abstract value output instead of 1D vector, then we don't need to squeeze/reshape after.
There was a problem hiding this comment.
Do you mean returning the value as a scalar from the C++, or returning the value as a 1D vector from the C++, then in the JAX primitive impl, we squeeze after bind, then return scalar from here?
There was a problem hiding this comment.
No, I mean in the abstract of the primitive here https://github.com/tdophung/TransformerEngine/blob/3fdeeef253e5496ceb6ca6de6053d9f619b0c2ba/transformer_engine/jax/cpp_extensions/router.py#L425, could we simply allocate the right output shape so that we don't need to squeeze it later?
aux_loss_aval = probs_aval.update(shape=(), dtype=i_dtype)
Additional Comments (1)
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
Additional Comments (2)
|
Additional Comments (1)
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
transformer_engine/jax/router.py
Outdated
| total_num_tokens : int | ||
| Total token count for normalization. | ||
| num_experts : int | ||
| Number of experts. |
There was a problem hiding this comment.
Can we also infer these two arguments from the input arrays?
transformer_engine/jax/router.py
Outdated
| probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff, | ||
| ) | ||
| residuals = (const_buf, tokens_per_expert, num_rows, num_cols) | ||
| return aux_loss.squeeze(), residuals |
There was a problem hiding this comment.
No, I mean in the abstract of the primitive here https://github.com/tdophung/TransformerEngine/blob/3fdeeef253e5496ceb6ca6de6053d9f619b0c2ba/transformer_engine/jax/cpp_extensions/router.py#L425, could we simply allocate the right output shape so that we don't need to squeeze it later?
aux_loss_aval = probs_aval.update(shape=(), dtype=i_dtype)
transformer_engine/jax/router.py
Outdated
| total_num_tokens, # pylint: disable=unused-argument | ||
| num_experts, # pylint: disable=unused-argument | ||
| topk, # pylint: disable=unused-argument | ||
| coeff, # pylint: disable=unused-argument |
There was a problem hiding this comment.
Minor suggestion: use del args instead of disabling pylint.
transformer_engine/jax/router.py
Outdated
| expert_bias, | ||
| compute_aux_scores, | ||
| ) | ||
| residuals = (routing_map, intermediate_output) |
There was a problem hiding this comment.
Should we rename intermediate_output to something more meaningful?
There was a problem hiding this comment.
I'll rename it to saved_scores
| Result_Type grad_probs_buf, // [num_rows, num_cols] | ||
| int64_t num_rows, int64_t num_cols) { |
There was a problem hiding this comment.
Should we infer num_rows, num_cols from grad_probs_buf?
Same for other places.
In general, I would recommend passing the function args only when needed to keep the API clean.
| double scaling_factor, int64_t score_function, int64_t compute_aux_scores) { | ||
| auto dtype = convert_ffi_datatype_to_te_dtype(logits_buf.element_type()); | ||
| auto dims = logits_buf.dimensions(); | ||
| auto num_tokens = compute_num_tokens(dims); |
There was a problem hiding this comment.
Also, we have a product function for this https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/csrc/extensions/ffi.h#L93-L98
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…e op Signed-off-by: tdophung <tdophung@nvidia.com>
…n to make basePrimitive class happy Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
…topk_and_score Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
Additional Comments (2)
Remove the first
"const_buf_one, num_experts, grad_one -> num_tokens num_experts"The inputs are Consider either (1) renaming the output dimension to match an actual input dimension, or (2) explicitly marking it as a non-sharded batch dimension using Shardy's notation for replicated axes. |
Signed-off-by: tdophung <tdophung@nvidia.com>
Additional Comments (3)
The router handlers were inserted immediately before the existing The new
The scalar output needs an empty
The public API defaults # test_fused_router.py
fused_topk_with_score_function(
logits_,
topk,
use_pre_softmax,
num_groups if num_groups else -1, # None → -1
group_topk if group_topk else -1, # None → -1
...
)If the underlying C++ kernel treats any value |
Signed-off-by: tdophung <tdophung@nvidia.com>
Additional Comments (1)
The rule names the second input's dimension The distributed tests ( Use a distinct symbolic name for the bias dimension so it is never unified with the expert count from This keeps |
Signed-off-by: tdophung <tdophung@nvidia.com>
Additional Comments (4)
Because Fix: Either (1) pass
Fix: Use a distinct label for the bias dimension (e.g., At minimum, add a test that runs this rule with
Fix: Save and restore the config value: This pattern should be applied at lines 88, 221, and 357.
Fix: Use a distinct sentinel label like Alternatively, if output replication is intentional, consider removing the dimension label entirely or adding a comment explaining the invariant. |
|
/te-ci jax |
Signed-off-by: tdophung <tdophung@nvidia.com>
Additional Comments (2)
The batcher passes
This produces incorrect results. While the partition rule enforces full replication for the intended Shardy use case, the presence of a batcher suggests vmap support was intended. Either the batcher should properly reshape the input (or iterate over the batch), or it should raise
The call to Please add a comment explaining that this invariant is enforced by the |
| residuals = (const_buf, tokens_per_expert, probs.shape[0]) | ||
| return aux_loss, residuals | ||
|
|
||
|
|
||
| def _fused_moe_aux_loss_bwd(topk, coeff, residuals, g): | ||
| del topk, coeff | ||
| const_buf, tokens_per_expert, num_tokens = residuals | ||
| grad_aux_loss = g.reshape(1) | ||
|
|
||
| grad_probs = fused_moe_aux_loss_bwd( | ||
| const_buf, | ||
| tokens_per_expert, | ||
| grad_aux_loss, | ||
| num_tokens, | ||
| ) |
There was a problem hiding this comment.
@phu0ngng num_tokens is the probs.shape[0] that was passed back via the residuals already.
Do you mean something else, like passing probs in residuals and then pass probs.shape[0] in place of num_tokens in the call to fused_moe_aux_loss_bwd?
There was a problem hiding this comment.
Just double-check the code again. I think the above comment is just outdated.
Could you rename the remaining num_rows and num_cols?
Description
Current router kernels are present in common and callable from Pytorch side but not JAX. This PR support JAX router for either standalone use or later intergation to Maxtext moe layer.
Fixes # 2710
Type of change
Changes
Please list the changes introduced in this PR:
a.
fused_topk_with_score_function_kernel: main routing kernel that outputs a sparsed probs matrix for chosen experts + a routing map to feed in permutation. This supports 2 scoring functions: softmax and sigmoid. Also support group_topk algorithmb.
fused_score_for_moe_aux_loss: step 1 of the side path to calculate auxiliary loss for load balancing. This step calculate the binary routing map, and the dense probs matrix for every expertsc.
fused_moe_aux_loss: step 2 of calculating auxiliary loss. This calculates the lossfused_score_for_moe_aux_loss)tokens_per_expert[i], derived fromrouting_map.sum(dim=0))total_num_tokens,num_expertsAdd custom partitioning for each of the above kernels when possible (sharded on the num token dimensions on the first 2 kernels, and just pure repetition on last kernel)
Add tests for both single GPU and distributed case to verify sharding correctness
Checklist: