Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue to ensure all the checkpoint states are saved in a common filepath with `DeepspeedStrategy` ([#12887](https://github.com/PyTorchLightning/pytorch-lightning/pull/12887))


- Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/PyTorchLightning/pytorch-lightning/pull/12653))
- Fixed an issue wrt recursive invocation of DDP configuration in hpu parallel plugin ([#12912](https://github.com/PyTorchLightning/pytorch-lightning/pull/12912))


- Fixed fit loop restart logic to enable resume using the checkpoint ([#12821](https://github.com/PyTorchLightning/pytorch-lightning/pull/12821)

- Fixed an issue wrt recursive invocation of DDP configuration in hpu parallel plugin ([#12912](https://github.com/PyTorchLightning/pytorch-lightning/pull/12912))

- Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/PyTorchLightning/pytorch-lightning/pull/12653))


- Fixed an issue where sharded grad scaler is passed in when using BF16 with the `ShardedStrategy` ([#12915](https://github.com/PyTorchLightning/pytorch-lightning/pull/12915))
Expand Down
20 changes: 11 additions & 9 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,10 @@ def running_loss(self) -> TensorRunningAccum:

@Loop.restarting.setter
def restarting(self, restarting: bool) -> None:
# if the last epoch completely finished, we are not actually restarting, we can check this to see if all
# current values are equal
values = (
self.epoch_progress.current.ready,
self.epoch_progress.current.started,
self.epoch_progress.current.processed,
)
finished_before_on_train_end = any(v != self.epoch_progress.current.completed for v in values)
restarting &= finished_before_on_train_end
# if the last epoch completely finished, we are not actually restarting
values = self.epoch_progress.current.ready, self.epoch_progress.current.started
epoch_unfinished = any(v != self.epoch_progress.current.processed for v in values)
restarting = restarting and epoch_unfinished or self._iteration_based_training()
Loop.restarting.fset(self, restarting) # call the parent setter

@property
Expand Down Expand Up @@ -205,6 +200,10 @@ def reset(self) -> None:

def on_run_start(self) -> None: # type: ignore[override]
"""Calls the ``on_train_start`` hook."""
# update the current_epoch in-case of checkpoint reload
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in case

if not self._iteration_based_training():
Copy link
Contributor Author

@rohitgr7 rohitgr7 Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

during restart if it's not iteration-based training, we need to update the current epoch so that it starts from a fresh epoch rather than the old one for the cases where checkpoint is reloaded using the one saved before on_train_end.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this make the following 2 lines unnecessary?

https://github.com/PyTorchLightning/pytorch-lightning/blob/46c59d04db4156ae98e184e1d9321932f7e2ebf7/pytorch_lightning/loops/fit_loop.py#L169-L171

Since on_run_start is runs before done and stop_epochs is only valid under "not iteration-based"

Copy link
Contributor Author

@rohitgr7 rohitgr7 Apr 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! good catch! will update

but even if that's the case, do we increment epoch_progress during iteration-based training? not sure. need to check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this make the following 2 lines unnecessary?

maybe... will check. I remember a test was failing: https://github.com/PyTorchLightning/pytorch-lightning/runs/6101153350?check_suite_focus=true

but it should not I guess since we increment current.completed after every epoch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay looks like done is called within skip, so we need to keep it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @rohitgr7
I don't understand this logic. There is no true iteration based training in Lightning and we always have epochs. We may restart from a completed epoch or from an incomplete epoch, regardless of how the max_ flags on the trainer is set.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a user is using max_steps=7 and uses a checkpoint from step=4 to restart, we need to start from step 5 in that case. Although I guess the dataloaders are re-iterated from the beginning in that case.

Copy link
Contributor

@awaelchli awaelchli May 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This explanation is not satisfying. It doesn't answer how max_steps has anything to do with the way we restart.

You could also have max_epochs=1 where the epoch size is 7 (equivalent to max_steps=7) and you would still restore the checkpoint on step 4 the exact same way.

max_steps / max_epochs are the stopping conditions. Them affecting the way we restart is beyond my understanding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also have max_epochs=1 where the epoch size is 7 (equivalent to max_steps=7) and you would still restore the checkpoint on step 4 the exact same way.

yes, it does but in that case we do start from step=4 but on an entirely new epoch. The reason we need to separate this a little is that we don't need to update the current epoch. Although just noticed that we do update the current_epoch at the end even if the user is performing training based on max_steps. Do you think we should remove this condition and keep incrementing the current epoch on each restart even though the training is based on max_steps?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whatever the fix is, it will not at all be conditioned on max_steps / max_epochs trainer flags. I don't know how to solve this problem and we need to brainstorm this what should be fixed.

All I know right now is that this PR did something weird 😅 that can't be the fix.

self.epoch_progress.current.completed = self.epoch_progress.current.processed

# reset train dataloader and val dataloader
self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)

Expand Down Expand Up @@ -336,6 +335,9 @@ def _should_accumulate(self) -> bool:
"""Whether the gradients should be accumulated."""
return self.epoch_loop._should_accumulate()

def _iteration_based_training(self) -> bool:
return self.trainer.max_steps != -1


def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]:
training_step_fx = getattr(trainer.lightning_module, "training_step")
Expand Down
18 changes: 3 additions & 15 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,8 +620,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir):
"state_dict": ANY,
"loops": ANY,
}
saved_ckpt1 = {**loaded_ckpt, "global_step": 2, "epoch": 0}
saved_ckpt2 = {**loaded_ckpt, "global_step": 4, "epoch": 1}
saved_ckpt = {**loaded_ckpt, "global_step": 4, "epoch": 1}
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
Expand Down Expand Up @@ -651,23 +650,12 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir):
dict(name="on_epoch_start"),
dict(name="Callback.on_train_epoch_start", args=(trainer, model)),
dict(name="on_train_epoch_start"),
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
dict(name="Callback.state_dict"),
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt1)),
dict(name="on_save_checkpoint", args=(saved_ckpt1,)),
dict(name="on_train_epoch_end"),
dict(name="Callback.on_epoch_end", args=(trainer, model)),
dict(name="on_epoch_end"),
dict(name="Callback.on_epoch_start", args=(trainer, model)),
dict(name="on_epoch_start"),
dict(name="Callback.on_train_epoch_start", args=(trainer, model)),
dict(name="on_train_epoch_start"),
*model._train_batch(trainer, model, 2, current_epoch=1, current_batch=0),
dict(name="training_epoch_end", args=([dict(loss=ANY)] * 2,)),
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
dict(name="Callback.state_dict"),
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt2)),
dict(name="on_save_checkpoint", args=(saved_ckpt2,)),
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)),
dict(name="on_save_checkpoint", args=(saved_ckpt,)),
dict(name="on_train_epoch_end"),
dict(name="Callback.on_epoch_end", args=(trainer, model)),
dict(name="on_epoch_end"),
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def on_train_start(self):
if self.trainer.state.fn == TrainerFn.TUNING:
self._test_on_val_test_predict_tune_start()
else:
assert self.trainer.current_epoch == state_dict["epoch"]
assert self.trainer.current_epoch == state_dict["epoch"] + 1
assert self.trainer.global_step == state_dict["global_step"]
assert self._check_model_state_dict()
assert self._check_optimizers()
Expand Down