-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Set allocation domain of sharded tensor #2271
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it safe now to replace https://github.com/NVIDIA/Fuser/pull/2245/files#diff-db5ba7cef14ad9a3c1eaab113a6f0a6f875e92890b34078f8185e0970022ce45R298 with a check for contiguity? Although there's little performance penalty to call .contiguous
on an already contiguous tensor, leaving it there would give readers the wrong impression that the input slices can be non-contiguous, which IIRC shouldn't happen with this PR.
Yup the added |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some nitpick, but overall looks straightforward to me.
I'll let @wujingyue stamp since he seems to have more questions. (feel free to delegate it to me for a more thorough review if you don't have time to wrap it up before your vacation 🍹 )
!build |
FYI, be extra careful with CI for PRs touching |
Sorry everyone who viewed before the big changes 😓 |
!build --dist |
!build |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty clean to me now. stamping.
csrc/multidevice/utils.cpp
Outdated
for (auto tv : ir_utils::filterByType<TensorView>(expr->inputs())) { | ||
for (auto c : tv->getContiguity()) { | ||
if (c.has_value()) { | ||
NVF_CHECK( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused. I thought setShardedAllocationDomain ought to make inputs of resharding exprs contiguous rather than expect them to be contiguous. Am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
^^^ that doesn't sound right. You cannot change contiguity / stride order on inputs.
IIUC, the code here validates that the input entry is contiguous and then later explicitly sets allocation domain on each if they are implicit.
One part I'm not totally sure is, Resharding expression input must be contiguous
. Should this check also apply to TensorViews which we are not specifying allocation domain?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You cannot change contiguity / stride order on inputs.
You are right for fusion inputs. However, resharding Expr's inputs usually have more flexibility.
Thinking more about this, maybe the logic should belong to somewhere near
Fuser/csrc/multidevice/utils.cpp
Lines 393 to 395 in 690134d
TensorView* input_permute = permute(input, {{sharding_axis, 0}}); | |
TensorView* output_permute = set(input_permute); | |
TensorView* new_output = permute(output_permute, {{0, sharding_axis}}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ProcessGroup only accept contiguous tensors (because nccl and ucc only deal with contiguous buffers, passed as void pointers). So for now it is reasonable to only support contiguous tensors. Later, we could add support for non-contiguous tensors, but to do that we'll have no choice but to make as many process group call as there a contiguous components.
So imo this "assert" makes sense for this pr and we could remove it later by implementing what I described above
@wujingyue this line
Fuser/csrc/multidevice/utils.cpp
Lines 393 to 395 in 690134d
TensorView* input_permute = permute(input, {{sharding_axis, 0}}); | |
TensorView* output_permute = set(input_permute); | |
TensorView* new_output = permute(output_permute, {{0, sharding_axis}}); |
is a "trick" to allow non-oputermost resharding while avoiding non-contiguous buffers, by reordering the axis to place the sharded axis at outermost position (therefore the actual buffer is contiguous in memory)
insertReshardings(&fusion); | ||
insertShardedAxisReordering(&fusion); | ||
setShardedAllocationDomain(&fusion); | ||
for (auto expr : fusion.exprs()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do you like making the check more specific? For example, is it possible to check the following:
- I'd expect there's only one resharding Expr, which is a sum
- I'd also expect the input of that Expr has DID as the first IterDomain in the containing allocation domain.
I won't have time to take another look before my vacation. As long as you and @jjsjann123 are confident in the change, please merge it without me. Anyhow, this is a strict improvement and my remaining comments are about potentially making allocation domains around resharding more correct. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Thanks!
csrc/multidevice/utils.cpp
Outdated
for (auto tv : ir_utils::filterByType<TensorView>(expr->inputs())) { | ||
for (auto c : tv->getContiguity()) { | ||
if (c.has_value()) { | ||
NVF_CHECK( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ProcessGroup only accept contiguous tensors (because nccl and ucc only deal with contiguous buffers, passed as void pointers). So for now it is reasonable to only support contiguous tensors. Later, we could add support for non-contiguous tensors, but to do that we'll have no choice but to make as many process group call as there a contiguous components.
So imo this "assert" makes sense for this pr and we could remove it later by implementing what I described above
@wujingyue this line
Fuser/csrc/multidevice/utils.cpp
Lines 393 to 395 in 690134d
TensorView* input_permute = permute(input, {{sharding_axis, 0}}); | |
TensorView* output_permute = set(input_permute); | |
TensorView* new_output = permute(output_permute, {{0, sharding_axis}}); |
is a "trick" to allow non-oputermost resharding while avoiding non-contiguous buffers, by reordering the axis to place the sharded axis at outermost position (therefore the actual buffer is contiguous in memory)
csrc/multidevice/utils.cpp
Outdated
expr); | ||
} | ||
} | ||
setShardedAllocationDomain(tv); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can somebody please explain why this is necessary? IIUC, we confirm this tensor is contiguous. Isn't that sufficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to explicitly set the allocation domain to avoid optimization passes mutating it. (i.e. empty allocation domain means a fair game for optimization passes).
i.e. allocation order inference might came in and change the stride order, if it's left empty. Which would trigger scheduling error because we cannot yet support stride order for resharding operations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, this is not necessary if the allocation order inference is not done?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so.
For the record, it's not just allocation order inference, alias passes could also update allocation domain introducing similar issue. See Jingyue's comment in #2245 (comment)
cc'ing @cowanmeg for sanity check
…llocation_domain
!build |
1 similar comment
!build |
I see an unrelated tolerance error in |
Sets allocation domain of sharded tensors during the pass `propagateShardingsAndSetAllocationDomain`. The two passes are merged in attempt to reduce the number of passes over all expressions in the fusion. Allocation domain is set to the tv's leaf domain. Since presegmentation passes and scheduling occur after the sharding passes, the leaf domain is identical to the rfact domain. After DID parallelization of the leaf domain is allowed the leaf and rfactor domain will not be the same. This will avoid issues such as #2245 (comment) and allow the `AllocationDomainPass` presegmentation pass on for distributed matmul tests
Sets allocation domain of sharded tensors during the pass
propagateShardingsAndSetAllocationDomain
.The two passes are merged in attempt to reduce the number of passes over all expressions in the fusion.
Allocation domain is set to the tv's leaf domain. Since presegmentation passes and scheduling occur after the sharding passes, the leaf domain is identical to the rfact domain. After DID parallelization of the leaf domain is allowed the leaf and rfactor domain will not be the same.
This will avoid issues such as #2245 (comment) and allow the
AllocationDomainPass
presegmentation pass on for distributed matmul tests