Skip to content

Commit

Permalink
del iterator on_run_end() (#9915)
Browse files Browse the repository at this point in the history
  • Loading branch information
cowwoc committed Oct 29, 2021
1 parent e4eb61d commit a967b6e
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 10 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -595,7 +595,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed bug where the training step output needed to be `deepcopy`-ed ([#9349](https://github.com/PyTorchLightning/pytorch-lightning/pull/9349))


- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386))
- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386)) ([#9915](https://github.com/PyTorchLightning/pytorch-lightning/pull/9915))


- Fixed `BasePredictionWriter` not returning the batch_indices in a non-distributed setting ([#9432](https://github.com/PyTorchLightning/pytorch-lightning/pull/9432))
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/loops/dataloader/evaluation_loop.py
Expand Up @@ -101,7 +101,9 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

dataloader_idx: int = self.current_dataloader_idx
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
dataloader = self.trainer._data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader(
dataloader, dataloader_idx=dataloader_idx
)
dl_max_batches = self._max_batches[dataloader_idx]

dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Expand Up @@ -306,8 +306,6 @@ def on_run_end(self) -> None:
if self._num_ready_batches_reached():
self.update_lr_schedulers("epoch", update_plateau_schedulers=True)

self._dataloader_iter = None

# if fault tolerant is enabled and process has been notified, exit.
self.trainer._exit_gracefully_on_signal()

Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/loops/fit_loop.py
Expand Up @@ -209,7 +209,9 @@ def on_advance_start(self) -> None:
self.trainer.reset_train_dataloader(model)
self._is_fresh_start_epoch = False

if callable(getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)):
if self.trainer.train_dataloader is not None and callable(
getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
):
# set seed for distributed sampler (enables shuffling for each epoch)
self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch)

Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/trainer/supporters.py
Expand Up @@ -19,7 +19,7 @@

import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import _BaseDataLoaderIter, DataLoader
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
from torch.utils.data.dataset import IterableDataset

from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
Expand Down Expand Up @@ -491,6 +491,19 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]:
def __len__(self) -> int:
return self._calc_num_batches(self.loaders)

@staticmethod
def _shutdown_workers_and_reset_iterator(dataloader) -> None:
if hasattr(dataloader, "_iterator") and isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter):
dataloader._iterator._shutdown_workers()
dataloader._iterator = None

def reset(self):
if self._iterator:
self._iterator._loader_iters = None
if self.loaders is not None:
apply_to_collection(self.loaders, DataLoader, self._shutdown_workers_and_reset_iterator)
self._iterator = None


class CombinedLoaderIterator:
"""Custom Iterator returning data from multple loaders, and allows sampling in parallel."""
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/utilities/fetching.py
Expand Up @@ -204,9 +204,13 @@ def __next__(self):

def reset(self) -> None:
self.batches: List = []
self.dataloader: Optional[Iterable]
self.fetched: int = 0
self.done: bool = False
if isinstance(self.dataloader, CombinedLoader):
self.dataloader.reset()
if isinstance(self.dataloader, DataLoader):
CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader)
self.dataloader_iter = None

def teardown(self) -> None:
self.reset()
Expand Down
2 changes: 1 addition & 1 deletion tests/loops/test_evaluation_loop.py
Expand Up @@ -14,7 +14,7 @@
from unittest import mock

import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.loops import EvaluationEpochLoop
Expand Down
37 changes: 36 additions & 1 deletion tests/loops/test_loops.py
Expand Up @@ -20,7 +20,7 @@

import pytest
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader

from pl_examples.bug_report_model import RandomDataset
from pytorch_lightning import LightningModule, Trainer
Expand Down Expand Up @@ -909,3 +909,38 @@ def val_dataloader(self):
expected[val_batch_progress]["total"]["ready"] += 1
expected[val_batch_progress]["total"]["started"] += 1
assert state_dict_after_restart[val_batch_progress] == expected[val_batch_progress]


@RunIf(min_torch="1.8.0")
@pytest.mark.parametrize("persistent_workers", (True, False))
def test_workers_are_shutdown(tmpdir, persistent_workers):
# `num_workers == 1` uses `_MultiProcessingDataLoaderIter`
# `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance

class _TestMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter):
def __init__(self, *args, dataloader: DataLoader, **kwargs):
super().__init__(*args, **kwargs)
self.dataloader = dataloader

def _shutdown_workers(self):
setattr(self.dataloader, "has_shutdown_workers", True)
super()._shutdown_workers()

class TestDataLoader(DataLoader):
def _get_iterator(self):
if self.num_workers == 0:
return super()._get_iterator()
else:
self.check_worker_number_rationality()
return _TestMultiProcessingDataLoaderIter(self, dataloader=self)

train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)
val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)

model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2)
trainer.fit(model, train_dataloader, val_dataloader)
assert train_dataloader.has_shutdown_workers
assert val_dataloader.has_shutdown_workers
assert train_dataloader._iterator is None
assert val_dataloader._iterator is None
1 change: 0 additions & 1 deletion tests/loops/test_training_loop.py
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

Expand Down

0 comments on commit a967b6e

Please sign in to comment.