-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support overlapped grad sync with Megatron pipeline parallelism #1475
Support overlapped grad sync with Megatron pipeline parallelism #1475
Conversation
Each grad bucket independently keeps track of grads that have been generated. Add helper function to create callback functions. Change default param arg in grad norm functions to None. Perform communication for checkpointing in main stream to avoid memory pool overheads.
Enables async grad reduction in first pipeline stage during last backward pass, and disables async grad reduction in all other pipeline stages.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
knowing this would be too early...
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py
Outdated
Show resolved
Hide resolved
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py
Show resolved
Hide resolved
Add unit test for pipeline parallelism with custom sync context. Style tweaks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM
Review suggestion from @crcrpar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you :)
…atron pipeline parallelism (NVIDIA#1475) * Refactor how dist Adam handles overlapped grad sync Each grad bucket independently keeps track of grads that have been generated. Add helper function to create callback functions. Change default param arg in grad norm functions to None. Perform communication for checkpointing in main stream to avoid memory pool overheads. * Support Megatron pipeline parallelism with async grad reduction Enables async grad reduction in first pipeline stage during last backward pass, and disables async grad reduction in all other pipeline stages. * Review suggestions from crcrpar Add unit test for pipeline parallelism with custom sync context. Style tweaks. * Use unittest assert functions in pipeline parallelism test Review suggestion from crcrpar
…atron pipeline parallelism (NVIDIA#1475) * Refactor how dist Adam handles overlapped grad sync Each grad bucket independently keeps track of grads that have been generated. Add helper function to create callback functions. Change default param arg in grad norm functions to None. Perform communication for checkpointing in main stream to avoid memory pool overheads. * Support Megatron pipeline parallelism with async grad reduction Enables async grad reduction in first pipeline stage during last backward pass, and disables async grad reduction in all other pipeline stages. * Review suggestions from crcrpar Add unit test for pipeline parallelism with custom sync context. Style tweaks. * Use unittest assert functions in pipeline parallelism test Review suggestion from crcrpar
This PR adds functionality so that the distributed Adam optimizer can overlap grad reduce-scatters with backward compute in the first pipeline stage. Async grad reductions are disabled in the other pipeline stages since it slows down the backward pass, so the reduction should be done externally while waiting for the first stage to finish. Note that this is not compatible with
DistributedDataParallel
since I am not aware of an option to manually trigger a reduction after usingno_sync
.I've also done some refactoring of distributed Adam to support NeMo-Megatron integration, mostly so that I can selectively disable async grad reductions for model-parallel operations:
[]
toNone
. This allows us to pass in empty iterators