Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,6 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
filepath: write-target file's path
"""
if self.zero_stage_3 and self._multi_device and self.is_global_zero:
# todo (sean): Add link to docs once docs are merged.
warning_cache.warn(
"When saving the DeepSpeed Stage 3 checkpoint, "
"each worker will save a shard of the checkpoint within a directory. "
Expand Down
9 changes: 7 additions & 2 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,15 +402,20 @@ def test_deepspeed_fp32_works(tmpdir):

@RunIf(min_gpus=2, deepspeed=True, special=True)
def test_deepspeed_stage_3_save_warning(tmpdir):
"""Test to ensure that DeepSpeed Stage 3 gives a warning when saving."""
"""Test to ensure that DeepSpeed Stage 3 gives a warning when saving on rank zero."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16
)
trainer.fit(model)
checkpoint_path = os.path.join(tmpdir, "model.pt")
with pytest.warns(UserWarning, match="each worker will save a shard of the checkpoint within a directory."):
with pytest.warns(UserWarning) as record:
# both ranks need to call save checkpoint
trainer.save_checkpoint(checkpoint_path)
if trainer.is_global_zero:
assert len(record) == 1
match = "each worker will save a shard of the checkpoint within a directory."
assert match in str(record[0].message)


@RunIf(min_gpus=1, deepspeed=True, special=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def training_epoch_end(self, outputs) -> None:


def test_batch_loop_releases_loss(tmpdir):
"""Test that loss/graph is released so that it can be garbage collected before the next training step"""
"""Test that loss/graph is released so that it can be garbage collected before the next training step."""

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
Expand Down