[Schedule] Support sequence parallelism#6
Merged
comaniac merged 14 commits intoawslabs:mainfrom Jan 23, 2023
Merged
Conversation
chhzh123
reviewed
Jan 19, 2023
chhzh123
reviewed
Jan 19, 2023
szhengac
reviewed
Jan 20, 2023
szhengac
reviewed
Jan 20, 2023
szhengac
reviewed
Jan 20, 2023
szhengac
reviewed
Jan 20, 2023
szhengac
approved these changes
Jan 21, 2023
comaniac
commented
Jan 21, 2023
szhengac
reviewed
Jan 21, 2023
szhengac
reviewed
Jan 21, 2023
Contributor
Author
|
Thanks @szhengac |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Note that
dist.reduce_scatteralways scatters along the first dimension, so we implicitly transpose input and output tensors when needed.UPDATE
Per offline discussion, the programming model is changed:
Other 1: For readability and flexibility, now
.syncrequires users to always specify the op.Other 2: The
.hookprimitive is integrated to.sync, which now allows a custom hook function.Accordingly, a unit test was added to verify the correctness for both forward and backward.
Use
dist.all_gather_into_tensorwhen available. This API wasdist._all_gather_base, which will be deprecated soon. Note thatdist.all_gather_into_tensoralways concats along the first dimension, so we implicitly transpose input and output tensors when needed. Even with the transpose overheads, it seems still faster thandist.all_gather+torch.catbased on my micro benchmarks.Move sharding logic to another place and establish a registration mechanism for sharable modules.
[MIsc] Refactor
task_lint.shso that we could run it locally to verify linting without being installed transformers everytime.[Misc] Add conftest.py to enforce the test order; otherwise the distributed tests may stuck if distributed devices are not running the same test at the same time.
[Misc] Remove DeepSpeed from docker image for now, so that we could make it public for CI.
cc @szhengac @chhzh123
Checklist