Skip to content

[Primitive][shard] Use autograd function for all sync ops#33

Merged
comaniac merged 2 commits intoawslabs:mainfrom
comaniac:dont_use_bwd_hook
Feb 1, 2023
Merged

[Primitive][shard] Use autograd function for all sync ops#33
comaniac merged 2 commits intoawslabs:mainfrom
comaniac:dont_use_bwd_hook

Conversation

@comaniac
Copy link
Contributor

@comaniac comaniac commented Feb 1, 2023

Description

We found that the behavior of PyTorch backward hooks is a bit weird, so for safety, this PR changes all use cases of backward hooks to be forward pre-hook + autograd function. Specifically:

Before

register_backward_hook(dist.all_reduce)

Now

class _ReduceBackwardGradient(torch.autograd.Function):
    def forward(...):
        # no-op

    def backward(...):
        # all-reduce gradient

register_forward_pre_hook(allreduce_backward_gradient)

Accordingly, _ReduceBackwardGradient is added. In addition, this PR also adds unit tests for all supported sync ops.

Checklist

  • PR's title starts with a category (e.g. [Bugfix], [Model], [Tutorial], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

cc @szhengac @chhzh123

@comaniac comaniac changed the title [Primitive][shard] Use auto-grad fn for all sync ops [Primitive][shard] Use autograd function for all sync ops Feb 1, 2023
@comaniac comaniac merged commit 1e15ee2 into awslabs:main Feb 1, 2023
@comaniac
Copy link
Contributor Author

comaniac commented Feb 1, 2023

Thanks @szhengac

@comaniac comaniac deleted the dont_use_bwd_hook branch February 1, 2023 01:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants