-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix fit loop restart logic to enable resume using the checkpoint #12821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5ad881b
2a6ee96
b63fbfe
6f37ed1
66dc601
6148687
17ebfd0
a93e8d8
c000630
4923a52
25cc29a
5b956d1
be291cb
8117596
bc1fcad
6ce61de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -123,15 +123,10 @@ def running_loss(self) -> TensorRunningAccum: | |
|
|
||
| @Loop.restarting.setter | ||
| def restarting(self, restarting: bool) -> None: | ||
| # if the last epoch completely finished, we are not actually restarting, we can check this to see if all | ||
| # current values are equal | ||
| values = ( | ||
| self.epoch_progress.current.ready, | ||
| self.epoch_progress.current.started, | ||
| self.epoch_progress.current.processed, | ||
| ) | ||
| finished_before_on_train_end = any(v != self.epoch_progress.current.completed for v in values) | ||
| restarting &= finished_before_on_train_end | ||
| # if the last epoch completely finished, we are not actually restarting | ||
| values = self.epoch_progress.current.ready, self.epoch_progress.current.started | ||
| epoch_unfinished = any(v != self.epoch_progress.current.processed for v in values) | ||
| restarting = restarting and epoch_unfinished or self._iteration_based_training() | ||
| Loop.restarting.fset(self, restarting) # call the parent setter | ||
|
|
||
| @property | ||
|
|
@@ -205,6 +200,10 @@ def reset(self) -> None: | |
|
|
||
| def on_run_start(self) -> None: # type: ignore[override] | ||
| """Calls the ``on_train_start`` hook.""" | ||
| # update the current_epoch in-case of checkpoint reload | ||
| if not self._iteration_based_training(): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. during restart if it's not iteration-based training, we need to update the current epoch so that it starts from a fresh epoch rather than the old one for the cases where checkpoint is reloaded using the one saved before
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this make the following 2 lines unnecessary? Since
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes! good catch! will update but even if that's the case, do we increment epoch_progress during iteration-based training? not sure. need to check
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
maybe... will check. I remember a test was failing: https://github.com/PyTorchLightning/pytorch-lightning/runs/6101153350?check_suite_focus=true but it should not I guess since we increment
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay looks like
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey @rohitgr7
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if a user is using
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This explanation is not satisfying. It doesn't answer how max_steps has anything to do with the way we restart. You could also have max_steps / max_epochs are the stopping conditions. Them affecting the way we restart is beyond my understanding.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
yes, it does but in that case we do start from step=4 but on an entirely new epoch. The reason we need to separate this a little is that we don't need to update the current epoch. Although just noticed that we do update the current_epoch at the end even if the user is performing training based on
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whatever the fix is, it will not at all be conditioned on max_steps / max_epochs trainer flags. I don't know how to solve this problem and we need to brainstorm this what should be fixed. All I know right now is that this PR did something weird 😅 that can't be the fix. |
||
| self.epoch_progress.current.completed = self.epoch_progress.current.processed | ||
|
|
||
| # reset train dataloader and val dataloader | ||
| self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module) | ||
|
|
||
|
|
@@ -336,6 +335,9 @@ def _should_accumulate(self) -> bool: | |
| """Whether the gradients should be accumulated.""" | ||
| return self.epoch_loop._should_accumulate() | ||
|
|
||
| def _iteration_based_training(self) -> bool: | ||
| return self.trainer.max_steps != -1 | ||
|
|
||
|
|
||
| def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]: | ||
| training_step_fx = getattr(trainer.lightning_module, "training_step") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in case