Skip to content

Commit

Permalink
fix deepspeed lr scheduler state dumping
Browse files Browse the repository at this point in the history
x
x
fix confirmed
  • Loading branch information
awaelchli committed Jan 4, 2022
1 parent 7fa1aeb commit 8b7865c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pytorch_lightning/strategies/deepspeed.py
Expand Up @@ -741,7 +741,7 @@ def save_checkpoint(self, checkpoint: Dict, filepath: _PATH) -> None:
)
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
# dump states as a checkpoint dictionary object
_exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"]
_exclude_keys = ["state_dict", "optimizer_states"]
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint)

Expand Down
2 changes: 2 additions & 0 deletions tests/plugins/test_deepspeed_plugin.py
Expand Up @@ -558,6 +558,8 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if not hasattr(self, "model"):
self.configure_sharded_model()

assert checkpoint["lr_schedulers"] is not None


class ManualModelParallelClassificationModel(ModelParallelClassificationModel):
@property
Expand Down

0 comments on commit 8b7865c

Please sign in to comment.