-
Notifications
You must be signed in to change notification settings - Fork 75
Add missing device meshes in tests #4117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| SetSelector selector( | ||
| std::unordered_set<TensorView*>(output_tvs.begin(), output_tvs.end())); | ||
| MaxLogicalDomainInfoSpanningTree(input_tv, &selector).traverse(&propagator); | ||
| scheduler_utils::parallelizeAllLike( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not set the device mesh.
|
Review updated until commit 3a7baef Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
|
!test |
wujingyue
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I no longer understand the reason behind this PR.
| shardAllLike(ref_input, outputs_without_mesh); |
| tv->setDeviceMesh(output_with_mesh->getDeviceMesh()); |
|
@wujingyue I prefer merging this to decouple updates to the presegmentation passes. Wdyt? |
|
Sure. I think we talked about this PR offline. Can you PR-describe the motivations before I approve? IIUC, this tries to work around upcoming problems in sharding propagation before and the added meshes are won't be necessary eventually |
|
!test |
Can you fix the grammar here? |
This PR is to decouple upcoming changes to the multi-device related presegmentation passes.
Since the presegmentation passes will be updated serially, this to workaround around test failures in the interim.
For example, some examples have manually set loop domains as required for communication. The
propagateShardingspreseg pass will update the loop domain if a tensorview does not have a device mesh. Since ourreorderShardedAxispreseg pass is not yet fully functional for DID loop split, we can run into errors due to incomplete changes to the preseg pass.This PR is to ease rolling out the changes incrementally.
Eventually, we should not require setting meshes once the update to all passes are complete.