Skip to content

Commit

Permalink
SequentialMode and dataloader_iter improvements (#16784)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Feb 16, 2023
1 parent ad698f0 commit 746c734
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 47 deletions.
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Expand Up @@ -42,7 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a new method `Strategy.on_exception` to the strategy base interface ([#16646](https://github.com/Lightning-AI/lightning/pull/16646))


- Added "sequential" mode support to `CombinedLoader` to consume multiple iterables in sequence ([#16743](https://github.com/Lightning-AI/lightning/pull/16743))
- Added "sequential" mode support to `CombinedLoader` to consume multiple iterables in sequence ([#16743](https://github.com/Lightning-AI/lightning/pull/16743), [#16784](https://github.com/Lightning-AI/lightning/pull/16784))

### Changed

Expand Down
11 changes: 6 additions & 5 deletions src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py
Expand Up @@ -114,11 +114,12 @@ def advance(
Raises:
StopIteration: If the current batch is None
"""
if not isinstance(data_fetcher, _DataLoaderIterDataFetcher):
batch_idx = self.batch_progress.current.ready
batch = next(data_fetcher)
else:
batch_idx, batch = next(data_fetcher)
batch_idx = (
data_fetcher.fetched
if isinstance(data_fetcher, _DataLoaderIterDataFetcher)
else self.batch_progress.current.ready
)
batch = next(data_fetcher)
self.batch_progress.is_last_batch = data_fetcher.done

dataloader_idx = kwargs.get("dataloader_idx", 0)
Expand Down
7 changes: 2 additions & 5 deletions src/lightning/pytorch/loops/epoch/training_epoch_loop.py
Expand Up @@ -186,11 +186,8 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
# we are going to train first so the val loop does not need to restart
self.val_loop.restarting = False

if not isinstance(data_fetcher, _DataLoaderIterDataFetcher):
batch_idx = self.batch_idx + 1
batch = next(data_fetcher)
else:
batch_idx, batch = next(data_fetcher)
batch_idx = data_fetcher.fetched if isinstance(data_fetcher, _DataLoaderIterDataFetcher) else self.batch_idx + 1
batch = next(data_fetcher)
self.batch_progress.is_last_batch = data_fetcher.done

trainer = self.trainer
Expand Down
33 changes: 24 additions & 9 deletions src/lightning/pytorch/loops/fetchers.py
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple
from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple, Union

from torch.utils.data.dataloader import DataLoader

from lightning.fabric.utilities.data import has_len
from lightning.pytorch.trainer.supporters import _shutdown_workers_and_reset_iterator, CombinedLoader
from lightning.pytorch.trainer.supporters import _Sequential, _shutdown_workers_and_reset_iterator, CombinedLoader
from lightning.pytorch.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -175,20 +175,35 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None:

def __iter__(self) -> "_DataLoaderIterDataFetcher":
super().__iter__()
iterator = self.dataloader_iter
assert iterator is not None
self.iterator = iter(_DataFetcherWrapper(self))
return self

def __next__(self) -> Tuple[int, Iterator]:
if not self.done:
return self.fetched, self.iterator
raise StopIteration
def __next__(self) -> Union["_DataFetcherWrapper", Tuple["_DataFetcherWrapper", int, int]]:
if self.done:
raise StopIteration
assert isinstance(self.iterator, _DataFetcherWrapper)
if self._is_sequential:
sequential_mode = self.dataloader._iterator
assert isinstance(sequential_mode, _Sequential)
batch_idx = sequential_mode._idx
dataloader_idx = sequential_mode._iterator_idx
return self.iterator, batch_idx, dataloader_idx
return self.iterator

@property
def _is_sequential(self) -> bool:
return isinstance(self.dataloader, CombinedLoader) and self.dataloader._mode == "sequential"


class _DataFetcherWrapper(Iterator):
def __init__(self, data_fetcher: _DataLoaderIterDataFetcher) -> None:
self.data_fetcher = data_fetcher

def __next__(self) -> Any:
return super(_DataLoaderIterDataFetcher, self.data_fetcher).__next__()
out = super(_DataLoaderIterDataFetcher, self.data_fetcher).__next__()
if self.data_fetcher._is_sequential:
# avoid breaking change with sequential mode and dataloader_iter. this is okay because
# dataloader_iter + sequential + multiple dataloaders is not supported so the `*_step(..., batch_idx)` value
# and the batch_index we are excluding here will match
return out[0]
return out
68 changes: 46 additions & 22 deletions src/lightning/pytorch/trainer/supporters.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable
from typing import Any, Callable, Iterator, List, Literal, Optional, Sized, Tuple, Type, TypeVar
from typing import Any, Callable, Iterator, List, Literal, Optional, Sized, Tuple, Type, TypeVar, Union

from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter
from typing_extensions import Self, TypedDict
Expand Down Expand Up @@ -74,27 +74,47 @@ def __next__(self) -> List:
return [next(it) for it in self.iterators]


class _Sequential(_ModeIterator[Tuple[int, Any]]):
def __init__(self, iterables: List[Iterable]) -> None:
class _Sequential(_ModeIterator[Tuple[Any, int, int]]):
def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None:
super().__init__(iterables)
self._iterator_idx = 0 # what would be dataloader_idx
self._idx = 0 # what would be batch_idx
self.limits = limits

def __next__(self) -> Tuple[int, Any]:
@property
def limits(self) -> Optional[List[Union[int, float]]]:
"""Optional limits per iterator."""
return self._limits

@limits.setter
def limits(self, limits: Optional[List[Union[int, float]]]) -> None:
if limits is not None and len(limits) != len(self.iterables):
raise ValueError(
f"Mismatch in number of limits ({len(limits)}) and number of iterables ({len(self.iterables)})"
)
self._limits = limits

def __next__(self) -> Tuple[Any, int, int]:
n = len(self.iterators)
if n == 0:
if n == 0 or self._iterator_idx >= n:
raise StopIteration

# if limits are set, go to the correct iterator
if self.limits is not None:
while self.limits[self._iterator_idx] <= self._idx:
self._use_next_iterator()
if self._iterator_idx >= n:
raise StopIteration

try:
out = next(self.iterators[self._iterator_idx])
index = self._idx
self._idx += 1
# the return is enumerated by default
return index, out
# batch, batch_idx, dataloader_idx
return out, index, self._iterator_idx
except StopIteration:
self._iterator_idx += 1
self._idx = 0
if self._iterator_idx >= n:
raise
# try the next iterator
self._use_next_iterator()
return self.__next__()

def __iter__(self) -> Self: # type: ignore[valid-type]
Expand All @@ -108,6 +128,10 @@ def reset(self) -> None:
self._iterator_idx = 0
self._idx = 0

def _use_next_iterator(self) -> None:
self._iterator_idx += 1
self._idx = 0


class _CombinationMode(TypedDict):
fn: Callable[[List[int]], int]
Expand Down Expand Up @@ -170,28 +194,28 @@ class CombinedLoader(Iterable):
>>> combined_loader = CombinedLoader(iterables, 'max_size_cycle')
>>> len(combined_loader)
3
>>> for item in combined_loader:
... print(item)
>>> for batch in combined_loader:
... print(batch)
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
{'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])}
>>> combined_loader = CombinedLoader(iterables, 'min_size')
>>> len(combined_loader)
2
>>> for item in combined_loader:
... print(item)
>>> for batch in combined_loader:
... print(batch)
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
>>> combined_loader = CombinedLoader(iterables, 'sequential')
>>> len(combined_loader)
5
>>> for item in combined_loader:
... print(*item)
0 tensor([0, 1, 2, 3])
1 tensor([4, 5])
0 tensor([0, 1, 2, 3, 4])
1 tensor([5, 6, 7, 8, 9])
2 tensor([10, 11, 12, 13, 14])
>>> for batch, batch_idx, dataloader_idx in combined_loader:
... print(f"{batch} {batch_idx=} {dataloader_idx=}")
tensor([0, 1, 2, 3]) batch_idx=0 dataloader_idx=0
tensor([4, 5]) batch_idx=1 dataloader_idx=0
tensor([0, 1, 2, 3, 4]) batch_idx=0 dataloader_idx=1
tensor([5, 6, 7, 8, 9]) batch_idx=1 dataloader_idx=1
tensor([10, 11, 12, 13, 14]) batch_idx=2 dataloader_idx=1
"""

def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") -> None:
Expand Down
32 changes: 27 additions & 5 deletions tests/tests_pytorch/trainer/test_supporters.py
Expand Up @@ -122,13 +122,14 @@ def test_combined_loader_modes():
combined_loader = CombinedLoader(iterables, "sequential")
assert combined_loader._iterator is None
assert len(combined_loader) == sum_len
for total_idx, (idx, item) in enumerate(combined_loader):
for total_idx, (item, batch_idx, dataloader_idx) in enumerate(combined_loader):
assert isinstance(combined_loader._iterator, _Sequential)
assert isinstance(idx, int)
assert isinstance(batch_idx, int)
assert isinstance(item, Tensor)
assert idx == lengths[-1] - 1
assert total_idx == sum_len - 1
assert total_idx == len(combined_loader) - 1
assert dataloader_idx == len(iterables) - 1

iterables = list(iterables.values())

Expand Down Expand Up @@ -156,13 +157,14 @@ def test_combined_loader_modes():
combined_loader = CombinedLoader(iterables, "sequential")
assert combined_loader._iterator is None
assert len(combined_loader) == sum_len
for total_idx, (idx, item) in enumerate(combined_loader):
for total_idx, (item, batch_idx, dataloader_idx) in enumerate(combined_loader):
assert isinstance(combined_loader._iterator, _Sequential)
assert isinstance(idx, int)
assert isinstance(batch_idx, int)
assert isinstance(item, Tensor)
assert idx == lengths[-1] - 1
assert total_idx == sum_len - 1
assert total_idx == len(combined_loader) - 1
assert dataloader_idx == len(iterables) - 1


def test_combined_loader_raises():
Expand Down Expand Up @@ -205,7 +207,6 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader
has_break = False
for idx, item in enumerate(combined_loader):
assert isinstance(item, Sequence)
assert len(item) == 2 if use_multiple_dataloaders else 1
if not use_multiple_dataloaders and idx == 4:
has_break = True
break
Expand All @@ -221,6 +222,27 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader
assert idx == expected - 1


@pytest.mark.parametrize(
("limits", "expected"),
[
(None, [("a", 0, 0), ("b", 1, 0), ("c", 2, 0), ("d", 0, 1), ("e", 1, 1)]),
([1, 0], [("a", 0, 0)]),
([0, float("inf")], [("d", 0, 1), ("e", 1, 1)]),
([1, 1], [("a", 0, 0), ("d", 0, 1)]),
],
)
def test_sequential_mode_limits(limits, expected):
iterable1 = ["a", "b", "c"]
iterable2 = ["d", "e"]
iterator = _Sequential([iterable1, iterable2], limits)
assert list(iterator) == expected


def test_sequential_mode_limits_raises():
with pytest.raises(ValueError, match=r"number of limits \(0\) and number of iterables \(2\)"):
_Sequential([0, 1], [])


@pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]])
def test_combined_loader_sequence_with_map_and_iterable(lengths):
class MyIterableDataset(IterableDataset):
Expand Down

0 comments on commit 746c734

Please sign in to comment.