[JAX] Change dtype of intermediate result aval of fused_topk_and_score_function_fwd to fp32#2752
Conversation
…ass as residuals to CompType, which is currently fp32. This prevents incorrect reading of this buffer when logits dtype used are not fp32 Signed-off-by: tdophung <tdophung@nvidia.com>
Greptile SummaryThis PR fixes a dtype mismatch bug in the JAX MoE router: the Key changes:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant JAX as JAX Tracing
participant PyFwd as FwdPrimitive.abstract
participant CppFwd as ForwardFFI (C++)
participant PyBwd as BwdPrimitive.abstract
participant CppBwd as BackwardFFI (C++)
JAX->>PyFwd: logits_aval (dtype=bf16/fp16)
PyFwd-->>JAX: probs_aval(bf16), routing_map_aval(bool), intermediate_aval(fp32) ✅ fixed
JAX->>CppFwd: logits(bf16), intermediate buffer(fp32)
CppFwd->>CppFwd: NVTE_CHECK intermediate==kFloat32 ✅ new guard
CppFwd-->>JAX: probs(bf16), routing_map(bool), intermediate(fp32)
JAX->>PyBwd: intermediate_aval(fp32), grad_probs_aval(bf16)
PyBwd-->>JAX: grad_logits_aval(bf16)
JAX->>CppBwd: intermediate(fp32), grad_probs(bf16)
CppBwd->>CppBwd: NVTE_CHECK intermediate==kFloat32 ✅ new guard
CppBwd->>CppBwd: grad_dtype = grad_probs dtype ✅ fixed (was intermediate dtype)
CppBwd-->>JAX: grad_logits(bf16)
|
| // intermediate is always float32 (CompType) regardless of logits dtype. | ||
| auto intermediate_dtype = convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type()); | ||
| auto grad_dtype = convert_ffi_datatype_to_te_dtype(grad_probs_buf.element_type()); | ||
| auto dims = intermediate_buf.dimensions(); | ||
| auto num_tokens = static_cast<int>(product(dims, 0, dims.size() - 1)); | ||
| auto num_experts = static_cast<int>(dims[dims.size() - 1]); | ||
|
|
||
| auto flat_shape = | ||
| std::vector<size_t>{static_cast<size_t>(num_tokens), static_cast<size_t>(num_experts)}; | ||
|
|
||
| auto intermediate_tensor = TensorWrapper(intermediate_buf.untyped_data(), flat_shape, dtype); | ||
| auto grad_probs_tensor = TensorWrapper(grad_probs_buf.untyped_data(), flat_shape, dtype); | ||
| auto grad_logits_tensor = TensorWrapper(grad_logits_buf->untyped_data(), flat_shape, dtype); | ||
| auto intermediate_tensor = | ||
| TensorWrapper(intermediate_buf.untyped_data(), flat_shape, intermediate_dtype); |
There was a problem hiding this comment.
Inconsistency in defensive dtype handling vs. forward
The comment on line 101 states // intermediate is always float32 (CompType) regardless of logits dtype., but unlike the forward pass (line 45) which explicitly hardcodes DType::kFloat32 for the intermediate_tensor, the backward reads the dtype dynamically from intermediate_buf.element_type(). After this fix the two are consistent in practice (because the Python aval now declares float32), but the backward's approach is less self-documenting and doesn't self-enforce the invariant.
If the Python aval were ever accidentally reverted, the forward would still correctly treat the buffer as float32, while the backward would silently pick up the wrong dtype from whatever JAX passes.
Consider making the backward match the forward's explicit approach:
// intermediate is always float32 (CompType) regardless of logits dtype.
auto intermediate_dtype = DType::kFloat32;
auto grad_dtype = convert_ffi_datatype_to_te_dtype(grad_probs_buf.element_type());|
Since @denera mentioned that CompType will soon be configurable, instead of hard coded to fp32, I wonder if I should just export the CompType to Wdyt @jberchtold-nvidia @phu0ngng ? |
I have a slight preference to wait on this until the change has been made in TE common and have this PR just be the bugfix to make the intermediate buf always fp32. If we make this change to have a generic CompType now before TE common is ready, we could get ahead on the boilerplate logic here. However, we'll still need asserts in Python to prevent users from passing a comp dtype other than fp32 since TE common as of now doesn't support it. Then once TE common supports compute dtypes other than fp32, we'll need a follow-up TE/JAX PR to update this assert to the new supported compute dtypes and add unit tests for them. Given we'll need a follow-up PR here anyways once TE common supports other compute dtypes, I feel like it's easier to have this PR just be the bugfix and wait for the generalizing until TE common changes have merged. Wdyt? |
| auto routing_map_tensor = TensorWrapper(routing_map, flat_shape, DType::kByte); | ||
| auto intermediate_tensor = TensorWrapper(intermediate, flat_shape, dtype); | ||
| // intermediate is always float32 (CompType) regardless of logits dtype. | ||
| auto intermediate_tensor = TensorWrapper(intermediate, flat_shape, DType::kFloat32); |
There was a problem hiding this comment.
Can we add an NVTE_CHECK to assert the dtype of intermediate_buf is float32 just in case something every breaks on the Python abstract side?
That makes sense. I did not think about the possibility that users could pass in a different dtype and mess the backend kernels up as it is not ready for anything but fp32. It will require another PR anyways. I will keep this as fp32 and add the |
Signed-off-by: tdophung <tdophung@nvidia.com>
|
/te-ci jax |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM, thanks!
| NVTE_CHECK(convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type()) == DType::kFloat32, | ||
| "intermediate_output must be float32 (CompType); got dtype ", | ||
| static_cast<int>(convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type())), | ||
| ". Check ROUTER_COMP_DTYPE in cpp_extensions/router.py."); |
There was a problem hiding this comment.
convert_ffi_datatype_to_te_dtype called twice; use a local variable
The backward's NVTE_CHECK calls convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type()) twice — once for the comparison and once to format the error integer. This is inconsistent with the forward pass, which stores the result in intermediate_dtype before the check and reuses it. Calling the conversion twice is unnecessary and makes the code harder to read.
| NVTE_CHECK(convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type()) == DType::kFloat32, | |
| "intermediate_output must be float32 (CompType); got dtype ", | |
| static_cast<int>(convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type())), | |
| ". Check ROUTER_COMP_DTYPE in cpp_extensions/router.py."); | |
| auto intermediate_dtype = convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type()); | |
| // intermediate uses CompType (set by the abstract via ROUTER_COMP_DTYPE). | |
| NVTE_CHECK(intermediate_dtype == DType::kFloat32, | |
| "intermediate_output must be float32 (CompType); got dtype ", | |
| static_cast<int>(intermediate_dtype), | |
| ". Check ROUTER_COMP_DTYPE in cpp_extensions/router.py."); |
There was a problem hiding this comment.
instead of static_cast, how about using to_string?
| NVTE_CHECK(intermediate_dtype == DType::kFloat32, | ||
| "intermediate_output must be float32 (CompType); got dtype ", | ||
| static_cast<int>(intermediate_dtype), | ||
| ". Check ROUTER_COMP_DTYPE in cpp_extensions/router.py."); |
There was a problem hiding this comment.
ROUTER_COMP_DTYPE constant does not exist in router.py
Both NVTE_CHECK error messages (here in the forward and on line 110 in the backward) tell the developer to "Check ROUTER_COMP_DTYPE in cpp_extensions/router.py", but no such constant is defined anywhere in the codebase. The dtype is hardcoded directly as jnp.float32 inside FusedTopkWithScoreFunctionFwdPrimitive.abstract. A developer who hits this assertion and searches for ROUTER_COMP_DTYPE will find nothing, making the error harder to diagnose.
Consider updating both messages to point to the actual location:
| ". Check ROUTER_COMP_DTYPE in cpp_extensions/router.py."); | |
| ". Check FusedTopkWithScoreFunctionFwdPrimitive.abstract in cpp_extensions/router.py."); |
And correspondingly on line 110:
". Check FusedTopkWithScoreFunctionBwdPrimitive.abstract in cpp_extensions/router.py."
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci jax |
1 similar comment
|
/te-ci jax |
Fixed aval for intermediate results (softmaxed/sigmoided logits from fwd) to pass as residuals to
fused_topk_and_score_function_bwdto CompType, which is currently hardcoded to fp32 intransformerengine/common. This prevents incorrect reading of this buffer when logits dtype used are not the same type as fp32. This was observed when integrating router to maxtext, and while grads flowing intofused_topk_and_score_function_bwdis clean of nans, the output from this kernel were nans.Fixes # (issue)
Type of change
Changes
Change dtype of intermediate result to CompType
Checklist: