diff --git a/CHANGELOG.md b/CHANGELOG.md index 96087fad69840..19111ecb3d29a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `DeviceStatsMonitor` to group metrics based on the logger's `group_separator` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254)) +- Raised `UserWarning` if evaluation is triggered with `best` ckpt and trainer is configured with multiple checkpoint callbacks ([#11274](https://github.com/PyTorchLightning/pytorch-lightning/pull/11274)) + + - `Trainer.logged_metrics` now always contains scalar tensors, even when a Python scalar was logged ([#11270](https://github.com/PyTorchLightning/pytorch-lightning/pull/11270)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5a6f67c9ea0d3..d89ad75411c8c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1381,16 +1381,22 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_ ckpt_path = "best" if ckpt_path == "best": - # if user requests the best checkpoint but we don't have it, error + if len(self.checkpoint_callbacks) > 1: + rank_zero_warn( + f'`.{fn}(ckpt_path="best")` is called with Trainer configured with multiple `ModelCheckpoint`' + " callbacks. It will use the best checkpoint path from first checkpoint callback." + ) + if not self.checkpoint_callback: raise MisconfigurationException( f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.' ) + if not self.checkpoint_callback.best_model_path: if self.fast_dev_run: raise MisconfigurationException( - f"You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do" - f" `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting." + f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True`.' + f" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`" ) raise MisconfigurationException( f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.' diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 52bd2305d74ca..281afae7c30f7 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -774,6 +774,21 @@ def predict_step(self, batch, *_): trainer_fn(model, ckpt_path="best") +def test_best_ckpt_evaluate_raises_warning_with_multiple_ckpt_callbacks(): + """Test that a warning is raised if best ckpt callback is used for evaluation configured with multiple + checkpoints.""" + + ckpt_callback1 = ModelCheckpoint() + ckpt_callback1.best_model_path = "foo_best_model.ckpt" + ckpt_callback2 = ModelCheckpoint() + ckpt_callback2.best_model_path = "bar_best_model.ckpt" + trainer = Trainer(callbacks=[ckpt_callback1, ckpt_callback2]) + trainer.state.fn = TrainerFn.TESTING + + with pytest.warns(UserWarning, match="best checkpoint path from first checkpoint callback"): + trainer._Trainer__set_ckpt_path(ckpt_path="best", model_provided=False, model_connected=True) + + def test_disabled_training(tmpdir): """Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`.""" @@ -1799,15 +1814,11 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None: trainer.fit(model, datamodule=dm) -def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir): - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - model = BoringModel() - trainer.fit(model) - - with pytest.raises(MisconfigurationException, match=r"\.validate\(\)` with `fast_dev_run=True"): - trainer.validate() - with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"): - trainer.test() +def test_exception_when_testing_or_validating_with_fast_dev_run(): + trainer = Trainer(fast_dev_run=True) + trainer.state.fn = TrainerFn.TESTING + with pytest.raises(MisconfigurationException, match=r"with `fast_dev_run=True`. .* pass an exact checkpoint path"): + trainer._Trainer__set_ckpt_path(ckpt_path="best", model_provided=False, model_connected=True) class TrainerStagesModel(BoringModel):