Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: function AllGatherGradBackward returned an incorrect number of gradients (expected 2, got 1) #6624

Closed
ArvinZhuang opened this issue Mar 22, 2021 · 0 comments 路 Fixed by #6625
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task

Comments

@ArvinZhuang
Copy link
Contributor

馃悰 Bug

When using self.all_gather in training_step to gather tensor with gradient function to compute and return loss throwing
RuntimeError: function AllGatherGradBackward returned an incorrect number of gradients (expected 2, got 1)

I think the bug is that the forward(ctx, tensor, group=group.WORLD) in distributed.AllGatherGrad function has two arguments but the backward(ctx, *grad_output) only returns one output.

class AllGatherGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor, group=group.WORLD):
        ctx.group = group

        gathered_tensor = [
            torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
        ]

        torch.distributed.all_gather(gathered_tensor, tensor, group=group)
        gathered_tensor = torch.stack(gathered_tensor, dim=0)

        return gathered_tensor

    @staticmethod
    def backward(ctx, *grad_output):
        grad_output = torch.cat(grad_output)

        torch.distributed.all_reduce(
            grad_output,
            op=torch.distributed.ReduceOp.SUM,
            async_op=False,
            group=ctx.group
        )

        return grad_output[torch.distributed.get_rank()]

The error can be solved by changing return grad_output[torch.distributed.get_rank()] to return grad_output[torch.distributed.get_rank()], None

Environment

  • PyTorch Version: 1.7.1
  • OS: Linux
  • How you installed PyTorch: pip
  • Python version: 3.7
  • lightning version: 1.1.8
@ArvinZhuang ArvinZhuang added bug Something isn't working help wanted Open to be worked on labels Mar 22, 2021
@Borda Borda added the priority: 1 Medium priority task label Mar 23, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants