diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 9d4323548cb7e..ea0e8f76461da 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -162,6 +162,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue to avoid the impact of sanity check on `reload_dataloaders_every_n_epochs` for validation ([#13964](https://github.com/Lightning-AI/lightning/pull/13964)) +- Fixed `Trainer.estimated_stepping_batches` when maximum number of epochs is not set ([#14317](https://github.com/Lightning-AI/lightning/pull/14317)) + + ## [1.7.2] - 2022-08-17 ### Added diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 08fade4021a8b..963c44dde21b9 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -2729,8 +2729,8 @@ def configure_optimizers(self): ) # infinite training - if self.max_epochs == -1 and self.max_steps == -1: - return float("inf") + if self.max_epochs == -1: + return float("inf") if self.max_steps == -1 else self.max_steps if self.train_dataloader is None: rank_zero_info("Loading `train_dataloader` to estimate number of stepping batches.") diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py index 177d2034a0273..72c07ec0790c2 100644 --- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py @@ -94,9 +94,9 @@ def test_num_stepping_batches_infinite_training(): assert trainer.estimated_stepping_batches == float("inf") -def test_num_stepping_batches_with_max_steps(): +@pytest.mark.parametrize("max_steps", [2, 100]) +def test_num_stepping_batches_with_max_steps(max_steps): """Test stepping batches with `max_steps`.""" - max_steps = 2 trainer = Trainer(max_steps=max_steps) model = BoringModel() trainer.fit(model)