diff --git a/CHANGELOG.md b/CHANGELOG.md index 35404e85a5816..57e717a91b017 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -203,9 +203,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed bug where data-loading functions where not getting the correct running stage passed ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858)) + - Fixed a bug in the binary search mode of auto batch size scaling where exception was thrown if the first trainer run resulted in OOM ([#8954](https://github.com/PyTorchLightning/pytorch-lightning/pull/8954)) +- Fixed not setting a default value for `max_epochs` if `max_time` was specified on the `Trainer` constructor ([#9072](https://github.com/PyTorchLightning/pytorch-lightning/pull/9072)) + + ## [1.4.3] - 2021-08-17 - Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 739e34aa249b5..c80c368aaad4e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -378,8 +378,8 @@ def __init__( self.tuner = Tuner(self) fit_loop = FitLoop( - min_epochs=(1 if (min_epochs is None and min_steps is None) else min_epochs), - max_epochs=(1000 if (max_epochs is None and max_steps is None) else max_epochs), + min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs), + max_epochs=(1000 if (max_epochs is None and max_steps is None and max_time is None) else max_epochs), ) training_epoch_loop = TrainingEpochLoop(min_steps, max_steps) training_batch_loop = TrainingBatchLoop() diff --git a/tests/callbacks/test_timer.py b/tests/callbacks/test_timer.py index 44d3c305bb1ac..c7b636d3f843a 100644 --- a/tests/callbacks/test_timer.py +++ b/tests/callbacks/test_timer.py @@ -42,6 +42,11 @@ def on_fit_start(self): trainer.fit(TestModel()) assert "callbacks list already contains a Timer" in caplog.text + seconds = 1 + trainer = Trainer(max_time=dict(seconds=seconds)) + assert trainer.max_epochs is None + assert trainer.max_steps is None + @pytest.mark.parametrize( "duration,expected",