Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do we need an N-dim sub-DeviceMesh? #126530

Closed
botbw opened this issue May 17, 2024 · 2 comments
Closed

Do we need an N-dim sub-DeviceMesh? #126530

botbw opened this issue May 17, 2024 · 2 comments
Assignees
Labels
module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@botbw
Copy link

botbw commented May 17, 2024

馃殌 The feature, motivation and pitch

Hey there! Currently torch.distributed._tensor.DeviceMesh only supports 1-D sub-meshes, will it be possible to manipulate it as an NDArray and generate N-dim sub-meshes?

For example, in 2-D tensor parallelism together with pipeline parallelism, the mesh looks like [pp, tp0, tp1] == [2, 2, 2], and if an all-gather/all-reduce on tp is needed for rank == 0, mesh[tp0], mesh[tp1] only gives [0, 1] and [0, 2] (3 is missing).

import os
import torch
from torch.distributed._tensor import DeviceMesh, mesh_resources
rank = int(os.environ['RANK'])
if __name__ == "__main__":
    mesh = DeviceMesh("cpu", [
        [
            [0, 1],
            [2, 3]
        ], # pp_rank == 0
        [
            [4, 5],
            [6, 7]
        ]  # pp_rank == 1
    ], mesh_dim_names=['pp', 'tp0', 'tp1'])
    if rank == 0: # pp_rank == tp0_rank == tp1_rank == 0
        print(mesh['tp0'])
        print(mesh['tp1'])

A workaround is to represent tp0 and tp1 in a single dim, but there could be scenario in which tp0 is replicated and only tp2 needs the communication.

Alternatively, could you please provide some workaround for this?

Alternatives

No response

Additional context

No response

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @msaroufim

@awgu awgu added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 17, 2024
@yf225 yf225 added module: dtensor distributed tensor tag triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 20, 2024
@botbw
Copy link
Author

botbw commented May 21, 2024

hey any updates on this?
@awgu @yf225 @wanchaol @wz337

@wz337
Copy link
Contributor

wz337 commented May 21, 2024

@botbw Thanks for raising the issue. We are actually working on the feature and hoping to have it in nightly asap so we can get this in 2.4 release as well.

bigfootjon pushed a commit that referenced this issue Jun 5, 2024
Fixes #126530

Pull Request resolved: #127465
Approved by: https://github.com/wconstab

(cherry picked from commit e72232f)
bigfootjon pushed a commit that referenced this issue Jun 5, 2024
petrex pushed a commit to petrex/pytorch that referenced this issue Jun 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants