Skip to content

[Schedule] Support sequence parallelism#6

Merged
comaniac merged 14 commits intoawslabs:mainfrom
comaniac:seq_para
Jan 23, 2023
Merged

[Schedule] Support sequence parallelism#6
comaniac merged 14 commits intoawslabs:mainfrom
comaniac:seq_para

Conversation

@comaniac
Copy link
Contributor

@comaniac comaniac commented Jan 19, 2023

Description

  • Support sequential parallelism. Specifically, now we can schedule as follows:
# output would be partial
sch["attention.out_proj"].shard("weight", axis=1)
# this indicates that we sync the output in forward pass, but defer the gather
# (at output axis=1) until resid_dropout. in other words, the schedule becomes:
# linear -> reduce_scatter -> dropout -> all_gather.
sch["attention.out_proj"].sync(
    mode="forward_defer_gather",
    gather_at=(sch["attention.resid_dropout"], 1)
)

Note that dist.reduce_scatter always scatters along the first dimension, so we implicitly transpose input and output tensors when needed.


UPDATE
Per offline discussion, the programming model is changed:

sch["attention.out_proj"].sync(mode="fwd_post", sync_op_or_fn="reduce_scatter", axis=1)
sch["attention.resid_dropout"].sync(mode="fwd_post", sync_op_or_fn="all_gather", axis=1)

Other 1: For readability and flexibility, now .sync requires users to always specify the op.
Other 2: The .hook primitive is integrated to .sync, which now allows a custom hook function.


  • Accordingly, a unit test was added to verify the correctness for both forward and backward.

  • Use dist.all_gather_into_tensor when available. This API was dist._all_gather_base, which will be deprecated soon. Note that dist.all_gather_into_tensor always concats along the first dimension, so we implicitly transpose input and output tensors when needed. Even with the transpose overheads, it seems still faster than dist.all_gather + torch.cat based on my micro benchmarks.

  • Move sharding logic to another place and establish a registration mechanism for sharable modules.

  • [MIsc] Refactor task_lint.sh so that we could run it locally to verify linting without being installed transformers everytime.

  • [Misc] Add conftest.py to enforce the test order; otherwise the distributed tests may stuck if distributed devices are not running the same test at the same time.

  • [Misc] Remove DeepSpeed from docker image for now, so that we could make it public for CI.

cc @szhengac @chhzh123

Checklist

  • PR's title starts with a category (e.g. [Bugfix], [Model], [Tutorial], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

@comaniac comaniac changed the title Seq para [Schedule] Support sequence parallelism Jan 19, 2023
@comaniac comaniac merged commit 54af148 into awslabs:main Jan 23, 2023
@comaniac
Copy link
Contributor Author

Thanks @szhengac

@comaniac comaniac deleted the seq_para branch January 23, 2023 18:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants