Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pure fp16 for gpt static #1353

Merged
merged 24 commits into from
Nov 29, 2021

Conversation

zhangbo9674
Copy link
Contributor

@zhangbo9674 zhangbo9674 commented Nov 23, 2021

PR types

Performance optimization

PR changes

Models

Description

Support pure fp16 for gpt static.

Speed Test:
enviroment : V100-32G-6, clip is ClipByNorm, global_batch_size=8, use_recompute=false.
amp - O1:
图片
amp - O2:
图片

@zhangbo9674 zhangbo9674 changed the title Support gpt-3 for static purefp16 Support pure fp16 for gpt static Nov 23, 2021
@@ -329,7 +329,8 @@ def build_dataset(index, name, num_samples):
sample_ids=sample_ids,
sample_lens=sample_lens,
eos_id=eos_id,
seed=args.seed)
seed=args.seed,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接 copy 一份到 gpt-3/static 吧。不然跟 gpt/run_pretrain.py不兼容

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if args.grad_clip > 0:
clip = paddle.fluid.clip.GradientClipByGlobalNorm(
if args.grad_clip > 0:
clip = paddle.fluid.clip.GradientClipByNorm(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同之前,此处不改。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -20,6 +20,7 @@
import random
import time
import sys
from paddle.fluid import core
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放到跟 paddle import 那边一起?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 core 好像没有使用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -357,6 +351,11 @@ def do_train(args):
exe = paddle.static.Executor(place)
exe.run(startup_program)
test_program = main_program.clone(for_test=True)


if args.use_amp and args.use_fp16:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的 逻辑需要清晰一下:
这样可行,可以讨论一下,形成以后的通用写法

- use_fp16 [True False]
- fp16_level ["amp", "pure_fp16"]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

得横向参考下同类产品怎么定义的,最好跟其他主流用法对齐

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有个问题,这里paddle的静态图都是用amp和pure_fp16的概念进行区分,其混合精度接口也是用pure_fp16进行定义

Copy link
Contributor Author

@zhangbo9674 zhangbo9674 Nov 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改为:
-- use_amp [True, False]
-- amp_level ["O1", "O2"]

Copy link
Collaborator

@ZHUI ZHUI Nov 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里将 gpt-3 动态图 参数方式也统一下?

TODO: 如果修改的话,benchmark 相关脚本需要更新
https://github.com/PaddlePaddle/PaddleNLP/blob/develop/tests/benchmark/run_benchmark.sh#L59
https://github.com/PaddlePaddle/benchmark/blob/master/dynamic_graph/gpt/paddle/run_benchmark.sh#L61

PaddlePaddle/benchmark 的修改可以我来,PaddleNLP/blob/develop/tests/benchmark/run_benchmark.sh 可以顺手改了。

@@ -227,11 +228,15 @@ def forward(self,
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)

fuse = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改为可配置比较好

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

动态图直接使用了softmax_mask_fuse_upper_triangle,这里静态图也直接使用这个策略了。

@ZHUI
Copy link
Collaborator

ZHUI commented Nov 24, 2021

如果有前后速度对比的测试结果,可以在PR介绍中贴一下。

@@ -23,6 +23,7 @@
from paddle.fluid import layers
Copy link
Collaborator

@ZHUI ZHUI Nov 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset.py 文件给不了 comments。但是这里应该是要上去两个 层级的目录,否者 data_tools目录的路径不对

# Used to load data_tools path.
sys.path.insert(0, "../")

->

# Used to load data_tools path.
sys.path.insert(0, "../../")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.



if args.use_amp and args.amp_level=="O2":
optimizer.amp_init(place)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里方便科普一下使用吗?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的amp_init,在使用方法上,感觉有一点点奇怪

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pure fp16需要将网络参数从fp32转为fp16,amp_init就是用于进行参数类型转换的:
https://github.com/PaddlePaddle/Paddle/blob/ed7a21dea0ddcffb6f7f33ce21c5c368f5c7866b/python/paddle/fluid/contrib/mixed_precision/decorator.py#L207

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @zhiqiu 这块API的使用要不在讨论讨论,感觉很可能用户会漏写

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个目前是需要使用的接口,后面再看看能否优化下。一个可能的途径是decorate的时候把startup也传入,转为fp16。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

行,那这里可以记一个 todo 吧 @zhangbo9674

@@ -329,7 +329,8 @@ def build_dataset(index, name, num_samples):
sample_ids=sample_ids,
sample_lens=sample_lens,
eos_id=eos_id,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处代码复原

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ZHUI
ZHUI previously approved these changes Nov 29, 2021
Copy link
Collaborator

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZHUI ZHUI merged commit d49b6b1 into PaddlePaddle:develop Nov 29, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants