Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support distributed Adam with T5 and support overlapped grad reductions with pipeline parallelism #4900

Conversation

timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Sep 7, 2022

What does this PR do ?

Generalize distributed Adam support for GPT-3 to T5 and other Megatron-LM models. It also implements several performance optimizations.

Collection: NLP

Changelog

  • When params are BF16, distributed Adam will store 16-bit param remainders instead of FP32 main params
  • Decouples distributed Adam support from Megatron O2-level optizations
  • Add support for Apex distributed Adam optimizer with other Megatron-LM models, namely T5
  • Add support for overlapped grad reductions with pipeline or sequence parallelism

Usage

Set optimizer to distributed_fused_adam in config file:

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

timmoon10 and others added 9 commits August 23, 2022 11:53
If params are bf16, dist Adam will only store 16-bit remainder needed to reconstruct fp32 params.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Requires dist Adam optimizer.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@lgtm-com
Copy link

lgtm-com bot commented Sep 7, 2022

This pull request introduces 1 alert and fixes 1 when merging 065a89b into abbe643 - view on LGTM.com

new alerts:

  • 1 for Unused import

fixed alerts:

  • 1 for Unused import

@lgtm-com
Copy link

lgtm-com bot commented Sep 8, 2022

This pull request introduces 1 alert and fixes 1 when merging 7943ebc into b9cf05c - view on LGTM.com

new alerts:

  • 1 for Unused import

fixed alerts:

  • 1 for Unused import

Requires dist Adam optimizer

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@lgtm-com
Copy link

lgtm-com bot commented Sep 9, 2022

This pull request fixes 1 alert when merging e06d34a into b9cf05c - view on LGTM.com

fixed alerts:

  • 1 for Unused import

@timmoon10
Copy link
Collaborator Author

timmoon10 commented Sep 10, 2022

Running T5 41B on 32 Selene nodes, I see a 1.2x speedup over the pure data-parallel impl, 66% of expected memory savings, and nearly identical loss values after 20 steps.

Full results with T5 41B and GPT-3 175B. The run configurations are detailed inside. Note that I ran with a relatively small global batch size, which makes communication a more significant portion of runtime.

@timmoon10 timmoon10 changed the title Enable overlapped grad reductions with pipeline or sequence parallelism Support distributed Adam with T5 and support overlapped grad reductions with pipeline parallelism Sep 15, 2022
…tion

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 marked this pull request as ready for review September 15, 2022 20:52
@lgtm-com
Copy link

lgtm-com bot commented Sep 15, 2022

This pull request fixes 1 alert when merging d528a89 into f1825bc - view on LGTM.com

fixed alerts:

  • 1 for Unused import

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@lgtm-com
Copy link

lgtm-com bot commented Sep 20, 2022

This pull request fixes 1 alert when merging 811b59c into f1825bc - view on LGTM.com

fixed alerts:

  • 1 for Unused import

@lgtm-com
Copy link

lgtm-com bot commented Sep 27, 2022

This pull request fixes 1 alert when merging ebd98c4 into 971485c - view on LGTM.com

fixed alerts:

  • 1 for Unused import

@lgtm-com
Copy link

lgtm-com bot commented Sep 27, 2022

This pull request fixes 1 alert when merging b2a61ad into 73fcfd7 - view on LGTM.com

fixed alerts:

  • 1 for Unused import

Changes were made to support pipeline parallelism with interleaved pipeline parallelism. Distributed Adam does not support this currently.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 force-pushed the dist-adam-pipeline-parallel-async-grad-reduction branch from 2638974 to 4ef0255 Compare October 19, 2022 23:09
@lgtm-com
Copy link

lgtm-com bot commented Oct 19, 2022

This pull request fixes 1 alert when merging a088304 into 8656574 - view on LGTM.com

fixed alerts:

  • 1 for Unused import

@timmoon10 timmoon10 marked this pull request as draft October 20, 2022 02:45
@timmoon10 timmoon10 marked this pull request as ready for review October 20, 2022 04:41
@timmoon10
Copy link
Collaborator Author

timmoon10 commented Oct 20, 2022

Running on a DGX A100 node for 50 steps with 2-way data, tensor, and pipeline parallelism, I see nearly identical learning behavior with and without the distributed optimizer:

Model ZeRO O2 Data type Throughput Train loss Val loss
GPT-2 124M Yes No FP16 2.69it/s 8.4 7.870
GPT-2 124M No No FP16 3.31it/s 8.4 7.870
GPT-2 124M Yes No BF16 3.19it/s 8.28 7.790
GPT-2 124M No No BF16 3.39it/s 8.28 7.790
GPT-2 124M Yes Yes BF16 3.44it/s 8.3 7.800
GPT-2 124M No Yes BF16 3.60it/s 8.3 7.800
T5 220M Yes No FP32 1.46it/s 7.64 7.530
T5 220M No No FP32 1.69it/s 7.64 7.530
T5 220M Yes No FP16 1.43it/s 8.45 8.290
T5 220M No No FP16 1.43it/s 8.45 8.290
T5 220M Yes No BF16 1.50it/s 7.66 7.560
T5 220M No No BF16 1.45it/s 7.65 7.550
T5 220M Yes Yes BF16 1.58it/s 7.65 7.540
T5 220M No Yes BF16 1.61it/s 7.65 7.540

I get runtime failures when I run GPT-2 with FP32 and with pipeline parallelism enabled. This error shows up in the main branch as well.

@lgtm-com
Copy link

lgtm-com bot commented Oct 20, 2022

This pull request fixes 1 alert when merging aed0e00 into 85fc659 - view on LGTM.com

fixed alerts:

  • 1 for Unused import

ericharper
ericharper previously approved these changes Oct 20, 2022
Copy link
Collaborator

@ericharper ericharper left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

…enabled

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Collaborator

@ericharper ericharper left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-approving

@lgtm-com
Copy link

lgtm-com bot commented Oct 20, 2022

This pull request fixes 1 alert when merging 190f992 into 0336000 - view on LGTM.com

fixed alerts:

  • 1 for Unused import

@timmoon10
Copy link
Collaborator Author

With NVIDIA/apex#1514 the distributed optimizer supports interleaved pipeline parallelism. Running GPT-2 124M for 20 steps, I get the same loss values with and without the distributed optimizer.

@ericharper ericharper merged commit 456410d into NVIDIA:main Oct 21, 2022
1-800-BAD-CODE pushed a commit to 1-800-BAD-CODE/NeMo that referenced this pull request Nov 13, 2022
…ns with pipeline parallelism (NVIDIA#4900)

* Avoid storing extra copy of params in dist Adam optimizer

If params are bf16, dist Adam will only store 16-bit remainder needed to reconstruct fp32 params.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add support for dist Adam in GPT-3 without O2-level AMP

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add support for dist Adam in Megatron-LM models

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug dist Adam support without Megatron AMP O2

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add support for overlapped grad sync with pipeline parallelism in GPT-3

Requires dist Adam optimizer.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug dist Adam support for T5

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add support for overlapped grad sync with pipeline parallelism in T5

Requires dist Adam optimizer

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Update Apex commits in Dockerfile and Jenkinsfile

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Support distributed Adam in Megatron grad scaler class.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Update dist Adam to accommodate changes in GPT model

Changes were made to support pipeline parallelism with interleaved pipeline parallelism. Distributed Adam does not support this currently.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor tweaks to dist Adam integration

Review suggestions from @ericharper and @crcrpar.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove error when dist Adam and interleaved pipeline parallelism are enabled

Signed-off-by: Tim Moon <tmoon@nvidia.com>

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
Signed-off-by: 1-800-bad-code <shane.carroll@utsa.edu>
hainan-xv pushed a commit to hainan-xv/NeMo that referenced this pull request Nov 29, 2022
…ns with pipeline parallelism (NVIDIA#4900)

* Avoid storing extra copy of params in dist Adam optimizer

If params are bf16, dist Adam will only store 16-bit remainder needed to reconstruct fp32 params.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add support for dist Adam in GPT-3 without O2-level AMP

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add support for dist Adam in Megatron-LM models

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug dist Adam support without Megatron AMP O2

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add support for overlapped grad sync with pipeline parallelism in GPT-3

Requires dist Adam optimizer.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug dist Adam support for T5

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add support for overlapped grad sync with pipeline parallelism in T5

Requires dist Adam optimizer

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Update Apex commits in Dockerfile and Jenkinsfile

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Support distributed Adam in Megatron grad scaler class.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Update dist Adam to accommodate changes in GPT model

Changes were made to support pipeline parallelism with interleaved pipeline parallelism. Distributed Adam does not support this currently.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor tweaks to dist Adam integration

Review suggestions from @ericharper and @crcrpar.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove error when dist Adam and interleaved pipeline parallelism are enabled

Signed-off-by: Tim Moon <tmoon@nvidia.com>

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
Signed-off-by: Hainan Xu <hainanx@nvidia.com>
hainan-xv pushed a commit to hainan-xv/NeMo that referenced this pull request Nov 29, 2022
…ns with pipeline parallelism (NVIDIA#4900)

* Avoid storing extra copy of params in dist Adam optimizer

If params are bf16, dist Adam will only store 16-bit remainder needed to reconstruct fp32 params.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add support for dist Adam in GPT-3 without O2-level AMP

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add support for dist Adam in Megatron-LM models

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug dist Adam support without Megatron AMP O2

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add support for overlapped grad sync with pipeline parallelism in GPT-3

Requires dist Adam optimizer.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug dist Adam support for T5

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add support for overlapped grad sync with pipeline parallelism in T5

Requires dist Adam optimizer

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Update Apex commits in Dockerfile and Jenkinsfile

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Support distributed Adam in Megatron grad scaler class.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Update dist Adam to accommodate changes in GPT model

Changes were made to support pipeline parallelism with interleaved pipeline parallelism. Distributed Adam does not support this currently.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor tweaks to dist Adam integration

Review suggestions from @ericharper and @crcrpar.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove error when dist Adam and interleaved pipeline parallelism are enabled

Signed-off-by: Tim Moon <tmoon@nvidia.com>

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
Signed-off-by: Hainan Xu <hainanx@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants