Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support uneven DDP inputs with pytorch model.join #3325

Open
edenlightning opened this issue Sep 2, 2020 · 23 comments
Open

Support uneven DDP inputs with pytorch model.join #3325

edenlightning opened this issue Sep 2, 2020 · 23 comments
Labels
3rd party Related to a 3rd-party distributed Generic distributed-related topic feature Is an improvement or enhancement help wanted Open to be worked on
Milestone

Comments

@edenlightning
Copy link
Contributor

edenlightning commented Sep 2, 2020

See more details: pytorch/pytorch#38174

cc @Borda @tchaton @rohitgr7 @akihironitta @awaelchli

@edenlightning edenlightning added feature Is an improvement or enhancement help wanted Open to be worked on distributed Generic distributed-related topic labels Sep 2, 2020
@stale
Copy link

stale bot commented Oct 21, 2020

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Oct 21, 2020
@carmocca
Copy link
Member

Interested in this issue! Hopefully some progress is done soon 👍

@stale stale bot removed the won't fix This will not be worked on label Oct 21, 2020
@xvr-hlt
Copy link

xvr-hlt commented Nov 19, 2020

Interested in this also :)

@rohan-varma
Copy link

Is there any progress on this issue? Happy to help in any way.

@edenlightning
Copy link
Contributor Author

@rohan-varma that would be great!! Want to try and submit a draft PR? And we can help from there?

@rohan-varma
Copy link

@edenlightning Sounds good, I also pinged the slack channel for any feedback/discussions.

@alanhdu
Copy link
Contributor

alanhdu commented Dec 11, 2020

We'd also be very interested in this feature. Let us know if there's anything I can do to help!

@rohan-varma
Copy link

The PR #5141 is ready for review, in case anyone wants to take a look.

@stale stale bot added the won't fix This will not be worked on label Jan 14, 2021
@Lightning-AI Lightning-AI deleted a comment from stale bot Jan 14, 2021
@stale stale bot removed the won't fix This will not be worked on label Jan 14, 2021
@stale
Copy link

stale bot commented Feb 13, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Feb 13, 2021
@stale stale bot closed this as completed Feb 20, 2021
@edenlightning edenlightning removed the won't fix This will not be worked on label Feb 22, 2021
@edenlightning edenlightning reopened this Feb 22, 2021
@edenlightning edenlightning added this to the 1.3 milestone Feb 22, 2021
@ananthsub
Copy link
Contributor

ananthsub commented Mar 24, 2021

I discussed this more with @rohan-varma - DDP join docs: https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.join

This module currently does not support custom distributed collective operations in the forward pass, such as SyncBatchNorm or other custom defined collectives in the model’s forward pass.

As the LightningModule is wrapped in another module which is then wrapped with DDP, the LightningModule's training_step becomes the forward pass run by the DDP wrapped module: https://github.com/PyTorchLightning/pytorch-lightning/blob/d471fa30b3bf95cfe601014bac544754067241ca/pytorch_lightning/plugins/training_type/ddp.py#L223-L227

As a result, any collective call (such as metric syncing or all_gather) that happens during training step will cause this to not work. Therefore I lean towards closing this out given the caveats. @awaelchli @justusschock what do you think?

@awaelchli
Copy link
Member

agree, I also don't see how this can be supported at the moment.

@edenlightning edenlightning added the 3rd party Related to a 3rd-party label Apr 15, 2021
@tmbdev
Copy link

tmbdev commented Dec 25, 2021

@ananthsub commented 6 hours ago

However, the manner in which join tracks collectives can quickly run into issues with other collectives that run in the forward pass / training_step.

In PyTorch, the "with Join" construct is used as a simple wrapper around training steps. It should work in simple cases even if there are more complex cases where it doesn't work.

So, why not simply add an option to the trainer that enables wrapping the invocations of training_step with with Join? That should be pretty straightforward, and it would leave it up to users to determine when with Join is the right thing to use and when it doesn't work.

@ananthsub ananthsub reopened this Dec 25, 2021
@awaelchli
Copy link
Member

awaelchli commented Dec 28, 2021

So, why not simply add an option to the trainer that enables wrapping the invocations of training_step with with Join?

The join here is specific to pytorch DDP. If it was implemented, it would have to live inside the DDP plugin/strategy. For simple cases it may work, but no collective calls are allowed except the ones under DDP.forward()/DDP.backward() if I understand correctly.

If we did want to do it "correctly", we would probably have to set throw_on_early_termination=True and then we must handle the error in all custom collective calls, including the ones in torchmetrics. I don't know if that would work, but it's probably not feasible.

@kaushikb11 kaushikb11 added the priority: 0 High priority task label Jan 24, 2022
@kaushikb11 kaushikb11 self-assigned this Jan 24, 2022
@carmocca
Copy link
Member

carmocca commented Feb 3, 2022

To recap, the plan would be:

  • Enable "join" as an optional feature of the DDP strategy: Trainer(strategy=DDPStrategy(uneven_input_support: bool). We could also add a registry string for it.
  • Add support for "joining" the training_step.
    • Is there a benefit to doing it for validation_step and test_step? Probably not
    • Could validation_step and test_step use UnrepeatedDistributedSampler just as trainer.predict? Probably yes.
  • When the feature is enabled, we don't automatically use the generic DistributedSampler as we wouldn't want to duplicate data to make inputs even.
  • We print a big warning about how this feature is experimental and describe all its caveats.
  • This would be Torch 1.10+ only

Some sources:
https://pytorch.org/docs/stable/distributed.algorithms.join.html#torch.distributed.algorithms.Join
https://pytorch.org/tutorials/advanced/generic_join.html

@carmocca carmocca added this to the 1.6 milestone Feb 3, 2022
@awaelchli
Copy link
Member

awaelchli commented Feb 5, 2022

Is there a benefit to doing it for validation_step and test_step? Probably not

I assume there is, if collectives are being used. For example, sync_dist=True in self.log or similar. However, we don't wrap the model in ddp during val and test, so join won't be available anyways.

When the feature is enabled, we don't automatically use the generic DistributedSampler as we wouldn't want to duplicate data to make inputs even.

pytorch/pytorch#49180 is great! Hopefully this will clarify the drop_last argument which has a slightly misleading/incomplete description :) We would indeed need the UnrepeatedDistributedSampler.

@otaj
Copy link
Contributor

otaj commented Aug 17, 2022

Hi, everyone, I'm gathering information on what is needed in order to support this properly.

  1. Use torch.distributed.algorithms.Join (https://pytorch.org/docs/stable/distributed.algorithms.join.html) as a context manager in which is the model run.
  2. Use UnrepeatedDistributedSamplerWrapper
  3. Check for all modules, that could use syncing (such as SyncBatchNorm) and have them as arguments to the Join context manager from 1.
  4. Figure out what to with calls to self.log(..., sync_dist=True)

Is that it? cc @awaelchli, @carmocca.

@justusschock
Copy link
Member

@otaj Almost.

Additionally, all metrics from torchmetrics would have to be considered as well as they are also capable of issuing syncs on their own. And in general, the user can run arbitrary syncing calls within each of the steps which have to be considered as well (which will be the trickiest part I guess)

@otaj
Copy link
Contributor

otaj commented Aug 18, 2022

oh, those torchmetrics are going to be fun... 😅 I think capturing user calls can be solved with yet another contextmanager (our, custom one), what do you think?

@justusschock
Copy link
Member

if we can capture user calls with that, it might work similarly with torchmetrics. So let's ignore those metrics for now and if you got a working solution for everything else, I'm sure we'll manage to integrate metrics with that :D

@Borda
Copy link
Member

Borda commented Sep 19, 2022

let's check the option with LigthingLite first 🦦

@awaelchli
Copy link
Member

Here is the corresponding issue as suggested in planning: #14635

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party distributed Generic distributed-related topic feature Is an improvement or enhancement help wanted Open to be worked on
Projects
No open projects
Status: Todo