-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support pure fp16 for gpt static #1353
Support pure fp16 for gpt static #1353
Conversation
…nto dev/support_gpt3_static_fp16
@@ -329,7 +329,8 @@ def build_dataset(index, name, num_samples): | |||
sample_ids=sample_ids, | |||
sample_lens=sample_lens, | |||
eos_id=eos_id, | |||
seed=args.seed) | |||
seed=args.seed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接 copy 一份到 gpt-3/static 吧。不然跟 gpt/run_pretrain.py不兼容
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
if args.grad_clip > 0: | ||
clip = paddle.fluid.clip.GradientClipByGlobalNorm( | ||
if args.grad_clip > 0: | ||
clip = paddle.fluid.clip.GradientClipByNorm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同之前,此处不改。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -20,6 +20,7 @@ | |||
import random | |||
import time | |||
import sys | |||
from paddle.fluid import core |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
放到跟 paddle import 那边一起?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 core 好像没有使用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -357,6 +351,11 @@ def do_train(args): | |||
exe = paddle.static.Executor(place) | |||
exe.run(startup_program) | |||
test_program = main_program.clone(for_test=True) | |||
|
|||
|
|||
if args.use_amp and args.use_fp16: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块的 逻辑需要清晰一下:
这样可行,可以讨论一下,形成以后的通用写法
- use_fp16 [True False]
- fp16_level ["amp", "pure_fp16"]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
得横向参考下同类产品怎么定义的,最好跟其他主流用法对齐
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有个问题,这里paddle的静态图都是用amp和pure_fp16的概念进行区分,其混合精度接口也是用pure_fp16进行定义
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改为:
-- use_amp [True, False]
-- amp_level ["O1", "O2"]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里将 gpt-3 动态图 参数方式也统一下?
TODO: 如果修改的话,benchmark 相关脚本需要更新
https://github.com/PaddlePaddle/PaddleNLP/blob/develop/tests/benchmark/run_benchmark.sh#L59
https://github.com/PaddlePaddle/benchmark/blob/master/dynamic_graph/gpt/paddle/run_benchmark.sh#L61
PaddlePaddle/benchmark 的修改可以我来,PaddleNLP/blob/develop/tests/benchmark/run_benchmark.sh
可以顺手改了。
@@ -227,11 +228,15 @@ def forward(self, | |||
# scale dot product attention | |||
product = layers.matmul( | |||
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) | |||
|
|||
fuse = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改为可配置比较好
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
动态图直接使用了softmax_mask_fuse_upper_triangle,这里静态图也直接使用这个策略了。
如果有前后速度对比的测试结果,可以在PR介绍中贴一下。 |
@@ -23,6 +23,7 @@ | |||
from paddle.fluid import layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dataset.py 文件给不了 comments。但是这里应该是要上去两个 层级的目录,否者 data_tools目录的路径不对
# Used to load data_tools path.
sys.path.insert(0, "../")
->
# Used to load data_tools path.
sys.path.insert(0, "../../")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
|
||
if args.use_amp and args.amp_level=="O2": | ||
optimizer.amp_init(place) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里方便科普一下使用吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的amp_init,在使用方法上,感觉有一点点奇怪
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pure fp16需要将网络参数从fp32转为fp16,amp_init就是用于进行参数类型转换的:
https://github.com/PaddlePaddle/Paddle/blob/ed7a21dea0ddcffb6f7f33ce21c5c368f5c7866b/python/paddle/fluid/contrib/mixed_precision/decorator.py#L207
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc: @zhiqiu 这块API的使用要不在讨论讨论,感觉很可能用户会漏写
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个目前是需要使用的接口,后面再看看能否优化下。一个可能的途径是decorate的时候把startup也传入,转为fp16。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
行,那这里可以记一个 todo 吧 @zhangbo9674
@@ -329,7 +329,8 @@ def build_dataset(index, name, num_samples): | |||
sample_ids=sample_ids, | |||
sample_lens=sample_lens, | |||
eos_id=eos_id, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处代码复原
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
…nto dev/support_gpt3_static_fp16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…nto dev/support_gpt3_static_fp16
…ngbo9674/PaddleNLP into dev/support_gpt3_static_fp16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Performance optimization
PR changes
Models
Description
Support pure fp16 for gpt static.
Speed Test:
![图片](https://user-images.githubusercontent.com/82555433/143164127-60aa27ea-ee73-4c2d-af40-44304527cca7.png)
![图片](https://user-images.githubusercontent.com/82555433/143164227-fd772778-febc-4af9-a337-a5dd5e1a2da7.png)
enviroment : V100-32G-6, clip is ClipByNorm, global_batch_size=8, use_recompute=false.
amp - O1:
amp - O2: