Skip to content

[PyTorch] Main_Grad buffer isnt overwritten when overwrite_main_grad=True #2936

Merged
vthumbe1503 merged 10 commits into
NVIDIA:mainfrom
vthumbe1503:delay_wgrad_bug
Apr 29, 2026
Merged

[PyTorch] Main_Grad buffer isnt overwritten when overwrite_main_grad=True #2936
vthumbe1503 merged 10 commits into
NVIDIA:mainfrom
vthumbe1503:delay_wgrad_bug

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 added bug Something isn't working 2.15.0 labels Apr 28, 2026
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@vthumbe1503 vthumbe1503 changed the title Main_Grad buffer isnt overwritten when overwrite_main_grad=True [Pytorch] Main_Grad buffer isnt overwritten when overwrite_main_grad=True Apr 28, 2026
@vthumbe1503 vthumbe1503 changed the title [Pytorch] Main_Grad buffer isnt overwritten when overwrite_main_grad=True [PyTorch] Main_Grad buffer isnt overwritten when overwrite_main_grad=True Apr 28, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 28, 2026

Greptile Summary

This PR fixes a bug where weight.main_grad was not overwritten when overwrite_main_grad=True (the MegatronFSDP path): the old if accumulate_into_main_grad: guard prevented grouped_wgrad from being aliased to main_grad, so the GEMM silently wrote into a scratch allocation and the main_grad buffer was left untouched. The fix unconditionally creates grouped_wgrad from main_grad for both the accumulate and overwrite paths, and switches the post-GEMM grad_added_to_main_grad bookkeeping to the op-level _accumulate_into_main_grad flag so it fires correctly regardless of overwrite_main_grad. A new MegatronTrainingHelper harness and a dedicated test_grouped_mlp_overwrite_main_grad test (NaN-sentinel fill, parametrized over delay_wgrad_compute / zero_out_wgrad / single_grouped_weight) make the corrected behavior easy to verify.

Confidence Score: 5/5

Safe to merge — the fix is minimal, correct, and well-tested with NaN sentinels.

No P0 or P1 issues found. The fix correctly addresses the root cause (missing alias of grouped_wgrad to main_grad when overwrite_main_grad=True) without affecting the accumulate path. The GEMM accumulate flag is correctly driven by the local accumulate_into_main_grad (False when overwriting, True when accumulating), and the new test covers both single/grouped-weight and deferred-wgrad variants.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Core fix: unconditionally aliases grouped_wgrad to main_grad for both accumulate and overwrite paths, and switches post-GEMM grad_added_to_main_grad bookkeeping to use the op-level _accumulate_into_main_grad flag instead of the stale local variable.
tests/pytorch/test_fusible_ops.py Adds MegatronTrainingHelper test harness and a dedicated test_grouped_mlp_overwrite_main_grad test that seeds main_grad with NaN sentinels to loudly catch missed writes; also refactors the existing accumulate-into-main_grad verification to use the new helper.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[_compute_grad_params called] --> B{fc_op._accumulate_into_main_grad?}
    B -- No --> C[Allocate scratch grouped_wgrad]
    B -- Yes --> D[Get weight_param.main_grad]
    D --> E{overwrite_main_grad attr?}
    E -- False / missing --> F["accumulate_into_main_grad = True\n(GEMM will add into main_grad)"]
    E -- True --> G["accumulate_into_main_grad = False\n(GEMM will overwrite main_grad)"]
    F --> H["grouped_wgrad <- alias of main_grad (FIX: always)"]
    G --> H
    C --> I[Launch GEMM\naccumulate=False]
    H --> J["Launch GEMM\naccumulate=accumulate_into_main_grad"]
    I --> K[Return scratch grad as param.grad]
    J --> L{fc_op._accumulate_into_main_grad\n& grad_added_to_main_grad attr?}
    L -- Yes --> M["Set grad_added_to_main_grad=True\nReturn dummy wgrad as param.grad"]
    L -- No --> N[Return packed_wgrad as param.grad]
Loading

Reviews (4): Last reviewed commit: "Merge branch 'delay_wgrad_bug' of https:..." | Re-trigger Greptile

Copy link
Copy Markdown
Collaborator

@zhongbozhu zhongbozhu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

vthumbe1503 and others added 6 commits April 28, 2026 19:31
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

vthumbe1503 and others added 2 commits April 28, 2026 20:28
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. The underlying problem is that GroupedLinear._accumulate_into_main_grad has become a deceptive name since we added support for overwrite_main_grad. Unfortunately we're stuck with the accumulate_into_main_grad kwarg in our APIs, but we can change the internal variable to something more descriptive like _output_into_main_grad.

Comment on lines +250 to +252
if delay_wgrad or fc_op._accumulate_into_main_grad:
w_list = [None] * num_groups
if accumulate_into_main_grad:
if fc_op._accumulate_into_main_grad:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think code duplication is worth it if it makes it easier to read.

Suggested change
if delay_wgrad or fc_op._accumulate_into_main_grad:
w_list = [None] * num_groups
if accumulate_into_main_grad:
if fc_op._accumulate_into_main_grad:
if delay_wgrad:
w_list = [None] * num_groups
if fc_op._accumulate_into_main_grad:
w_list = [None] * num_groups

return ref, test


class MegatronTrainingHelper:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should consider moving this into a separate test. The purpose of test_fusible_ops.py is to do minimal tests with te.Sequential, and Megatron integrations are orthogonal to that. Creating a separate test file would also make it a logical place put tests that are completely unrelated to te.Sequential, like te.Linear or attention.

assert (
wp.grad is not None
), "weight.grad is None; the Megatron protocol expects a dummy tensor stand-in here."
assert wp.grad.data_ptr() == expected_dummy.data_ptr(), (
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, but it's also a little strict isn't it? The real thing we are testing is:

Suggested change
assert wp.grad.data_ptr() == expected_dummy.data_ptr(), (
assert wp.grad.data_ptr() != wp.main_grad.data_ptr(), (

@vthumbe1503 vthumbe1503 merged commit b4aeed1 into NVIDIA:main Apr 29, 2026
11 of 14 checks passed
KshitijLakhani pushed a commit that referenced this pull request Apr 29, 2026
…True (#2936)

* fix

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add test

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* zero_out should also be tested

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

---------

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: root <root@gb-nvl-059-compute03.nvidia.com>
faradawn pushed a commit to faradawn/TransformerEngine that referenced this pull request May 14, 2026
…True (NVIDIA#2936)

* fix

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add test

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* zero_out should also be tested

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

---------

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: root <root@gb-nvl-059-compute03.nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.15.0 bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants