diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b58ee426a39b..d90e008fa1ef7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -155,10 +155,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue to ensure all the checkpoint states are saved in a common filepath with `DeepspeedStrategy` ([#12887](https://github.com/PyTorchLightning/pytorch-lightning/pull/12887)) -- Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/PyTorchLightning/pytorch-lightning/pull/12653)) +- Fixed an issue wrt recursive invocation of DDP configuration in hpu parallel plugin ([#12912](https://github.com/PyTorchLightning/pytorch-lightning/pull/12912)) + +- Fixed fit loop restart logic to enable resume using the checkpoint ([#12821](https://github.com/PyTorchLightning/pytorch-lightning/pull/12821) -- Fixed an issue wrt recursive invocation of DDP configuration in hpu parallel plugin ([#12912](https://github.com/PyTorchLightning/pytorch-lightning/pull/12912)) + +- Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/PyTorchLightning/pytorch-lightning/pull/12653)) - Fixed an issue where sharded grad scaler is passed in when using BF16 with the `ShardedStrategy` ([#12915](https://github.com/PyTorchLightning/pytorch-lightning/pull/12915)) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index db3f60fb28ede..40334387c0688 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -123,15 +123,10 @@ def running_loss(self) -> TensorRunningAccum: @Loop.restarting.setter def restarting(self, restarting: bool) -> None: - # if the last epoch completely finished, we are not actually restarting, we can check this to see if all - # current values are equal - values = ( - self.epoch_progress.current.ready, - self.epoch_progress.current.started, - self.epoch_progress.current.processed, - ) - finished_before_on_train_end = any(v != self.epoch_progress.current.completed for v in values) - restarting &= finished_before_on_train_end + # if the last epoch completely finished, we are not actually restarting + values = self.epoch_progress.current.ready, self.epoch_progress.current.started + epoch_unfinished = any(v != self.epoch_progress.current.processed for v in values) + restarting = restarting and epoch_unfinished or self._iteration_based_training() Loop.restarting.fset(self, restarting) # call the parent setter @property @@ -205,6 +200,10 @@ def reset(self) -> None: def on_run_start(self) -> None: # type: ignore[override] """Calls the ``on_train_start`` hook.""" + # update the current_epoch in-case of checkpoint reload + if not self._iteration_based_training(): + self.epoch_progress.current.completed = self.epoch_progress.current.processed + # reset train dataloader and val dataloader self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module) @@ -336,6 +335,9 @@ def _should_accumulate(self) -> bool: """Whether the gradients should be accumulated.""" return self.epoch_loop._should_accumulate() + def _iteration_based_training(self) -> bool: + return self.trainer.max_steps != -1 + def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]: training_step_fx = getattr(trainer.lightning_module, "training_step") diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 0ebaab553b593..6c9dab1480d3f 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -620,8 +620,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir): "state_dict": ANY, "loops": ANY, } - saved_ckpt1 = {**loaded_ckpt, "global_step": 2, "epoch": 0} - saved_ckpt2 = {**loaded_ckpt, "global_step": 4, "epoch": 1} + saved_ckpt = {**loaded_ckpt, "global_step": 4, "epoch": 1} expected = [ dict(name="Callback.on_init_start", args=(trainer,)), dict(name="Callback.on_init_end", args=(trainer,)), @@ -651,23 +650,12 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir): dict(name="on_epoch_start"), dict(name="Callback.on_train_epoch_start", args=(trainer, model)), dict(name="on_train_epoch_start"), - dict(name="Callback.on_train_epoch_end", args=(trainer, model)), - dict(name="Callback.state_dict"), - dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt1)), - dict(name="on_save_checkpoint", args=(saved_ckpt1,)), - dict(name="on_train_epoch_end"), - dict(name="Callback.on_epoch_end", args=(trainer, model)), - dict(name="on_epoch_end"), - dict(name="Callback.on_epoch_start", args=(trainer, model)), - dict(name="on_epoch_start"), - dict(name="Callback.on_train_epoch_start", args=(trainer, model)), - dict(name="on_train_epoch_start"), *model._train_batch(trainer, model, 2, current_epoch=1, current_batch=0), dict(name="training_epoch_end", args=([dict(loss=ANY)] * 2,)), dict(name="Callback.on_train_epoch_end", args=(trainer, model)), dict(name="Callback.state_dict"), - dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt2)), - dict(name="on_save_checkpoint", args=(saved_ckpt2,)), + dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)), + dict(name="on_save_checkpoint", args=(saved_ckpt,)), dict(name="on_train_epoch_end"), dict(name="Callback.on_epoch_end", args=(trainer, model)), dict(name="on_epoch_end"), diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 0d6c9772b9c45..136e8ee516bbb 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -199,7 +199,7 @@ def on_train_start(self): if self.trainer.state.fn == TrainerFn.TUNING: self._test_on_val_test_predict_tune_start() else: - assert self.trainer.current_epoch == state_dict["epoch"] + assert self.trainer.current_epoch == state_dict["epoch"] + 1 assert self.trainer.global_step == state_dict["global_step"] assert self._check_model_state_dict() assert self._check_optimizers()