Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Inductor passes to micro-pipeline all-gather-matmul and matmul-reduce-scatter in certain cases #126598

Open
wants to merge 16 commits into
base: gh/yifuwang/84/base
Choose a base branch
from

Conversation

…mul-reduce-scatter in certain cases

[ghstack-poisoned]
Copy link

pytorch-bot bot commented May 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126598

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 1 Pending, 2 Unrelated Failures

As of commit ccc11e5 with merge base 4afc5c7 (image):

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…mul and matmul-reduce-scatter in certain cases"

[ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases"

[ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases"

[ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases"

[ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases"


## Context
See context [here](#122163).

[ghstack-poisoned]
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass looks very clean and easy to understand! Mainly have some questions

@parametrize("A_dims", [2, 3])
@parametrize("gather_dim", [0, 1, 2])
@fresh_inductor_cache()
def test_fuse_all_gather_matmul(self, A_dims, gather_dim):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to add in this PR, let's try to add the e2e integration test with ColwiseParallel for allgather matmuls in follow up PRs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not test this in this PR 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test for dtensor-based seq-par.

match.nodes,
aten.cat.default,
)[0]
shard_node = ag_node.args[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shard_node is not getting used?

second reshape node is replaced with `new_node`.
In addition, we ensure that the original mm node ends up with zero
users by replacing it with a reverse reshape of `new_node`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you elaborate this part more? how could replacing it with a reverse shape of new_node results in original mm node with zero users?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An ND-matmul shows up in fx graphs as reshape -> mm -> reshape sequences. The first reshape flattens the leading dims while the second one unflattens them. Consider the following fake fx graph:

buf_0 = allgather(...)
buf_1 = aten.reshape(buf_0, ...)
buf_2 = aten.mm(buf_1, ...)
buf_3 = aten.reshape(buf_2, ...)

Since fused_all_gather_matmul semantically performs matmuls (as opposed to mms), its results will replace buf_0 and buf_3. It's okay if buf_1 ends up with non-zero users, since it's just a view on buf_0. However, if for some reason buf_2 ends up with non-zero users and can't be removed after the fusion (e.g. buf_2 being returned for some reason), we'd be performing an extra mm.

To ensure buf_2 has zero users after the fusion, since buf_3 is always available and it's a reshape from buf_2, we replace buf_2 with the reverse reshape from buf_3.

patterns = PatternMatcherPass()


def _is_backward(graph: torch.fx.Graph) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: this function seems not getting used, is it for debugging?

…mul and matmul-reduce-scatter in certain cases"


## Context
See context [here](#122163).

[ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases"


## Context
See context [here](#122163).

[ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases"


## Context
See context [here](#122163).

[ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases"


## Context
See context [here](#122163).

[ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases"


## Context
See context [here](#122163).

[ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases"


## Context
See context [here](#122163).

[ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases"

[ghstack-poisoned]
…mul and matmul-reduce-scatter in certain cases"

[ghstack-poisoned]
@yifuwang yifuwang added the topic: not user facing topic category label May 29, 2024
…mul and matmul-reduce-scatter in certain cases"

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k voznesenskym EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire

[ghstack-poisoned]
@yifuwang
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 31, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 jobs have failed, first few of them are: trunk / macos-13-py3-arm64 / test (default, 1, 3, macos-m1-stable), trunk / macos-13-py3-arm64 / test (default, 3, 3, macos-m1-stable)

Details for Dev Infra team Raised by workflow job

…mul and matmul-reduce-scatter in certain cases"

[ghstack-poisoned]
@yifuwang yifuwang mentioned this pull request Jun 1, 2024
@yifuwang
Copy link
Contributor Author

yifuwang commented Jun 1, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request merging module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants