diff --git a/CHANGELOG.md b/CHANGELOG.md index ea760c8fccfef..989ce1a773b4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700)) +- Added support for returning a single Callback from `LightningModule.configure_callbacks` without wrapping it into a list ([#11060](https://github.com/PyTorchLightning/pytorch-lightning/issues/11060)) + + - Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c7b6d1ced35e1..3b1f57853a164 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -21,7 +21,7 @@ import tempfile from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Sequence, Tuple, Union import torch from torch import ScriptModule, Tensor @@ -31,6 +31,7 @@ from typing_extensions import Literal import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks.progress import base as progress_base from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin @@ -1119,15 +1120,16 @@ def predicts_step(self, batch, batch_idx, dataloader_idx=0): """ return self(batch) - def configure_callbacks(self): + def configure_callbacks(self) -> Union[Sequence[Callback], Callback]: """Configure model-specific callbacks. When the model gets attached, e.g., when ``.fit()`` or ``.test()`` - gets called, the list returned here will be merged with the list of callbacks passed to the Trainer's - ``callbacks`` argument. If a callback returned here has the same type as one or several callbacks already - present in the Trainer's callbacks list, it will take priority and replace them. In addition, Lightning - will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks run last. + gets called, the list or a callback returned here will be merged with the list of callbacks passed to the + Trainer's ``callbacks`` argument. If a callback returned here has the same type as one or several callbacks + already present in the Trainer's callbacks list, it will take priority and replace them. In addition, + Lightning will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks + run last. Return: - A list of callbacks which will extend the list of callbacks in the Trainer. + A callback or a list of callbacks which will extend the list of callbacks in the Trainer. Example:: diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 3065685548f15..74f55c16edad9 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -13,7 +13,7 @@ # limitations under the License. import os from datetime import timedelta -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Sequence, Union from pytorch_lightning.callbacks import ( Callback, @@ -272,6 +272,8 @@ def _attach_model_callbacks(self) -> None: model_callbacks = self.trainer._call_lightning_module_hook("configure_callbacks") if not model_callbacks: return + + model_callbacks = [model_callbacks] if not isinstance(model_callbacks, Sequence) else model_callbacks model_callback_types = {type(c) for c in model_callbacks} trainer_callback_types = {type(c) for c in self.trainer.callbacks} override_types = model_callback_types.intersection(trainer_callback_types) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 68e28890abc79..d97af8e211249 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -83,7 +83,7 @@ def test_configure_callbacks_hook_multiple_calls(tmpdir): class TestModel(BoringModel): def configure_callbacks(self): - return [model_callback_mock] + return model_callback_mock model = TestModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, enable_checkpointing=False)