Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fault-tolerant training:
* Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366))
* Added `SharedCycleIteratorState` to prevent infinite loop ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889))


### Changed
Expand Down Expand Up @@ -156,6 +157,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861))


- Fixed infinite loop with CycleIterator and multiple loaders ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889))


## [1.4.0] - 2021-07-27

### Added
Expand Down
55 changes: 49 additions & 6 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

import os
from collections.abc import Iterable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -170,12 +171,35 @@ def to_disk(self) -> None:
torch.save(outputs, fp)


@dataclass
class SharedCycleIteratorState:

mode: str = "max_size_cycle"
Comment on lines +175 to +177
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class SharedCycleIteratorState:
mode: str = "max_size_cycle"
class SharedCycleIteratorState:
"""A state shared between all CylceIterators in a CombinedLoader.
With a shared state, the iterators can decide to terminate based on the state of all others.
If the mode is *max_size_cycle*, all iterators need to have finished before the combined loading is considered
finished, and otherwise any iterator finishing early will lead to all iterators ending early.
"""
mode: str = "max_size_cycle"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of these docs? Did I get it right?

dataloaders: List[DataLoader] = field(default_factory=lambda: [])
has_finished: Dict[int, bool] = field(default_factory=lambda: {})
has_reset: bool = False

def reset(self) -> None:
for dataloader in self.dataloaders:
self.has_finished[id(dataloader)] = False
self.has_reset = True

@property
def done(self) -> bool:
if not self.has_reset:
raise MisconfigurationException("Please, call reset once all dataloaders have been added.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise MisconfigurationException("Please, call reset once all dataloaders have been added.")
raise MisconfigurationException("Please call reset once all dataloaders have been added.")

with the comma in there it sounds a bit passive-aggressive 🤣

if len(self.dataloaders) == 1:
return False
decision_fn = all if self.mode == "max_size_cycle" else any
return decision_fn(self.has_finished.values())


class CycleIterator:
"""
Iterator for restarting a dataloader if it runs out of samples
"""

def __init__(self, loader: Any, length: Optional[int] = None):
def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycleIteratorState = None):
"""
Args:
loader: the loader to restart for cyclic (and optionally infinite) sampling
Expand All @@ -185,6 +209,15 @@ def __init__(self, loader: Any, length: Optional[int] = None):
if length is None:
length = float("inf")

if not state:
state = SharedCycleIteratorState()
state.dataloaders.append(loader)
state.reset()
else:
state.dataloaders.append(loader)

self.state = state

self.length = length
self.loader = loader
self._loader_iter = None
Expand All @@ -205,21 +238,27 @@ def __next__(self) -> Any:
"""
Fetches the next batch from internal dataloader and restarts
it if necessary

Returns:
Any: the resulting batch

Raises:
StopIteration: if more then :attr:`length` batches have been returned
"""
# Note: if self.length is `inf`, then the iterator will never stop
if self.counter >= self.__len__():
if self.counter >= self.__len__() or self.state.done:
raise StopIteration

try:
return next(self._loader_iter)

except StopIteration:

# inform the shared state this loader has completed
self.state.has_finished[id(self.loader)] = True

# check if iteration should be stopped.
if self.state.done:
raise StopIteration

self._loader_iter = iter(self.loader)
return next(self._loader_iter)

Expand Down Expand Up @@ -468,10 +507,14 @@ def _wrap_loaders_max_size_cycle(self) -> Any:

# multiple loaders
if isinstance(self.loaders, (Sequence, Mapping)):
state = SharedCycleIteratorState()

self.loaders = apply_to_collection(
self.loaders, Iterable, CycleIterator, length=length, wrong_dtype=(Sequence, Mapping)
self.loaders, Iterable, CycleIterator, length=length, state=state, wrong_dtype=(Sequence, Mapping)
)

state.reset()

def __iter__(self) -> Any:
"""
Create and return an iterator, `CombinedLoaderIterator`, for the combined loader.
Expand Down
80 changes: 79 additions & 1 deletion tests/trainer/test_supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import Sampler
from torch.utils.data.sampler import Sampler, SequentialSampler

from pytorch_lightning import Trainer
from pytorch_lightning.trainer.supporters import (
Expand Down Expand Up @@ -59,6 +59,7 @@ def test_tensor_running_accum_reset():

def test_cycle_iterator():
"""Test the cycling function of `CycleIterator`"""

iterator = CycleIterator(range(100), 1000)
assert len(iterator) == 1000
for idx, item in enumerate(iterator):
Expand Down Expand Up @@ -216,6 +217,83 @@ def test_combined_loader_sequence_min_size():
assert idx == len(combined_loader) - 1


class TestIterableDataset(IterableDataset):
def __init__(self, size: int = 10):
self.size = size

def __iter__(self):
self.sampler = SequentialSampler(range(self.size))
self.sampler_iter = iter(self.sampler)
return self

def __next__(self):
return next(self.sampler_iter)


@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle"])
@pytest.mark.parametrize("use_multiple_dataloaders", [False, True])
def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloaders):
"""Test `CombinedLoader` of mode 'min_size' given sequence loaders"""
if use_multiple_dataloaders:
loaders = [
torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
torch.utils.data.DataLoader(TestIterableDataset(20), batch_size=2),
]
else:
loaders = [
torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
]

combined_loader = CombinedLoader(loaders, mode)

has_break = False

for idx, item in enumerate(combined_loader):
assert isinstance(item, Sequence)
assert len(item) == 2 if use_multiple_dataloaders else 1
if not use_multiple_dataloaders and idx == 4:
has_break = True
break

if mode == "max_size_cycle":
assert combined_loader.loaders[0].state.done == (not has_break)
expected = (10 if mode == "max_size_cycle" else 5) if use_multiple_dataloaders else 5
assert (expected - 1) == idx, (mode, use_multiple_dataloaders)


@pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]])
def test_combined_loader_sequence_with_map_and_iterable(lengths):
class MyIterableDataset(IterableDataset):
def __init__(self, size: int = 10):
self.size = size

def __iter__(self):
self.sampler = SequentialSampler(range(self.size))
self.iter_sampler = iter(self.sampler)
return self

def __next__(self):
return next(self.iter_sampler)

class MyMapDataset(Dataset):
def __init__(self, size: int = 10):
self.size = size

def __getitem__(self, index):
return index

def __len__(self):
return self.size

x, y = lengths
loaders = [DataLoader(MyIterableDataset(x)), DataLoader(MyMapDataset(y))]
dataloader = CombinedLoader(loaders, mode="max_size_cycle")
counter = 0
for _ in dataloader:
counter += 1
assert counter == max(x, y)


def test_combined_loader_sequence_max_size_cycle():
"""Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders"""
loaders = [
Expand Down