From 861d6b9578bf2b3993b90b880a9c9126f093e9b7 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Fri, 13 Aug 2021 09:33:10 +0200 Subject: [PATCH 1/4] detach loaders after run --- .../trainer/connectors/data_connector.py | 33 +++++++++-- pytorch_lightning/trainer/trainer.py | 10 ++++ tests/trainer/test_data_loading.py | 59 +++++++++++++++++++ 3 files changed, 96 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index c6d471fad04b1..0ab082d32e4de 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Callable, Optional, Union import pytorch_lightning as pl from pytorch_lightning.trainer.supporters import prefetch_iterator @@ -117,19 +117,23 @@ def attach_dataloaders( # functions to overwrite with these implementations if train_dataloaders is not None: self.trainer.train_dataloader = None - model.train_dataloader = _PatchDataLoader(train_dataloaders) + train_dataloader = _PatchDataLoader(train_dataloaders, "train") + train_dataloader.patch(model) if val_dataloaders is not None: self.trainer.val_dataloaders = None - model.val_dataloader = _PatchDataLoader(val_dataloaders) + val_dataloader = _PatchDataLoader(val_dataloaders, "val") + val_dataloader.patch(model) if test_dataloaders is not None: self.trainer.test_dataloaders = None - model.test_dataloader = _PatchDataLoader(test_dataloaders) + test_dataloader = _PatchDataLoader(test_dataloaders, "test") + test_dataloader.patch(model) if predict_dataloaders is not None: self.trainer.predict_dataloaders = None - model.predict_dataloader = _PatchDataLoader(predict_dataloaders) + predict_dataloader = _PatchDataLoader(predict_dataloaders, "predict") + predict_dataloader.patch(model) def attach_datamodule( self, model: "pl.LightningModule", datamodule: Optional["pl.LightningDataModule"] = None @@ -157,6 +161,13 @@ def attach_datamodule( if hasattr(datamodule, "data_pipeline"): model.data_pipeline = datamodule.data_pipeline + @staticmethod + def detach_data(model: "pl.LightningModule") -> None: + for stage in ("train", "val", "test", "predict"): + loader = getattr(model, f"{stage}_dataloader", None) + if isinstance(loader, _PatchDataLoader): + loader.unpatch(model) + class _PatchDataLoader: r""" @@ -167,13 +178,23 @@ class _PatchDataLoader: dataloader: Dataloader object to return when called. """ - def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None: + def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS], stage: str) -> None: self.dataloader = dataloader # cannot pickle __code__ so cannot verify if PatchDataloader # exists which shows dataloader methods have been overwritten. # so, we hack it by using the string representation self.patch_loader_code = str(self.__call__.__code__) + self.old_loader: Optional[Callable] = None + self.stage = stage def __call__(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: return self.dataloader + + def patch(self, model: "pl.LightningModule") -> None: + self._old_loader = getattr(model, self.stage + "_dataloader") + setattr(model, self.stage + "_dataloader", self) + + def unpatch(self, model: "pl.LightningModule") -> None: + setattr(model, self.stage + "_dataloader", self._old_loader) + self._old_loader = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c9eca4da38c4b..8ca342338443a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -536,6 +536,8 @@ def fit( self._run(model) + self.data_connector.detach_data(model) + assert self.state.stopped self.training = False @@ -612,6 +614,8 @@ def validate( # run validate results = self._run(model) + self.data_connector.detach_data(model) + assert self.state.stopped self.validating = False @@ -691,6 +695,8 @@ def test( # run test results = self._run(model) + self.data_connector.detach_data(model) + assert self.state.stopped self.testing = False @@ -762,6 +768,8 @@ def predict( results = self._run(model) + self.data_connector.detach_data(model) + assert self.state.stopped self.predicting = False @@ -824,6 +832,8 @@ def tune( result = self.tuner._tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs, lr_find_kwargs=lr_find_kwargs) + self.data_connector.detach_data(model) + assert self.state.stopped self.tuning = False diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index c8be6727cd7ed..e9d5d3cc047cb 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -254,3 +254,62 @@ class CustomSampler(Sampler): dataloader = CustomDataLoader(dataset, sampler=CustomSampler(dataset)) with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"): trainer.auto_add_sampler(dataloader, shuffle=True) + + +def test_loader_detaching(): + """Checks that the loader has been resetted after the entrypoint""" + + class LoaderTestModel(BoringModel): + def training_step(self, batch, batch_idx): + assert len(model.train_dataloader()) == 10 + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + assert len(model.val_dataloader()) == 10 + return super().validation_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + assert len(model.test_dataloader()) == 10 + return super().test_step(batch, batch_idx) + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + assert len(model.predict_dataloader()) == 10 + return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + + loader = DataLoader(RandomDataset(32, 10), batch_size=1) + + model = LoaderTestModel() + + assert len(model.train_dataloader()) == 64 + assert len(model.val_dataloader()) == 64 + assert len(model.predict_dataloader()) == 64 + assert len(model.test_dataloader()) == 64 + + trainer = Trainer(fast_dev_run=1) + trainer.fit(model, loader, loader) + + assert len(model.train_dataloader()) == 64 + assert len(model.val_dataloader()) == 64 + assert len(model.predict_dataloader()) == 64 + assert len(model.test_dataloader()) == 64 + + trainer.validate(model, loader) + + assert len(model.train_dataloader()) == 64 + assert len(model.val_dataloader()) == 64 + assert len(model.predict_dataloader()) == 64 + assert len(model.test_dataloader()) == 64 + + trainer.predict(model, loader) + + assert len(model.train_dataloader()) == 64 + assert len(model.val_dataloader()) == 64 + assert len(model.predict_dataloader()) == 64 + assert len(model.test_dataloader()) == 64 + + trainer.test(model, loader) + + assert len(model.train_dataloader()) == 64 + assert len(model.val_dataloader()) == 64 + assert len(model.predict_dataloader()) == 64 + assert len(model.test_dataloader()) == 64 From 2e2bd327064e7cd83e1c9e20861cbc8fcfc5dbc5 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 13 Aug 2021 13:22:45 +0200 Subject: [PATCH 2/4] chlog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 384fb6a20e1a0..763c061190901 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -123,6 +123,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Restore original loaders if replaced by entrypoint ([#8885](https://github.com/PyTorchLightning/pytorch-lightning/pull/8885)) + - Fixed `trainer.fit_loop.split_idx` always returning `None` ([#8601](https://github.com/PyTorchLightning/pytorch-lightning/pull/8601)) From 5b7eff5eb2e53b7e53e7a6374747446a80440227 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Fri, 13 Aug 2021 17:18:42 +0200 Subject: [PATCH 3/4] move to teardown --- pytorch_lightning/trainer/callback_hook.py | 3 +++ pytorch_lightning/trainer/trainer.py | 10 ---------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 5aac1acb6c572..3999b38ea06f7 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -31,6 +31,7 @@ class TrainerCallbackHookMixin(ABC): # the proper values/initialisation should be done in child class callbacks: List[Callback] = [] lightning_module: "pl.LightningModule" + data_connector: "pl.trainer.connectors.DataConnector" def on_before_accelerator_backend_setup(self) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" @@ -52,6 +53,8 @@ def teardown(self, stage: Optional[str] = None) -> None: for callback in self.callbacks: callback.teardown(self, self.lightning_module, stage=stage) + self.data_connector.detach_data(self.lightning_module) + def on_init_start(self): """Called when the trainer initialization begins, model has not yet been set.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8ca342338443a..c9eca4da38c4b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -536,8 +536,6 @@ def fit( self._run(model) - self.data_connector.detach_data(model) - assert self.state.stopped self.training = False @@ -614,8 +612,6 @@ def validate( # run validate results = self._run(model) - self.data_connector.detach_data(model) - assert self.state.stopped self.validating = False @@ -695,8 +691,6 @@ def test( # run test results = self._run(model) - self.data_connector.detach_data(model) - assert self.state.stopped self.testing = False @@ -768,8 +762,6 @@ def predict( results = self._run(model) - self.data_connector.detach_data(model) - assert self.state.stopped self.predicting = False @@ -832,8 +824,6 @@ def tune( result = self.tuner._tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs, lr_find_kwargs=lr_find_kwargs) - self.data_connector.detach_data(model) - assert self.state.stopped self.tuning = False From a1753abe341cb2da67f6b94434d068f6583e27e9 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Fri, 13 Aug 2021 17:29:15 +0200 Subject: [PATCH 4/4] move again --- pytorch_lightning/trainer/callback_hook.py | 3 --- pytorch_lightning/trainer/trainer.py | 2 ++ 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 3999b38ea06f7..5aac1acb6c572 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -31,7 +31,6 @@ class TrainerCallbackHookMixin(ABC): # the proper values/initialisation should be done in child class callbacks: List[Callback] = [] lightning_module: "pl.LightningModule" - data_connector: "pl.trainer.connectors.DataConnector" def on_before_accelerator_backend_setup(self) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" @@ -53,8 +52,6 @@ def teardown(self, stage: Optional[str] = None) -> None: for callback in self.callbacks: callback.teardown(self, self.lightning_module, stage=stage) - self.data_connector.detach_data(self.lightning_module) - def on_init_start(self): """Called when the trainer initialization begins, model has not yet been set.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c9eca4da38c4b..7b279bfbe7fc9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1209,6 +1209,8 @@ def _call_teardown_hook(self) -> None: if self.datamodule is not None: self.datamodule.teardown(stage=fn) + self.data_connector.detach_data(self.lightning_module) + self.teardown(stage=fn) self.lightning_module.teardown(stage=fn)