diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index 452f3c8e1a8b48..4cd108a99821d4 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -741,7 +741,7 @@ def save_checkpoint(self, checkpoint: Dict, filepath: _PATH) -> None: ) # Use deepspeed's internal checkpointing function to handle partitioned weights across processes # dump states as a checkpoint dictionary object - _exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"] + _exclude_keys = ["state_dict", "optimizer_states"] checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index c349e40b2aa4cc..3af1ffbf59d4d8 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -558,6 +558,8 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: if not hasattr(self, "model"): self.configure_sharded_model() + assert checkpoint["lr_schedulers"] is not None + class ManualModelParallelClassificationModel(ModelParallelClassificationModel): @property