diff --git a/CHANGELOG.md b/CHANGELOG.md index e4bc87b64390b..5219c569cfcb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -104,6 +104,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed failure when `DataLoader(batch_size=None)` is passed ([#10345](https://github.com/PyTorchLightning/pytorch-lightning/issues/10345)) +- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374)) + + - Fixed issue with pickling `CSVLogger` after a call to `CSVLogger.save` ([#10388](https://github.com/PyTorchLightning/pytorch-lightning/pull/10388)) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index c41d80b903d4e..a954e0f1d1b68 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -28,7 +28,7 @@ from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.trainer.supporters import CombinedLoader +from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import ( @@ -136,14 +136,22 @@ def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[Runn if isinstance(dataloader, CombinedLoader): # apply `prepare_dataloader` on all the collection of loaders dataloader.loaders = apply_to_collection( - dataloader.loaders, DataLoader, self.prepare_dataloader, shuffle, mode=mode + dataloader.loaders, (DataLoader, CycleIterator), self.prepare_dataloader, shuffle, mode=mode ) + # the length need to recomputed across all dataloaders in case of special behavior. + dataloader._apply_cycle_iterator_length() return dataloader # don't do anything if it's not a dataloader - if not isinstance(dataloader, DataLoader): + if not isinstance(dataloader, (DataLoader, CycleIterator)): return dataloader + cycle_iterator: Optional[CycleIterator] = None + + if isinstance(dataloader, CycleIterator): + cycle_iterator = dataloader + dataloader = dataloader.loader + if ( _fault_tolerant_training() # injects components to track the state or self._requires_distributed_sampler(dataloader) # sets the distributed sampler @@ -153,6 +161,10 @@ def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[Runn sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode) dataloader = self._update_dataloader(dataloader, sampler, mode=mode) + if cycle_iterator is not None: + cycle_iterator.loader = dataloader + return cycle_iterator + return dataloader def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None) -> Sampler: diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 816f4da38f5b9..6e2e51e82bbf1 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -457,6 +457,19 @@ def _wrap_loaders_max_size_cycle(self) -> Any: ) state.reset() + def _apply_cycle_iterator_length(self) -> None: + """When the model is `max_size_cycle`, compute the length across all ``CycleIterator`` and re-assign it to + all dataloaders.""" + if self.mode != "max_size_cycle": + return + + def set_len(cycle_iterator: CycleIterator, length: int) -> None: + cycle_iterator.length = length + + all_lengths = apply_to_collection(self.loaders, CycleIterator, lambda c: get_len(c.loader)) + max_length = _nested_calc_num_data(all_lengths, max) + apply_to_collection(self.loaders, CycleIterator, set_len, length=max_length) + def __iter__(self) -> Any: """Create and return an iterator, `CombinedLoaderIterator`, for the combined loader.""" @@ -473,11 +486,12 @@ def __getstate__patch__(*_): return iterator @staticmethod - def _calc_num_batches(loaders: Any) -> Union[int, float]: + def _calc_num_batches(loaders: Any, mode="min_size") -> Union[int, float]: """Compute the length (aka the number of batches) of `CombinedLoader`. Args: loaders: a collections of loaders. + mode: Mode used by the CombinedDataloader Returns: length: the minimum length of loaders @@ -486,10 +500,10 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]: if isinstance(all_lengths, (int, float)): return all_lengths - return _nested_calc_num_data(all_lengths, min) + return _nested_calc_num_data(all_lengths, max if mode == "max_size_cycle" else min) def __len__(self) -> int: - return self._calc_num_batches(self.loaders) + return self._calc_num_batches(self.loaders, mode=self.mode) @staticmethod def _shutdown_workers_and_reset_iterator(dataloader) -> None: diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 204f3079f544b..e4598550c24fb 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -33,8 +33,10 @@ ) from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import CaptureMapDataset, FastForwardSampler +from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 +from tests.helpers.boring_model import RandomDataset def test_tensor_running_accum_reset(): @@ -379,3 +381,56 @@ def _assert_dataset(loader): assert isinstance(d, CustomDataset) apply_to_collection(dataloader.loaders, DataLoader, _assert_dataset) + + +@pytest.mark.parametrize("replace_sampler_ddp", [False, True]) +def test_combined_data_loader_with_max_size_cycle_and_ddp(replace_sampler_ddp, tmpdir): + """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader + with ddp and `max_size_cycle` mode.""" + trainer = Trainer(strategy="ddp", accelerator="auto", devices=2, replace_sampler_ddp=replace_sampler_ddp) + + dataloader = CombinedLoader( + {"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)}, + ) + dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) + assert len(dataloader) == 4 if replace_sampler_ddp else 8 + + for a_length in [6, 8, 10]: + dataloader = CombinedLoader( + { + "a": DataLoader(range(a_length), batch_size=1), + "b": DataLoader(range(8), batch_size=1), + }, + mode="max_size_cycle", + ) + + length = max(a_length, 8) + assert len(dataloader) == length + dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) + assert len(dataloader) == length // 2 if replace_sampler_ddp else length + if replace_sampler_ddp: + last_batch = list(dataloader)[-1] + if a_length == 6: + assert last_batch == {"a": torch.tensor([0]), "b": torch.tensor([6])} + elif a_length == 8: + assert last_batch == {"a": torch.tensor([6]), "b": torch.tensor([6])} + elif a_length == 10: + assert last_batch == {"a": torch.tensor([8]), "b": torch.tensor([0])} + + class InfiniteDataset(IterableDataset): + def __iter__(self): + while True: + yield 1 + + dataloader = CombinedLoader( + { + "a": DataLoader(InfiniteDataset(), batch_size=1), + "b": DataLoader(range(8), batch_size=1), + }, + mode="max_size_cycle", + ) + assert get_len(dataloader) == float("inf") + assert len(dataloader.loaders["b"].loader) == 8 + dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) + assert len(dataloader.loaders["b"].loader) == 4 if replace_sampler_ddp else 8 + assert get_len(dataloader) == float("inf")