[PyTorch] Main_Grad buffer isnt overwritten when overwrite_main_grad=True #2936
Conversation
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
Greptile SummaryThis PR fixes a bug where Confidence Score: 5/5Safe to merge — the fix is minimal, correct, and well-tested with NaN sentinels. No P0 or P1 issues found. The fix correctly addresses the root cause (missing alias of grouped_wgrad to main_grad when overwrite_main_grad=True) without affecting the accumulate path. The GEMM accumulate flag is correctly driven by the local accumulate_into_main_grad (False when overwriting, True when accumulating), and the new test covers both single/grouped-weight and deferred-wgrad variants. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[_compute_grad_params called] --> B{fc_op._accumulate_into_main_grad?}
B -- No --> C[Allocate scratch grouped_wgrad]
B -- Yes --> D[Get weight_param.main_grad]
D --> E{overwrite_main_grad attr?}
E -- False / missing --> F["accumulate_into_main_grad = True\n(GEMM will add into main_grad)"]
E -- True --> G["accumulate_into_main_grad = False\n(GEMM will overwrite main_grad)"]
F --> H["grouped_wgrad <- alias of main_grad (FIX: always)"]
G --> H
C --> I[Launch GEMM\naccumulate=False]
H --> J["Launch GEMM\naccumulate=accumulate_into_main_grad"]
I --> K[Return scratch grad as param.grad]
J --> L{fc_op._accumulate_into_main_grad\n& grad_added_to_main_grad attr?}
L -- Yes --> M["Set grad_added_to_main_grad=True\nReturn dummy wgrad as param.grad"]
L -- No --> N[Return packed_wgrad as param.grad]
Reviews (4): Last reviewed commit: "Merge branch 'delay_wgrad_bug' of https:..." | Re-trigger Greptile |
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…sformerEngine into delay_wgrad_bug
|
/te-ci L1 pytorch |
timmoon10
left a comment
There was a problem hiding this comment.
Overall LGTM. The underlying problem is that GroupedLinear._accumulate_into_main_grad has become a deceptive name since we added support for overwrite_main_grad. Unfortunately we're stuck with the accumulate_into_main_grad kwarg in our APIs, but we can change the internal variable to something more descriptive like _output_into_main_grad.
| if delay_wgrad or fc_op._accumulate_into_main_grad: | ||
| w_list = [None] * num_groups | ||
| if accumulate_into_main_grad: | ||
| if fc_op._accumulate_into_main_grad: |
There was a problem hiding this comment.
Nit: I think code duplication is worth it if it makes it easier to read.
| if delay_wgrad or fc_op._accumulate_into_main_grad: | |
| w_list = [None] * num_groups | |
| if accumulate_into_main_grad: | |
| if fc_op._accumulate_into_main_grad: | |
| if delay_wgrad: | |
| w_list = [None] * num_groups | |
| if fc_op._accumulate_into_main_grad: | |
| w_list = [None] * num_groups |
| return ref, test | ||
|
|
||
|
|
||
| class MegatronTrainingHelper: |
There was a problem hiding this comment.
We should consider moving this into a separate test. The purpose of test_fusible_ops.py is to do minimal tests with te.Sequential, and Megatron integrations are orthogonal to that. Creating a separate test file would also make it a logical place put tests that are completely unrelated to te.Sequential, like te.Linear or attention.
| assert ( | ||
| wp.grad is not None | ||
| ), "weight.grad is None; the Megatron protocol expects a dummy tensor stand-in here." | ||
| assert wp.grad.data_ptr() == expected_dummy.data_ptr(), ( |
There was a problem hiding this comment.
This is fine, but it's also a little strict isn't it? The real thing we are testing is:
| assert wp.grad.data_ptr() == expected_dummy.data_ptr(), ( | |
| assert wp.grad.data_ptr() != wp.main_grad.data_ptr(), ( |
…True (#2936) * fix Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add test Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * zero_out should also be tested Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> --------- Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: root <root@gb-nvl-059-compute03.nvidia.com>
…True (NVIDIA#2936) * fix Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add test Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * zero_out should also be tested Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> --------- Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: root <root@gb-nvl-059-compute03.nvidia.com>
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: