Bias Prob Scaling for GroupedLinear and Fused MOE Layers#2864
Bias Prob Scaling for GroupedLinear and Fused MOE Layers#2864vthumbe1503 merged 8 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds per-token bias probability scaling ( Confidence Score: 5/5Safe to merge; all remaining findings are minor P2 style suggestions with no correctness impact on the actual usage paths. The bias-prob-scaling math is correct end-to-end (forward biasprobs, backward dbias via sum(dyprobs) and dprobs via sum(dy*bias)). The Triton kernel is logically sound: masked tiles handle empty/short groups, atomic-adds are used correctly for multi-CTA accumulation, and both outputs (dbias, dscales) are produced in float32. Prior P1 concerns (shared tensor assumption, dscales dtype assertion) have been acknowledged/addressed. The only open items are two P2 style suggestions (contiguity enforcement, explicit float32 cast for scales_vals). transformer_engine/common/triton/grouped_dbias_dscales.py — two minor style suggestions around memory-layout safety and dtype consistency. Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant GroupedLinear_FC2
participant FusedFwd as ForwardGroupedMLP (MXFP8)
participant TritonKernel as _grouped_dbias_dscales_kernel
participant FusedBwd as BackwardGroupedMLP (MXFP8)
User->>GroupedLinear_FC2: forward(x, split_sizes, probs)
alt scale_bias=True (non-fused path)
GroupedLinear_FC2->>GroupedLinear_FC2: out = xW^T
GroupedLinear_FC2->>GroupedLinear_FC2: out += bias * probs (per-token)
GroupedLinear_FC2->>GroupedLinear_FC2: save(split_sizes, probs, xs, ws)
end
alt MXFP8 fused path
User->>FusedFwd: forward(x, split_sizes, probs_swiglu, split_sizes, probs_fc2)
FusedFwd->>FusedFwd: FC1 GEMM + SwiGLU (scaled by probs_swiglu)
FusedFwd->>FusedFwd: FC2 GEMM with prob_tensor=probs_fc2
end
User->>GroupedLinear_FC2: backward(dy)
alt scale_bias=True (non-fused path)
GroupedLinear_FC2->>TritonKernel: _compute_grouped_dbias_dscales(dy, probs, bias, offsets)
TritonKernel-->>GroupedLinear_FC2: dbias (num_groups, hidden), dprobs (total_tokens)
GroupedLinear_FC2-->>User: grad_input, grad_weights, grad_biases, grad_probs
end
alt MXFP8 fused path
User->>FusedBwd: backward(dy)
FusedBwd->>FusedBwd: FC2 DGLU kernel -> dprob_tensor (grad via SwiGLU)
FusedBwd->>TritonKernel: _compute_grouped_dbias_dscales(fc2_dy, probs, bias, offsets, dscales=dprob_tensor)
TritonKernel-->>FusedBwd: fc2_dbias, grad_scales (accumulated)
FusedBwd-->>User: grad_input, grad_weights, fc2_dbias, grad_scales
end
Reviews (5): Last reviewed commit: "Update docstring" | Re-trigger Greptile |
| fc2_grad_extra = (None, None) if fc2_op._scale_bias else (None,) | ||
| return ( | ||
| grad_input, | ||
| [fc1_grad_params, (), fc2_grad_params], | ||
| [(None,), (grad_scales,), (None,)], | ||
| [(None,), (grad_scales,), fc2_grad_extra], |
There was a problem hiding this comment.
FC2 scales gradient silently routed through SwiGLU path — undocumented assumption
When fc2_op._scale_bias is True, fc2_grad_extra returns (None, None) — the gradient for FC2's second extra input (its routing-probability tensor) is not propagated here. Instead, the full grad_scales (SwiGLU prob gradient + FC2 bias contribution accumulated by _compute_grouped_dbias_dscales) is returned via (grad_scales,) in the SwiGLU slot.
This is only correct when the SwiGLU probability tensor and the FC2 probability tensor are the same Python object. The test enforces this by passing probs_test to both, but there is no runtime assertion. If a caller passes distinct tensors, the FC2-prob gradient is silently dropped. Consider adding a guard or documenting this constraint in the class docstring.
There was a problem hiding this comment.
This is true. If we were to check for this rigorously, we could check that the data pointers are the same between the SwiGLU and FC2. We would also want to check that FC1 and FC2 are getting the same splits. I don't think this is a blocker though.
| if scale_bias: | ||
| fc2_biases = fc2_op._get_bias_tensors(dtype) | ||
| bias_packed = torch.stack(fc2_biases) | ||
| scales_f32 = scales.detach().to(dtype=torch.float32) | ||
| fc2_dbias_packed_result, grad_scales = _compute_grouped_dbias_dscales( | ||
| fc2_dy, | ||
| scales_f32, | ||
| bias_packed, | ||
| split_sizes, | ||
| offsets=fc1_ctx.base_split_offsets, | ||
| dscales=grad_scales, | ||
| ) |
There was a problem hiding this comment.
scales here is the SwiGLU probability, not the FC2-specific probability
scales is restored from swiglu_ctx.saved_tensors (line 298) — it is the probability saved during the FC1/SwiGLU forward, not the FC2-specific routing probability. For the dbias/dscales computation of FC2 the FC2 probability should be used (d_bias2_g = Σ_i dy_i · prob2_i).
This works correctly today because the test always passes the same probs_test tensor to both ops, but it is an undocumented constraint. A comment explaining the shared-tensor assumption would help future maintainers.
| if dbias is None: | ||
| dbias = torch.zeros(num_groups, hidden, dtype=torch.float32, device=dy.device) | ||
| if dscales is None: | ||
| dscales = torch.zeros(total_tokens, dtype=torch.float32, device=dy.device) |
There was a problem hiding this comment.
No dtype validation for pre-allocated
dscales
When dscales is provided by the caller it is passed directly to tl.atomic_add in the Triton kernel, which requires a float32 pointer. If a caller passes a non-float32 tensor the atomic add silently corrupts the output. Consider adding an assertion:
if dscales is None:
dscales = torch.zeros(total_tokens, dtype=torch.float32, device=dy.device)
else:
assert dscales.dtype == torch.float32, (
f"_compute_grouped_dbias_dscales: dscales must be float32, got {dscales.dtype}"
)Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
|
Tip: Greploop — Automatically fix all review issues by running Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal. |
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: