Skip to content

Commit

Permalink
Fix overlap communication of ZeRO stage 1 and 2 (microsoft#5606)
Browse files Browse the repository at this point in the history
`deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer.average_tensor`
only sets reduction stream waiting for default stream. This is ok in
cases where the computation time is longer than the communication time,
but when the communication time is longer, it may result in a rewrite of
the ipg_buffer when the communication is not completed.



![image](https://github.com/microsoft/DeepSpeed/assets/35059704/950cbf8a-f439-4cf9-a364-dcdfd47f46a0)



To fix this bug, the easiest way is just add default stream to wait for
reduction stream at the **same point**. For example, in point 1, the
`reduction stream` needs to wait for '2', so we add a wait_stream to
`reduction stream` waiting for `default stream`. Also, the `default
stream` needs to wait for 'A', so we need to add a wait_stream to
`default stream` waiting for `reduction stream` before the 'B'.


![image](https://github.com/microsoft/DeepSpeed/assets/35059704/588a9469-d3f9-4c39-976d-3ae0502cf1d1)



Compared with the modification of
microsoft#5523, wait_stream does not
cause host synchronization.

Compared with the modification of
microsoft#5545, the modification is
more simple and the logic is the same, just waiting for what needs to
wait.

---

With this modification, losses of Qwen-1.5 with and without overlap_comm
are totally identical.


![image](https://github.com/microsoft/DeepSpeed/assets/35059704/4d48d54e-e55b-4230-8b99-93549910a43f)

---

On the contrary, there is an obvious gap with a small sequence length,
which means a short computation time.


![image](https://github.com/microsoft/DeepSpeed/assets/35059704/c80af498-3358-4e36-9b13-8f266551d51d)

Co-authored-by: gp513 <guopeng34@huawei.com>
Co-authored-by: CurryRice233 <nmeia@qq.com>
Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
  • Loading branch information
5 people authored and sfc-gh-reyazda committed Jun 10, 2024
1 parent 8a4d03c commit 5e5b1f7
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,7 @@ def average_tensor(self, tensor):
stream = self.reduction_stream
if not get_accelerator().resolves_data_dependency():
stream.wait_stream(get_accelerator().current_stream())
get_accelerator().current_stream().wait_stream(stream)
else:
stream = get_accelerator().current_stream()

Expand Down

0 comments on commit 5e5b1f7

Please sign in to comment.