From 8d27d066a07231bb99e9566505529e4264895a98 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 14 Dec 2021 21:44:52 +0530 Subject: [PATCH 1/5] add support for returning callback from LightningModule.configure_callbacks --- pytorch_lightning/core/lightning.py | 14 ++++++++------ .../trainer/connectors/callback_connector.py | 5 ++++- tests/callbacks/test_callbacks.py | 2 +- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 13e12a11f97aa..0a834c3499d55 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -21,7 +21,7 @@ import tempfile from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Sequence, Tuple, Union import torch from torch import ScriptModule, Tensor @@ -31,6 +31,7 @@ from typing_extensions import Literal import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks.progress import base as progress_base from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin @@ -1119,12 +1120,13 @@ def predicts_step(self, batch, batch_idx, dataloader_idx=0): """ return self(batch) - def configure_callbacks(self): + def configure_callbacks(self) -> Optional[Union[Sequence[Callback], Callback]]: """Configure model-specific callbacks. When the model gets attached, e.g., when ``.fit()`` or ``.test()`` - gets called, the list returned here will be merged with the list of callbacks passed to the Trainer's - ``callbacks`` argument. If a callback returned here has the same type as one or several callbacks already - present in the Trainer's callbacks list, it will take priority and replace them. In addition, Lightning - will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks run last. + gets called, the list or a callback returned here will be merged with the list of callbacks passed to the + Trainer's ``callbacks`` argument. If a callback returned here has the same type as one or several callbacks + already present in the Trainer's callbacks list, it will take priority and replace them. In addition, + Lightning will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks + run last. Return: A list of callbacks which will extend the list of callbacks in the Trainer. diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 0c2ccd73a42a0..71947e4924dac 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -13,7 +13,7 @@ # limitations under the License. import os from datetime import timedelta -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Sequence, Union from pytorch_lightning.callbacks import ( Callback, @@ -271,6 +271,9 @@ def _attach_model_callbacks(self) -> None: model_callbacks = self.trainer._call_lightning_module_hook("configure_callbacks") if not model_callbacks: return + + model_callbacks = [model_callbacks] if not isinstance(model_callbacks, Sequence) else model_callbacks + model_callbacks = list(model_callbacks) model_callback_types = {type(c) for c in model_callbacks} trainer_callback_types = {type(c) for c in self.trainer.callbacks} override_types = model_callback_types.intersection(trainer_callback_types) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 68e28890abc79..d97af8e211249 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -83,7 +83,7 @@ def test_configure_callbacks_hook_multiple_calls(tmpdir): class TestModel(BoringModel): def configure_callbacks(self): - return [model_callback_mock] + return model_callback_mock model = TestModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, enable_checkpointing=False) From cf8d02a2a36eaaf23e20adbee278c4d01a331dea Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 14 Dec 2021 21:46:52 +0530 Subject: [PATCH 2/5] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c562604e73512..ce5561ca26cf1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700)) +- Added support for returning a Callback from `LightningModule.configure_callbacks` ([#11060](https://github.com/PyTorchLightning/pytorch-lightning/issues/11060)) + + ### Changed - Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418)) From d3498ac8ce84cc3e2e33f86087ed312c50a8d084 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 15 Dec 2021 01:49:12 +0530 Subject: [PATCH 3/5] not optional --- pytorch_lightning/core/lightning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 0a834c3499d55..a287a1cb3022f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1120,7 +1120,7 @@ def predicts_step(self, batch, batch_idx, dataloader_idx=0): """ return self(batch) - def configure_callbacks(self) -> Optional[Union[Sequence[Callback], Callback]]: + def configure_callbacks(self) -> Union[Sequence[Callback], Callback]: """Configure model-specific callbacks. When the model gets attached, e.g., when ``.fit()`` or ``.test()`` gets called, the list or a callback returned here will be merged with the list of callbacks passed to the Trainer's ``callbacks`` argument. If a callback returned here has the same type as one or several callbacks @@ -1129,7 +1129,7 @@ def configure_callbacks(self) -> Optional[Union[Sequence[Callback], Callback]]: run last. Return: - A list of callbacks which will extend the list of callbacks in the Trainer. + A callback or a list of callbacks which will extend the list of callbacks in the Trainer. Example:: From 7090fb2231e954738c5f10d131011ed3d70b8e21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 18 Dec 2021 03:47:09 +0100 Subject: [PATCH 4/5] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3014daaffa25e..989ce1a773b4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,7 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700)) -- Added support for returning a Callback from `LightningModule.configure_callbacks` ([#11060](https://github.com/PyTorchLightning/pytorch-lightning/issues/11060)) +- Added support for returning a single Callback from `LightningModule.configure_callbacks` without wrapping it into a list ([#11060](https://github.com/PyTorchLightning/pytorch-lightning/issues/11060)) - Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875)) From b448a3177372ca6b695db4b406597843ee105791 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Sat, 18 Dec 2021 15:47:07 +0530 Subject: [PATCH 5/5] Apply suggestions from code review --- pytorch_lightning/trainer/connectors/callback_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index c208a4f73607b..74f55c16edad9 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -274,7 +274,6 @@ def _attach_model_callbacks(self) -> None: return model_callbacks = [model_callbacks] if not isinstance(model_callbacks, Sequence) else model_callbacks - model_callbacks = list(model_callbacks) model_callback_types = {type(c) for c in model_callbacks} trainer_callback_types = {type(c) for c in self.trainer.callbacks} override_types = model_callback_types.intersection(trainer_callback_types)