-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[amp] refine transformer fp16 train #1574
[amp] refine transformer fp16 train #1574
Conversation
scaler = paddle.amp.GradScaler( | ||
init_loss_scaling=args.scale_loss) | ||
with paddle.amp.auto_cast(): | ||
with paddle.amp.auto_cast(custom_black_list={'scale', 'reduce_sum', 'elementwise_div'} if amp_level=='O2' else {}, level=amp_level): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
run pre-commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, tks!
@@ -71,6 +72,11 @@ def parse_args(): | |||
default=None, | |||
type=str, | |||
help="The eos token. It should be provided when use custom vocab_file. ") | |||
parser.add_argument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
若使用 arg parser 的方式新增 --num_workers
,注意还需要在:
tests/transformer/train.py
examples/machine_translation/transformer/predict.py
examples/machine_translation/transformer/static/train.py
examples/machine_translation/transformer/static/predict.py
examples/machine_translation/transformer/deploy/python/inference.py
examples/machine_translation/transformer/faster_transformer/encoder_decoding_predict.py
也需要补充下。
或是加在 yaml 文件:
examples/machine_translation/transformer/configs/transformer.base.yaml
examples/machine_translation/transformer/configs/transformer.big.yaml
28 行,Args for reader
的位置。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, tks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Performance optimization
PR changes
Models
Description
优化transformer动态图fp16速度
1、优化点:
Adam优化器使用multi_tensor策略
clear_grad使用set_to_zero=False策略
dataloader num_workers支持>0 :--num_workers
优化一些代码调用顺序
2、性能测试: