Skip to content

[PyTorch] Fix fuse_wgrad_accumulation for GroupedLinear#1488

Merged
timmoon10 merged 4 commits into
NVIDIA:mainfrom
yaox12:xiny/fix_grouped_linear
Feb 19, 2025
Merged

[PyTorch] Fix fuse_wgrad_accumulation for GroupedLinear#1488
timmoon10 merged 4 commits into
NVIDIA:mainfrom
yaox12:xiny/fix_grouped_linear

Conversation

@yaox12
Copy link
Copy Markdown
Member

@yaox12 yaox12 commented Feb 17, 2025

Description

Due to the wrong indent, the wgrad computation is not called when ctx.fuse_wgrad_accumulation == True.

Also update the test to cover this case.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 added bug Something isn't working 2.1.0 labels Feb 17, 2025
@yaox12 yaox12 requested a review from timmoon10 February 17, 2025 08:22
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented Feb 17, 2025

/te-ci pytorch

Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented Feb 17, 2025

/te-ci pytorch


ctx.weights_requires_grad = weights[0].requires_grad
if fuse_wgrad_accumulation and ctx.weights_requires_grad:
ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's recommended to use ctx.save_for_backward instead of storing tensors directly in ctx. They warn about messing up the grad graph and memory leaks, although I'm not sure what cases they are specifically worried about.

Copy link
Copy Markdown
Member Author

@yaox12 yaox12 Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. We were saving main_grad tensors using ctx.save_for_backward in TE 1.x. But I'm seeing there is comment here.

# Since main_grad can be modified inplace, it should not be a part of saved_tensors

I'm wondering if we have seen issues with ctx.save_for_backward?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, we should follow the example of Linear then.

@ksivaman This change is from commit 7e58678 in the internal repo. Do you remember why we can't store main_grad in saved_tensors?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I know, previous prepare_for_saving saves tensor.data instead of the tensor itself. So for main_grad that need to be modified inplace, this could be an issue.

Now #1474 changed prepare_for_saving to save the tensor itself, this is no longer a problem.

outputs.append(p.grad)
if getattr(p, "main_grad", None) is not None:
outputs.append(p.main_grad)
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out Mcore expects p.grad to not be None: #1474 (comment)
#1474 sets grad to an uninitialized tensor and assumes Mcore will ignore it.

@timmoon10 timmoon10 self-requested a review February 18, 2025 23:33
@timmoon10
Copy link
Copy Markdown
Member

/te-ci pytorch

@timmoon10 timmoon10 merged commit fceff07 into NVIDIA:main Feb 19, 2025
timmoon10 added a commit that referenced this pull request Feb 21, 2025
* fix fuse_wgrad_accumulation for GroupedLinear

Signed-off-by: Xin Yao <xiny@nvidia.com>

* fix fuse_wgrad_accumulation for GroupedLinear

Signed-off-by: Xin Yao <xiny@nvidia.com>

* update tests

Signed-off-by: Xin Yao <xiny@nvidia.com>

---------

Signed-off-by: Xin Yao <xiny@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.1.0 bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants