-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce Inductor passes to micro-pipeline all-gather-matmul and matmul-reduce-scatter in certain cases #126598
base: gh/yifuwang/84/base
Are you sure you want to change the base?
Conversation
…mul-reduce-scatter in certain cases [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126598
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 2 Unrelated FailuresAs of commit ccc11e5 with merge base 4afc5c7 (): UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…mul and matmul-reduce-scatter in certain cases" [ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases" [ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases" [ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases" [ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases" ## Context See context [here](#122163). [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pass looks very clean and easy to understand! Mainly have some questions
@parametrize("A_dims", [2, 3]) | ||
@parametrize("gather_dim", [0, 1, 2]) | ||
@fresh_inductor_cache() | ||
def test_fuse_all_gather_matmul(self, A_dims, gather_dim): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to add in this PR, let's try to add the e2e integration test with ColwiseParallel for allgather matmuls in follow up PRs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not test this in this PR 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test for dtensor-based seq-par.
match.nodes, | ||
aten.cat.default, | ||
)[0] | ||
shard_node = ag_node.args[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shard_node is not getting used?
second reshape node is replaced with `new_node`. | ||
In addition, we ensure that the original mm node ends up with zero | ||
users by replacing it with a reverse reshape of `new_node`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you elaborate this part more? how could replacing it with a reverse shape of new_node
results in original mm node with zero users?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An ND-matmul shows up in fx graphs as reshape -> mm -> reshape
sequences. The first reshape flattens the leading dims while the second one unflattens them. Consider the following fake fx graph:
buf_0 = allgather(...)
buf_1 = aten.reshape(buf_0, ...)
buf_2 = aten.mm(buf_1, ...)
buf_3 = aten.reshape(buf_2, ...)
Since fused_all_gather_matmul
semantically performs matmul
s (as opposed to mm
s), its results will replace buf_0
and buf_3
. It's okay if buf_1
ends up with non-zero users, since it's just a view on buf_0
. However, if for some reason buf_2
ends up with non-zero users and can't be removed after the fusion (e.g. buf_2
being returned for some reason), we'd be performing an extra mm
.
To ensure buf_2
has zero users after the fusion, since buf_3
is always available and it's a reshape from buf_2
, we replace buf_2
with the reverse reshape from buf_3
.
patterns = PatternMatcherPass() | ||
|
||
|
||
def _is_backward(graph: torch.fx.Graph) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: this function seems not getting used, is it for debugging?
…mul and matmul-reduce-scatter in certain cases" ## Context See context [here](#122163). [ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases" ## Context See context [here](#122163). [ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases" ## Context See context [here](#122163). [ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases" ## Context See context [here](#122163). [ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases" ## Context See context [here](#122163). [ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases" ## Context See context [here](#122163). [ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases" [ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases" [ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases" cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k voznesenskym EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 jobs have failed, first few of them are: trunk / macos-13-py3-arm64 / test (default, 1, 3, macos-m1-stable), trunk / macos-13-py3-arm64 / test (default, 3, 3, macos-m1-stable) Details for Dev Infra teamRaised by workflow job |
…mul and matmul-reduce-scatter in certain cases" [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
Stack from ghstack (oldest at bottom):
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire