-
Notifications
You must be signed in to change notification settings - Fork 546
Fix memory overhead of linear layer when all gather from sequence parallel #2125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b6443bb
743406b
ce0a634
5130ea3
9ff7a55
b3b6a0c
70c84bc
232d3af
8f57b12
faa2a07
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -349,9 +349,14 @@ def _create_columnwise(self): | |||
| def _transpose_columnwise_data(self): | ||||
| """Plainly transpose the columnwise data and scale inv.""" | ||||
| if self._columnwise_data is not None: | ||||
| # TODO(yuzhongw, tmoon): Figure out why _old_data is not automatically | ||||
| # deallocated by GC. Manually deallocating is a temporary hack. | ||||
| _old_data = self._columnwise_data | ||||
| self._columnwise_data = tex.fp8_transpose( | ||||
| self._columnwise_data, self._fp8_dtype, out=None | ||||
| ) | ||||
| _old_data.data = _empty_tensor() | ||||
| del _old_data | ||||
|
Comment on lines
+354
to
+359
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't Python refcounting deallocate the old data automatically? If not, then there's a larger bug in I see that we return
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is a very good question that also confuses me. I have a similar question: why we need to use I'm not very familiar with the code and deeper implementation. Could you please share this problem with some TE experts to help solve it? Or do you think we could merge this PR first and try to find out the root cause later, because it is a little bit emergent for the runnability and perf of DSV3 / MLA long context training.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We call
The forward GEMM input tensor is stored within the autograd ctx, so GC will not deallocate until after the backward has finished. However, we don't need this buffer after the wgrad GEMM and ideally it would be reused for the LayerNorm grad.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I described in this thread, I think it is not the business of C++ extension, but the business of
timmoon10 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
|
||||
| def __repr__(self): | ||||
| if self._rowwise_data is not None: | ||||
|
|
||||
Uh oh!
There was an error while loading. Please reload this page.