Skip to content

fix: correct MoE auxiliary loss gradient scaling#1412

Merged
hemildesai merged 6 commits intomainfrom
hemil/moe-aux-loss-fix
Mar 2, 2026
Merged

fix: correct MoE auxiliary loss gradient scaling#1412
hemildesai merged 6 commits intomainfrom
hemil/moe-aux-loss-fix

Conversation

@hemildesai
Copy link
Copy Markdown
Contributor

Summary

Fixes #1408

MoE auxiliary (load-balancing) loss gradients were incorrectly scaled due to two bugs:

  1. MoEAuxLossAutoScaler.apply() return value was discarded — the gate weights tensor was not reassigned from the return value, so the custom autograd backward never attached to the computation graph. Aux loss gradients silently never flowed.

  2. MoEAuxLossAutoScaler.main_loss_backward_scale was never set — without it, aux loss gradients were inadvertently divided by dp_group_size (non-PP, from FSDP gradient averaging) or by num_label_tokens (PP, from FSDP + post-hoc PP scaling in scale_grads_and_clip_grad_norm). This caused the effective aux loss coefficient to silently shrink with larger parallelism configs, leading to worse load balancing.

Root cause detail

The MoEAuxLossAutoScaler custom autograd function piggybacks aux loss gradients onto the gate weights tensor. Its backward multiplies aux loss grad by main_loss_backward_scale to counteract downstream gradient scaling:

  • Non-PP: FSDP allreduce divides all grads by dp_group_size. Setting main_loss_backward_scale = dp_group_size cancels this.
  • PP: FSDP divides by dp_group_size, then scale_grads_and_clip_grad_norm divides by num_label_tokens / dp_group_size. The dp_group_size factors cancel, leaving a net 1/num_label_tokens. Setting main_loss_backward_scale = num_label_tokens cancels this.

After the fix, effective aux loss gradient = aux_loss_coeff × ∂(aux_loss)/∂θ, invariant to parallelism configuration.

Changes

  • layers.py: capture MoEAuxLossAutoScaler.apply() return value; scale input by aux_loss_coeff instead of token count
  • train_ft.py: set main_loss_backward_scale to num_label_tokens (PP) or dp_group_size (non-PP) before forward-backward
  • vlm/finetune.py: same as train_ft.py for the VLM recipe

Test plan

  • TestMoEAuxLossAutoScaler — verifies apply() returns output with grad_fn, backward scales correctly, defaults to 1.0
  • TestGateAuxLossGradientFlow — verifies Gate wires aux loss into autograd graph, gradients reach router params, aux_loss_coeff scales correctly
  • TestRunTrainOptimStepSetsMoEScale — verifies recipe sets scale to num_label_tokens (PP) or dp_group_size (non-PP)
  • All existing MoE unit tests pass (310 passed)
  • All existing recipe unit tests pass
  • Ruff lint + format clean

🤖 Generated with Claude Code

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Feb 28, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test fc2dbe9

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 7a3d2a4

@hemildesai
Copy link
Copy Markdown
Contributor Author

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test fa41512

hemildesai and others added 3 commits March 2, 2026 10:13
Same inline comment as train_ft.py, explaining why main_loss_backward_scale
is set to dp_group_size (non-PP) vs num_label_tokens (PP).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
When all tokens are masked (e.g. a padding-heavy CP rank),
context_length can be zero, causing a division-by-zero in the
f_i and P_i computations. Clamp to min=1 to produce a safe
zero aux loss instead.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
When cp_mesh is not passed explicitly to MoE.forward() (e.g. via
AutoPipeline schedules that don't thread it through), the module now
falls back to self.cp_mesh, which is set by parallelizer.apply_cp
during model parallelization.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
@hemildesai hemildesai force-pushed the hemil/moe-aux-loss-fix branch from fa41512 to e69ad94 Compare March 2, 2026 18:15
@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test e69ad94

@hemildesai hemildesai merged commit 45095e6 into main Mar 2, 2026
51 checks passed
@hemildesai hemildesai deleted the hemil/moe-aux-loss-fix branch March 2, 2026 20:11
hemildesai added a commit that referenced this pull request Mar 4, 2026
* fix: correct MoE auxiliary loss gradient scaling

MoE auxiliary (load-balancing) loss gradients were incorrectly scaled in
the training pipeline due to two bugs:

1. MoEAuxLossAutoScaler.apply() return value was not captured, so aux
   loss gradients never flowed back through the autograd graph. The gate
   weights tensor must be reassigned from the return value for the custom
   backward to attach.

2. MoEAuxLossAutoScaler.main_loss_backward_scale was never set by the
   training recipes, so aux loss gradients were inadvertently divided by
   dp_group_size (non-PP case, from FSDP gradient averaging) or by
   num_label_tokens (PP case, from post-hoc pipeline-parallel scaling).
   This caused the effective aux loss coefficient to silently shrink with
   larger parallelism configs, leading to worse load balancing.

The fix ensures effective aux loss gradient = aux_loss_coeff * d(aux_loss)/dθ,
invariant to parallelism configuration.

Changes:
- layers.py: capture MoEAuxLossAutoScaler.apply() return value and scale
  by aux_loss_coeff instead of token count
- train_ft.py, vlm/finetune.py: set main_loss_backward_scale to
  num_label_tokens (PP) or dp_group_size (non-PP) before forward-backward
  so the custom backward correctly counteracts FSDP/PP scaling

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* test: add unit tests for MoE aux loss gradient scaling fix

Tests cover three aspects of the fix:

1. MoEAuxLossAutoScaler (test_layers.py::TestMoEAuxLossAutoScaler):
   - apply() returns output unchanged but with grad_fn
   - backward scales aux_loss grad by main_loss_backward_scale
   - defaults to scale=1.0 when main_loss_backward_scale is unset
   - grad_output passes through unmodified

2. Gate aux loss gradient flow (test_layers.py::TestGateAuxLossGradientFlow):
   - weights returned by Gate.forward() carry grad_fn from apply()
   - backward through weights produces gradients on gate parameters
   - aux_loss_coeff correctly scales the value passed to apply()
   - no aux loss computed when coeff is zero

3. Recipe main_loss_backward_scale (test_train_ft.py::TestRunTrainOptimStepSetsMoEScale):
   - PP enabled: scale set to num_label_tokens
   - PP disabled: scale set to dp_group_size

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* docs: explain MoE aux loss backward scale derivation in train_ft

Add an inline comment explaining why main_loss_backward_scale is set to
dp_group_size (non-PP) vs num_label_tokens (PP), tracing the gradient
scaling pipeline through FSDP allreduce and PP post-hoc rescaling.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* docs: explain MoE aux loss backward scale derivation in vlm finetune

Same inline comment as train_ft.py, explaining why main_loss_backward_scale
is set to dp_group_size (non-PP) vs num_label_tokens (PP).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* fix: clamp context_length to avoid division by zero in aux loss

When all tokens are masked (e.g. a padding-heavy CP rank),
context_length can be zero, causing a division-by-zero in the
f_i and P_i computations. Clamp to min=1 to produce a safe
zero aux loss instead.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* feat: store cp_mesh on MoE and assign it during apply_cp

When cp_mesh is not passed explicitly to MoE.forward() (e.g. via
AutoPipeline schedules that don't thread it through), the module now
falls back to self.cp_mesh, which is set by parallelizer.apply_cp
during model parallelization.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

---------

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
SwekeR-463 pushed a commit to SwekeR-463/Automodel that referenced this pull request Mar 11, 2026
* fix: correct MoE auxiliary loss gradient scaling

MoE auxiliary (load-balancing) loss gradients were incorrectly scaled in
the training pipeline due to two bugs:

1. MoEAuxLossAutoScaler.apply() return value was not captured, so aux
   loss gradients never flowed back through the autograd graph. The gate
   weights tensor must be reassigned from the return value for the custom
   backward to attach.

2. MoEAuxLossAutoScaler.main_loss_backward_scale was never set by the
   training recipes, so aux loss gradients were inadvertently divided by
   dp_group_size (non-PP case, from FSDP gradient averaging) or by
   num_label_tokens (PP case, from post-hoc pipeline-parallel scaling).
   This caused the effective aux loss coefficient to silently shrink with
   larger parallelism configs, leading to worse load balancing.

The fix ensures effective aux loss gradient = aux_loss_coeff * d(aux_loss)/dθ,
invariant to parallelism configuration.

Changes:
- layers.py: capture MoEAuxLossAutoScaler.apply() return value and scale
  by aux_loss_coeff instead of token count
- train_ft.py, vlm/finetune.py: set main_loss_backward_scale to
  num_label_tokens (PP) or dp_group_size (non-PP) before forward-backward
  so the custom backward correctly counteracts FSDP/PP scaling

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* test: add unit tests for MoE aux loss gradient scaling fix

Tests cover three aspects of the fix:

1. MoEAuxLossAutoScaler (test_layers.py::TestMoEAuxLossAutoScaler):
   - apply() returns output unchanged but with grad_fn
   - backward scales aux_loss grad by main_loss_backward_scale
   - defaults to scale=1.0 when main_loss_backward_scale is unset
   - grad_output passes through unmodified

2. Gate aux loss gradient flow (test_layers.py::TestGateAuxLossGradientFlow):
   - weights returned by Gate.forward() carry grad_fn from apply()
   - backward through weights produces gradients on gate parameters
   - aux_loss_coeff correctly scales the value passed to apply()
   - no aux loss computed when coeff is zero

3. Recipe main_loss_backward_scale (test_train_ft.py::TestRunTrainOptimStepSetsMoEScale):
   - PP enabled: scale set to num_label_tokens
   - PP disabled: scale set to dp_group_size

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* docs: explain MoE aux loss backward scale derivation in train_ft

Add an inline comment explaining why main_loss_backward_scale is set to
dp_group_size (non-PP) vs num_label_tokens (PP), tracing the gradient
scaling pipeline through FSDP allreduce and PP post-hoc rescaling.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* docs: explain MoE aux loss backward scale derivation in vlm finetune

Same inline comment as train_ft.py, explaining why main_loss_backward_scale
is set to dp_group_size (non-PP) vs num_label_tokens (PP).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* fix: clamp context_length to avoid division by zero in aux loss

When all tokens are masked (e.g. a padding-heavy CP rank),
context_length can be zero, causing a division-by-zero in the
f_i and P_i computations. Clamp to min=1 to produce a safe
zero aux loss instead.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* feat: store cp_mesh on MoE and assign it during apply_cp

When cp_mesh is not passed explicitly to MoE.forward() (e.g. via
AutoPipeline schedules that don't thread it through), the module now
falls back to self.cp_mesh, which is set by parallelizer.apply_cp
during model parallelization.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

---------

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: SwekeR-463 <swekerswasti@gmail.com>
linnanwang pushed a commit that referenced this pull request Apr 24, 2026
* fix: correct MoE auxiliary loss gradient scaling

MoE auxiliary (load-balancing) loss gradients were incorrectly scaled in
the training pipeline due to two bugs:

1. MoEAuxLossAutoScaler.apply() return value was not captured, so aux
   loss gradients never flowed back through the autograd graph. The gate
   weights tensor must be reassigned from the return value for the custom
   backward to attach.

2. MoEAuxLossAutoScaler.main_loss_backward_scale was never set by the
   training recipes, so aux loss gradients were inadvertently divided by
   dp_group_size (non-PP case, from FSDP gradient averaging) or by
   num_label_tokens (PP case, from post-hoc pipeline-parallel scaling).
   This caused the effective aux loss coefficient to silently shrink with
   larger parallelism configs, leading to worse load balancing.

The fix ensures effective aux loss gradient = aux_loss_coeff * d(aux_loss)/dθ,
invariant to parallelism configuration.

Changes:
- layers.py: capture MoEAuxLossAutoScaler.apply() return value and scale
  by aux_loss_coeff instead of token count
- train_ft.py, vlm/finetune.py: set main_loss_backward_scale to
  num_label_tokens (PP) or dp_group_size (non-PP) before forward-backward
  so the custom backward correctly counteracts FSDP/PP scaling

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* test: add unit tests for MoE aux loss gradient scaling fix

Tests cover three aspects of the fix:

1. MoEAuxLossAutoScaler (test_layers.py::TestMoEAuxLossAutoScaler):
   - apply() returns output unchanged but with grad_fn
   - backward scales aux_loss grad by main_loss_backward_scale
   - defaults to scale=1.0 when main_loss_backward_scale is unset
   - grad_output passes through unmodified

2. Gate aux loss gradient flow (test_layers.py::TestGateAuxLossGradientFlow):
   - weights returned by Gate.forward() carry grad_fn from apply()
   - backward through weights produces gradients on gate parameters
   - aux_loss_coeff correctly scales the value passed to apply()
   - no aux loss computed when coeff is zero

3. Recipe main_loss_backward_scale (test_train_ft.py::TestRunTrainOptimStepSetsMoEScale):
   - PP enabled: scale set to num_label_tokens
   - PP disabled: scale set to dp_group_size

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* docs: explain MoE aux loss backward scale derivation in train_ft

Add an inline comment explaining why main_loss_backward_scale is set to
dp_group_size (non-PP) vs num_label_tokens (PP), tracing the gradient
scaling pipeline through FSDP allreduce and PP post-hoc rescaling.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* docs: explain MoE aux loss backward scale derivation in vlm finetune

Same inline comment as train_ft.py, explaining why main_loss_backward_scale
is set to dp_group_size (non-PP) vs num_label_tokens (PP).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* fix: clamp context_length to avoid division by zero in aux loss

When all tokens are masked (e.g. a padding-heavy CP rank),
context_length can be zero, causing a division-by-zero in the
f_i and P_i computations. Clamp to min=1 to produce a safe
zero aux loss instead.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* feat: store cp_mesh on MoE and assign it during apply_cp

When cp_mesh is not passed explicitly to MoE.forward() (e.g. via
AutoPipeline schedules that don't thread it through), the module now
falls back to self.cp_mesh, which is set by parallelizer.apply_cp
during model parallelization.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

---------

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] MoE aux_loss gradient never flows — MoEAuxLossAutoScaler return value discarded

2 participants