From 78efb8bca0822e73c48c49fb88644511c88d355b Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 13 Aug 2021 17:21:08 +0200 Subject: [PATCH 1/4] resolve issues --- pytorch_lightning/trainer/supporters.py | 46 +++++++++++++++++++++---- tests/trainer/test_supporters.py | 42 +++++++++++++++++++++- 2 files changed, 81 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 058202fea42cc..4bf7c32e0a994 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -14,8 +14,9 @@ import os from collections.abc import Iterable, Iterator, Mapping, Sequence +from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union import torch from torch import Tensor @@ -170,12 +171,29 @@ def to_disk(self) -> None: torch.save(outputs, fp) +@dataclass +class SharedCycleIteratorState: + + mode: str = "max_size_cycle" + dataloaders: List[DataLoader] = field(default_factory=lambda: []) + has_finished: Dict[int, bool] = field(default_factory=lambda: {}) + + def reset(self) -> None: + for dataloader in self.dataloaders: + self.has_finished[id(dataloader)] = False + + @property + def done(self) -> bool: + decision_fn = all if self.mode == "max_size_cycle" else any + return decision_fn(self.has_finished.values()) + + class CycleIterator: """ Iterator for restarting a dataloader if it runs out of samples """ - def __init__(self, loader: Any, length: Optional[int] = None): + def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycleIteratorState = None): """ Args: loader: the loader to restart for cyclic (and optionally infinite) sampling @@ -185,6 +203,10 @@ def __init__(self, loader: Any, length: Optional[int] = None): if length is None: length = float("inf") + self.state = state + if state: + state.dataloaders.append(loader) + self.length = length self.loader = loader self._loader_iter = None @@ -205,22 +227,30 @@ def __next__(self) -> Any: """ Fetches the next batch from internal dataloader and restarts it if necessary - Returns: Any: the resulting batch - Raises: StopIteration: if more then :attr:`length` batches have been returned """ # Note: if self.length is `inf`, then the iterator will never stop - if self.counter >= self.__len__(): + if self.counter >= self.__len__() or (self.state is not None and self.state.done): raise StopIteration try: return next(self._loader_iter) except StopIteration: + + if self.state is not None: + # inform the shared state this loader has completed + self.state.has_finished[id(self.loader)] = True + + # check if iteration should be stopped. + if self.state.done: + raise StopIteration + self._loader_iter = iter(self.loader) + return next(self._loader_iter) finally: @@ -468,10 +498,14 @@ def _wrap_loaders_max_size_cycle(self) -> Any: # multiple loaders if isinstance(self.loaders, (Sequence, Mapping)): + state = SharedCycleIteratorState() + self.loaders = apply_to_collection( - self.loaders, Iterable, CycleIterator, length=length, wrong_dtype=(Sequence, Mapping) + self.loaders, Iterable, CycleIterator, length=length, state=state, wrong_dtype=(Sequence, Mapping) ) + state.reset() + def __iter__(self) -> Any: """ Create and return an iterator, `CombinedLoaderIterator`, for the combined loader. diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index a4bb622477d36..8b8b7d808493a 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -20,7 +20,7 @@ from torch.utils.data import DataLoader, TensorDataset from torch.utils.data.dataset import Dataset, IterableDataset from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import Sampler +from torch.utils.data.sampler import Sampler, SequentialSampler from pytorch_lightning import Trainer from pytorch_lightning.trainer.supporters import ( @@ -59,6 +59,7 @@ def test_tensor_running_accum_reset(): def test_cycle_iterator(): """Test the cycling function of `CycleIterator`""" + iterator = CycleIterator(range(100), 1000) assert len(iterator) == 1000 for idx, item in enumerate(iterator): @@ -216,6 +217,45 @@ def test_combined_loader_sequence_min_size(): assert idx == len(combined_loader) - 1 +class TestIterableDataset(IterableDataset): + def __init__(self, size: int = 10): + self.size = size + + def __iter__(self): + self.sampler = SequentialSampler(range(self.size)) + self.sampler_iter = iter(self.sampler) + return self + + def __next__(self): + return next(self.sampler_iter) + + +@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle"]) +@pytest.mark.parametrize("use_multiple_dataloaders", [False, True]) +def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloaders): + """Test `CombinedLoader` of mode 'min_size' given sequence loaders""" + if use_multiple_dataloaders: + loaders = [ + torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2), + torch.utils.data.DataLoader(TestIterableDataset(20), batch_size=2), + ] + else: + loaders = [ + torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2), + ] + + combined_loader = CombinedLoader(loaders, mode) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, Sequence) + assert len(item) == 2 if use_multiple_dataloaders else 1 + + if mode == "max_size_cycle": + assert combined_loader.loaders[0].state.done + expected = (10 if mode == "max_size_cycle" else 5) if use_multiple_dataloaders else 5 + assert (expected - 1) == idx, (mode, use_multiple_dataloaders) + + def test_combined_loader_sequence_max_size_cycle(): """Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders""" loaders = [ From 4682188c92253820ace27ea470e5ad69d942ab17 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 13 Aug 2021 17:30:34 +0200 Subject: [PATCH 2/4] add an extra test --- tests/trainer/test_supporters.py | 33 ++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 8b8b7d808493a..14e745352428b 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -256,6 +256,39 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader assert (expected - 1) == idx, (mode, use_multiple_dataloaders) +@pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]]) +def test_combined_loader_sequence_with_map_and_iterable(lengths): + class MyIterableDataset(IterableDataset): + def __init__(self, size: int = 10): + self.size = size + + def __iter__(self): + self.sampler = SequentialSampler(range(self.size)) + self.iter_sampler = iter(self.sampler) + return self + + def __next__(self): + return next(self.iter_sampler) + + class MyMapDataset(Dataset): + def __init__(self, size: int = 10): + self.size = size + + def __getitem__(self, index): + return index + + def __len__(self): + return self.size + + x, y = lengths + loaders = [DataLoader(MyIterableDataset(x)), DataLoader(MyMapDataset(y))] + dataloader = CombinedLoader(loaders, mode="max_size_cycle") + counter = 0 + for _ in dataloader: + counter += 1 + assert counter == max(x, y) + + def test_combined_loader_sequence_max_size_cycle(): """Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders""" loaders = [ From d77c77f0aa1133e3ce7e0adc5e9c4548ffb8975d Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 13 Aug 2021 17:35:08 +0200 Subject: [PATCH 3/4] update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 384fb6a20e1a0..0b86b175e7c4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault-tolerant training: * Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366)) + * Added `SharedCycleIteratorState` to prevent infinite loop ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889)) ### Changed @@ -156,6 +157,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861)) +- Fixed infinite loop with CycleIterator and multiple loaders ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889)) + + ## [1.4.0] - 2021-07-27 ### Added From 5e917d141dafa4c9b3eb45ff7d78f9515b3bffde Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 13 Aug 2021 18:53:46 +0200 Subject: [PATCH 4/4] update --- pytorch_lightning/trainer/supporters.py | 29 ++++++++++++++++--------- tests/trainer/test_supporters.py | 7 +++++- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 4bf7c32e0a994..6c87ac6c3e9b6 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -177,13 +177,19 @@ class SharedCycleIteratorState: mode: str = "max_size_cycle" dataloaders: List[DataLoader] = field(default_factory=lambda: []) has_finished: Dict[int, bool] = field(default_factory=lambda: {}) + has_reset: bool = False def reset(self) -> None: for dataloader in self.dataloaders: self.has_finished[id(dataloader)] = False + self.has_reset = True @property def done(self) -> bool: + if not self.has_reset: + raise MisconfigurationException("Please, call reset once all dataloaders have been added.") + if len(self.dataloaders) == 1: + return False decision_fn = all if self.mode == "max_size_cycle" else any return decision_fn(self.has_finished.values()) @@ -203,10 +209,15 @@ def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycle if length is None: length = float("inf") - self.state = state - if state: + if not state: + state = SharedCycleIteratorState() + state.dataloaders.append(loader) + state.reset() + else: state.dataloaders.append(loader) + self.state = state + self.length = length self.loader = loader self._loader_iter = None @@ -233,7 +244,7 @@ def __next__(self) -> Any: StopIteration: if more then :attr:`length` batches have been returned """ # Note: if self.length is `inf`, then the iterator will never stop - if self.counter >= self.__len__() or (self.state is not None and self.state.done): + if self.counter >= self.__len__() or self.state.done: raise StopIteration try: @@ -241,16 +252,14 @@ def __next__(self) -> Any: except StopIteration: - if self.state is not None: - # inform the shared state this loader has completed - self.state.has_finished[id(self.loader)] = True + # inform the shared state this loader has completed + self.state.has_finished[id(self.loader)] = True - # check if iteration should be stopped. - if self.state.done: - raise StopIteration + # check if iteration should be stopped. + if self.state.done: + raise StopIteration self._loader_iter = iter(self.loader) - return next(self._loader_iter) finally: diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 14e745352428b..e8e5d0be10c35 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -246,12 +246,17 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader combined_loader = CombinedLoader(loaders, mode) + has_break = False + for idx, item in enumerate(combined_loader): assert isinstance(item, Sequence) assert len(item) == 2 if use_multiple_dataloaders else 1 + if not use_multiple_dataloaders and idx == 4: + has_break = True + break if mode == "max_size_cycle": - assert combined_loader.loaders[0].state.done + assert combined_loader.loaders[0].state.done == (not has_break) expected = (10 if mode == "max_size_cycle" else 5) if use_multiple_dataloaders else 5 assert (expected - 1) == idx, (mode, use_multiple_dataloaders)