diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 303cc985aad3a..36d487e3419c7 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -667,7 +667,6 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: filepath: write-target file's path """ if self.zero_stage_3 and self._multi_device and self.is_global_zero: - # todo (sean): Add link to docs once docs are merged. warning_cache.warn( "When saving the DeepSpeed Stage 3 checkpoint, " "each worker will save a shard of the checkpoint within a directory. " diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 9b4a1f8a4ba99..b0268f24177ef 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -402,15 +402,20 @@ def test_deepspeed_fp32_works(tmpdir): @RunIf(min_gpus=2, deepspeed=True, special=True) def test_deepspeed_stage_3_save_warning(tmpdir): - """Test to ensure that DeepSpeed Stage 3 gives a warning when saving.""" + """Test to ensure that DeepSpeed Stage 3 gives a warning when saving on rank zero.""" model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16 ) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, "model.pt") - with pytest.warns(UserWarning, match="each worker will save a shard of the checkpoint within a directory."): + with pytest.warns(UserWarning) as record: + # both ranks need to call save checkpoint trainer.save_checkpoint(checkpoint_path) + if trainer.is_global_zero: + assert len(record) == 1 + match = "each worker will save a shard of the checkpoint within a directory." + assert match in str(record[0].message) @RunIf(min_gpus=1, deepspeed=True, special=True) diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index c37681e4831ca..d21e8efc7a5cb 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -191,7 +191,7 @@ def training_epoch_end(self, outputs) -> None: def test_batch_loop_releases_loss(tmpdir): - """Test that loss/graph is released so that it can be garbage collected before the next training step""" + """Test that loss/graph is released so that it can be garbage collected before the next training step.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx):