Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Move save_function to accelerator 1/n [DeepSpeed] #6689

Merged
merged 5 commits into from
Mar 29, 2021

Conversation

tchaton
Copy link
Contributor

@tchaton tchaton commented Mar 26, 2021

What does this PR do?

This PR refactor save_checkpoint to be trainer - checkpoint_connector - accelerator - training type plugin responsibility.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@pep8speaks
Copy link

pep8speaks commented Mar 26, 2021

Hello @tchaton! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2021-03-29 13:54:43 UTC

@SeanNaren
Copy link
Contributor

This is also motivated by the need for the training type plugin to handle saving of checkpoint in sharded environments, where the model has been sharded onto multiple processes. Eventually the same should happen for loading of checkpoints if we're using restore!

@codecov
Copy link

codecov bot commented Mar 26, 2021

Codecov Report

Merging #6689 (d5aa78e) into master (b730a5a) will decrease coverage by 9%.
The diff coverage is 71%.

@@           Coverage Diff            @@
##           master   #6689     +/-   ##
========================================
- Coverage      91%     82%     -9%     
========================================
  Files         192     192             
  Lines       12238   13106    +868     
========================================
- Hits        11152   10759    -393     
- Misses       1086    2347   +1261     

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we keep the current logic and just call:

(in dump_checkpoint)

model = self.trainer.lightning_module

checkpoint = {
    'epoch': current_epoch,
    'global_step': global_step,
    'pytorch-lightning_version': pytorch_lightning.__version__,
    'state_dict': self.trainer.accelerator.model_state(model),
}

Just as we do for the optimizer?

pytorch_lightning/utilities/cloud_io.py Outdated Show resolved Hide resolved
@SeanNaren
Copy link
Contributor

SeanNaren commented Mar 26, 2021

Why don't we keep the current logic and just call:

(in dump_checkpoint)

model = self.trainer.lightning_module

checkpoint = {
    'epoch': current_epoch,
    'global_step': global_step,
    'pytorch-lightning_version': pytorch_lightning.__version__,
    'state_dict': self.trainer.accelerator.model_state(model),
}

Just as we do for the optimizer?

I'm probably missing some pieces here, but in the DeepSpeed Plugin we need to save the checkpoints differently to the standard saving of one file from rank 0. Each process needs to save it's weights to a directory, thus this logic needs to be managed by either the training type plugin, or a separate class!

@carmocca
Copy link
Contributor

I'm probably missing some pieces here, but in the DeepSpeed Plugin we need to save the checkpoints differently to the standard saving of one file from rank 0. Each process needs to save it's weights to a directory, thus this logic needs to be managed by either the training type plugin, or a separate class!

We should then make sure we have the requirements clear to properly refactor things. If it's just about that, we could extract the model state outside of dump checkpoint:

(in save_checkpoint, rough)

        model_state = self.trainer.accelerator.model_state(model)
        # dump states as a checkpoint dictionary object
        checkpoint = self.dump_checkpoint(model_state, weights_only)
        if self.trainer.is_global_zero:
            checkpoint = self.on_save(checkpoint)
            try:
                atomic_save(checkpoint, filepath)
            ...

But I'm sure there are other requirements, so we should state them clearly

@tchaton
Copy link
Contributor Author

tchaton commented Mar 26, 2021

I'm probably missing some pieces here, but in the DeepSpeed Plugin we need to save the checkpoints differently to the standard saving of one file from rank 0. Each process needs to save it's weights to a directory, thus this logic needs to be managed by either the training type plugin, or a separate class!

We should then make sure we have the requirements clear to properly refactor things. If it's just about that, we could extract the model state outside of dump checkpoint:

(in save_checkpoint, rough)

        model_state = self.trainer.accelerator.model_state(model)
        # dump states as a checkpoint dictionary object
        checkpoint = self.dump_checkpoint(model_state, weights_only)
        if self.trainer.is_global_zero:
            checkpoint = self.on_save(checkpoint)
            try:
                atomic_save(checkpoint, filepath)
            ...

But I'm sure there are other requirements, so we should state them clearly

The code you shared doesn't adapt so well with DeepSpeed and FSDP I guess.
I find it cleaner to leave the entire responsibility to the training type plugin, especially as more will be added.

    def save_checkpoint(self, trainer: 'pl.Trainer', filepath: str, weights_only: bool = False) -> None:
        """Save model/training states as a checkpoint file through state-dump and file-write.
        Args:
            filepath: write-target file's path
            weights_only: saving model weights only
        """
        if torch.distributed.get_world_size() > 1:
            # Use deepspeed's internal checkpointing function to handle partitioned weights across processes
            # dump states as a checkpoint dictionary object
            client_state = dump_checkpoint(trainer, weights_only)
            save_dir = self._filepath_to_dir(filepath)
            _exclude_keys = ['state_dict', 'optimizer_states', 'lr_schedulers']
            client_state = {k: v for k, v in client_state.items() if k not in _exclude_keys}
            self.deepspeed_engine.save_checkpoint(save_dir, client_state=client_state)

        else:
            super().save_checkpoint(trainer, filepath, weights_only)

@tchaton tchaton enabled auto-merge (squash) March 26, 2021 18:14
Copy link
Contributor

@ananthsub ananthsub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • we should spec out what the path for loading will be too. i think it's going to be much trickier especially with resharding to determine what the execution order should look like. since we're saving to a directory, what does last.ckpt mean anymore? do we have last-{rank}.ckpt ? how will this be specified on the trainer params? do we now take a directory instead of a specific path? what happens if that directory contains other checkpoint files that are unrelated to the one at hand? would we accidentally load those too?

  • the main difference with deepspeed and FSDP is that the model and optimizer states can vary across ranks. should the other pieces of the checkpoint dict (trainer progress, callback states) remain inside of the checkpoint connector? or should everything be part of the trainer type plugin?

  • not specific to this PR, but this is an existing inefficiency: we we potentially call this multiple times inside of the checkpoint callback (1 for save top K, 1 for save last) if the existing saved file is available, we should copy over to last.ckpt instead of going through the whole dumping again

  • Could we get rid of this and rely directly on the training type plugin? https://github.com/PyTorchLightning/pytorch-lightning/blob/21fc5eb21e6db07bcc222afa4204b3d5fb5be323/pytorch_lightning/callbacks/model_checkpoint.py#L217

i think this was there for mocks before but I don't think it's needed now

pytorch_lightning/utilities/cloud_io.py Outdated Show resolved Hide resolved
@SeanNaren
Copy link
Contributor

I'll make an RFC for us to properly track the changes here, and explain the motivations since it seems its too low level for us to start at

@tchaton
Copy link
Contributor Author

tchaton commented Mar 26, 2021

  • we should spec out what the path for loading will be too. i think it's going to be much trickier especially with resharding to determine what the execution order should look like. since we're saving to a directory, what does last.ckpt mean anymore? do we have last-{rank}.ckpt ? how will this be specified on the trainer params? do we now take a directory instead of a specific path? what happens if that directory contains other checkpoint files that are unrelated to the one at hand? would we accidentally load those too?
  • the main difference with deepspeed and FSDP is that the model and optimizer states can vary across ranks. should the other pieces of the checkpoint dict (trainer progress, callback states) remain inside of the checkpoint connector? or should everything be part of the trainer type plugin?
  • not specific to this PR, but this is an existing inefficiency: we we potentially call this multiple times inside of the checkpoint callback (1 for save top K, 1 for save last) if the existing saved file is available, we should copy over to last.ckpt instead of going through the whole dumping again
  • Could we get rid of this and rely directly on the training type plugin? https://github.com/PyTorchLightning/pytorch-lightning/blob/21fc5eb21e6db07bcc222afa4204b3d5fb5be323/pytorch_lightning/callbacks/model_checkpoint.py#L217

i think this was there for mocks before but I don't think it's needed now

Hey @ananthsub ,

  • For DeepSpeed and FSDP, we should support an extra argument to `consolidated_checkpoint=True to the plugin. If consolidated_checkpoint is False, the checkpoint_path should be a directory otherwise a file path.

Currently in https://github.com/PyTorchLightning/pytorch-lightning/pull/6546/files

last/
   trainer_state.pt # handled by DeepSpeed as client_cache
   rank_{rank}_model_states.pt
   rank_{rank}_optimizer_states.pt

In Feat/ds update PR: https://github.com/PyTorchLightning/pytorch-lightning/blob/924d9e2c40a7a0af7766d1131b1b963c95b721a3/pytorch_lightning/plugins/training_type/deepspeed.py#L458

  • self.save_function = trainer.save_checkpoint is still needed, as trainer.save_checkpoint -> checkpoint_connector.save_checkpoint -> accelerator.save_checkpoint -> training_type_plugin.save_checkpoint.

IMO, we will clean this API when FSDP will be at the same stage than DeepSpeed around saving / reloading.
Right now, the DeepSpeed PR #6546 is ready to be merged.

Best,
T.C

@SeanNaren
Copy link
Contributor

To track the conversation on what the API should look like I made an RFC here: #6691

@tchaton
Copy link
Contributor Author

tchaton commented Mar 29, 2021

To track the conversation on what the API should look like I made an RFC here: #6691

Yes.
Let's remove the trainer arguments as you suggested.
We can just pass the dump_checkpoint output and leave the logic to the training_type_plugin for saving.
@SeanNaren @ananthsub

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving after speaking with @SeanNaren

Current goal is to keep it simple and not go for just one state-dict per process and instead save the entire checkpoint per-process.

the people who are using this know what they're doing (they've already had to read the documentation to use the model parallel hook). i get the eventual goal of reducing duplication, but i think regardless of the decision you'll have X amount of state dicts saved to disk, so those are just optimizations imo. If we wanted to be really specific we could do:

pl.Trainer(
    plugin=DeepSpeedPlugin(stage=3),
    callbacks=[ShardedCheckpoint(...)]
)

Ideally, the end result for me is:

last/
   trainer_state.pt # contains the basic checkpoint data, saved by the checkpoint connector
   rank_{rank}_model_states.pt # contains the model state, saved by the training_type_plugin
   rank_{rank}_optimizer_states.pt # contains the optimizer state, saved by the training_type_plugin

And a ShardedCheckpoint would be used to read this special checkpoint, where we may convert any present ModelCheckpoint into ShardedCheckpoint

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
trainer: PyTorch Lightning Trainer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
trainer: PyTorch Lightning Trainer
checkpoint: dict containing model and trainer state

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done !

@tchaton tchaton merged commit 646cf2f into master Mar 29, 2021
@tchaton tchaton deleted the move_save_chekpoint branch March 29, 2021 19:02
@mergify mergify bot added the ready PRs ready to be merged label Nov 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready PRs ready to be merged refactor
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants