Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[distributed][Tensor Parallelism] Implement early transforms for column-wise and row-wise linear and embedding #410

Merged
merged 95 commits into from
May 31, 2024

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented May 13, 2024

this implements a trace transform that converts one or more linear and/or embedding layers into column-wise or row-wise tensor parallel ones by (1) sharding their weight and bias and (2) inserting needed communication and/or scattering before and/or after the modified layers.

Out of four supported ops, row-wise parallel linear would lead to a BoundSymbol modification. The change caused is to omit the bias term and that bias is added to the result of communication (after post-processing).


example

class Model(nn.Module):
    def __init__(self, n_in: int, n_hidden: int, n_out: int) -> None:
        self.l1 = nn.Linear(n_in, n_hidden)
        self.l2 = nn.Linear(n_hidden, n_out)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.l2(F.gelu(self.l1(x))


device = torch.device(f"cuda:{rank}")

model = Model().to(device)
jitted_model = thunder.jit(model)
tp_jitted_model = thunder.distributed.column_parallel(jitted_model, ("l1",))
tp_jitted_model = thunder.distributed.row_parallel(tp_jitted_model, ("l2",))

x = torch.randn(..., device=device)
y = tp_jitted_model(x)
assert y.size(1) == n_out

cc @Borda @apaz-cli @carmocca @awaelchli @crcrpar

@crcrpar crcrpar requested a review from IvanYashchuk May 13, 2024 13:16
@crcrpar crcrpar force-pushed the crpa/tensor-parallel branch 2 times, most recently from 60844aa to f2278ed Compare May 13, 2024 16:53
@github-actions github-actions bot added the documentation Improvements or additions to documentation label May 13, 2024
@crcrpar crcrpar marked this pull request as ready for review May 13, 2024 18:04
@crcrpar crcrpar force-pushed the crpa/tensor-parallel branch 4 times, most recently from 25b2d14 to f724a88 Compare May 17, 2024 16:04
@crcrpar
Copy link
Collaborator Author

crcrpar commented May 17, 2024

The failures as of f724a88 look related to #432.

thunder/distributed/tensor_parallel.py Outdated Show resolved Hide resolved
thunder/distributed/tensor_parallel.py Outdated Show resolved Hide resolved
@crcrpar crcrpar changed the title [distributed][Tensor Parallelism] Implement Column-wise Linear [distributed][Tensor Parallelism] Implement early transform for Column-wise Parallel May 20, 2024
@crcrpar crcrpar marked this pull request as draft May 23, 2024 07:07
@crcrpar crcrpar changed the title [distributed][Tensor Parallelism] Implement early transform for Column-wise Parallel [distributed][Tensor Parallelism] Implement early transforms for column-wise and row-wise linear and embedding May 24, 2024
@crcrpar crcrpar marked this pull request as ready for review May 28, 2024 12:48
Copy link
Collaborator

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work @crcrpar

Mostly nitpicks, it was great fun reviewing this

thunder/distributed/__init__.py Outdated Show resolved Hide resolved
thunder/distributed/__init__.py Outdated Show resolved Hide resolved
thunder/distributed/__init__.py Outdated Show resolved Hide resolved
thunder/distributed/prims.py Outdated Show resolved Hide resolved
thunder/distributed/prims.py Outdated Show resolved Hide resolved
thunder/distributed/tensor_parallel/row_wise.py Outdated Show resolved Hide resolved
thunder/executors/torchex.py Outdated Show resolved Hide resolved
thunder/tests/distributed/test_ddp.py Outdated Show resolved Hide resolved
thunder/tests/distributed/test_ddp.py Outdated Show resolved Hide resolved
thunder/tests/distributed/test_ddp.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supergood, LGTM. Very excited to see the test_tensor_parallel_both_column_and_row be the first actual example of composing early_transforms!
I added a few minor nits.

thunder/distributed/prims.py Outdated Show resolved Hide resolved
thunder/distributed/prims.py Show resolved Hide resolved
thunder/executors/torchex.py Show resolved Hide resolved
thunder/distributed/tensor_parallel/column_wise.py Outdated Show resolved Hide resolved
thunder/distributed/tensor_parallel/column_wise.py Outdated Show resolved Hide resolved
thunder/distributed/prims.py Outdated Show resolved Hide resolved
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
This is to avoid passing preprocessed input into another ops while they
are supposed to take the original input.

For example, suppose we have two embeddings and one of them is
column-parallel and the other not, the previous implementation modified
the input regardless of embedding's parallelism.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
crcrpar added 15 commits May 30, 2024 07:38
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@t-vi t-vi enabled auto-merge (squash) May 31, 2024 10:50
Copy link
Collaborator

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Ship it! 🚀

@t-vi t-vi merged commit 9107a3d into main May 31, 2024
37 checks passed
@t-vi t-vi deleted the crpa/tensor-parallel branch May 31, 2024 16:47
@crcrpar crcrpar added the tensor parallel distributed - tensor parallel label Jun 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
distributed documentation Improvements or additions to documentation tensor parallel distributed - tensor parallel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants