Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support overlapped grad sync with Megatron pipeline parallelism #1475

Merged
merged 4 commits into from
Sep 20, 2022

Conversation

timmoon10
Copy link
Contributor

This PR adds functionality so that the distributed Adam optimizer can overlap grad reduce-scatters with backward compute in the first pipeline stage. Async grad reductions are disabled in the other pipeline stages since it slows down the backward pass, so the reduction should be done externally while waiting for the first stage to finish. Note that this is not compatible with DistributedDataParallel since I am not aware of an option to manually trigger a reduction after using no_sync.

I've also done some refactoring of distributed Adam to support NeMo-Megatron integration, mostly so that I can selectively disable async grad reductions for model-parallel operations:

  • Each bucket independently keeps track of which grads it has received
  • Exposed function that creates callback functions for launching async grad reductions, which I override in NeMo
  • Perform collective communication for checkpointing in main stream in order to reduce memory pool overheads
  • Change default value for grad norm function from [] to None. This allows us to pass in empty iterators

Each grad bucket independently keeps track of grads that have been generated. Add helper function to create callback functions. Change default param arg in grad norm functions to None. Perform communication for checkpointing in main stream to avoid memory pool overheads.
Enables async grad reduction in first pipeline stage during last backward pass, and disables async grad reduction in all other pipeline stages.
Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

knowing this would be too early...

Add unit test for pipeline parallelism with custom sync context. Style tweaks.
@timmoon10 timmoon10 marked this pull request as ready for review September 15, 2022 20:25
Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py Outdated Show resolved Hide resolved
@crcrpar crcrpar self-requested a review September 20, 2022 04:28
Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you :)

@crcrpar crcrpar merged commit 2b0e837 into NVIDIA:master Sep 20, 2022
@crcrpar crcrpar added this to the 22.11 milestone Oct 25, 2022
hubertlu-tw pushed a commit to ROCm/apex that referenced this pull request Dec 29, 2022
…atron pipeline parallelism (NVIDIA#1475)

* Refactor how dist Adam handles overlapped grad sync

Each grad bucket independently keeps track of grads that have been generated. Add helper function to create callback functions. Change default param arg in grad norm functions to None. Perform communication for checkpointing in main stream to avoid memory pool overheads.

* Support Megatron pipeline parallelism with async grad reduction

Enables async grad reduction in first pipeline stage during last backward pass, and disables async grad reduction in all other pipeline stages.

* Review suggestions from crcrpar

Add unit test for pipeline parallelism with custom sync context. Style tweaks.

* Use unittest assert functions in pipeline parallelism test

Review suggestion from crcrpar
yuanzhedong pushed a commit to yuanzhedong/apex that referenced this pull request Jul 14, 2023
…atron pipeline parallelism (NVIDIA#1475)

* Refactor how dist Adam handles overlapped grad sync

Each grad bucket independently keeps track of grads that have been generated. Add helper function to create callback functions. Change default param arg in grad norm functions to None. Perform communication for checkpointing in main stream to avoid memory pool overheads.

* Support Megatron pipeline parallelism with async grad reduction

Enables async grad reduction in first pipeline stage during last backward pass, and disables async grad reduction in all other pipeline stages.

* Review suggestions from crcrpar

Add unit test for pipeline parallelism with custom sync context. Style tweaks.

* Use unittest assert functions in pipeline parallelism test

Review suggestion from crcrpar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants