Skip to content

Bias Prob Scaling for GroupedLinear and Fused MOE Layers#2864

Merged
vthumbe1503 merged 8 commits intoNVIDIA:mainfrom
vthumbe1503:grouped_bias_add
Apr 10, 2026
Merged

Bias Prob Scaling for GroupedLinear and Fused MOE Layers#2864
vthumbe1503 merged 8 commits intoNVIDIA:mainfrom
vthumbe1503:grouped_bias_add

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title bias*prob, dbias+dprob triton kernel Bias Prob Scaling for GroupedLinear and FusedGrouped Layers Apr 9, 2026
vthumbe1503 and others added 2 commits April 9, 2026 15:37
@vthumbe1503 vthumbe1503 requested a review from timmoon10 April 9, 2026 22:49
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 9, 2026

Greptile Summary

This PR adds per-token bias probability scaling (scale_bias=True) to GroupedLinear and the fused MoE layers: when enabled, the forward pass computes y = xW^T + bias * probs and a new fused Triton kernel (_grouped_dbias_dscales_kernel) computes dbias and dprobs together in a single backward pass. It also enables MXFP8 + bias support in the fused grouped MLP by removing the previous skip condition.

Confidence Score: 5/5

Safe to merge; all remaining findings are minor P2 style suggestions with no correctness impact on the actual usage paths.

The bias-prob-scaling math is correct end-to-end (forward biasprobs, backward dbias via sum(dyprobs) and dprobs via sum(dy*bias)). The Triton kernel is logically sound: masked tiles handle empty/short groups, atomic-adds are used correctly for multi-CTA accumulation, and both outputs (dbias, dscales) are produced in float32. Prior P1 concerns (shared tensor assumption, dscales dtype assertion) have been acknowledged/addressed. The only open items are two P2 style suggestions (contiguity enforcement, explicit float32 cast for scales_vals).

transformer_engine/common/triton/grouped_dbias_dscales.py — two minor style suggestions around memory-layout safety and dtype consistency.

Important Files Changed

Filename Overview
transformer_engine/common/triton/grouped_dbias_dscales.py New Triton kernel computing fused dbias+dscales; kernel logic is correct but assumes C-contiguous input layout without enforcing it.
transformer_engine/pytorch/triton/grouped_dbias_dscales.py Python wrapper for the Triton kernel; adds dtype assertions for dbias/dscales, clean API, correct grid computation.
transformer_engine/pytorch/ops/basic/grouped_linear.py Adds scale_bias parameter; forward applies bias*probs per-token, backward calls fused kernel for dbias/dprobs correctly.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py MXFP8 backward path updated to accumulate dprob from GEMM kernel into fused dbias/dscales computation; fc2_grad_extra correctly sized.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py MXFP8 forward path now propagates fc2_scales (routing probs) into the FC2 GEMM prob_tensor instead of a constant-ones tensor.
tests/pytorch/test_fusible_ops.py Reference computation updated to match bias*probs semantics; mxfp8+bias skip removed; fc2 extra inputs correctly conditioned on bias.

Sequence Diagram

sequenceDiagram
    participant User
    participant GroupedLinear_FC2
    participant FusedFwd as ForwardGroupedMLP (MXFP8)
    participant TritonKernel as _grouped_dbias_dscales_kernel
    participant FusedBwd as BackwardGroupedMLP (MXFP8)

    User->>GroupedLinear_FC2: forward(x, split_sizes, probs)
    alt scale_bias=True (non-fused path)
        GroupedLinear_FC2->>GroupedLinear_FC2: out = xW^T
        GroupedLinear_FC2->>GroupedLinear_FC2: out += bias * probs (per-token)
        GroupedLinear_FC2->>GroupedLinear_FC2: save(split_sizes, probs, xs, ws)
    end
    alt MXFP8 fused path
        User->>FusedFwd: forward(x, split_sizes, probs_swiglu, split_sizes, probs_fc2)
        FusedFwd->>FusedFwd: FC1 GEMM + SwiGLU (scaled by probs_swiglu)
        FusedFwd->>FusedFwd: FC2 GEMM with prob_tensor=probs_fc2
    end

    User->>GroupedLinear_FC2: backward(dy)
    alt scale_bias=True (non-fused path)
        GroupedLinear_FC2->>TritonKernel: _compute_grouped_dbias_dscales(dy, probs, bias, offsets)
        TritonKernel-->>GroupedLinear_FC2: dbias (num_groups, hidden), dprobs (total_tokens)
        GroupedLinear_FC2-->>User: grad_input, grad_weights, grad_biases, grad_probs
    end
    alt MXFP8 fused path
        User->>FusedBwd: backward(dy)
        FusedBwd->>FusedBwd: FC2 DGLU kernel -> dprob_tensor (grad via SwiGLU)
        FusedBwd->>TritonKernel: _compute_grouped_dbias_dscales(fc2_dy, probs, bias, offsets, dscales=dprob_tensor)
        TritonKernel-->>FusedBwd: fc2_dbias, grad_scales (accumulated)
        FusedBwd-->>User: grad_input, grad_weights, fc2_dbias, grad_scales
    end
Loading

Reviews (5): Last reviewed commit: "Update docstring" | Re-trigger Greptile

Comment on lines +667 to +671
fc2_grad_extra = (None, None) if fc2_op._scale_bias else (None,)
return (
grad_input,
[fc1_grad_params, (), fc2_grad_params],
[(None,), (grad_scales,), (None,)],
[(None,), (grad_scales,), fc2_grad_extra],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 FC2 scales gradient silently routed through SwiGLU path — undocumented assumption

When fc2_op._scale_bias is True, fc2_grad_extra returns (None, None) — the gradient for FC2's second extra input (its routing-probability tensor) is not propagated here. Instead, the full grad_scales (SwiGLU prob gradient + FC2 bias contribution accumulated by _compute_grouped_dbias_dscales) is returned via (grad_scales,) in the SwiGLU slot.

This is only correct when the SwiGLU probability tensor and the FC2 probability tensor are the same Python object. The test enforces this by passing probs_test to both, but there is no runtime assertion. If a caller passes distinct tensors, the FC2-prob gradient is silently dropped. Consider adding a guard or documenting this constraint in the class docstring.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is true. If we were to check for this rigorously, we could check that the data pointers are the same between the SwiGLU and FC2. We would also want to check that FC1 and FC2 are getting the same splits. I don't think this is a blocker though.

Comment on lines +488 to +499
if scale_bias:
fc2_biases = fc2_op._get_bias_tensors(dtype)
bias_packed = torch.stack(fc2_biases)
scales_f32 = scales.detach().to(dtype=torch.float32)
fc2_dbias_packed_result, grad_scales = _compute_grouped_dbias_dscales(
fc2_dy,
scales_f32,
bias_packed,
split_sizes,
offsets=fc1_ctx.base_split_offsets,
dscales=grad_scales,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 scales here is the SwiGLU probability, not the FC2-specific probability

scales is restored from swiglu_ctx.saved_tensors (line 298) — it is the probability saved during the FC1/SwiGLU forward, not the FC2-specific routing probability. For the dbias/dscales computation of FC2 the FC2 probability should be used (d_bias2_g = Σ_i dy_i · prob2_i).

This works correctly today because the test always passes the same probs_test tensor to both ops, but it is an undocumented constraint. A comment explaining the shared-tensor assumption would help future maintainers.

Comment on lines +60 to +63
if dbias is None:
dbias = torch.zeros(num_groups, hidden, dtype=torch.float32, device=dy.device)
if dscales is None:
dscales = torch.zeros(total_tokens, dtype=torch.float32, device=dy.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 No dtype validation for pre-allocated dscales

When dscales is provided by the caller it is passed directly to tl.atomic_add in the Triton kernel, which requires a float32 pointer. If a caller passes a non-float32 tensor the atomic add silently corrupts the output. Consider adding an assertion:

if dscales is None:
    dscales = torch.zeros(total_tokens, dtype=torch.float32, device=dy.device)
else:
    assert dscales.dtype == torch.float32, (
        f"_compute_grouped_dbias_dscales: dscales must be float32, got {dscales.dtype}"
    )

@vthumbe1503 vthumbe1503 changed the title Bias Prob Scaling for GroupedLinear and FusedGrouped Layers Bias Prob Scaling for GroupedLinear and Fused MOE Layers Apr 9, 2026
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 9, 2026

Tip:

Greploop — Automatically fix all review issues by running /greploops in Claude Code. It iterates: fix, push, re-review, repeat until 5/5 confidence.

Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal.

timmoon10
timmoon10 previously approved these changes Apr 9, 2026
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
timmoon10
timmoon10 previously approved these changes Apr 10, 2026
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@vthumbe1503 vthumbe1503 merged commit 580e7aa into NVIDIA:main Apr 10, 2026
10 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants