New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for sharded optimizers when dumping checkpoints outside of the DDP sharded training type plugin #6387
Support for sharded optimizers when dumping checkpoints outside of the DDP sharded training type plugin #6387
Comments
This is new to me, but on fairscale's or pytorch's side it's easy to make the checkpointing compatible with calls from all ranks. It was not the default because some frameworks (classy and vissl at least) only call state_dict from a single rank, and until now I thought that lightning was doing the same, if it's useful then both behaviors can be supported through a flag for instance |
Thanks for the issue @ananthsub! Just to make sure, I see the issue being that if the user wants to use a Sharded Optimizer outside the plugin we do not support this. I'll need to think this moreover, currently Just to clear up we call |
In this case, could we upstream the if fairscale is available and if the optimizer is of type OSS, then we call |
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
馃悰 Bug
Using fairscale distributed optimizers without DDP sharded leads to crashes or inconsistent stae in the trainer when checkpointing. This will also occur with PyTorch's latest prototype version of the ZeroRedundancy optimizer. @SeanNaren
Imagine this scenario:
optimizer.state_dict()
- For fairscale/pytorch distributed optimizers, we need to consolidate the state dict on a rank. Afterwards, we should look up the state dict only from that rank.One could add a callback which implements
on_save_checkpoint
to callconsolidate_state_dict()
on the optimizer across all ranks. However, the trainer callsstate_dict
on all ranks, leading to the exception here: https://github.com/facebookresearch/fairscale/blob/1204c7cf54ec301d46a0d3f3fd703da6b306f8f5/fairscale/optim/oss.py#L354-L358This error occurs only on non-zero ranks. As a result, the error in checkpointing is compounded by the comment here around the exception handling logic for training and its interaction with checkpointing: #6343 (comment)
Proposal to fix:
consolidate_state_dict
on all ranks, and then fetch the optimizer state only from rank 0.cc @awaelchli @ananthsub @ninginthecloud @rohitgr7 @tchaton @akihironitta @blefaudeux
The text was updated successfully, but these errors were encountered: