Broadcast-based allgather in host for-loop#5925
Conversation
|
Review updated until commit 46c4698 Description
|
| Relevant files | |||||||||
|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||
| Bug fix |
| ||||||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Missing performance data
While this is acknowledged, it would be helpful to have some baseline performance numbers or at least an estimate of the expected performance gap. Please consider adding any available performance metrics or explaining why the current implementation is expected to be faster than the previous approach once integrated with multicast. |
|
!test |
This comment was marked as outdated.
This comment was marked as outdated.
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
|
!test |
Additional Comments (1)
Consider extending validation to require |
|
!test |
Additional Comments (1)
|
wujingyue
left a comment
There was a problem hiding this comment.
It's great to see this work functionally!
|
!test |
Additional Comments (3)
|
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
|
!test |
| if (auto* swizzle1d = dynamic_cast<Swizzle1D*>(transform)) { | ||
| if (swizzle1d->out()->isParallelized()) { | ||
| continue; | ||
| } | ||
| auto it = loop.erase(swizzle1d->out()).second; | ||
| loop.insert(it, swizzle1d->in(), std::monostate()); | ||
| continue; | ||
| } |
There was a problem hiding this comment.
Missing loop.contains guard for Swizzle1D
The Split case immediately below explicitly guards with if (!loop.contains(split->outer()) || !loop.contains(split->inner())) { continue; } before erasing. The rationale is that an intermediate ID can be absent from loop if it was already replaced by an earlier iteration (e.g., a downstream transform already "consumed" it in the reverse traversal).
The new Swizzle1D case lacks an equivalent guard. If swizzle1d->out() has already been replaced (because a transform downstream of the swizzle was processed first and removed it from loop), calling loop.erase(swizzle1d->out()) on an absent key returns a potentially invalid iterator, and the subsequent loop.insert(it, swizzle1d->in(), ...) then inserts at the wrong position or inserts a duplicate. This diverges from the consistent defensive pattern used for Split.
| if (auto* swizzle1d = dynamic_cast<Swizzle1D*>(transform)) { | |
| if (swizzle1d->out()->isParallelized()) { | |
| continue; | |
| } | |
| auto it = loop.erase(swizzle1d->out()).second; | |
| loop.insert(it, swizzle1d->in(), std::monostate()); | |
| continue; | |
| } | |
| if (auto* swizzle1d = dynamic_cast<Swizzle1D*>(transform)) { | |
| if (swizzle1d->out()->isParallelized()) { | |
| continue; | |
| } | |
| if (!loop.contains(swizzle1d->out())) { | |
| continue; | |
| } | |
| auto it = loop.erase(swizzle1d->out()).second; | |
| loop.insert(it, swizzle1d->in(), std::monostate()); | |
| continue; | |
| } |
| # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. | ||
| h, t = 8192, 8192 |
There was a problem hiding this comment.
Copy-paste error in comment — wrong test name
The comment says this ports RowParallelLinear_Forward, but this benchmark exercises column_parallel_linear_forward (column-parallel, not row-parallel). The functional test test_column_parallel_linear_forward above correctly refers to ColumnAndSequenceParallelLinear_Forward.
| # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. | |
| h, t = 8192, 8192 | |
| # This is a port of CollectiveBasedOverlapTest.ColumnAndSequenceParallelLinear_Forward. |
|
!test |
|
!test |
| // Check if we are going from DID -> Stream, which is a ring allgather. | ||
| // This can be executed as a broadcast or send recvs, which is decided | ||
| // by the presence of a swizzle in the stream id definition. | ||
| if (c_logical_stream_id == p2c.at(p_logical_id)) { |
There was a problem hiding this comment.
When c_logical_stream_id is nullptr (the consumer has no stream-parallelized ID), the expression p2c.at(p_logical_id) is still evaluated to perform the pointer comparison. If p_logical_id happens not to be a key in p2c, this throws std::out_of_range before we even reach the existing Allgather/Gather path below.
The condition should short-circuit on the null check first:
| if (c_logical_stream_id == p2c.at(p_logical_id)) { | |
| if (c_logical_stream_id != nullptr && | |
| c_logical_stream_id == p2c.at(p_logical_id)) { |
This makes the intent explicit and avoids the map lookup entirely when there is no stream-parallel consumer ID.
The broadcast version is very slow so I am not comparing timings until we integrate this with multicast