Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send / Recv Order #156

Closed
merrymercy opened this issue Nov 3, 2021 · 3 comments
Closed

Send / Recv Order #156

merrymercy opened this issue Nov 3, 2021 · 3 comments
Assignees
Labels
wontfix This will not be worked on

Comments

@merrymercy
Copy link
Member

merrymercy commented Nov 3, 2021

We want to schedule the order of send/recv when a stage wants to send tensors to multiple stages.

Principle:

  • Sender: Send to earlier stages first
  • Receiver: Receive earlier stages first

Need to think twice

@ZYHowell
Copy link
Collaborator

ZYHowell commented Nov 4, 2021

Not only the order with multiple sender/receiver is a problem, but the order between send and receive is a problem, due to our current implementation that both use the default stream.
The below is a simple explanation:
Suppose that a send/recv takes 1 unit time, and a computation takes k=2 units(in real bert layer, k >>1).
Denote microbatch by 1,2..., and communication between mesh i and mesh j by cij. If a unit time is occupied by c12 in mesh1, the same happens in mesh2. All communication and computation with same microbatch have same color.
Denote mesh by M1,M2...

Then if we use the principle that "for each mesh, always deal with prior stage first, whether send or recv", it looks like:
image
That is, to enable computation on Mesh4, c34 should be done, but c34 is after c23 on Mesh3, and c23 is after c12 on Mesh2. In consequence, Mesh i waits all its prior communication to be done(i units of time).
A better solution should be like:
image
That is, in the first unit, we run C12, C34, C56..., then in the second we run C23, C45...In case of that, only 2 units are required no matter how many meshes we have.

@ZYHowell
Copy link
Collaborator

ZYHowell commented Nov 4, 2021

In our current case, that only embedding causes multiple sender/receiver for a mesh, the order of send/recv only has little influence if #157 is addressed.
image
In the last column of the timeline above, the top timeline is slightly better than the bottom one because of a tricky communication c23. It is only 1 communication faster because even the first microbatch is much faster in Mesh3 and later meshes, it will finally waits to receive results for the second microbatch. As a result, only the time of one c23 is reduced.

@ZYHowell
Copy link
Collaborator

ZYHowell commented Nov 4, 2021

In a general case that for each microbatch, there are some communications c_ij where j is neither i+1 nor i-1, which is frequent in U-Net, to always send/recv earlier stages are not enough.
Let me give an example still with notations above, but let the communication c14 happens for each microbatch:
There are two policies. The first is to let c14 before c12 while the second is c12 before c14.
image
The result is that to let c14 before c12 can have slightly better performance.

However, when we extend it into 6 stages, and (1,6), (2,5) have extra communications, the situation is totally different because Mesh2 is too busy: it has three communications(C12, C23, C25), so the communication influences other meshes and creates bubbles.
image

As mentioned in the graph, a swap between C25 and C23 makes the pipeline tighter, which is in contrast with the first example who always send/recv with later stages first.

@ZYHowell ZYHowell added the wontfix This will not be worked on label Nov 15, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants