From d469978903992ffbc5af5855ac830749e04c7452 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 26 Apr 2022 19:36:27 +0530 Subject: [PATCH 1/3] Fix to ensure the checkpoint states are saved in a common filepath with deepspeed --- pytorch_lightning/strategies/deepspeed.py | 3 +++ tests/strategies/test_deepspeed_strategy.py | 14 ++++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index 7874322a4d74a..b6a353ae15e7b 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -762,6 +762,9 @@ def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Op TypeError: If ``storage_options`` arg is passed in """ + # broadcast the filepath from rank 0 to ensure all the states are saved in a common filepath + filepath = self.broadcast(filepath) + if storage_options is not None: raise TypeError( "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" diff --git a/tests/strategies/test_deepspeed_strategy.py b/tests/strategies/test_deepspeed_strategy.py index 319289d200f4f..e295caa7e622c 100644 --- a/tests/strategies/test_deepspeed_strategy.py +++ b/tests/strategies/test_deepspeed_strategy.py @@ -1269,13 +1269,19 @@ def test_deepspeed_with_meta_device(tmpdir): def test_deepspeed_multi_save_same_filepath(tmpdir): """Test that verifies that deepspeed saves only latest checkpoint in the specified path and deletes the old sharded checkpoints.""" - model = BoringModel() + + class CustomModel(BoringModel): + def training_step(self, *args, **kwargs): + self.log("grank", self.global_rank) + return super().training_step(*args, **kwargs) + + model = CustomModel() trainer = Trainer( default_root_dir=tmpdir, strategy="deepspeed", accelerator="gpu", devices=2, - callbacks=[ModelCheckpoint(save_top_k=1, save_last=True)], + callbacks=[ModelCheckpoint(filename="{epoch}_{step}_{grank}", save_top_k=1, save_last=True)], limit_train_batches=1, limit_val_batches=0, num_sanity_val_steps=0, @@ -1284,6 +1290,10 @@ def test_deepspeed_multi_save_same_filepath(tmpdir): enable_model_summary=False, ) trainer.fit(model) + + expected = ["last.ckpt", "epoch=1_step=2_grank=0.0.ckpt"] + assert set(expected) == set(os.listdir(trainer.checkpoint_callback.dirpath)) + ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, "last.ckpt") expected = ["latest", "zero_to_fp32.py", "checkpoint"] assert set(expected) == set(os.listdir(ckpt_path)) From b47fb351c025cf2d9ff3ad8cffecb9bf9d096ce9 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 26 Apr 2022 19:38:54 +0530 Subject: [PATCH 2/3] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index daee5ae803144..ccda5e46c67e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -154,6 +154,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed support for `ModelCheckpoint` monitors with dots ([#12783](https://github.com/PyTorchLightning/pytorch-lightning/pull/12783)) +- Fixed an issue to ensure all the checkpoint states are saved in a common filepath with `DeepspeedStrategy` ([#12887](https://github.com/PyTorchLightning/pytorch-lightning/pull/12887)) + + ## [1.6.1] - 2022-04-13 ### Changed From 8e17c09178fd54ba6ad211a368c1c23b9b0abf65 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 26 Apr 2022 11:50:21 -0400 Subject: [PATCH 3/3] small improvements --- tests/strategies/test_deepspeed_strategy.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/strategies/test_deepspeed_strategy.py b/tests/strategies/test_deepspeed_strategy.py index e295caa7e622c..dd9bfcea236d2 100644 --- a/tests/strategies/test_deepspeed_strategy.py +++ b/tests/strategies/test_deepspeed_strategy.py @@ -1281,7 +1281,7 @@ def training_step(self, *args, **kwargs): strategy="deepspeed", accelerator="gpu", devices=2, - callbacks=[ModelCheckpoint(filename="{epoch}_{step}_{grank}", save_top_k=1, save_last=True)], + callbacks=[ModelCheckpoint(filename="{epoch}_{step}_{grank}", save_top_k=1)], limit_train_batches=1, limit_val_batches=0, num_sanity_val_steps=0, @@ -1291,12 +1291,13 @@ def training_step(self, *args, **kwargs): ) trainer.fit(model) - expected = ["last.ckpt", "epoch=1_step=2_grank=0.0.ckpt"] - assert set(expected) == set(os.listdir(trainer.checkpoint_callback.dirpath)) + filepath = "epoch=1_step=2_grank=0.0.ckpt" + expected = {filepath} + assert expected == set(os.listdir(trainer.checkpoint_callback.dirpath)) - ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, "last.ckpt") - expected = ["latest", "zero_to_fp32.py", "checkpoint"] - assert set(expected) == set(os.listdir(ckpt_path)) + ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filepath) + expected = {"latest", "zero_to_fp32.py", "checkpoint"} + assert expected == set(os.listdir(ckpt_path)) @RunIf(min_gpus=2, standalone=True, deepspeed=True)