Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lr scheduler state not being dumped to checkpoint in deepspeed strategy #11307

Merged
merged 5 commits into from Jan 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -403,6 +403,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a race condition that could result in incorrect (zero) values being observed in prediction writer callbacks ([#11288](https://github.com/PyTorchLightning/pytorch-lightning/pull/11288))


- Fixed the lr-scheduler state not being dumped to checkpoint when using the deepspeed strategy ([#11307](https://github.com/PyTorchLightning/pytorch-lightning/pull/11307))


## [1.5.7] - 2021-12-21

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/strategies/deepspeed.py
Expand Up @@ -741,7 +741,7 @@ def save_checkpoint(self, checkpoint: Dict, filepath: _PATH) -> None:
)
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
# dump states as a checkpoint dictionary object
_exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"]
_exclude_keys = ["state_dict", "optimizer_states"]
carmocca marked this conversation as resolved.
Show resolved Hide resolved
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint)

Expand Down
4 changes: 4 additions & 0 deletions tests/strategies/test_deepspeed_strategy.py
Expand Up @@ -558,6 +558,10 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if not hasattr(self, "model"):
self.configure_sharded_model()

# Lightning saves the lr schedulers, but DeepSpeed saves the optimizer states separately
assert len(checkpoint["lr_schedulers"]) == 1
assert "optimizer_states" not in checkpoint


class ManualModelParallelClassificationModel(ModelParallelClassificationModel):
@property
Expand Down