Skip to content

Conversation

@LTMeyer
Copy link

@LTMeyer LTMeyer commented Oct 20, 2025

What does this PR do?

This PR fixes learning rate not being updated at the end of epoch when on_train_batch_start returns -1.

Fixes #21296

It postpones the existing raise StopIteration after the learning rate update.

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs): Related issue LR is not updated when on_train_batch_start returns -1 #21296 hasn't been discussed yet.
  • Did you read the contributor guideline, Pull Request section? Yes
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary) No
  • Did you write any new necessary tests? (not for typos and docs) Yes
  • Did you verify new and existing tests pass locally with your changes? Yes
  • Did you list all the breaking changes introduced by this pull request? None
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors) Yes

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--21307.org.readthedocs.build/en/21307/

@github-actions github-actions bot added fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package labels Oct 20, 2025
Copy link
Collaborator

@SkafteNicki SkafteNicki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we please also add some testing that the changed logic then correctly updates the learning rate when response=-1?

should_skip_rest_of_epoch = response == -1
# Signal this is the last batch for the current epoch
if should_skip_rest_of_epoch:
self.batch_progress.increment_by(0, is_last_batch=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the logic here from changing from
self.batch_progress.increment_processed()
to
self.batch_progress.increment_by(0, is_last_batch=True)
?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the logic here from changing from self.batch_progress.increment_processed()
to self.batch_progress.increment_by(0, is_last_batch=True)?

batch_progress.increment_by is the only method that can set is_last_batch to True, which is required to trigger the update of lrs in case of IterableDataset.
The increment_processed only increments the counters. In case of IterableDataset, for which the expected number of batches is not known, this may not be enough to detect the epoch has ended.
Indeed, the lrs are later updated only if num_ready_batches_reached returns True. It does return True if epoch_finished_on_ready or is_last_batch are True.

Could we please also add some testing that the changed logic then correctly updates the learning rate when response=-1?

I'll do it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SkafteNicki I've just added tests. Not sure about the most appropriate location though.
10d18d9 actually added tests:

  • to check last_batch is set when rest of epoch was skipped.
  • to check lr is being updated at the end of epoch when on_train_batch_start returns -1.

I've checked that these tests were indeed failing before the changes introduced in the current PR.

When `on_train_batch_start` returns -1, the rest of the epoch is skipped.
The lr update should still happen at the end of the epoch.
- Test is_last_batch has been set correctly
- Test lr has been updated at the end of each epoch
@LTMeyer LTMeyer requested a review from SkafteNicki October 27, 2025 08:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LR is not updated when on_train_batch_start returns -1

2 participants