Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for sharded optimizers when dumping checkpoints outside of the DDP sharded training type plugin #6387

Closed
ananthsub opened this issue Mar 7, 2021 · 4 comments 路 Fixed by facebookresearch/fairscale#500 or #14208
Assignees
Labels
checkpointing Related to checkpointing priority: 1 Medium priority task
Milestone

Comments

@ananthsub
Copy link
Contributor

ananthsub commented Mar 7, 2021

馃悰 Bug

Using fairscale distributed optimizers without DDP sharded leads to crashes or inconsistent stae in the trainer when checkpointing. This will also occur with PyTorch's latest prototype version of the ZeroRedundancy optimizer. @SeanNaren

Imagine this scenario:

  • Someone wraps their optimizer with Fairscale OSS inside their lightning module. The user does not use the DDP sharded plugin.
  • At checkpoint time, when the trainer dumps the checkpoint dict, it looks up the optimizer state
  • The optimizer state goes through the training type plugin
  • The training type plugin calls optimizer.state_dict() - For fairscale/pytorch distributed optimizers, we need to consolidate the state dict on a rank. Afterwards, we should look up the state dict only from that rank.

One could add a callback which implements on_save_checkpoint to call consolidate_state_dict() on the optimizer across all ranks. However, the trainer calls state_dict on all ranks, leading to the exception here: https://github.com/facebookresearch/fairscale/blob/1204c7cf54ec301d46a0d3f3fd703da6b306f8f5/fairscale/optim/oss.py#L354-L358

This error occurs only on non-zero ranks. As a result, the error in checkpointing is compounded by the comment here around the exception handling logic for training and its interaction with checkpointing: #6343 (comment)

Proposal to fix:

cc @awaelchli @ananthsub @ninginthecloud @rohitgr7 @tchaton @akihironitta @blefaudeux

@ananthsub ananthsub added bug Something isn't working help wanted Open to be worked on labels Mar 7, 2021
@blefaudeux
Copy link

blefaudeux commented Mar 7, 2021

This is new to me, but on fairscale's or pytorch's side it's easy to make the checkpointing compatible with calls from all ranks. It was not the default because some frameworks (classy and vissl at least) only call state_dict from a single rank, and until now I thought that lightning was doing the same, if it's useful then both behaviors can be supported through a flag for instance

@SeanNaren
Copy link
Contributor

Thanks for the issue @ananthsub!

Just to make sure, I see the issue being that if the user wants to use a Sharded Optimizer outside the plugin we do not support this.

I'll need to think this moreover, currently configure_optimizer is called before distributed communication is initialized (the FSDP plugin integration will make an option for the training type plugin to delay this till after) so wrapping your optimizers in OSS should lead to a crash, unless you init distributed yourself currently

Just to clear up we call consolidate_state_dict on all processes, but only get the state dict from rank 0. Just to make sure I understand @ananthsub are you suggesting upstreaming the consolidation/return to Fairscale?

@edenlightning edenlightning added priority: 1 Medium priority task and removed bug Something isn't working labels Mar 8, 2021
@ananthsub
Copy link
Contributor Author

Just to clear up we call consolidate_state_dict on all processes, but only get the state dict from rank 0. Just to make sure I understand @ananthsub are you suggesting upstreaming the consolidation/return to Fairscale?

In this case, could we upstream the optimizer_state from the sharded plugin into the base training type plugin?
https://github.com/PyTorchLightning/pytorch-lightning/blob/523c59bfddca48d003ce20168e727e6683f3efd4/pytorch_lightning/plugins/training_type/sharded.py#L56-L60

if fairscale is available and if the optimizer is of type OSS, then we call consolidate_state_dict on all ranks, and then return the optimizer state from rank 0. otherwise we return the optimizer state from all ranks

@stale
Copy link

stale bot commented Apr 10, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Apr 10, 2021
@stale stale bot closed this as completed Apr 18, 2021
@ananthsub ananthsub reopened this Feb 3, 2022
@stale stale bot removed the won't fix This will not be worked on label Feb 3, 2022
@ananthsub ananthsub added checkpointing Related to checkpointing and removed help wanted Open to be worked on labels Feb 16, 2022
@ananthsub ananthsub added this to the 1.6 milestone Feb 16, 2022
@carmocca carmocca modified the milestones: 1.6, 1.5.x Feb 16, 2022
@Borda Borda modified the milestones: 1.5.x, 1.6 Mar 21, 2022
@awaelchli awaelchli modified the milestones: 1.6, 1.7 Mar 21, 2022
@carmocca carmocca modified the milestones: pl:1.7, pl:future Jul 19, 2022
@carmocca carmocca modified the milestones: pl:future, pl:1.8 Aug 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment