-
Couldn't load subscription status.
- Fork 3.6k
LRs updates are called at the end of a skipped epoch #21307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
LRs updates are called at the end of a skipped epoch #21307
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we please also add some testing that the changed logic then correctly updates the learning rate when response=-1?
| should_skip_rest_of_epoch = response == -1 | ||
| # Signal this is the last batch for the current epoch | ||
| if should_skip_rest_of_epoch: | ||
| self.batch_progress.increment_by(0, is_last_batch=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the logic here from changing from
self.batch_progress.increment_processed()
to
self.batch_progress.increment_by(0, is_last_batch=True)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the logic here from changing from
self.batch_progress.increment_processed()
toself.batch_progress.increment_by(0, is_last_batch=True)?
batch_progress.increment_by is the only method that can set is_last_batch to True, which is required to trigger the update of lrs in case of IterableDataset.
The increment_processed only increments the counters. In case of IterableDataset, for which the expected number of batches is not known, this may not be enough to detect the epoch has ended.
Indeed, the lrs are later updated only if num_ready_batches_reached returns True. It does return True if epoch_finished_on_ready or is_last_batch are True.
Could we please also add some testing that the changed logic then correctly updates the learning rate when response=-1?
I'll do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SkafteNicki I've just added tests. Not sure about the most appropriate location though.
10d18d9 actually added tests:
- to check
last_batchis set when rest of epoch was skipped. - to check lr is being updated at the end of epoch when
on_train_batch_startreturns -1.
I've checked that these tests were indeed failing before the changes introduced in the current PR.
When `on_train_batch_start` returns -1, the rest of the epoch is skipped. The lr update should still happen at the end of the epoch. - Test is_last_batch has been set correctly - Test lr has been updated at the end of each epoch
What does this PR do?
This PR fixes learning rate not being updated at the end of epoch when
on_train_batch_startreturns -1.Fixes #21296
It postpones the existing
raise StopIterationafter the learning rate update.Before submitting
on_train_batch_startreturns -1 #21296 hasn't been discussed yet.PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--21307.org.readthedocs.build/en/21307/