Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

del iterator on_run_end() #9915

Merged
merged 33 commits into from Oct 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ee7ef64
del iterator on_run_end()
cowwoc Oct 13, 2021
c2b155a
del iterator on_run_end()
cowwoc Oct 13, 2021
4ac6c86
Don't bother deleting data_fetcher
cowwoc Oct 13, 2021
b587849
Updated changelog
cowwoc Oct 13, 2021
3ff2121
Fixed typo in changelog
cowwoc Oct 13, 2021
4437863
Added comment per PR review
cowwoc Oct 14, 2021
2f391a8
Added comment per PR review
cowwoc Oct 14, 2021
949092e
Add failing test
carmocca Oct 15, 2021
06e20a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2021
183b988
resolve iterator reference
tchaton Oct 25, 2021
20e7560
resolve iterator reference
tchaton Oct 25, 2021
3234730
update
tchaton Oct 25, 2021
deb4f25
update
tchaton Oct 25, 2021
590a262
merge
tchaton Oct 25, 2021
9032c2c
Use mock in tests
carmocca Oct 25, 2021
a673c6a
Remove evaluation test that does not fail in master
carmocca Oct 25, 2021
95cfcdf
update
tchaton Oct 26, 2021
07af739
update
tchaton Oct 26, 2021
4f00eab
drop dataloader
tchaton Oct 26, 2021
0da7b70
update
tchaton Oct 26, 2021
c81c056
add extra check
tchaton Oct 26, 2021
41bc58b
update
tchaton Oct 26, 2021
15da2b1
update
tchaton Oct 26, 2021
4c55761
update
tchaton Oct 26, 2021
1afa66e
update
tchaton Oct 26, 2021
bb6ec3b
delete iterator only on end
tchaton Oct 26, 2021
fb592a4
delete iterator only on end
tchaton Oct 26, 2021
b164568
update
tchaton Oct 26, 2021
33aa531
update
tchaton Oct 26, 2021
6b2a9f6
remove dataloader delete
tchaton Oct 27, 2021
189a121
remove un-necessary
tchaton Oct 27, 2021
f11a21d
Merge branch 'master' into del-iterators
tchaton Oct 27, 2021
257eb3a
resolve on comments
tchaton Oct 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -588,7 +588,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed bug where the training step output needed to be `deepcopy`-ed ([#9349](https://github.com/PyTorchLightning/pytorch-lightning/pull/9349))


- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386))
- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386)) ([#9915](https://github.com/PyTorchLightning/pytorch-lightning/pull/9915))


- Fixed `BasePredictionWriter` not returning the batch_indices in a non-distributed setting ([#9432](https://github.com/PyTorchLightning/pytorch-lightning/pull/9432))
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/loops/dataloader/evaluation_loop.py
Expand Up @@ -101,7 +101,9 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

dataloader_idx: int = self.current_dataloader_idx
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
dataloader = self.trainer._data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader(
dataloader, dataloader_idx=dataloader_idx
)
dl_max_batches = self._max_batches[dataloader_idx]

dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Expand Up @@ -306,8 +306,6 @@ def on_run_end(self) -> None:
if self._num_ready_batches_reached():
self.update_lr_schedulers("epoch", update_plateau_schedulers=True)

self._dataloader_iter = None

# if fault tolerant is enabled and process has been notified, exit.
self.trainer._exit_gracefully_on_signal()

Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/loops/fit_loop.py
Expand Up @@ -209,7 +209,9 @@ def on_advance_start(self) -> None:
self.trainer.reset_train_dataloader(model)
self._is_fresh_start_epoch = False

if callable(getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)):
if self.trainer.train_dataloader is not None and callable(
getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
):
# set seed for distributed sampler (enables shuffling for each epoch)
self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch)

Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/trainer/supporters.py
Expand Up @@ -19,7 +19,7 @@

import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import _BaseDataLoaderIter, DataLoader
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
from torch.utils.data.dataset import IterableDataset

from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
Expand Down Expand Up @@ -491,6 +491,19 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]:
def __len__(self) -> int:
return self._calc_num_batches(self.loaders)

@staticmethod
def _shutdown_workers_and_reset_iterator(dataloader) -> None:
if hasattr(dataloader, "_iterator") and isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter):
dataloader._iterator._shutdown_workers()
dataloader._iterator = None

def reset(self):
if self._iterator:
self._iterator._loader_iters = None
if self.loaders is not None:
apply_to_collection(self.loaders, DataLoader, self._shutdown_workers_and_reset_iterator)
self._iterator = None


class CombinedLoaderIterator:
"""Custom Iterator returning data from multple loaders, and allows sampling in parallel."""
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/utilities/fetching.py
Expand Up @@ -204,9 +204,13 @@ def __next__(self):

def reset(self) -> None:
self.batches: List = []
self.dataloader: Optional[Iterable]
self.fetched: int = 0
self.done: bool = False
if isinstance(self.dataloader, CombinedLoader):
self.dataloader.reset()
if isinstance(self.dataloader, DataLoader):
CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader)
self.dataloader_iter = None

def teardown(self) -> None:
self.reset()
Expand Down
2 changes: 1 addition & 1 deletion tests/loops/test_evaluation_loop.py
Expand Up @@ -14,7 +14,7 @@
from unittest import mock

import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.loops import EvaluationEpochLoop
Expand Down
37 changes: 36 additions & 1 deletion tests/loops/test_loops.py
Expand Up @@ -20,7 +20,7 @@

import pytest
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader

from pl_examples.bug_report_model import RandomDataset
from pytorch_lightning import LightningModule, Trainer
Expand Down Expand Up @@ -909,3 +909,38 @@ def val_dataloader(self):
expected[val_batch_progress]["total"]["ready"] += 1
expected[val_batch_progress]["total"]["started"] += 1
assert state_dict_after_restart[val_batch_progress] == expected[val_batch_progress]


@RunIf(min_torch="1.8.0")
@pytest.mark.parametrize("persistent_workers", (True, False))
def test_workers_are_shutdown(tmpdir, persistent_workers):
# `num_workers == 1` uses `_MultiProcessingDataLoaderIter`
# `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance

class _TestMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter):
def __init__(self, *args, dataloader: DataLoader, **kwargs):
super().__init__(*args, **kwargs)
self.dataloader = dataloader

def _shutdown_workers(self):
setattr(self.dataloader, "has_shutdown_workers", True)
super()._shutdown_workers()

class TestDataLoader(DataLoader):
def _get_iterator(self):
if self.num_workers == 0:
return super()._get_iterator()
else:
self.check_worker_number_rationality()
return _TestMultiProcessingDataLoaderIter(self, dataloader=self)

train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)
val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)

model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2)
trainer.fit(model, train_dataloader, val_dataloader)
assert train_dataloader.has_shutdown_workers
assert val_dataloader.has_shutdown_workers
assert train_dataloader._iterator is None
assert val_dataloader._iterator is None
1 change: 0 additions & 1 deletion tests/loops/test_training_loop.py
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

Expand Down