From ee7ef64fe64c240c61191d21113685aa263d4525 Mon Sep 17 00:00:00 2001 From: Gili Tzabari Date: Wed, 13 Oct 2021 10:00:22 -0400 Subject: [PATCH 01/31] del iterator on_run_end() --- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 3e1b88a2d41c3..d37a82b2b8fcd 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -148,7 +148,9 @@ def on_run_end(self) -> EPOCH_OUTPUT: outputs = self.outputs # free memory self.outputs = [] + del self._dataloader_iter self._dataloader_iter = None + del self._data_fetcher self._data_fetcher = None return outputs From c2b155ae68a1fbaf89504b1001e5c35300c35aa7 Mon Sep 17 00:00:00 2001 From: Gili Tzabari Date: Wed, 13 Oct 2021 10:04:52 -0400 Subject: [PATCH 02/31] del iterator on_run_end() --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index fe3a2dc7431cc..05dde40e199f5 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -268,6 +268,7 @@ def on_run_end(self) -> None: if self._num_ready_batches_reached(): self.update_lr_schedulers("epoch", update_plateau_schedulers=True) + del self._dataloader_iter self._dataloader_iter = None # if fault tolerant is enabled and process has been notified, exit. From 4ac6c865f619e7bbc4124516101e39d09af732ef Mon Sep 17 00:00:00 2001 From: Gili Tzabari Date: Wed, 13 Oct 2021 10:09:38 -0400 Subject: [PATCH 03/31] Don't bother deleting data_fetcher --- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index d37a82b2b8fcd..5ba37c77fa85c 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -150,7 +150,6 @@ def on_run_end(self) -> EPOCH_OUTPUT: self.outputs = [] del self._dataloader_iter self._dataloader_iter = None - del self._data_fetcher self._data_fetcher = None return outputs From b587849eb2fa4948717500006fcb7c0c9d9ac9a2 Mon Sep 17 00:00:00 2001 From: Gili Tzabari Date: Wed, 13 Oct 2021 10:14:11 -0400 Subject: [PATCH 04/31] Updated changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ba1be2433463f..22ab0bf3601fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -475,7 +475,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed incorrect main progress bar indicator when resuming training mid-epoch ([#9310](https://github.com/PyTorchLightning/pytorch-lightning/pull/9310)) -- Fixed freeing datafetchers during teardown ([#9387](https://github.com/PyTorchLightning/pytorch-lightning/pull/9387)) +- Fixed freeing datafetchers during teardown ([#9387](https://github.com/PyTorchLightning/pytorch-lightning/pull/9387)) ([#9915](https://github.com/PyTorchLightning/pytorch-lightning/pull/9915)) - Fixed bug where the training step output needed to be `deepcopy`-ed ([#9349](https://github.com/PyTorchLightning/pytorch-lightning/pull/9349)) From 3ff212160524654b3448e06131e51f62bde34b21 Mon Sep 17 00:00:00 2001 From: Gili Tzabari Date: Wed, 13 Oct 2021 10:15:18 -0400 Subject: [PATCH 05/31] Fixed typo in changelog --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 22ab0bf3601fa..3ec31913be918 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -475,13 +475,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed incorrect main progress bar indicator when resuming training mid-epoch ([#9310](https://github.com/PyTorchLightning/pytorch-lightning/pull/9310)) -- Fixed freeing datafetchers during teardown ([#9387](https://github.com/PyTorchLightning/pytorch-lightning/pull/9387)) ([#9915](https://github.com/PyTorchLightning/pytorch-lightning/pull/9915)) +- Fixed freeing datafetchers during teardown ([#9387](https://github.com/PyTorchLightning/pytorch-lightning/pull/9387)) - Fixed bug where the training step output needed to be `deepcopy`-ed ([#9349](https://github.com/PyTorchLightning/pytorch-lightning/pull/9349)) -- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386)) +- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386)) ([#9915](https://github.com/PyTorchLightning/pytorch-lightning/pull/9915)) - Fixed `BasePredictionWriter` not returning the batch_indices in a non-distributed setting ([#9432](https://github.com/PyTorchLightning/pytorch-lightning/pull/9432)) From 44378639e3af65be1488e9765dcd4b55026bc205 Mon Sep 17 00:00:00 2001 From: Gili Tzabari Date: Thu, 14 Oct 2021 17:47:23 -0400 Subject: [PATCH 06/31] Added comment per PR review --- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 5ba37c77fa85c..ef2a28532f5ca 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -148,6 +148,8 @@ def on_run_end(self) -> EPOCH_OUTPUT: outputs = self.outputs # free memory self.outputs = [] + # manually delete the DataLoader as PyTorch shuts down any persistent workers on `__del__` + # https://github.com/pytorch/pytorch/issues/64766#issuecomment-930467482 del self._dataloader_iter self._dataloader_iter = None self._data_fetcher = None From 2f391a8c9db3e3476874d4000ce977162c5e124e Mon Sep 17 00:00:00 2001 From: Gili Tzabari Date: Thu, 14 Oct 2021 17:48:07 -0400 Subject: [PATCH 07/31] Added comment per PR review --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 05dde40e199f5..480f6e8e65a99 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -268,6 +268,8 @@ def on_run_end(self) -> None: if self._num_ready_batches_reached(): self.update_lr_schedulers("epoch", update_plateau_schedulers=True) + # manually delete the DataLoader as PyTorch shuts down any persistent workers on `__del__` + # https://github.com/pytorch/pytorch/issues/64766#issuecomment-930467482 del self._dataloader_iter self._dataloader_iter = None From 949092ee010e1e33dfeeb0bec39f4a7da3598779 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 15 Oct 2021 18:08:16 +0200 Subject: [PATCH 08/31] Add failing test --- tests/loops/test_training_loop.py | 43 +++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index d491db3bbc91c..7d2aac52ec0ee 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -11,11 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc +from unittest import mock import pytest import torch +from torch.utils.data import DataLoader +from kk import RandomDataset from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.loops import TrainingEpochLoop from tests.helpers import BoringModel @@ -142,3 +147,41 @@ def training_step_end(self, outputs): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) trainer.fit(model) + + +@mock.patch("torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers") +def test_training_loop_workers_are_shutdown(shutdown_mock, tmpdir): + # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` + # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance + train_dataloader = DataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=True) + + class TestLoop(TrainingEpochLoop): + def on_run_end(self): + # this works - but this is the `enumerate` object, not the actual iterator + referrers = gc.get_referrers(self._dataloader_iter) + assert len(referrers) == 1, referrers + + # this fails - there are 2 referrers + referrers = gc.get_referrers(train_dataloader._iterator) + assert len(referrers) == 1, referrers + + out = super().on_run_end() + + # no referrers after destruction + referrers = gc.get_referrers(train_dataloader._iterator) + assert len(referrers) == 0, referrers + + shutdown_mock.assert_called_once() + shutdown_mock.reset_mock() + + assert self._dataloader_iter is None + return out + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=0, max_epochs=2) + + epoch_loop = TestLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps) + epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop) + trainer.fit_loop.connect(epoch_loop) + + trainer.fit(model, train_dataloader) From 06e20a230af4335ebf8a488198f8b380e8773036 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Oct 2021 16:10:31 +0000 Subject: [PATCH 09/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/loops/test_training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 7d2aac52ec0ee..030639cd36283 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -16,9 +16,9 @@ import pytest import torch +from kk import RandomDataset from torch.utils.data import DataLoader -from kk import RandomDataset from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.loops import TrainingEpochLoop from tests.helpers import BoringModel From 183b98872f8b2d5d0c6cb5ed152b9a205f7332f3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 25 Oct 2021 20:04:02 +0100 Subject: [PATCH 10/31] resolve iterator reference --- .../loops/epoch/evaluation_epoch_loop.py | 2 +- .../loops/epoch/training_epoch_loop.py | 6 ++--- pytorch_lightning/loops/fit_loop.py | 4 ++- pytorch_lightning/trainer/supporters.py | 5 ++++ pytorch_lightning/utilities/fetching.py | 8 +++++- tests/loops/test_training_loop.py | 25 ++++++++----------- 6 files changed, 30 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index ef2a28532f5ca..fc0982e58775f 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -150,7 +150,7 @@ def on_run_end(self) -> EPOCH_OUTPUT: self.outputs = [] # manually delete the DataLoader as PyTorch shuts down any persistent workers on `__del__` # https://github.com/pytorch/pytorch/issues/64766#issuecomment-930467482 - del self._dataloader_iter + self.trainer.val_dataloaders = None self._dataloader_iter = None self._data_fetcher = None return outputs diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 480f6e8e65a99..7e3be04992641 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -268,9 +268,9 @@ def on_run_end(self) -> None: if self._num_ready_batches_reached(): self.update_lr_schedulers("epoch", update_plateau_schedulers=True) - # manually delete the DataLoader as PyTorch shuts down any persistent workers on `__del__` - # https://github.com/pytorch/pytorch/issues/64766#issuecomment-930467482 - del self._dataloader_iter + # delete any persistent workers. + self.trainer.train_dataloader.reset() + self.trainer.data_connector.train_data_fetcher.teardown() self._dataloader_iter = None # if fault tolerant is enabled and process has been notified, exit. diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 0d16c978fb374..5d08aeb3ad5da 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -188,7 +188,9 @@ def on_advance_start(self) -> None: if self.current_epoch != 0 and self.trainer._should_reload_dl_epoch: self.trainer.reset_train_dataloader(model) - if callable(getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)): + if self.trainer.train_dataloader and callable( + getattr(self.trainer.train_dataloader.sampler, "set_epoch", None) + ): # set seed for distributed sampler (enables shuffling for each epoch) self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 87e5f9f4f7bd8..1115b4eb95e42 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -491,6 +491,11 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]: def __len__(self) -> int: return self._calc_num_batches(self.loaders) + def reset(self): + if self._iterator: + self._iterator._loader_iters = None + self._iterator = None + class CombinedLoaderIterator: """Custom Iterator returning data from multple loaders, and allows sampling in parallel.""" diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 689c2bff8e5b3..ddc9d743c8792 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -204,12 +204,18 @@ def __next__(self): def reset(self) -> None: self.batches: List = [] - self.dataloader: Optional[Iterable] self.fetched: int = 0 self.done: bool = False def teardown(self) -> None: self.reset() + if isinstance(self.dataloader, CombinedLoader): + self.dataloader.loaders._iterator._loader_iters = None + self.dataloader.loaders._iterator = None + if isinstance(self.dataloader, DataLoader): + self.dataloader._iterator = None + self.dataloader = None + self.dataloader_iter = None class DataFetcher(AbstractDataFetcher): diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 030639cd36283..90aaf36c88b64 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -12,16 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import gc -from unittest import mock import pytest import torch -from kk import RandomDataset from torch.utils.data import DataLoader from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.loops import TrainingEpochLoop -from tests.helpers import BoringModel +from tests.helpers import BoringModel, RandomDataset def test_outputs_format(tmpdir): @@ -149,8 +147,8 @@ def training_step_end(self, outputs): trainer.fit(model) -@mock.patch("torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers") -def test_training_loop_workers_are_shutdown(shutdown_mock, tmpdir): +# @mock.patch("torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers") +def test_training_loop_workers_are_shutdown(tmpdir): # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance train_dataloader = DataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=True) @@ -159,25 +157,24 @@ class TestLoop(TrainingEpochLoop): def on_run_end(self): # this works - but this is the `enumerate` object, not the actual iterator referrers = gc.get_referrers(self._dataloader_iter) - assert len(referrers) == 1, referrers + assert len(referrers) == 1 # this fails - there are 2 referrers referrers = gc.get_referrers(train_dataloader._iterator) - assert len(referrers) == 1, referrers + assert len(referrers) == 2 + del referrers + iterator = train_dataloader._iterator out = super().on_run_end() + assert self._dataloader_iter is None - # no referrers after destruction - referrers = gc.get_referrers(train_dataloader._iterator) - assert len(referrers) == 0, referrers - - shutdown_mock.assert_called_once() - shutdown_mock.reset_mock() + referrers = gc.get_referrers(iterator) + assert len(referrers) == 0 - assert self._dataloader_iter is None return out model = BoringModel() + model.validation_step = None trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=0, max_epochs=2) epoch_loop = TestLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps) From 20e756043fcadf6f2d53ba9c2331c4b48959238c Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 25 Oct 2021 20:04:58 +0100 Subject: [PATCH 11/31] resolve iterator reference --- tests/loops/test_training_loop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 90aaf36c88b64..e2cbffee5ce64 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -174,7 +174,6 @@ def on_run_end(self): return out model = BoringModel() - model.validation_step = None trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=0, max_epochs=2) epoch_loop = TestLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps) From 32347302802b6422e9cf53e76f12c13655900e75 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 25 Oct 2021 20:16:41 +0100 Subject: [PATCH 12/31] update --- .../loops/dataloader/evaluation_loop.py | 2 ++ .../trainer/connectors/data_connector.py | 11 ++++--- tests/loops/test_training_loop.py | 32 +++++++++++++------ 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 92c7d36cfd0d8..829b8f51c2275 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -139,6 +139,8 @@ def on_run_end(self) -> List[_OUT_DICT]: # enable train mode again self._on_evaluation_model_train() + self.trainer.data_connector.teardown(self.trainer.state.stage) + return eval_loop_results def teardown(self) -> None: diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 589906d2bb0e4..b7201fc4e5156 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -16,6 +16,7 @@ from typing import Callable, Iterable, Optional, Union import pytorch_lightning as pl +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import ( @@ -245,17 +246,17 @@ def detach_data(model: "pl.LightningModule") -> None: if isinstance(loader, _PatchDataLoader): loader.unpatch(model) - def teardown(self) -> None: - if self.train_data_fetcher: + def teardown(self, stage: Optional[RunningStage] = None) -> None: + if (stage is None or stage == RunningStage.TRAINING) and self.train_data_fetcher: self.train_data_fetcher.teardown() self.train_data_fetcher = None - if self.validate_data_fetcher: + if (stage is None or stage == RunningStage.VALIDATING) and self.validate_data_fetcher: self.validate_data_fetcher.teardown() self.validate_data_fetcher = None - if self.test_data_fetcher: + if (stage is None or stage == RunningStage.TESTING) and self.test_data_fetcher: self.test_data_fetcher.teardown() self.test_data_fetcher = None - if self.sanity_check_data_fetcher: + if (stage is None or stage == RunningStage.SANITY_CHECKING) and self.sanity_check_data_fetcher: self.sanity_check_data_fetcher.teardown() self.sanity_check_data_fetcher = None diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index e2cbffee5ce64..1df1ef30e5e46 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -18,7 +18,7 @@ from torch.utils.data import DataLoader from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.loops import TrainingEpochLoop +from pytorch_lightning.loops import EvaluationLoop, TrainingEpochLoop from tests.helpers import BoringModel, RandomDataset @@ -147,13 +147,11 @@ def training_step_end(self, outputs): trainer.fit(model) -# @mock.patch("torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers") -def test_training_loop_workers_are_shutdown(tmpdir): - # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` - # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance - train_dataloader = DataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=True) +def test_dataloader_workers_are_shutdown_properly(tmpdir): + train_dataloader = DataLoader(RandomDataset(32, 4), num_workers=1, persistent_workers=True) + val_dataloader = DataLoader(RandomDataset(32, 4), num_workers=1, persistent_workers=True) - class TestLoop(TrainingEpochLoop): + class TestTrainingEpochLoopLoop(TrainingEpochLoop): def on_run_end(self): # this works - but this is the `enumerate` object, not the actual iterator referrers = gc.get_referrers(self._dataloader_iter) @@ -173,11 +171,25 @@ def on_run_end(self): return out + class TestEvaluationLoop(EvaluationLoop): + def on_run_end(self): + iterator = val_dataloader._iterator + out = super().on_run_end() + referrers = gc.get_referrers(iterator) + assert len(referrers) == 0 + + return out + model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=0, max_epochs=2) + model.validation_epoch_end = None + trainer = Trainer( + default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2, num_sanity_val_steps=0 + ) - epoch_loop = TestLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps) + epoch_loop = TestTrainingEpochLoopLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps) + val_loop = TestEvaluationLoop() epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop) trainer.fit_loop.connect(epoch_loop) + epoch_loop.connect(epoch_loop.batch_loop, val_loop) - trainer.fit(model, train_dataloader) + trainer.fit(model, train_dataloader, val_dataloader) From deb4f250fa27257126bec3da21b0325a21fc5caa Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 25 Oct 2021 20:18:59 +0100 Subject: [PATCH 13/31] update --- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index fc0982e58775f..3e1b88a2d41c3 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -148,9 +148,6 @@ def on_run_end(self) -> EPOCH_OUTPUT: outputs = self.outputs # free memory self.outputs = [] - # manually delete the DataLoader as PyTorch shuts down any persistent workers on `__del__` - # https://github.com/pytorch/pytorch/issues/64766#issuecomment-930467482 - self.trainer.val_dataloaders = None self._dataloader_iter = None self._data_fetcher = None return outputs From 9032c2c272dfcde973e0f1db213d0cdc2236de23 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 26 Oct 2021 01:40:32 +0200 Subject: [PATCH 14/31] Use mock in tests --- .../loops/dataloader/evaluation_loop.py | 2 +- .../loops/epoch/training_epoch_loop.py | 2 +- tests/loops/test_training_loop.py | 67 +++++++++---------- 3 files changed, 35 insertions(+), 36 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 8c6cbdfa5f070..2d9ac1afc37e5 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -139,7 +139,7 @@ def on_run_end(self) -> List[_OUT_DICT]: # enable train mode again self._on_evaluation_model_train() - self.trainer.data_connector.teardown(self.trainer.state.stage) + self.trainer._data_connector.teardown(self.trainer.state.stage) return eval_loop_results diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 7ca89a728920f..9f4fb9c310578 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -300,7 +300,7 @@ def on_run_end(self) -> None: # delete any persistent workers. self.trainer.train_dataloader.reset() - self.trainer.data_connector.train_data_fetcher.teardown() + self.trainer._data_connector.train_data_fetcher.teardown() self._dataloader_iter = None # if fault tolerant is enabled and process has been notified, exit. diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 81147d71f285d..9dcfac243718d 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import gc +from unittest import mock import pytest import torch @@ -147,49 +147,48 @@ def training_step_end(self, outputs): trainer.fit(model) -def test_dataloader_workers_are_shutdown_properly(tmpdir): - train_dataloader = DataLoader(RandomDataset(32, 4), num_workers=1, persistent_workers=True) - val_dataloader = DataLoader(RandomDataset(32, 4), num_workers=1, persistent_workers=True) +@pytest.mark.parametrize("persistent_workers", (True, False)) +def test_training_loop_workers_are_shutdown(tmpdir, persistent_workers): + # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` + # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance + train_dataloader = DataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) - class TestTrainingEpochLoopLoop(TrainingEpochLoop): + class TestLoop(TrainingEpochLoop): def on_run_end(self): - # this works - but this is the `enumerate` object, not the actual iterator - referrers = gc.get_referrers(self._dataloader_iter) - assert len(referrers) == 1 + with mock.patch( + "torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers" + ) as shutdown_mock: + out = super().on_run_end() + shutdown_mock.assert_called_once() + return out - # this fails - there are 2 referrers - referrers = gc.get_referrers(train_dataloader._iterator) - assert len(referrers) == 2 - del referrers + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=0, max_epochs=2) - iterator = train_dataloader._iterator - out = super().on_run_end() - assert self._dataloader_iter is None + epoch_loop = TestLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps) + epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop) + trainer.fit_loop.connect(epoch_loop) - referrers = gc.get_referrers(iterator) - assert len(referrers) == 0 + trainer.fit(model, train_dataloader) - return out - class TestEvaluationLoop(EvaluationLoop): - def on_run_end(self): - iterator = val_dataloader._iterator - out = super().on_run_end() - referrers = gc.get_referrers(iterator) - assert len(referrers) == 0 +@pytest.mark.parametrize("persistent_workers", (True, False)) +def test_evaluation_loop_workers_are_shutdown(tmpdir, persistent_workers): + val_dataloader = DataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) + class TestLoop(EvaluationLoop): + def on_run_end(self): + with mock.patch( + "torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers" + ) as shutdown_mock: + out = super().on_run_end() + shutdown_mock.assert_called_once() return out model = BoringModel() - model.validation_epoch_end = None - trainer = Trainer( - default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2, num_sanity_val_steps=0 - ) + trainer = Trainer(default_root_dir=tmpdir, limit_val_batches=2) - epoch_loop = TestTrainingEpochLoopLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps) - val_loop = TestEvaluationLoop() - epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop) - trainer.fit_loop.connect(epoch_loop) - epoch_loop.connect(epoch_loop.batch_loop, val_loop) + val_loop = TestLoop() + trainer.fit_loop.epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, val_loop) - trainer.fit(model, train_dataloader, val_dataloader) + trainer.validate(model, val_dataloaders=val_dataloader) From a673c6ac56d3a767802703c63b84528a58fe1b2f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 26 Oct 2021 01:43:38 +0200 Subject: [PATCH 15/31] Remove evaluation test that does not fail in master --- tests/loops/test_training_loop.py | 24 +----------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 9dcfac243718d..88c2706d75145 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -18,7 +18,7 @@ from torch.utils.data import DataLoader from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.loops import EvaluationLoop, TrainingEpochLoop +from pytorch_lightning.loops import TrainingEpochLoop from tests.helpers import BoringModel, RandomDataset @@ -170,25 +170,3 @@ def on_run_end(self): trainer.fit_loop.connect(epoch_loop) trainer.fit(model, train_dataloader) - - -@pytest.mark.parametrize("persistent_workers", (True, False)) -def test_evaluation_loop_workers_are_shutdown(tmpdir, persistent_workers): - val_dataloader = DataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) - - class TestLoop(EvaluationLoop): - def on_run_end(self): - with mock.patch( - "torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers" - ) as shutdown_mock: - out = super().on_run_end() - shutdown_mock.assert_called_once() - return out - - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, limit_val_batches=2) - - val_loop = TestLoop() - trainer.fit_loop.epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, val_loop) - - trainer.validate(model, val_dataloaders=val_dataloader) From 95cfcdf551072ec614cb72dce27ed718f6538d6f Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Oct 2021 17:36:44 +0100 Subject: [PATCH 16/31] update --- .../loops/dataloader/evaluation_loop.py | 9 ++++--- .../loops/epoch/training_epoch_loop.py | 2 +- pytorch_lightning/loops/fit_loop.py | 2 +- .../trainer/connectors/data_connector.py | 10 +++---- pytorch_lightning/trainer/supporters.py | 2 ++ pytorch_lightning/utilities/fetching.py | 6 ++--- tests/loops/test_evaluation_loop.py | 27 ++++++++++++++++++- tests/loops/test_training_loop.py | 1 + 8 files changed, 44 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 2d9ac1afc37e5..dcff6f5ed74a6 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -19,6 +19,7 @@ from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection +from pytorch_lightning.utilities.fetching import AbstractDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -34,6 +35,7 @@ def __init__(self): self._results = ResultCollection(training=False) self._max_batches: Optional[Union[int, Sequence[int]]] = None self._has_run: bool = False + self.data_fetcher: Optional[AbstractDataFetcher] = None @property def num_dataloaders(self) -> int: @@ -101,7 +103,9 @@ def advance(self, *args: Any, **kwargs: Any) -> None: dataloader_idx: int = self.current_dataloader_idx dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader) - dataloader = self.trainer._data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx) + self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader( + dataloader, dataloader_idx=dataloader_idx + ) dl_max_batches = self._max_batches[dataloader_idx] dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders) @@ -119,6 +123,7 @@ def on_run_end(self) -> List[_OUT_DICT]: # free memory self.outputs = [] + self.data_fetcher.reset() # with a single dataloader don't pass a 2D list if len(outputs) > 0 and self.num_dataloaders == 1: @@ -139,8 +144,6 @@ def on_run_end(self) -> List[_OUT_DICT]: # enable train mode again self._on_evaluation_model_train() - self.trainer._data_connector.teardown(self.trainer.state.stage) - return eval_loop_results def teardown(self) -> None: diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 9f4fb9c310578..eb6417462dd08 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -300,7 +300,7 @@ def on_run_end(self) -> None: # delete any persistent workers. self.trainer.train_dataloader.reset() - self.trainer._data_connector.train_data_fetcher.teardown() + self.trainer._data_connector.train_data_fetcher.reset() self._dataloader_iter = None # if fault tolerant is enabled and process has been notified, exit. diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index c611f428e6a4c..e368c8361dde9 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -195,7 +195,7 @@ def on_advance_start(self) -> None: self.trainer.reset_train_dataloader(model) self._is_fresh_start_epoch = False - if self.trainer.train_dataloader and callable( + if self.trainer.train_dataloader is not None and callable( getattr(self.trainer.train_dataloader.sampler, "set_epoch", None) ): # set seed for distributed sampler (enables shuffling for each epoch) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index b48d230884e0f..3f50b042873b1 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -240,17 +240,17 @@ def attach_datamodule( if hasattr(datamodule, "data_pipeline"): model.data_pipeline = datamodule.data_pipeline - def teardown(self, stage: Optional[RunningStage] = None) -> None: - if (stage is None or stage == RunningStage.TRAINING) and self.train_data_fetcher: + def teardown(self) -> None: + if self.train_data_fetcher: self.train_data_fetcher.teardown() self.train_data_fetcher = None - if (stage is None or stage == RunningStage.VALIDATING) and self.validate_data_fetcher: + if self.validate_data_fetcher: self.validate_data_fetcher.teardown() self.validate_data_fetcher = None - if (stage is None or stage == RunningStage.TESTING) and self.test_data_fetcher: + if self.test_data_fetcher: self.test_data_fetcher.teardown() self.test_data_fetcher = None - if (stage is None or stage == RunningStage.SANITY_CHECKING) and self.sanity_check_data_fetcher: + if self.sanity_check_data_fetcher: self.sanity_check_data_fetcher.teardown() self.sanity_check_data_fetcher = None diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 1115b4eb95e42..907948ccced91 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -494,6 +494,8 @@ def __len__(self) -> int: def reset(self): if self._iterator: self._iterator._loader_iters = None + if self.loaders: + self.loaders._iterator = None self._iterator = None diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index ddc9d743c8792..3b1f6d301d6d9 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -206,16 +206,14 @@ def reset(self) -> None: self.batches: List = [] self.fetched: int = 0 self.done: bool = False + self.dataloader_iter = None def teardown(self) -> None: self.reset() if isinstance(self.dataloader, CombinedLoader): - self.dataloader.loaders._iterator._loader_iters = None - self.dataloader.loaders._iterator = None + self.dataloader.reset() if isinstance(self.dataloader, DataLoader): self.dataloader._iterator = None - self.dataloader = None - self.dataloader_iter = None class DataFetcher(AbstractDataFetcher): diff --git a/tests/loops/test_evaluation_loop.py b/tests/loops/test_evaluation_loop.py index 2b67dec18de34..b838ad5c3c87b 100644 --- a/tests/loops/test_evaluation_loop.py +++ b/tests/loops/test_evaluation_loop.py @@ -13,11 +13,12 @@ # limitations under the License. from unittest import mock +import pytest import torch from torch.utils.data import DataLoader from pytorch_lightning import Trainer -from pytorch_lightning.loops import EvaluationEpochLoop +from pytorch_lightning.loops import EvaluationEpochLoop, EvaluationLoop from pytorch_lightning.utilities.model_helpers import is_overridden from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -130,3 +131,27 @@ def on_advance_end(self): trainer.test_loop.connect(TestLoop()) trainer.test(model) assert did_assert + + +@pytest.mark.parametrize("persistent_workers", (True, False)) +def test_evaluation_workers_are_shutdown(tmpdir, persistent_workers): + # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` + # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance + train_dataloader = DataLoader(RandomDataset(32, 64), num_workers=0) + val_dataloader = DataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) + + class ValLoop(EvaluationLoop): + def on_run_end(self): + with mock.patch( + "torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers" + ) as shutdown_mock: + out = super().on_run_end() + shutdown_mock.assert_called_once() + return out + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2) + val_loop = ValLoop() + trainer.fit_loop.epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, val_loop) + val_loop.trainer = trainer + trainer.fit(model, train_dataloader, val_dataloader) diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 88c2706d75145..515c20ea3bd71 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -19,6 +19,7 @@ from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.loops import TrainingEpochLoop +from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from tests.helpers import BoringModel, RandomDataset From 07af7392d07417fe41992230de463692f9196937 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Oct 2021 18:32:12 +0100 Subject: [PATCH 17/31] update --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 1 + pytorch_lightning/utilities/fetching.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index dcff6f5ed74a6..3d29713d38da6 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -124,6 +124,7 @@ def on_run_end(self) -> List[_OUT_DICT]: # free memory self.outputs = [] self.data_fetcher.reset() + self.data_fetcher = None # with a single dataloader don't pass a 2D list if len(outputs) > 0 and self.num_dataloaders == 1: diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 3b1f6d301d6d9..476b197d9def6 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -206,14 +206,14 @@ def reset(self) -> None: self.batches: List = [] self.fetched: int = 0 self.done: bool = False - self.dataloader_iter = None - - def teardown(self) -> None: - self.reset() if isinstance(self.dataloader, CombinedLoader): self.dataloader.reset() if isinstance(self.dataloader, DataLoader): self.dataloader._iterator = None + self.dataloader_iter = None + + def teardown(self) -> None: + self.reset() class DataFetcher(AbstractDataFetcher): From 4f00eab1cbd328b27c9eb02be66ae2f9d9eb421d Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Oct 2021 18:32:29 +0100 Subject: [PATCH 18/31] drop dataloader --- pytorch_lightning/utilities/fetching.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 476b197d9def6..499cd8e0e017e 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -214,6 +214,7 @@ def reset(self) -> None: def teardown(self) -> None: self.reset() + self.dataloader = None class DataFetcher(AbstractDataFetcher): From 0da7b7009848357e6a7922a82966bd5e3746234f Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Oct 2021 18:56:10 +0100 Subject: [PATCH 19/31] update --- pytorch_lightning/trainer/supporters.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 907948ccced91..0044edaf56c57 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -491,11 +491,15 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]: def __len__(self) -> int: return self._calc_num_batches(self.loaders) + @staticmethod + def _reset(dataloader) -> None: + dataloader._iterator = None + def reset(self): if self._iterator: self._iterator._loader_iters = None if self.loaders: - self.loaders._iterator = None + apply_to_collection(self.loaders, DataLoader, self._reset) self._iterator = None From c81c056fb256b1110ffac6f8e0087d8915d65e60 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Oct 2021 19:15:23 +0100 Subject: [PATCH 20/31] add extra check --- pytorch_lightning/trainer/supporters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 0044edaf56c57..6dffa18a42ee0 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -498,7 +498,7 @@ def _reset(dataloader) -> None: def reset(self): if self._iterator: self._iterator._loader_iters = None - if self.loaders: + if self.loaders is not None: apply_to_collection(self.loaders, DataLoader, self._reset) self._iterator = None From 41bc58b6da91ed2a1422e05aaff6cfd196955cfa Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Oct 2021 19:34:20 +0100 Subject: [PATCH 21/31] update --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 3d29713d38da6..63a26bfb4bac9 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -123,8 +123,9 @@ def on_run_end(self) -> List[_OUT_DICT]: # free memory self.outputs = [] - self.data_fetcher.reset() - self.data_fetcher = None + if isinstance(self.data_fetcher, AbstractDataFetcher): + self.data_fetcher.reset() + self.data_fetcher = None # with a single dataloader don't pass a 2D list if len(outputs) > 0 and self.num_dataloaders == 1: From 15da2b12471d8bc476a1f9f9ff50e7aedaf0cc9b Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Oct 2021 19:39:16 +0100 Subject: [PATCH 22/31] update --- pytorch_lightning/trainer/connectors/data_connector.py | 1 - tests/loops/test_training_loop.py | 1 - 2 files changed, 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 3f50b042873b1..9b6f97f1ebec4 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -17,7 +17,6 @@ from typing import Iterable, Optional, Union import pytorch_lightning as pl -from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import ( diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 515c20ea3bd71..88c2706d75145 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -19,7 +19,6 @@ from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.loops import TrainingEpochLoop -from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from tests.helpers import BoringModel, RandomDataset From 4c55761617d2c51f43aabebf271bfc5617b37067 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Oct 2021 19:56:03 +0100 Subject: [PATCH 23/31] update --- tests/loops/test_evaluation_loop.py | 1 + tests/loops/test_training_loop.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/tests/loops/test_evaluation_loop.py b/tests/loops/test_evaluation_loop.py index b838ad5c3c87b..569262f0be1b7 100644 --- a/tests/loops/test_evaluation_loop.py +++ b/tests/loops/test_evaluation_loop.py @@ -133,6 +133,7 @@ def on_advance_end(self): assert did_assert +@RunIf(min_torch="1.8.0") @pytest.mark.parametrize("persistent_workers", (True, False)) def test_evaluation_workers_are_shutdown(tmpdir, persistent_workers): # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 88c2706d75145..24d621893f496 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -20,6 +20,7 @@ from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.loops import TrainingEpochLoop from tests.helpers import BoringModel, RandomDataset +from tests.helpers.runif import RunIf def test_outputs_format(tmpdir): @@ -147,6 +148,7 @@ def training_step_end(self, outputs): trainer.fit(model) +@RunIf(min_torch="1.8.0") @pytest.mark.parametrize("persistent_workers", (True, False)) def test_training_loop_workers_are_shutdown(tmpdir, persistent_workers): # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` From 1afa66ef6c52f7e06233ae528de40357f967f008 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Oct 2021 20:42:08 +0100 Subject: [PATCH 24/31] update --- tests/loops/test_evaluation_loop.py | 16 ++-------------- tests/loops/test_training_loop.py | 18 +----------------- 2 files changed, 3 insertions(+), 31 deletions(-) diff --git a/tests/loops/test_evaluation_loop.py b/tests/loops/test_evaluation_loop.py index 569262f0be1b7..f86285d31b4ed 100644 --- a/tests/loops/test_evaluation_loop.py +++ b/tests/loops/test_evaluation_loop.py @@ -18,7 +18,7 @@ from torch.utils.data import DataLoader from pytorch_lightning import Trainer -from pytorch_lightning.loops import EvaluationEpochLoop, EvaluationLoop +from pytorch_lightning.loops import EvaluationEpochLoop from pytorch_lightning.utilities.model_helpers import is_overridden from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -140,19 +140,7 @@ def test_evaluation_workers_are_shutdown(tmpdir, persistent_workers): # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance train_dataloader = DataLoader(RandomDataset(32, 64), num_workers=0) val_dataloader = DataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) - - class ValLoop(EvaluationLoop): - def on_run_end(self): - with mock.patch( - "torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers" - ) as shutdown_mock: - out = super().on_run_end() - shutdown_mock.assert_called_once() - return out - model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2) - val_loop = ValLoop() - trainer.fit_loop.epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, val_loop) - val_loop.trainer = trainer trainer.fit(model, train_dataloader, val_dataloader) + assert val_dataloader._iterator is None diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 24d621893f496..38168b47d4f5b 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -11,14 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest import mock - import pytest import torch from torch.utils.data import DataLoader from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.loops import TrainingEpochLoop from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -155,20 +152,7 @@ def test_training_loop_workers_are_shutdown(tmpdir, persistent_workers): # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance train_dataloader = DataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) - class TestLoop(TrainingEpochLoop): - def on_run_end(self): - with mock.patch( - "torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers" - ) as shutdown_mock: - out = super().on_run_end() - shutdown_mock.assert_called_once() - return out - model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=0, max_epochs=2) - - epoch_loop = TestLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps) - epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop) - trainer.fit_loop.connect(epoch_loop) - trainer.fit(model, train_dataloader) + assert train_dataloader._iterator is None From bb6ec3b7e4d520c6ee9b75a29e6368f5c79a694b Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Oct 2021 20:55:23 +0100 Subject: [PATCH 25/31] delete iterator only on end --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 5 ----- pytorch_lightning/loops/epoch/training_epoch_loop.py | 5 ----- pytorch_lightning/loops/fit_loop.py | 1 + 3 files changed, 1 insertion(+), 10 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 63a26bfb4bac9..6140bd60d6a7f 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -19,7 +19,6 @@ from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection -from pytorch_lightning.utilities.fetching import AbstractDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -35,7 +34,6 @@ def __init__(self): self._results = ResultCollection(training=False) self._max_batches: Optional[Union[int, Sequence[int]]] = None self._has_run: bool = False - self.data_fetcher: Optional[AbstractDataFetcher] = None @property def num_dataloaders(self) -> int: @@ -123,9 +121,6 @@ def on_run_end(self) -> List[_OUT_DICT]: # free memory self.outputs = [] - if isinstance(self.data_fetcher, AbstractDataFetcher): - self.data_fetcher.reset() - self.data_fetcher = None # with a single dataloader don't pass a 2D list if len(outputs) > 0 and self.num_dataloaders == 1: diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index eb6417462dd08..091156fda1d23 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -298,11 +298,6 @@ def on_run_end(self) -> None: if self._num_ready_batches_reached(): self.update_lr_schedulers("epoch", update_plateau_schedulers=True) - # delete any persistent workers. - self.trainer.train_dataloader.reset() - self.trainer._data_connector.train_data_fetcher.reset() - self._dataloader_iter = None - # if fault tolerant is enabled and process has been notified, exit. self.trainer._exit_gracefully_on_signal() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index e368c8361dde9..b6e7b168187eb 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -247,6 +247,7 @@ def on_run_end(self) -> None: def teardown(self) -> None: self.epoch_loop.teardown() + self.trainer._data_connector.teardown() def _should_accumulate(self) -> bool: """Whether the gradients should be accumulated.""" From fb592a45653dfbf13e303fe9f2eee52a9fb4e51e Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Oct 2021 20:58:28 +0100 Subject: [PATCH 26/31] delete iterator only on end --- pytorch_lightning/trainer/supporters.py | 4 +++- pytorch_lightning/utilities/fetching.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 6dffa18a42ee0..dddc1614b78fe 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -19,7 +19,7 @@ import torch from torch.utils.data import Dataset -from torch.utils.data.dataloader import _BaseDataLoaderIter, DataLoader +from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader from torch.utils.data.dataset import IterableDataset from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections @@ -493,6 +493,8 @@ def __len__(self) -> int: @staticmethod def _reset(dataloader) -> None: + if isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter): + dataloader._iterator._shutdown_workers() dataloader._iterator = None def reset(self): diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 499cd8e0e017e..a90a108efcb50 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -20,7 +20,7 @@ from typing import Any, Callable, Generator, List, Optional, Tuple import torch -from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader import pytorch_lightning as pl from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator @@ -209,6 +209,8 @@ def reset(self) -> None: if isinstance(self.dataloader, CombinedLoader): self.dataloader.reset() if isinstance(self.dataloader, DataLoader): + if isinstance(self.dataloader._iterator, _MultiProcessingDataLoaderIter): + self.dataloader._iterator._shutdown_workers() self.dataloader._iterator = None self.dataloader_iter = None From b16456892d1d5c032af25d4c9b383225284109e2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Oct 2021 21:06:01 +0100 Subject: [PATCH 27/31] update --- tests/loops/test_evaluation_loop.py | 21 +++++++++++++++++++-- tests/loops/test_training_loop.py | 22 ++++++++++++++++++++-- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/tests/loops/test_evaluation_loop.py b/tests/loops/test_evaluation_loop.py index f86285d31b4ed..61eb7ed21eea9 100644 --- a/tests/loops/test_evaluation_loop.py +++ b/tests/loops/test_evaluation_loop.py @@ -15,7 +15,7 @@ import pytest import torch -from torch.utils.data import DataLoader +from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader from pytorch_lightning import Trainer from pytorch_lightning.loops import EvaluationEpochLoop @@ -138,9 +138,26 @@ def on_advance_end(self): def test_evaluation_workers_are_shutdown(tmpdir, persistent_workers): # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance + has_shutdown_workers = False + + class _TestMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter): + def _shutdown_workers(self): + nonlocal has_shutdown_workers + has_shutdown_workers = True + super()._shutdown_workers() + + class TestDataLoader(DataLoader): + def _get_iterator(self): + if self.num_workers == 0: + return super()._get_iterator() + else: + self.check_worker_number_rationality() + return _TestMultiProcessingDataLoaderIter(self) + train_dataloader = DataLoader(RandomDataset(32, 64), num_workers=0) - val_dataloader = DataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) + val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2) trainer.fit(model, train_dataloader, val_dataloader) + assert has_shutdown_workers assert val_dataloader._iterator is None diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 38168b47d4f5b..bc387d71761c3 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -13,7 +13,7 @@ # limitations under the License. import pytest import torch -from torch.utils.data import DataLoader +from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader from pytorch_lightning import seed_everything, Trainer from tests.helpers import BoringModel, RandomDataset @@ -150,9 +150,27 @@ def training_step_end(self, outputs): def test_training_loop_workers_are_shutdown(tmpdir, persistent_workers): # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance - train_dataloader = DataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) + + has_shutdown_workers = False + + class _TestMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter): + def _shutdown_workers(self): + nonlocal has_shutdown_workers + has_shutdown_workers = True + super()._shutdown_workers() + + class TestDataLoader(DataLoader): + def _get_iterator(self): + if self.num_workers == 0: + return super()._get_iterator() + else: + self.check_worker_number_rationality() + return _TestMultiProcessingDataLoaderIter(self) + + train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=0, max_epochs=2) trainer.fit(model, train_dataloader) + assert has_shutdown_workers assert train_dataloader._iterator is None From 33aa5312dab858c1eaf9eb74b99696cb24a0eaec Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Oct 2021 21:24:02 +0100 Subject: [PATCH 28/31] update --- pytorch_lightning/trainer/supporters.py | 2 +- pytorch_lightning/utilities/fetching.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index dddc1614b78fe..059055db33eda 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -493,7 +493,7 @@ def __len__(self) -> int: @staticmethod def _reset(dataloader) -> None: - if isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter): + if hasattr(dataloader, "_iterator") and isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter): dataloader._iterator._shutdown_workers() dataloader._iterator = None diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index a90a108efcb50..de0fbbfcefc4d 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -209,7 +209,9 @@ def reset(self) -> None: if isinstance(self.dataloader, CombinedLoader): self.dataloader.reset() if isinstance(self.dataloader, DataLoader): - if isinstance(self.dataloader._iterator, _MultiProcessingDataLoaderIter): + if hasattr(self.dataloader, "_iterator") and isinstance( + self.dataloader._iterator, _MultiProcessingDataLoaderIter + ): self.dataloader._iterator._shutdown_workers() self.dataloader._iterator = None self.dataloader_iter = None From 6b2a9f6e40ca8b67126cba044f85a93dbc598c58 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 27 Oct 2021 08:48:40 +0100 Subject: [PATCH 29/31] remove dataloader delete --- pytorch_lightning/utilities/fetching.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index de0fbbfcefc4d..22d6fef852ea8 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -218,7 +218,6 @@ def reset(self) -> None: def teardown(self) -> None: self.reset() - self.dataloader = None class DataFetcher(AbstractDataFetcher): From 189a1219e94f908cb4f4695c9b09d7d22dd66226 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 27 Oct 2021 13:19:36 +0100 Subject: [PATCH 30/31] remove un-necessary --- pytorch_lightning/loops/fit_loop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index b6e7b168187eb..e368c8361dde9 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -247,7 +247,6 @@ def on_run_end(self) -> None: def teardown(self) -> None: self.epoch_loop.teardown() - self.trainer._data_connector.teardown() def _should_accumulate(self) -> bool: """Whether the gradients should be accumulated.""" From 257eb3a4eb0345116dfcb048abd0e7b2fcd96357 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 29 Oct 2021 13:20:10 +0100 Subject: [PATCH 31/31] resolve on comments --- pytorch_lightning/trainer/supporters.py | 4 +-- pytorch_lightning/utilities/fetching.py | 8 ++---- tests/loops/test_evaluation_loop.py | 33 +--------------------- tests/loops/test_loops.py | 37 ++++++++++++++++++++++++- tests/loops/test_training_loop.py | 35 +---------------------- 5 files changed, 42 insertions(+), 75 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 059055db33eda..816f4da38f5b9 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -492,7 +492,7 @@ def __len__(self) -> int: return self._calc_num_batches(self.loaders) @staticmethod - def _reset(dataloader) -> None: + def _shutdown_workers_and_reset_iterator(dataloader) -> None: if hasattr(dataloader, "_iterator") and isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter): dataloader._iterator._shutdown_workers() dataloader._iterator = None @@ -501,7 +501,7 @@ def reset(self): if self._iterator: self._iterator._loader_iters = None if self.loaders is not None: - apply_to_collection(self.loaders, DataLoader, self._reset) + apply_to_collection(self.loaders, DataLoader, self._shutdown_workers_and_reset_iterator) self._iterator = None diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 22d6fef852ea8..fd9baf3e9c4f1 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -20,7 +20,7 @@ from typing import Any, Callable, Generator, List, Optional, Tuple import torch -from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader +from torch.utils.data.dataloader import DataLoader import pytorch_lightning as pl from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator @@ -209,11 +209,7 @@ def reset(self) -> None: if isinstance(self.dataloader, CombinedLoader): self.dataloader.reset() if isinstance(self.dataloader, DataLoader): - if hasattr(self.dataloader, "_iterator") and isinstance( - self.dataloader._iterator, _MultiProcessingDataLoaderIter - ): - self.dataloader._iterator._shutdown_workers() - self.dataloader._iterator = None + CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader) self.dataloader_iter = None def teardown(self) -> None: diff --git a/tests/loops/test_evaluation_loop.py b/tests/loops/test_evaluation_loop.py index 61eb7ed21eea9..d6b2c15553fb9 100644 --- a/tests/loops/test_evaluation_loop.py +++ b/tests/loops/test_evaluation_loop.py @@ -13,9 +13,8 @@ # limitations under the License. from unittest import mock -import pytest import torch -from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader +from torch.utils.data.dataloader import DataLoader from pytorch_lightning import Trainer from pytorch_lightning.loops import EvaluationEpochLoop @@ -131,33 +130,3 @@ def on_advance_end(self): trainer.test_loop.connect(TestLoop()) trainer.test(model) assert did_assert - - -@RunIf(min_torch="1.8.0") -@pytest.mark.parametrize("persistent_workers", (True, False)) -def test_evaluation_workers_are_shutdown(tmpdir, persistent_workers): - # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` - # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance - has_shutdown_workers = False - - class _TestMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter): - def _shutdown_workers(self): - nonlocal has_shutdown_workers - has_shutdown_workers = True - super()._shutdown_workers() - - class TestDataLoader(DataLoader): - def _get_iterator(self): - if self.num_workers == 0: - return super()._get_iterator() - else: - self.check_worker_number_rationality() - return _TestMultiProcessingDataLoaderIter(self) - - train_dataloader = DataLoader(RandomDataset(32, 64), num_workers=0) - val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2) - trainer.fit(model, train_dataloader, val_dataloader) - assert has_shutdown_workers - assert val_dataloader._iterator is None diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index a1efa838e9e64..dd390ab4939d5 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -20,7 +20,7 @@ import pytest import torch -from torch.utils.data import DataLoader +from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader from pl_examples.bug_report_model import RandomDataset from pytorch_lightning import LightningModule, Trainer @@ -909,3 +909,38 @@ def val_dataloader(self): expected[val_batch_progress]["total"]["ready"] += 1 expected[val_batch_progress]["total"]["started"] += 1 assert state_dict_after_restart[val_batch_progress] == expected[val_batch_progress] + + +@RunIf(min_torch="1.8.0") +@pytest.mark.parametrize("persistent_workers", (True, False)) +def test_workers_are_shutdown(tmpdir, persistent_workers): + # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` + # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance + + class _TestMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter): + def __init__(self, *args, dataloader: DataLoader, **kwargs): + super().__init__(*args, **kwargs) + self.dataloader = dataloader + + def _shutdown_workers(self): + setattr(self.dataloader, "has_shutdown_workers", True) + super()._shutdown_workers() + + class TestDataLoader(DataLoader): + def _get_iterator(self): + if self.num_workers == 0: + return super()._get_iterator() + else: + self.check_worker_number_rationality() + return _TestMultiProcessingDataLoaderIter(self, dataloader=self) + + train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) + val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2) + trainer.fit(model, train_dataloader, val_dataloader) + assert train_dataloader.has_shutdown_workers + assert val_dataloader.has_shutdown_workers + assert train_dataloader._iterator is None + assert val_dataloader._iterator is None diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index bc387d71761c3..86801f56266c6 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -13,11 +13,9 @@ # limitations under the License. import pytest import torch -from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader from pytorch_lightning import seed_everything, Trainer -from tests.helpers import BoringModel, RandomDataset -from tests.helpers.runif import RunIf +from tests.helpers import BoringModel def test_outputs_format(tmpdir): @@ -143,34 +141,3 @@ def training_step_end(self, outputs): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) trainer.fit(model) - - -@RunIf(min_torch="1.8.0") -@pytest.mark.parametrize("persistent_workers", (True, False)) -def test_training_loop_workers_are_shutdown(tmpdir, persistent_workers): - # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` - # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance - - has_shutdown_workers = False - - class _TestMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter): - def _shutdown_workers(self): - nonlocal has_shutdown_workers - has_shutdown_workers = True - super()._shutdown_workers() - - class TestDataLoader(DataLoader): - def _get_iterator(self): - if self.num_workers == 0: - return super()._get_iterator() - else: - self.check_worker_number_rationality() - return _TestMultiProcessingDataLoaderIter(self) - - train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) - - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=0, max_epochs=2) - trainer.fit(model, train_dataloader) - assert has_shutdown_workers - assert train_dataloader._iterator is None