From 010e84632118f215f497ac166a77adba990d6a7e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 3 Nov 2021 00:23:11 +0100 Subject: [PATCH] Add `Loop.replace` --- docs/source/extensions/loops.rst | 20 +++++++--- pytorch_lightning/loops/base.py | 33 +++++++++++++++- .../loops/batch/training_batch_loop.py | 2 +- .../loops/epoch/evaluation_epoch_loop.py | 7 +--- .../loops/epoch/training_epoch_loop.py | 8 ++-- pytorch_lightning/loops/fit_loop.py | 6 +-- tests/loops/test_evaluation_loop.py | 2 +- tests/loops/test_loops.py | 38 +++++++++++++++++-- 8 files changed, 87 insertions(+), 29 deletions(-) diff --git a/docs/source/extensions/loops.rst b/docs/source/extensions/loops.rst index 9291fca4819d2..267f637be63a1 100644 --- a/docs/source/extensions/loops.rst +++ b/docs/source/extensions/loops.rst @@ -267,17 +267,25 @@ run (optional) Subloops -------- -When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method: +When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.replace` method: .. code-block:: python - # Step 1: create your loop - my_epoch_loop = MyEpochLoop() + # This takes care of properly instantiating the new Loop and setting all references + trainer.fit_loop.replace(MyEpochLoop) + # Trainer runs the fit loop with your new epoch loop! + trainer.fit(model) - # Step 2: use connect() - trainer.fit_loop.connect(epoch_loop=my_epoch_loop) +Alternatively, for more fine-grained control, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method: - # Trainer runs the fit loop with your new epoch loop! +.. code-block:: python + + # Optional: stitch back the trainer arguments + epoch_loop = MyEpochLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps) + # Optional: connect children loops as they might have existing state + epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop) + # Instantiate and connect the loop. + trainer.fit_loop.connect(epoch_loop=epoch_loop) trainer.fit(model) More about the built-in loops and how they are composed is explained in the next section. diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 38b0d652e5d2f..37d63fca885aa 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, Optional, TypeVar +from typing import Any, Dict, Generic, Optional, Type, TypeVar from deprecate import void from torchmetrics import Metric @@ -99,6 +99,35 @@ def connect(self, **kwargs: "Loop") -> None: Linked loops should form a tree. """ + def replace(self, loop_cls: Type["Loop"]) -> "Loop": + # find the target + for name, old_loop in self.__dict__.items(): + if issubclass(loop_cls, type(old_loop)): + break + else: + raise MisconfigurationException( + f"Did not find an attribute with the same parent class as `{loop_cls.__name__}`" + ) + # compare the signatures + old_parameters = inspect.signature(old_loop.__class__.__init__).parameters + current_parameters = inspect.signature(loop_cls.__init__).parameters + if old_parameters != current_parameters: + raise MisconfigurationException( + f"`{self.__class__.__name__}.replace({loop_cls.__name__})` can only be used if the `__init__`" + f" signatures match but `{old_loop.__class__.__name__}` does not." + ) + # instantiate the loop + kwargs = {p: getattr(old_loop, p) for p in old_parameters if p != "self"} + loop = loop_cls(**kwargs) + # connect subloops + kwargs = {n: l for n, l in old_loop.__dict__.items() if isinstance(l, Loop)} + loop.connect(**kwargs) + # set the trainer reference + loop.trainer = self.trainer + # connect to self + self.connect(**{name: loop}) + return loop + def on_skip(self) -> Optional[Any]: """The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`. diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index c1d800c42d853..7ed199e56be13 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -48,7 +48,7 @@ def done(self) -> bool: return len(self._remaining_splits) == 0 def connect( - self, optimizer_loop: Optional["Loop"] = None, manual_loop: Optional[ManualOptimization] = None + self, optimizer_loop: Optional[OptimizerLoop] = None, manual_loop: Optional[ManualOptimization] = None ) -> None: if optimizer_loop is not None: self.optimizer_loop = optimizer_loop diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index b4660c96a0989..c8018b63f8a66 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -45,16 +45,13 @@ def __init__(self) -> None: self._num_dataloaders: Optional[int] = None self._dataloader_iter: Optional[Iterator] = None self._data_fetcher: Optional[DataFetcher] = None - self._dataloader_state_dict: Dict[str, Any] = None + self._dataloader_state_dict: Dict[str, Any] = {} @property def done(self) -> bool: """Returns ``True`` if the current iteration count reaches the number of dataloader batches.""" return self.batch_progress.current.completed >= self._dl_max_batches - def connect(self, **kwargs: "Loop") -> None: - raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") - def reset(self) -> None: """Resets the loop's internal state.""" self._dl_max_batches = None @@ -181,7 +178,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher): if not self.trainer.sanity_checking and self._dataloader_state_dict: reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict) - self._dataloader_state_dict = None + self._dataloader_state_dict = {} def _num_completed_batches_reached(self) -> bool: epoch_finished_on_completed = self.batch_progress.current.completed == self._dl_max_batches diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 21d89a8be8b52..f80c02ea48de0 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -61,8 +61,8 @@ def __init__(self, min_steps: Optional[int] = 0, max_steps: int = -1) -> None: self.batch_progress = BatchProgress() self.scheduler_progress = SchedulerProgress() - self.batch_loop: Optional[TrainingBatchLoop] = None - self.val_loop: Optional["loops.EvaluationLoop"] = None + self.batch_loop = TrainingBatchLoop() + self.val_loop = loops.EvaluationLoop() self._results = ResultCollection(training=True) self._outputs: _OUTPUTS_TYPE = [] @@ -106,7 +106,7 @@ def done(self) -> bool: def connect( self, - batch_loop: TrainingBatchLoop = None, + batch_loop: Optional[TrainingBatchLoop] = None, val_loop: Optional["loops.EvaluationLoop"] = None, ) -> None: """Optionally connect a custom batch or validation loop to this training epoch loop.""" @@ -117,8 +117,6 @@ def connect( def reset(self) -> None: """Resets the internal state of the loop for a new run.""" - assert self.batch_loop is not None - assert self.batch_loop.optimizer_loop is not None if self.restarting: self.batch_progress.reset_on_restart() self.scheduler_progress.reset_on_restart() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index df6634c963851..4040d08d4f3dd 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -48,7 +48,7 @@ def __init__( self.max_epochs = max_epochs self.min_epochs = min_epochs - self.epoch_loop: Optional[TrainingEpochLoop] = None + self.epoch_loop = TrainingEpochLoop() self.epoch_progress = Progress() self._is_fresh_start_epoch: bool = True @@ -128,15 +128,11 @@ def running_loss(self) -> TensorRunningAccum: @property def _skip_backward(self) -> bool: """Determines whether the loop will skip backward during automatic optimization.""" - assert self.epoch_loop.batch_loop is not None - assert self.epoch_loop.batch_loop.optimizer_loop is not None return self.epoch_loop.batch_loop.optimizer_loop._skip_backward @_skip_backward.setter def _skip_backward(self, value: bool) -> None: """Determines whether the loop will skip backward during automatic optimization.""" - assert self.epoch_loop.batch_loop is not None - assert self.epoch_loop.batch_loop.optimizer_loop is not None self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value @property diff --git a/tests/loops/test_evaluation_loop.py b/tests/loops/test_evaluation_loop.py index d6b2c15553fb9..1507817357299 100644 --- a/tests/loops/test_evaluation_loop.py +++ b/tests/loops/test_evaluation_loop.py @@ -127,6 +127,6 @@ def on_advance_end(self): assert not is_overridden("test_epoch_end", model) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=3) - trainer.test_loop.connect(TestLoop()) + trainer.test_loop.replace(TestLoop) trainer.test(model) assert did_assert diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index dd390ab4939d5..212138915efba 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -22,12 +22,12 @@ import torch from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader -from pl_examples.bug_report_model import RandomDataset from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loops import Loop, TrainingBatchLoop +from pytorch_lightning.loops import Loop, PredictionEpochLoop, TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.trainer.progress import BaseProgress -from tests.helpers import BoringModel +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -62,7 +62,7 @@ def test_connect_loops_direct(loop_name): trainer = Trainer() - # trainer.loop = loop + # trainer.loop_name = loop setattr(trainer, loop_name, loop) assert loop.trainer is trainer @@ -103,6 +103,36 @@ def test_connect_subloops(tmpdir): assert new_batch_loop.trainer is trainer +def test_replace_loops(): + class TestLoop(TrainingEpochLoop): + def __init__(self, foo): + super().__init__() + + trainer = Trainer(min_steps=123, max_steps=321) + + with pytest.raises( + MisconfigurationException, match=r"FitLoop.replace\(TestLoop\)`.*`__init__`.*`TrainingEpochLoop`" + ): + trainer.fit_loop.replace(TestLoop) + + with pytest.raises(MisconfigurationException, match="Did not find.*same parent class as `PredictionEpochLoop`"): + trainer.fit_loop.replace(PredictionEpochLoop) + + class TestLoop(TrainingEpochLoop): + ... + + old_loop = trainer.fit_loop.epoch_loop + new_loop = trainer.fit_loop.replace(TestLoop) + + assert isinstance(new_loop, TestLoop) + assert trainer.fit_loop.epoch_loop is new_loop + assert new_loop.min_steps == 123 + assert new_loop.max_steps == 321 + assert new_loop.batch_loop is old_loop.batch_loop + assert new_loop.val_loop is old_loop.val_loop + assert new_loop.trainer is trainer + + class CustomException(Exception): pass