-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support distributed Adam with T5 and support overlapped grad reductions with pipeline parallelism #4900
Support distributed Adam with T5 and support overlapped grad reductions with pipeline parallelism #4900
Conversation
If params are bf16, dist Adam will only store 16-bit remainder needed to reconstruct fp32 params. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Requires dist Adam optimizer. Signed-off-by: Tim Moon <tmoon@nvidia.com>
This pull request introduces 1 alert and fixes 1 when merging 065a89b into abbe643 - view on LGTM.com new alerts:
fixed alerts:
|
Signed-off-by: Tim Moon <tmoon@nvidia.com>
…async-grad-reduction
This pull request introduces 1 alert and fixes 1 when merging 7943ebc into b9cf05c - view on LGTM.com new alerts:
fixed alerts:
|
Requires dist Adam optimizer Signed-off-by: Tim Moon <tmoon@nvidia.com>
This pull request fixes 1 alert when merging e06d34a into b9cf05c - view on LGTM.com fixed alerts:
|
Running T5 41B on 32 Selene nodes, I see a 1.2x speedup over the pure data-parallel impl, 66% of expected memory savings, and nearly identical loss values after 20 steps. Full results with T5 41B and GPT-3 175B. The run configurations are detailed inside. Note that I ran with a relatively small global batch size, which makes communication a more significant portion of runtime. |
…tion Signed-off-by: Tim Moon <tmoon@nvidia.com>
This pull request fixes 1 alert when merging d528a89 into f1825bc - view on LGTM.com fixed alerts:
|
Signed-off-by: Tim Moon <tmoon@nvidia.com>
This pull request fixes 1 alert when merging 811b59c into f1825bc - view on LGTM.com fixed alerts:
|
…m-pipeline-parallel-async-grad-reduction
This pull request fixes 1 alert when merging ebd98c4 into 971485c - view on LGTM.com fixed alerts:
|
This pull request fixes 1 alert when merging b2a61ad into 73fcfd7 - view on LGTM.com fixed alerts:
|
Changes were made to support pipeline parallelism with interleaved pipeline parallelism. Distributed Adam does not support this currently. Signed-off-by: Tim Moon <tmoon@nvidia.com>
2638974
to
4ef0255
Compare
for more information, see https://pre-commit.ci
This pull request fixes 1 alert when merging a088304 into 8656574 - view on LGTM.com fixed alerts:
|
nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py
Show resolved
Hide resolved
nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py
Outdated
Show resolved
Hide resolved
nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Outdated
Show resolved
Hide resolved
Review suggestions from @ericharper and @crcrpar. Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Running on a DGX A100 node for 50 steps with 2-way data, tensor, and pipeline parallelism, I see nearly identical learning behavior with and without the distributed optimizer:
I get runtime failures when I run GPT-2 with FP32 and with pipeline parallelism enabled. This error shows up in the |
This pull request fixes 1 alert when merging aed0e00 into 85fc659 - view on LGTM.com fixed alerts:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
…enabled Signed-off-by: Tim Moon <tmoon@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re-approving
This pull request fixes 1 alert when merging 190f992 into 0336000 - view on LGTM.com fixed alerts:
|
With NVIDIA/apex#1514 the distributed optimizer supports interleaved pipeline parallelism. Running GPT-2 124M for 20 steps, I get the same loss values with and without the distributed optimizer. |
…ns with pipeline parallelism (NVIDIA#4900) * Avoid storing extra copy of params in dist Adam optimizer If params are bf16, dist Adam will only store 16-bit remainder needed to reconstruct fp32 params. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add support for dist Adam in GPT-3 without O2-level AMP Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add support for dist Adam in Megatron-LM models Signed-off-by: Tim Moon <tmoon@nvidia.com> * Debug dist Adam support without Megatron AMP O2 Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add support for overlapped grad sync with pipeline parallelism in GPT-3 Requires dist Adam optimizer. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Debug dist Adam support for T5 Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add support for overlapped grad sync with pipeline parallelism in T5 Requires dist Adam optimizer Signed-off-by: Tim Moon <tmoon@nvidia.com> * Update Apex commits in Dockerfile and Jenkinsfile Signed-off-by: Tim Moon <tmoon@nvidia.com> * Support distributed Adam in Megatron grad scaler class. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Update dist Adam to accommodate changes in GPT model Changes were made to support pipeline parallelism with interleaved pipeline parallelism. Distributed Adam does not support this currently. Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor tweaks to dist Adam integration Review suggestions from @ericharper and @crcrpar. Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove error when dist Adam and interleaved pipeline parallelism are enabled Signed-off-by: Tim Moon <tmoon@nvidia.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Harper <complex451@gmail.com> Signed-off-by: 1-800-bad-code <shane.carroll@utsa.edu>
…ns with pipeline parallelism (NVIDIA#4900) * Avoid storing extra copy of params in dist Adam optimizer If params are bf16, dist Adam will only store 16-bit remainder needed to reconstruct fp32 params. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add support for dist Adam in GPT-3 without O2-level AMP Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add support for dist Adam in Megatron-LM models Signed-off-by: Tim Moon <tmoon@nvidia.com> * Debug dist Adam support without Megatron AMP O2 Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add support for overlapped grad sync with pipeline parallelism in GPT-3 Requires dist Adam optimizer. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Debug dist Adam support for T5 Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add support for overlapped grad sync with pipeline parallelism in T5 Requires dist Adam optimizer Signed-off-by: Tim Moon <tmoon@nvidia.com> * Update Apex commits in Dockerfile and Jenkinsfile Signed-off-by: Tim Moon <tmoon@nvidia.com> * Support distributed Adam in Megatron grad scaler class. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Update dist Adam to accommodate changes in GPT model Changes were made to support pipeline parallelism with interleaved pipeline parallelism. Distributed Adam does not support this currently. Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor tweaks to dist Adam integration Review suggestions from @ericharper and @crcrpar. Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove error when dist Adam and interleaved pipeline parallelism are enabled Signed-off-by: Tim Moon <tmoon@nvidia.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Harper <complex451@gmail.com> Signed-off-by: Hainan Xu <hainanx@nvidia.com>
…ns with pipeline parallelism (NVIDIA#4900) * Avoid storing extra copy of params in dist Adam optimizer If params are bf16, dist Adam will only store 16-bit remainder needed to reconstruct fp32 params. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add support for dist Adam in GPT-3 without O2-level AMP Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add support for dist Adam in Megatron-LM models Signed-off-by: Tim Moon <tmoon@nvidia.com> * Debug dist Adam support without Megatron AMP O2 Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add support for overlapped grad sync with pipeline parallelism in GPT-3 Requires dist Adam optimizer. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Debug dist Adam support for T5 Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add support for overlapped grad sync with pipeline parallelism in T5 Requires dist Adam optimizer Signed-off-by: Tim Moon <tmoon@nvidia.com> * Update Apex commits in Dockerfile and Jenkinsfile Signed-off-by: Tim Moon <tmoon@nvidia.com> * Support distributed Adam in Megatron grad scaler class. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Update dist Adam to accommodate changes in GPT model Changes were made to support pipeline parallelism with interleaved pipeline parallelism. Distributed Adam does not support this currently. Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor tweaks to dist Adam integration Review suggestions from @ericharper and @crcrpar. Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove error when dist Adam and interleaved pipeline parallelism are enabled Signed-off-by: Tim Moon <tmoon@nvidia.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Harper <complex451@gmail.com> Signed-off-by: Hainan Xu <hainanx@nvidia.com>
What does this PR do ?
Generalize distributed Adam support for GPT-3 to T5 and other Megatron-LM models. It also implements several performance optimizations.
Collection: NLP
Changelog
Usage
Set optimizer to
distributed_fused_adam
in config file:NeMo/examples/nlp/language_modeling/conf/megatron_t5_config.yaml
Line 137 in 265f7b1
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information