Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"sequential" mode for CombinedLoader #16743

Merged
merged 5 commits into from Feb 14, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/pl_servable_module/production.py
Expand Up @@ -81,7 +81,7 @@ def serialize(self, tensor: torch.Tensor) -> int:
class ProductionReadyModel(LitModule, ServableModule):
def configure_payload(self):
# 1: Access the train dataloader and load a single sample.
image, _ = self.trainer.train_dataloader.loaders.dataset[0]
image, _ = self.trainer.train_dataloader.iterables.dataset[0]

# 2: Convert the image into a PIL Image to bytes and encode it with base64
pil_image = T.ToPILImage()(image)
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Expand Up @@ -242,7 +242,7 @@ def _prepare_dataloader(
- Wrapping the dataloader based on strategy-specific logic
"""
if isinstance(dataloader, CombinedLoader):
for i, dl in enumerate(dataloader._loaders_flattened):
for i, dl in enumerate(dataloader._flattened):
dataloader._update_index(self._prepare_dataloader(dl, shuffle=shuffle, mode=mode), i)
return dataloader

Expand Down Expand Up @@ -344,7 +344,7 @@ def _reset_eval_dataloader(

for loader in dataloaders:
apply_to_collection(
loader.loaders if isinstance(loader, CombinedLoader) else loader,
loader.iterables if isinstance(loader, CombinedLoader) else loader,
DataLoader,
self._check_eval_shuffling,
mode=mode,
Expand Down
121 changes: 84 additions & 37 deletions src/lightning/pytorch/trainer/supporters.py
Expand Up @@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Iterable, Iterator, List, Literal, Optional, Sized, Type, TypeVar
from collections.abc import Iterable
from typing import Any, Callable, Iterator, List, Literal, Optional, Sized, Tuple, Type, TypeVar

from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter
from typing_extensions import Self, TypedDict

from lightning.fabric.utilities.data import sized_len
Expand Down Expand Up @@ -73,6 +74,41 @@ def __next__(self) -> List:
return [next(it) for it in self.iterators]


class _Sequential(_ModeIterator[Tuple[int, Any]]):
def __init__(self, iterables: List[Iterable]) -> None:
super().__init__(iterables)
self._iterator_idx = 0 # what would be dataloader_idx
self._idx = 0 # what would be batch_idx

def __next__(self) -> Tuple[int, Any]:
n = len(self.iterators)
if n == 0:
raise StopIteration
try:
out = next(self.iterators[self._iterator_idx])
index = self._idx
self._idx += 1
# the return is enumerated by default
return index, out
except StopIteration:
self._iterator_idx += 1
self._idx = 0
if self._iterator_idx >= n:
raise
return self.__next__()

def __iter__(self) -> Self: # type: ignore[valid-type]
super().__iter__()
self._iterator_idx = 0
self._idx = 0
return self

def reset(self) -> None:
super().reset()
self._iterator_idx = 0
self._idx = 0


class _CombinationMode(TypedDict):
fn: Callable[[List[int]], int]
iterator: Type[_ModeIterator]
Expand All @@ -81,9 +117,10 @@ class _CombinationMode(TypedDict):
_supported_modes = {
"min_size": _CombinationMode(fn=min, iterator=_MinSize),
"max_size_cycle": _CombinationMode(fn=max, iterator=_MaxSizeCycle),
"sequential": _CombinationMode(fn=sum, iterator=_Sequential),
}

_LITERAL_SUPPORTED_MODES = Literal["min_size", "max_size_cycle"]
_LITERAL_SUPPORTED_MODES = Literal["min_size", "max_size_cycle", "sequential"]


class _CombinedDataset(Sized):
Expand Down Expand Up @@ -114,86 +151,96 @@ def __len__(self) -> int:


class CombinedLoader(Iterable):
"""Combines different dataloaders and allows sampling in parallel.
"""Combines different iterables under custom sampling modes.

Args:
loaders: the loaders to sample from. Can be all kind of collection
iterables: the loaders to sample from. Can be any kind of collection
mode:
* ``"min_size"``, which raises StopIteration after the shortest loader (the one with the lowest number of
batches) is done.
* ``"max_size_cycle"`` which raises StopIteration after the longest loader (the one with most batches) is
done, while cycling through the shorter loaders.
* ``"min_size"``, which raises StopIteration after the shortest iterable (the one with the lowest number of
items) is done.
* ``"max_size_cycle"`` which raises StopIteration after the longest iterable (the one with most items) is
done, while cycling through rest of the iterables.
* ``"sequential"`` will consume ecah iterable sequentially, and returns a tuple with the associated index
from each iterable.

Examples:
>>> loaders = {'a': DataLoader(range(6), batch_size=4),
... 'b': DataLoader(range(15), batch_size=5)}
>>> combined_loader = CombinedLoader(loaders, 'max_size_cycle')
>>> from torch.utils.data import DataLoader
>>> iterables = {'a': DataLoader(range(6), batch_size=4),
... 'b': DataLoader(range(15), batch_size=5)}
>>> combined_loader = CombinedLoader(iterables, 'max_size_cycle')
>>> len(combined_loader)
3
>>> for item in combined_loader:
... print(item)
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
{'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])}
>>> combined_loader = CombinedLoader(loaders, 'min_size')
>>> combined_loader = CombinedLoader(iterables, 'min_size')
>>> len(combined_loader)
2
>>> for item in combined_loader:
... print(item)
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
>>> combined_loader = CombinedLoader(iterables, 'sequential')
>>> len(combined_loader)
5
>>> for item in combined_loader:
... print(*item)
0 tensor([0, 1, 2, 3])
1 tensor([4, 5])
0 tensor([0, 1, 2, 3, 4])
1 tensor([5, 6, 7, 8, 9])
2 tensor([10, 11, 12, 13, 14])
"""

def __init__(self, loaders: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") -> None:
def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") -> None:
if mode not in _supported_modes:
raise ValueError(f"Unsupported mode {mode!r}, please select one of: {list(_supported_modes)}.")
# TODO(carmocca): rename loaders to iterables
self._loaders = loaders
self._loaders_flattened, self._loaders_spec = _tree_flatten(loaders)
self._iterables = iterables
self._flattened, self._spec = _tree_flatten(iterables)

# TODO(carmocca): doing this might not be necessary
datasets = _map_and_unflatten(
lambda x: getattr(x, "dataset", None), self._loaders_flattened, self._loaders_spec
)
datasets = _map_and_unflatten(lambda x: getattr(x, "dataset", None), self._flattened, self._spec)
# could be multiple datasets, but use self.dataset to follow the name convention in DataLoader
self.dataset = _CombinedDataset(datasets, mode)

self._mode = mode
self._iterator: Optional[_ModeIterator] = None

@property
def loaders(self) -> Any:
"""Return the original collection of loaders."""
return self._loaders
def iterables(self) -> Any:
"""Return the original collection of iterables."""
return self._iterables

@property
def sampler(self) -> Any:
"""Return a collections of samplers extracted from loaders."""
return _map_and_unflatten(lambda x: getattr(x, "sampler", None), self._loaders_flattened, self._loaders_spec)
"""Return a collections of samplers extracted from iterables."""
return _map_and_unflatten(lambda x: getattr(x, "sampler", None), self._flattened, self._spec)

@property
def batch_sampler(self) -> Any:
"""Return a collections of batch samplers extracted from loaders."""
return _map_and_unflatten(
lambda x: getattr(x, "batch_sampler", None), self._loaders_flattened, self._loaders_spec
)
"""Return a collections of batch samplers extracted from iterables."""
return _map_and_unflatten(lambda x: getattr(x, "batch_sampler", None), self._flattened, self._spec)

def __next__(self) -> Any:
assert self._iterator is not None
out = next(self._iterator)
return tree_unflatten(out, self._loaders_spec)
if isinstance(self._iterator, _Sequential):
return out
return tree_unflatten(out, self._spec)

def __iter__(self) -> Self: # type: ignore[valid-type]
cls = _supported_modes[self._mode]["iterator"]
iterator = cls(self._loaders_flattened)
iterator = cls(self._flattened)
iter(iterator)
self._iterator = iterator
return self

def __len__(self) -> int:
"""Compute the number of batches."""
lengths = []
for dl in self._loaders_flattened:
for dl in self._flattened:
length = sized_len(dl)
if length is None:
raise NotImplementedError(f"`{type(dl).__name__}` does not define `__len__`")
Expand All @@ -205,16 +252,16 @@ def reset(self) -> None:
if self._iterator is not None:
self._iterator.reset()
self._iterator = None
for loader in self._loaders_flattened:
_shutdown_workers_and_reset_iterator(loader)
for iterable in self._flattened:
_shutdown_workers_and_reset_iterator(iterable)

def _update_index(self, dataloader: Iterable, index: int) -> None:
# mutation needs to be done using this method to avoid stale references
self._loaders_flattened[index] = dataloader
self._loaders = tree_unflatten(self._loaders_flattened, self._loaders_spec)
self._flattened[index] = dataloader
self._iterables = tree_unflatten(self._flattened, self._spec)


def _shutdown_workers_and_reset_iterator(dataloader: DataLoader) -> None:
def _shutdown_workers_and_reset_iterator(dataloader: object) -> None:
if hasattr(dataloader, "_iterator"):
if isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter):
dataloader._iterator._shutdown_workers()
Expand Down
11 changes: 4 additions & 7 deletions src/lightning/pytorch/trainer/trainer.py
Expand Up @@ -296,11 +296,8 @@ def __init__(
enable_model_summary: Whether to enable model summarization by default.
Default: ``True``.

multiple_trainloader_mode: How to loop over the datasets when there are multiple train loaders.
In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed,
and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
reload when reaching the minimum length of datasets.
Default: ``"max_size_cycle"``.
multiple_trainloader_mode: How to loop over the datasets when there are multiple iterables.
See :class:`lightning.pytorch.trainer.supporters.CombinedLoader`.

inference_mode: Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during
evaluation (``validate``/``test``/``predict``).
Expand Down Expand Up @@ -1252,7 +1249,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
mode=RunningStage.TRAINING,
)
loaders = (
self.train_dataloader.loaders
self.train_dataloader.iterables
if isinstance(self.train_dataloader, CombinedLoader)
else self.train_dataloader
)
Expand All @@ -1263,7 +1260,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
# add worker_init_fn for correct seeding in worker processes
apply_to_collection(loaders, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank)

# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
# wrap the sequence of train iterables to a CombinedLoader object for computing the num_training_batches
if not isinstance(self.train_dataloader, CombinedLoader):
self.train_dataloader = CombinedLoader(loaders, self._data_connector.multiple_trainloader_mode)

Expand Down
6 changes: 3 additions & 3 deletions tests/tests_pytorch/accelerators/test_ipu.py
Expand Up @@ -378,7 +378,7 @@ def train_dataloader(self):

assert isinstance(trainer.strategy, IPUStrategy)
assert trainer.strategy.training_opts is other_options
dataloader = trainer.train_dataloader.loaders
dataloader = trainer.train_dataloader.iterables
assert dataloader is model.poptorch_dataloader # exact object, was not recreated
# dataloader uses the options in the model, not the strategy
assert dataloader.options is model_options
Expand Down Expand Up @@ -406,7 +406,7 @@ def test_manual_poptorch_opts(tmpdir):
assert trainer.strategy.training_opts == training_opts
assert trainer.strategy.inference_opts == inference_opts

dataloader = trainer.train_dataloader.loaders
dataloader = trainer.train_dataloader.iterables
assert isinstance(dataloader, poptorch.DataLoader)
assert dataloader.options == training_opts
assert trainer.num_devices > 1 # testing this only makes sense in a distributed setting
Expand Down Expand Up @@ -440,7 +440,7 @@ def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
val_dataloader = trainer.val_dataloaders[0]
train_dataloader = trainer.train_dataloader
assert isinstance(train_dataloader, CombinedLoader)
train_dataloader = train_dataloader.loaders
train_dataloader = train_dataloader.iterables
assert isinstance(val_dataloader, poptorch.DataLoader)
assert isinstance(train_dataloader, poptorch.DataLoader)
assert train_dataloader.options.replication_factor == 2
Expand Down
Expand Up @@ -116,7 +116,7 @@ def on_fit_start(self):

def on_train_end(self):
def _get_warning_msg():
dl = self.trainer.train_dataloader.loaders
dl = self.trainer.train_dataloader.iterables
if hasattr(dl, "persistent_workers"):
if self.num_workers == 0:
warn_str = "Consider setting num_workers>0 and persistent_workers=True"
Expand Down Expand Up @@ -295,7 +295,7 @@ def __iter__(self):

class LoaderTestModel(BoringModel):
def training_step(self, batch, batch_idx):
assert len(self.trainer.train_dataloader.loaders) == 10
assert len(self.trainer.train_dataloader.iterables) == 10
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_pytorch/trainer/flags/test_overfit_batches.py
Expand Up @@ -74,7 +74,7 @@ def val_dataloader(self):
with pytest.warns(UserWarning, match="requested to overfit but enabled train dataloader shuffling"):
trainer.fit(model)

assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)
assert isinstance(trainer.train_dataloader.iterables.sampler, SequentialSampler)
assert isinstance(trainer.val_dataloaders[0].sampler, SequentialSampler)


Expand Down Expand Up @@ -161,6 +161,6 @@ def test_distributed_sampler_with_overfit_batches():
trainer.strategy._lightning_module = model
trainer._data_connector.attach_dataloaders(model)
trainer.reset_train_dataloader()
train_sampler = trainer.train_dataloader.loaders.sampler
train_sampler = trainer.train_dataloader.iterables.sampler
assert isinstance(train_sampler, DistributedSampler)
assert train_sampler.shuffle is False
10 changes: 5 additions & 5 deletions tests/tests_pytorch/trainer/test_dataloaders.py
Expand Up @@ -160,7 +160,7 @@ def test_train_dataloader_passed_to_fit(tmpdir):
fit_options = dict(train_dataloaders=train_loader)
trainer.fit(model, **fit_options)
assert trainer.num_training_batches == 2
assert trainer.train_dataloader.loaders == train_loader
assert trainer.train_dataloader.iterables == train_loader

assert trainer.state.finished, f"Training failed with {trainer.state}"

Expand Down Expand Up @@ -836,7 +836,7 @@ def test_dataloader_distributed_sampler_already_attached(tmpdir):
[("min_size", 16), ("max_size_cycle", 64)],
)
def test_fit_multiple_train_loaders(tmpdir, multiple_trainloader_mode, num_training_batches):
"""Integration test for multiple train loaders."""
"""Integration test for multiple train iterables."""

class CustomBoringModel(BoringModel):
def train_dataloader(self):
Expand Down Expand Up @@ -1178,12 +1178,12 @@ def test_dataloaders_reset_and_attach(tmpdir):

# 1st fit
trainer.fit(model, train_dataloaders=dataloader_0, val_dataloaders=dataloader_1)
assert trainer.train_dataloader.loaders.dataset is dataloader_0.dataset
assert trainer.train_dataloader.iterables.dataset is dataloader_0.dataset
assert trainer.val_dataloaders[0].dataset is dataloader_1.dataset
# 2nd fit
trainer.fit_loop.max_steps += 1
trainer.fit(model, train_dataloaders=dataloader_2, val_dataloaders=dataloader_3)
assert trainer.train_dataloader.loaders.dataset is dataloader_2.dataset
assert trainer.train_dataloader.iterables.dataset is dataloader_2.dataset
assert trainer.val_dataloaders[0].dataset is dataloader_3.dataset

# 1st validate
Expand Down Expand Up @@ -1316,7 +1316,7 @@ def train_dataloader(self):
return DataLoaderWrapper(loader)

def on_train_batch_start(self, batch, batch_idx: int) -> None:
assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper)
assert isinstance(self.trainer.train_dataloader.iterables, DataLoaderWrapper)
self.on_train_batch_start_called = True

def val_dataloader(self):
Expand Down