diff --git a/CHANGELOG.md b/CHANGELOG.md index 384fb6a20e1a0..763c061190901 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -123,6 +123,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Restore original loaders if replaced by entrypoint ([#8885](https://github.com/PyTorchLightning/pytorch-lightning/pull/8885)) + - Fixed `trainer.fit_loop.split_idx` always returning `None` ([#8601](https://github.com/PyTorchLightning/pytorch-lightning/pull/8601)) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index c6d471fad04b1..0ab082d32e4de 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Callable, Optional, Union import pytorch_lightning as pl from pytorch_lightning.trainer.supporters import prefetch_iterator @@ -117,19 +117,23 @@ def attach_dataloaders( # functions to overwrite with these implementations if train_dataloaders is not None: self.trainer.train_dataloader = None - model.train_dataloader = _PatchDataLoader(train_dataloaders) + train_dataloader = _PatchDataLoader(train_dataloaders, "train") + train_dataloader.patch(model) if val_dataloaders is not None: self.trainer.val_dataloaders = None - model.val_dataloader = _PatchDataLoader(val_dataloaders) + val_dataloader = _PatchDataLoader(val_dataloaders, "val") + val_dataloader.patch(model) if test_dataloaders is not None: self.trainer.test_dataloaders = None - model.test_dataloader = _PatchDataLoader(test_dataloaders) + test_dataloader = _PatchDataLoader(test_dataloaders, "test") + test_dataloader.patch(model) if predict_dataloaders is not None: self.trainer.predict_dataloaders = None - model.predict_dataloader = _PatchDataLoader(predict_dataloaders) + predict_dataloader = _PatchDataLoader(predict_dataloaders, "predict") + predict_dataloader.patch(model) def attach_datamodule( self, model: "pl.LightningModule", datamodule: Optional["pl.LightningDataModule"] = None @@ -157,6 +161,13 @@ def attach_datamodule( if hasattr(datamodule, "data_pipeline"): model.data_pipeline = datamodule.data_pipeline + @staticmethod + def detach_data(model: "pl.LightningModule") -> None: + for stage in ("train", "val", "test", "predict"): + loader = getattr(model, f"{stage}_dataloader", None) + if isinstance(loader, _PatchDataLoader): + loader.unpatch(model) + class _PatchDataLoader: r""" @@ -167,13 +178,23 @@ class _PatchDataLoader: dataloader: Dataloader object to return when called. """ - def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None: + def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS], stage: str) -> None: self.dataloader = dataloader # cannot pickle __code__ so cannot verify if PatchDataloader # exists which shows dataloader methods have been overwritten. # so, we hack it by using the string representation self.patch_loader_code = str(self.__call__.__code__) + self.old_loader: Optional[Callable] = None + self.stage = stage def __call__(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: return self.dataloader + + def patch(self, model: "pl.LightningModule") -> None: + self._old_loader = getattr(model, self.stage + "_dataloader") + setattr(model, self.stage + "_dataloader", self) + + def unpatch(self, model: "pl.LightningModule") -> None: + setattr(model, self.stage + "_dataloader", self._old_loader) + self._old_loader = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c9eca4da38c4b..7b279bfbe7fc9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1209,6 +1209,8 @@ def _call_teardown_hook(self) -> None: if self.datamodule is not None: self.datamodule.teardown(stage=fn) + self.data_connector.detach_data(self.lightning_module) + self.teardown(stage=fn) self.lightning_module.teardown(stage=fn) diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index c8be6727cd7ed..e9d5d3cc047cb 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -254,3 +254,62 @@ class CustomSampler(Sampler): dataloader = CustomDataLoader(dataset, sampler=CustomSampler(dataset)) with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"): trainer.auto_add_sampler(dataloader, shuffle=True) + + +def test_loader_detaching(): + """Checks that the loader has been resetted after the entrypoint""" + + class LoaderTestModel(BoringModel): + def training_step(self, batch, batch_idx): + assert len(model.train_dataloader()) == 10 + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + assert len(model.val_dataloader()) == 10 + return super().validation_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + assert len(model.test_dataloader()) == 10 + return super().test_step(batch, batch_idx) + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + assert len(model.predict_dataloader()) == 10 + return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + + loader = DataLoader(RandomDataset(32, 10), batch_size=1) + + model = LoaderTestModel() + + assert len(model.train_dataloader()) == 64 + assert len(model.val_dataloader()) == 64 + assert len(model.predict_dataloader()) == 64 + assert len(model.test_dataloader()) == 64 + + trainer = Trainer(fast_dev_run=1) + trainer.fit(model, loader, loader) + + assert len(model.train_dataloader()) == 64 + assert len(model.val_dataloader()) == 64 + assert len(model.predict_dataloader()) == 64 + assert len(model.test_dataloader()) == 64 + + trainer.validate(model, loader) + + assert len(model.train_dataloader()) == 64 + assert len(model.val_dataloader()) == 64 + assert len(model.predict_dataloader()) == 64 + assert len(model.test_dataloader()) == 64 + + trainer.predict(model, loader) + + assert len(model.train_dataloader()) == 64 + assert len(model.val_dataloader()) == 64 + assert len(model.predict_dataloader()) == 64 + assert len(model.test_dataloader()) == 64 + + trainer.test(model, loader) + + assert len(model.train_dataloader()) == 64 + assert len(model.val_dataloader()) == 64 + assert len(model.predict_dataloader()) == 64 + assert len(model.test_dataloader()) == 64