-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix fit loop restart logic to enable resume using the checkpoint #12821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| def on_run_start(self) -> None: # type: ignore[override] | ||
| """Calls the ``on_train_start`` hook.""" | ||
| # update the current_epoch in-case of checkpoint reload | ||
| if not self._iteration_based_training(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
during restart if it's not iteration-based training, we need to update the current epoch so that it starts from a fresh epoch rather than the old one for the cases where checkpoint is reloaded using the one saved before on_train_end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this make the following 2 lines unnecessary?
Since on_run_start is runs before done and stop_epochs is only valid under "not iteration-based"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes! good catch! will update
but even if that's the case, do we increment epoch_progress during iteration-based training? not sure. need to check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this make the following 2 lines unnecessary?
maybe... will check. I remember a test was failing: https://github.com/PyTorchLightning/pytorch-lightning/runs/6101153350?check_suite_focus=true
but it should not I guess since we increment current.completed after every epoch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay looks like done is called within skip, so we need to keep it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @rohitgr7
I don't understand this logic. There is no true iteration based training in Lightning and we always have epochs. We may restart from a completed epoch or from an incomplete epoch, regardless of how the max_ flags on the trainer is set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if a user is using max_steps=7 and uses a checkpoint from step=4 to restart, we need to start from step 5 in that case. Although I guess the dataloaders are re-iterated from the beginning in that case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This explanation is not satisfying. It doesn't answer how max_steps has anything to do with the way we restart.
You could also have max_epochs=1 where the epoch size is 7 (equivalent to max_steps=7) and you would still restore the checkpoint on step 4 the exact same way.
max_steps / max_epochs are the stopping conditions. Them affecting the way we restart is beyond my understanding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could also have max_epochs=1 where the epoch size is 7 (equivalent to max_steps=7) and you would still restore the checkpoint on step 4 the exact same way.
yes, it does but in that case we do start from step=4 but on an entirely new epoch. The reason we need to separate this a little is that we don't need to update the current epoch. Although just noticed that we do update the current_epoch at the end even if the user is performing training based on max_steps. Do you think we should remove this condition and keep incrementing the current epoch on each restart even though the training is based on max_steps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whatever the fix is, it will not at all be conditioned on max_steps / max_epochs trainer flags. I don't know how to solve this problem and we need to brainstorm this what should be fixed.
All I know right now is that this PR did something weird 😅 that can't be the fix.
|
|
| def on_run_start(self) -> None: # type: ignore[override] | ||
| """Calls the ``on_train_start`` hook.""" | ||
| # update the current_epoch in-case of checkpoint reload | ||
| if not self._iteration_based_training(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this make the following 2 lines unnecessary?
Since on_run_start is runs before done and stop_epochs is only valid under "not iteration-based"
| def on_run_start(self) -> None: # type: ignore[override] | ||
| """Calls the ``on_train_start`` hook.""" | ||
| # update the current_epoch in-case of checkpoint reload | ||
| if not self._iteration_based_training(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @rohitgr7
I don't understand this logic. There is no true iteration based training in Lightning and we always have epochs. We may restart from a completed epoch or from an incomplete epoch, regardless of how the max_ flags on the trainer is set.
|
|
||
| def on_run_start(self) -> None: # type: ignore[override] | ||
| """Calls the ``on_train_start`` hook.""" | ||
| # update the current_epoch in-case of checkpoint reload |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in case
…ckpoint - #12821 Summary: patch fix the PR Lightning-AI/pytorch-lightning#12821 Reviewed By: hudeven, rayhou0710 Differential Revision: D36193410 fbshipit-source-id: 0adf2d3e5202ed85d8d0a305906df9be2ee696c3
…ckpoint - #12821 (#1249) Summary: Pull Request resolved: #1249 patch fix the PR Lightning-AI/pytorch-lightning#12821 Reviewed By: hudeven, rayhou0710 Differential Revision: D36193410 fbshipit-source-id: 6594c4daf8fe5be1eaec72d42c45789f5da36125
What does this PR do?
Fixes #12724
Does your PR introduce any breaking changes? If yes, please list them.
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃
cc @Borda @tchaton @rohitgr7 @carmocca @justusschock @ananthsub @ninginthecloud