Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[amp] refine transformer fp16 train #1574

Merged
merged 9 commits into from
Jan 12, 2022

Conversation

zhangbo9674
Copy link
Contributor

@zhangbo9674 zhangbo9674 commented Jan 10, 2022

PR types

Performance optimization

PR changes

Models

Description

优化transformer动态图fp16速度

1、优化点:

  • Adam优化器使用multi_tensor策略

  • clear_grad使用set_to_zero=False策略

  • dataloader num_workers支持>0 :--num_workers

  • 优化一些代码调用顺序

2、性能测试:

图片

scaler = paddle.amp.GradScaler(
init_loss_scaling=args.scale_loss)
with paddle.amp.auto_cast():
with paddle.amp.auto_cast(custom_black_list={'scale', 'reduce_sum', 'elementwise_div'} if amp_level=='O2' else {}, level=amp_level):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run pre-commit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, tks!

@@ -71,6 +72,11 @@ def parse_args():
default=None,
type=str,
help="The eos token. It should be provided when use custom vocab_file. ")
parser.add_argument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

若使用 arg parser 的方式新增 --num_workers,注意还需要在:

  • tests/transformer/train.py
  • examples/machine_translation/transformer/predict.py
  • examples/machine_translation/transformer/static/train.py
  • examples/machine_translation/transformer/static/predict.py
  • examples/machine_translation/transformer/deploy/python/inference.py
  • examples/machine_translation/transformer/faster_transformer/encoder_decoding_predict.py

也需要补充下。

或是加在 yaml 文件:

  • examples/machine_translation/transformer/configs/transformer.base.yaml
  • examples/machine_translation/transformer/configs/transformer.big.yaml
    28 行,Args for reader 的位置。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, tks!

Copy link
Contributor

@FrostML FrostML left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@FrostML FrostML merged commit 1e4975d into PaddlePaddle:develop Jan 12, 2022
FrostML added a commit that referenced this pull request Jan 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants