Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set allocation domain of sharded tensor #2271

Merged
merged 9 commits into from
May 30, 2024

Conversation

cowanmeg
Copy link
Collaborator

Sets allocation domain of sharded tensors during the pass propagateShardingsAndSetAllocationDomain.
The two passes are merged in attempt to reduce the number of passes over all expressions in the fusion.

Allocation domain is set to the tv's leaf domain. Since presegmentation passes and scheduling occur after the sharding passes, the leaf domain is identical to the rfact domain. After DID parallelization of the leaf domain is allowed the leaf and rfactor domain will not be the same.

This will avoid issues such as #2245 (comment) and allow the AllocationDomainPass presegmentation pass on for distributed matmul tests

@cowanmeg cowanmeg marked this pull request as draft May 20, 2024 15:03
@wujingyue wujingyue marked this pull request as ready for review May 20, 2024 16:46
Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe now to replace https://github.com/NVIDIA/Fuser/pull/2245/files#diff-db5ba7cef14ad9a3c1eaab113a6f0a6f875e92890b34078f8185e0970022ce45R298 with a check for contiguity? Although there's little performance penalty to call .contiguous on an already contiguous tensor, leaving it there would give readers the wrong impression that the input slices can be non-contiguous, which IIRC shouldn't happen with this PR.

@cowanmeg
Copy link
Collaborator Author

Is it safe now to replace https://github.com/NVIDIA/Fuser/pull/2245/files#diff-db5ba7cef14ad9a3c1eaab113a6f0a6f875e92890b34078f8185e0970022ce45R298 with a check for contiguity? Although there's little performance penalty to call .contiguous on an already contiguous tensor, leaving it there would give readers the wrong impression that the input slices can be non-contiguous, which IIRC shouldn't happen with this PR.

Yup the added PipelineTwoStage tests pass when removing the contiguous call, but 'PipelineTestStagedReduction.StagedReduction/ReductionOnly' is failing. I would hold off on reviewing until I fix!

csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
tests/cpp/test_sharding.cpp Outdated Show resolved Hide resolved
csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some nitpick, but overall looks straightforward to me.

I'll let @wujingyue stamp since he seems to have more questions. (feel free to delegate it to me for a more thorough review if you don't have time to wrap it up before your vacation 🍹 )

csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
tests/cpp/test_multidevice_matmul.cpp Outdated Show resolved Hide resolved
csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
@jjsjann123
Copy link
Collaborator

!build

@jjsjann123
Copy link
Collaborator

FYI, be extra careful with CI for PRs touching allocation domain. Some of our schedulers have some sharp corners on this and could give you some unexpected failures.

@cowanmeg
Copy link
Collaborator Author

Sorry everyone who viewed before the big changes 😓
The new update only sets the allocation domain of TVs that are in a resharding expression. This fixed most of the tests, but one DistributedMatmulTest.
There are some big changes to contiguity that this doesn't address, namely that DID axes continue to have true/false settings while they shouldn't have a value since they aren't allocated. I'll fix that in a later PR!
Also, this isn't blocking anything so feel free to review after you get back @wujingyue

@cowanmeg
Copy link
Collaborator Author

!build --dist

@wujingyue wujingyue self-requested a review May 20, 2024 21:52
@cowanmeg
Copy link
Collaborator Author

!build

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty clean to me now. stamping.

csrc/multidevice/communication.cpp Show resolved Hide resolved
csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
csrc/multidevice/utils.cpp Show resolved Hide resolved
csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
for (auto tv : ir_utils::filterByType<TensorView>(expr->inputs())) {
for (auto c : tv->getContiguity()) {
if (c.has_value()) {
NVF_CHECK(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused. I thought setShardedAllocationDomain ought to make inputs of resharding exprs contiguous rather than expect them to be contiguous. Am I missing something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^^^ that doesn't sound right. You cannot change contiguity / stride order on inputs.

IIUC, the code here validates that the input entry is contiguous and then later explicitly sets allocation domain on each if they are implicit.

One part I'm not totally sure is, Resharding expression input must be contiguous . Should this check also apply to TensorViews which we are not specifying allocation domain?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You cannot change contiguity / stride order on inputs.

You are right for fusion inputs. However, resharding Expr's inputs usually have more flexibility.

Thinking more about this, maybe the logic should belong to somewhere near

TensorView* input_permute = permute(input, {{sharding_axis, 0}});
TensorView* output_permute = set(input_permute);
TensorView* new_output = permute(output_permute, {{0, sharding_axis}});
? That's where we actively reorder the resharded dimension to be outermost in rfactor however not enforcing it to be outermost in allocation?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ProcessGroup only accept contiguous tensors (because nccl and ucc only deal with contiguous buffers, passed as void pointers). So for now it is reasonable to only support contiguous tensors. Later, we could add support for non-contiguous tensors, but to do that we'll have no choice but to make as many process group call as there a contiguous components.

So imo this "assert" makes sense for this pr and we could remove it later by implementing what I described above

@wujingyue this line

TensorView* input_permute = permute(input, {{sharding_axis, 0}});
TensorView* output_permute = set(input_permute);
TensorView* new_output = permute(output_permute, {{0, sharding_axis}});

is a "trick" to allow non-oputermost resharding while avoiding non-contiguous buffers, by reordering the axis to place the sharded axis at outermost position (therefore the actual buffer is contiguous in memory)

tests/cpp/test_sharding.cpp Outdated Show resolved Hide resolved
insertReshardings(&fusion);
insertShardedAxisReordering(&fusion);
setShardedAllocationDomain(&fusion);
for (auto expr : fusion.exprs()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you like making the check more specific? For example, is it possible to check the following:

  • I'd expect there's only one resharding Expr, which is a sum
  • I'd also expect the input of that Expr has DID as the first IterDomain in the containing allocation domain.

@wujingyue
Copy link
Collaborator

I'll let @wujingyue stamp since he seems to have more questions. (feel free to delegate it to me for a more thorough review if you don't have time to wrap it up before your vacation 🍹 )

I won't have time to take another look before my vacation. As long as you and @jjsjann123 are confident in the change, please merge it without me. Anyhow, this is a strict improvement and my remaining comments are about potentially making allocation domains around resharding more correct.

Copy link
Collaborator

@samnordmann samnordmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Thanks!

csrc/multidevice/communication.cpp Show resolved Hide resolved
csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
for (auto tv : ir_utils::filterByType<TensorView>(expr->inputs())) {
for (auto c : tv->getContiguity()) {
if (c.has_value()) {
NVF_CHECK(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ProcessGroup only accept contiguous tensors (because nccl and ucc only deal with contiguous buffers, passed as void pointers). So for now it is reasonable to only support contiguous tensors. Later, we could add support for non-contiguous tensors, but to do that we'll have no choice but to make as many process group call as there a contiguous components.

So imo this "assert" makes sense for this pr and we could remove it later by implementing what I described above

@wujingyue this line

TensorView* input_permute = permute(input, {{sharding_axis, 0}});
TensorView* output_permute = set(input_permute);
TensorView* new_output = permute(output_permute, {{0, sharding_axis}});

is a "trick" to allow non-oputermost resharding while avoiding non-contiguous buffers, by reordering the axis to place the sharded axis at outermost position (therefore the actual buffer is contiguous in memory)

tests/cpp/test_multidevice_sharding.cpp Show resolved Hide resolved
expr);
}
}
setShardedAllocationDomain(tv);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can somebody please explain why this is necessary? IIUC, we confirm this tensor is contiguous. Isn't that sufficient?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to explicitly set the allocation domain to avoid optimization passes mutating it. (i.e. empty allocation domain means a fair game for optimization passes).

i.e. allocation order inference might came in and change the stride order, if it's left empty. Which would trigger scheduling error because we cannot yet support stride order for resharding operations

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this is not necessary if the allocation order inference is not done?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so.

For the record, it's not just allocation order inference, alias passes could also update allocation domain introducing similar issue. See Jingyue's comment in #2245 (comment)

cc'ing @cowanmeg for sanity check

@cowanmeg
Copy link
Collaborator Author

!build

1 similar comment
@cowanmeg
Copy link
Collaborator Author

!build

@cowanmeg
Copy link
Collaborator Author

I see an unrelated tolerance error in IndexingOpTest.TorchGatherSumAdd_CUDA, so will merge this.

@cowanmeg cowanmeg merged commit b60ea8a into NVIDIA:main May 30, 2024
36 of 37 checks passed
protonu pushed a commit that referenced this pull request May 30, 2024
Sets allocation domain of sharded tensors during the pass
`propagateShardingsAndSetAllocationDomain`.
The two passes are merged in attempt to reduce the number of passes over
all expressions in the fusion.

Allocation domain is set to the tv's leaf domain. Since presegmentation
passes and scheduling occur after the sharding passes, the leaf domain
is identical to the rfact domain. After DID parallelization of the leaf
domain is allowed the leaf and rfactor domain will not be the same.

This will avoid issues such as
#2245 (comment) and
allow the `AllocationDomainPass` presegmentation pass on for distributed
matmul tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants