Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hybrid performance] Optimize pipeline send wait #34086

Merged
merged 2 commits into from Jul 12, 2021

Conversation

wangxicoding
Copy link
Contributor

@wangxicoding wangxicoding commented Jul 11, 2021

PR types

Performance optimization

PR changes

Others

Describe

1、添加nop op,不作任何操作。主要作用有(静态图下有用,动态图不需要):

  • 占住变量防止显存释放。可用于在pipeline中占住send的变量,等待send发送完成
  • 添加拓扑依赖。可用于recompute中重计算块添加与反向op的拓扑依赖

2、优化pipeline前向send的wait_comm,在保证完成send的情况下,使用nop替换sync_comm,减少不必要的同步。
如下图,根据执行顺序可知,若当前stage反向recv完成,那么前向的send也一定完成了。则该场景下,send可以不需要加同步;
但send使用的变量如果没有op使用它,则会被gc回收,在send使用通信流场景下就会出错;所以为解决这个问题,在反向recv后面添加一个nop op,确保在recv完之后对应的变量才被回收。
image

V100 32G,gpt2-en模型测试

  • NVLINK下
卡数 优化 dtype speed(tokens/s) (S) 提升
4卡pp baseline fp32 25858  
  fp16 64029  
4卡pp send wait优化 fp32 25917 0.22%
  fp16 65005 1.52%
  • SHM下
卡数 优化 dtype speed(tokens/s) (S) 提升
4卡pp baseline fp32 25306  
  fp16 59507  
4卡pp send wait优化 fp32 25382 0.3%
  fp16 60772 2.1%

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@wangxicoding wangxicoding changed the title Optimize pp send wait [hybrid performance] Optimize pipeline send wait Jul 12, 2021
Copy link

@sandyhouse sandyhouse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for op DataType registeration.

Copy link
Contributor

@JZ-LIANG JZ-LIANG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in pipeline case. But for recompute scenario, the major bottleneck of training is the GPU memory instead of the concurrency of computation (backward & recompute). Allowing concurrency of recompute might influence the max bsz recompute could reach.

@wangxicoding wangxicoding merged commit 5f65ff9 into PaddlePaddle:develop Jul 12, 2021
@wangxicoding wangxicoding deleted the optimize_pp_send_wait branch July 12, 2021 12:05
assert dev_type == "gpu" or dev_type == 'npu', (
"Now only gpu and npu devices are supported "
"for pipeline parallelism.")
if not device in device_list:

if device not in device_list:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个的原因是啥?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants