Skip to content

Commit

Permalink
Raise a warning if evaulation is triggered with best ckpt in case of …
Browse files Browse the repository at this point in the history
…multiple checkpoint callbacks (#11274)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
rohitgr7 and carmocca committed Jan 4, 2022
1 parent 650c710 commit 7eab379
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 12 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `DeviceStatsMonitor` to group metrics based on the logger's `group_separator` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254))


- Raised `UserWarning` if evaluation is triggered with `best` ckpt and trainer is configured with multiple checkpoint callbacks ([#11274](https://github.com/PyTorchLightning/pytorch-lightning/pull/11274))


- `Trainer.logged_metrics` now always contains scalar tensors, even when a Python scalar was logged ([#11270](https://github.com/PyTorchLightning/pytorch-lightning/pull/11270))


Expand Down
12 changes: 9 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Expand Up @@ -1381,16 +1381,22 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
ckpt_path = "best"

if ckpt_path == "best":
# if user requests the best checkpoint but we don't have it, error
if len(self.checkpoint_callbacks) > 1:
rank_zero_warn(
f'`.{fn}(ckpt_path="best")` is called with Trainer configured with multiple `ModelCheckpoint`'
" callbacks. It will use the best checkpoint path from first checkpoint callback."
)

if not self.checkpoint_callback:
raise MisconfigurationException(
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.'
)

if not self.checkpoint_callback.best_model_path:
if self.fast_dev_run:
raise MisconfigurationException(
f"You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do"
f" `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting."
f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True`.'
f" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`"
)
raise MisconfigurationException(
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
Expand Down
29 changes: 20 additions & 9 deletions tests/trainer/test_trainer.py
Expand Up @@ -774,6 +774,21 @@ def predict_step(self, batch, *_):
trainer_fn(model, ckpt_path="best")


def test_best_ckpt_evaluate_raises_warning_with_multiple_ckpt_callbacks():
"""Test that a warning is raised if best ckpt callback is used for evaluation configured with multiple
checkpoints."""

ckpt_callback1 = ModelCheckpoint()
ckpt_callback1.best_model_path = "foo_best_model.ckpt"
ckpt_callback2 = ModelCheckpoint()
ckpt_callback2.best_model_path = "bar_best_model.ckpt"
trainer = Trainer(callbacks=[ckpt_callback1, ckpt_callback2])
trainer.state.fn = TrainerFn.TESTING

with pytest.warns(UserWarning, match="best checkpoint path from first checkpoint callback"):
trainer._Trainer__set_ckpt_path(ckpt_path="best", model_provided=False, model_connected=True)


def test_disabled_training(tmpdir):
"""Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`."""

Expand Down Expand Up @@ -1799,15 +1814,11 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None:
trainer.fit(model, datamodule=dm)


def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
model = BoringModel()
trainer.fit(model)

with pytest.raises(MisconfigurationException, match=r"\.validate\(\)` with `fast_dev_run=True"):
trainer.validate()
with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"):
trainer.test()
def test_exception_when_testing_or_validating_with_fast_dev_run():
trainer = Trainer(fast_dev_run=True)
trainer.state.fn = TrainerFn.TESTING
with pytest.raises(MisconfigurationException, match=r"with `fast_dev_run=True`. .* pass an exact checkpoint path"):
trainer._Trainer__set_ckpt_path(ckpt_path="best", model_provided=False, model_connected=True)


class TrainerStagesModel(BoringModel):
Expand Down

0 comments on commit 7eab379

Please sign in to comment.