From 620fc7bad570d06858fb4702138c00b43187679d Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 5 Nov 2021 09:45:16 -0400 Subject: [PATCH 01/13] solve combinedloader --- pytorch_lightning/trainer/data_loading.py | 18 +++++++++++--- pytorch_lightning/trainer/supporters.py | 29 +++++++++++++++++++---- tests/trainer/test_supporters.py | 24 +++++++++++++++++++ 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index c41d80b903d4e..40d507413ffda 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -28,7 +28,7 @@ from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.trainer.supporters import CombinedLoader +from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import ( @@ -136,14 +136,22 @@ def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[Runn if isinstance(dataloader, CombinedLoader): # apply `prepare_dataloader` on all the collection of loaders dataloader.loaders = apply_to_collection( - dataloader.loaders, DataLoader, self.prepare_dataloader, shuffle, mode=mode + dataloader.loaders, (DataLoader, CycleIterator), self.prepare_dataloader, shuffle, mode=mode ) + # the length need to recomputed across all dataloaders in case of special behaviour. + dataloader._apply_cycle_iterator_length() return dataloader # don't do anything if it's not a dataloader - if not isinstance(dataloader, DataLoader): + if not isinstance(dataloader, (DataLoader, CycleIterator)): return dataloader + cycle_iterator: Optional[CycleIterator] = None + + if isinstance(dataloader, CycleIterator): + cycle_iterator = dataloader + dataloader = dataloader.loader + if ( _fault_tolerant_training() # injects components to track the state or self._requires_distributed_sampler(dataloader) # sets the distributed sampler @@ -153,6 +161,10 @@ def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[Runn sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode) dataloader = self._update_dataloader(dataloader, sampler, mode=mode) + if cycle_iterator: + cycle_iterator.loader = dataloader + return cycle_iterator + return dataloader def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None) -> Sampler: diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 816f4da38f5b9..daff3376c4118 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -28,7 +28,7 @@ patch_dataloader_iterator, reload_dataloader_state_dict, ) -from pytorch_lightning.utilities.data import get_len +from pytorch_lightning.utilities.data import get_len, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -217,6 +217,13 @@ def __next__(self) -> Any: def __len__(self) -> Union[int, float]: return self.length + @staticmethod + def get_len(cycle_iterator: "CycleIterator") -> Union[int, float]: + if has_len(cycle_iterator.loader): + return len(cycle_iterator.loader) + + return float("inf") + class CombinedDataset: """Combine multiple datasets and compute their statistics.""" @@ -457,6 +464,19 @@ def _wrap_loaders_max_size_cycle(self) -> Any: ) state.reset() + def _apply_cycle_iterator_length(self): + if self.mode == "max_size_cycle": + all_lengths = apply_to_collection( + self.loaders, CycleIterator, CycleIterator.get_len, wrong_dtype=(Sequence, Mapping) + ) + length = _nested_calc_num_data(all_lengths, max) + + def _apply_fn(cycle_iterator: CycleIterator) -> None: + nonlocal length + cycle_iterator.length = length + + apply_to_collection(self.loaders, CycleIterator, _apply_fn) + def __iter__(self) -> Any: """Create and return an iterator, `CombinedLoaderIterator`, for the combined loader.""" @@ -473,11 +493,12 @@ def __getstate__patch__(*_): return iterator @staticmethod - def _calc_num_batches(loaders: Any) -> Union[int, float]: + def _calc_num_batches(loaders: Any, mode="min_size") -> Union[int, float]: """Compute the length (aka the number of batches) of `CombinedLoader`. Args: loaders: a collections of loaders. + mode: Mode used by the CombinedDataloader Returns: length: the minimum length of loaders @@ -486,10 +507,10 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]: if isinstance(all_lengths, (int, float)): return all_lengths - return _nested_calc_num_data(all_lengths, min) + return _nested_calc_num_data(all_lengths, max if mode == "max_size_cycle" else min) def __len__(self) -> int: - return self._calc_num_batches(self.loaders) + return self._calc_num_batches(self.loaders, mode=self.mode) @staticmethod def _shutdown_workers_and_reset_iterator(dataloader) -> None: diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 204f3079f544b..7019fefd5f7cf 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -35,6 +35,7 @@ from pytorch_lightning.utilities.auto_restart import CaptureMapDataset, FastForwardSampler from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 +from tests.helpers.boring_model import RandomDataset def test_tensor_running_accum_reset(): @@ -379,3 +380,26 @@ def _assert_dataset(loader): assert isinstance(d, CustomDataset) apply_to_collection(dataloader.loaders, DataLoader, _assert_dataset) + + dataloader = CombinedLoader( + {"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)}, + ) + + trainer = Trainer(strategy="ddp", gpus=2, replace_sampler_ddp=replace_sampler_ddp) + dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) + assert len(dataloader) == 4 if replace_sampler_ddp else 8 + + for length in [6, 8, 10]: + dataloader = CombinedLoader( + { + "a": DataLoader(RandomDataset(32, length), batch_size=1), + "b": DataLoader(RandomDataset(32, 8), batch_size=1), + }, + mode="max_size_cycle", + ) + + length = max(length, 8) + assert len(dataloader) == length + trainer = Trainer(strategy="ddp", gpus=2, replace_sampler_ddp=replace_sampler_ddp) + dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) + assert len(dataloader) == length // 2 if replace_sampler_ddp else length From d39351eaf339f1effb63ff423c59651857bb5f39 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 5 Nov 2021 09:53:51 -0400 Subject: [PATCH 02/13] update --- tests/trainer/test_supporters.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 7019fefd5f7cf..711b9fb3f02fb 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -389,17 +389,25 @@ def _assert_dataset(loader): dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) assert len(dataloader) == 4 if replace_sampler_ddp else 8 - for length in [6, 8, 10]: + for a_length in [6, 8, 10]: dataloader = CombinedLoader( { - "a": DataLoader(RandomDataset(32, length), batch_size=1), - "b": DataLoader(RandomDataset(32, 8), batch_size=1), + "a": DataLoader(range(a_length), batch_size=1), + "b": DataLoader(range(8), batch_size=1), }, mode="max_size_cycle", ) - length = max(length, 8) + length = max(a_length, 8) assert len(dataloader) == length trainer = Trainer(strategy="ddp", gpus=2, replace_sampler_ddp=replace_sampler_ddp) dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) assert len(dataloader) == length // 2 if replace_sampler_ddp else length + if replace_sampler_ddp: + batches = [batch for batch in dataloader] + if a_length == 6: + assert batches[-1] == {"a": torch.tensor([0]), "b": torch.tensor([6])} + elif a_length == 8: + assert batches[-1] == {"a": torch.tensor([6]), "b": torch.tensor([6])} + elif a_length == 10: + assert batches[-1] == {"a": torch.tensor([8]), "b": torch.tensor([0])} From 515dbaa98faa22404efc88b8bb37b8638eff2748 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 5 Nov 2021 09:56:25 -0400 Subject: [PATCH 03/13] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 09cced9ed9bd7..f7fc4e343d844 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -86,6 +86,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed failure when `DataLoader(batch_size=None)` is passed ([#10345](https://github.com/PyTorchLightning/pytorch-lightning/issues/10345)) +- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374)) + + - From c36b282d9bdf6d83ac44f5595834685682e117b2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 9 Nov 2021 09:56:18 +0000 Subject: [PATCH 04/13] update on comments --- pytorch_lightning/trainer/data_loading.py | 2 +- pytorch_lightning/trainer/supporters.py | 34 +++++++++++------------ tests/trainer/test_supporters.py | 12 ++++++++ 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 40d507413ffda..435c907957a89 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -138,7 +138,7 @@ def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[Runn dataloader.loaders = apply_to_collection( dataloader.loaders, (DataLoader, CycleIterator), self.prepare_dataloader, shuffle, mode=mode ) - # the length need to recomputed across all dataloaders in case of special behaviour. + # the length need to recomputed across all dataloaders in case of special behavior. dataloader._apply_cycle_iterator_length() return dataloader diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index daff3376c4118..6017fd645eaf6 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -28,7 +28,7 @@ patch_dataloader_iterator, reload_dataloader_state_dict, ) -from pytorch_lightning.utilities.data import get_len, has_len +from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -217,13 +217,6 @@ def __next__(self) -> Any: def __len__(self) -> Union[int, float]: return self.length - @staticmethod - def get_len(cycle_iterator: "CycleIterator") -> Union[int, float]: - if has_len(cycle_iterator.loader): - return len(cycle_iterator.loader) - - return float("inf") - class CombinedDataset: """Combine multiple datasets and compute their statistics.""" @@ -464,18 +457,23 @@ def _wrap_loaders_max_size_cycle(self) -> Any: ) state.reset() - def _apply_cycle_iterator_length(self): - if self.mode == "max_size_cycle": - all_lengths = apply_to_collection( - self.loaders, CycleIterator, CycleIterator.get_len, wrong_dtype=(Sequence, Mapping) - ) - length = _nested_calc_num_data(all_lengths, max) + def _apply_cycle_iterator_length(self) -> None: + """When the model is `max_size_cycle`, compute the length across all ``CycleIterator`` and re-assign it to + all dataloaders.""" + if self.mode != "max_size_cycle": + return + + def get_len(cycle_iterator: CycleIterator) -> Union[float, int]: + return len(cycle_iterator.loader) + + all_lengths = apply_to_collection(self.loaders, CycleIterator, get_len, wrong_dtype=(Sequence, Mapping)) + length = _nested_calc_num_data(all_lengths, max) - def _apply_fn(cycle_iterator: CycleIterator) -> None: - nonlocal length - cycle_iterator.length = length + def _apply_fn(cycle_iterator: CycleIterator) -> None: + nonlocal length + cycle_iterator.length = length - apply_to_collection(self.loaders, CycleIterator, _apply_fn) + apply_to_collection(self.loaders, CycleIterator, _apply_fn) def __iter__(self) -> Any: """Create and return an iterator, `CombinedLoaderIterator`, for the combined loader.""" diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 711b9fb3f02fb..a7e48c2c84415 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -381,6 +381,18 @@ def _assert_dataset(loader): apply_to_collection(dataloader.loaders, DataLoader, _assert_dataset) + +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7") +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"}) +@mock.patch("torch.cuda.device_count", return_value=2) +@mock.patch("torch.cuda.is_available", return_value=True) +@pytest.mark.parametrize("replace_sampler_ddp", [False, True]) +def test_combined_data_loader_with_max_size_cycle_and_ddp( + cuda_available_mock, device_count_mock, replace_sampler_ddp, tmpdir +): + """This test makes sure distributed sampler has been properly injected in dataloaders when using + CombinedLoader.""" + dataloader = CombinedLoader( {"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)}, ) From 67ab70fd6d4336ffdb4f304add74db3259310619 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 9 Nov 2021 10:08:10 +0000 Subject: [PATCH 05/13] resolve iterable dataset support --- pytorch_lightning/trainer/data_loading.py | 2 +- pytorch_lightning/trainer/supporters.py | 8 +++++--- tests/trainer/test_supporters.py | 20 ++++++++++++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 435c907957a89..a954e0f1d1b68 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -161,7 +161,7 @@ def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[Runn sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode) dataloader = self._update_dataloader(dataloader, sampler, mode=mode) - if cycle_iterator: + if cycle_iterator is not None: cycle_iterator.loader = dataloader return cycle_iterator diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 6017fd645eaf6..5c90fbbbed1ef 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -463,10 +463,12 @@ def _apply_cycle_iterator_length(self) -> None: if self.mode != "max_size_cycle": return - def get_len(cycle_iterator: CycleIterator) -> Union[float, int]: - return len(cycle_iterator.loader) + def get_cycle_iterator_len(cycle_iterator: CycleIterator) -> Union[float, int]: + return get_len(cycle_iterator.loader) - all_lengths = apply_to_collection(self.loaders, CycleIterator, get_len, wrong_dtype=(Sequence, Mapping)) + all_lengths = apply_to_collection( + self.loaders, CycleIterator, get_cycle_iterator_len, wrong_dtype=(Sequence, Mapping) + ) length = _nested_calc_num_data(all_lengths, max) def _apply_fn(cycle_iterator: CycleIterator) -> None: diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index a7e48c2c84415..280688f2c85fd 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -33,6 +33,7 @@ ) from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import CaptureMapDataset, FastForwardSampler +from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 from tests.helpers.boring_model import RandomDataset @@ -423,3 +424,22 @@ def test_combined_data_loader_with_max_size_cycle_and_ddp( assert batches[-1] == {"a": torch.tensor([6]), "b": torch.tensor([6])} elif a_length == 10: assert batches[-1] == {"a": torch.tensor([8]), "b": torch.tensor([0])} + + class InfiniteDataset(IterableDataset): + def __iter__(self): + while True: + yield 1 + + dataloader = CombinedLoader( + { + "a": DataLoader(InfiniteDataset(), batch_size=1), + "b": DataLoader(range(8), batch_size=1), + }, + mode="max_size_cycle", + ) + assert get_len(dataloader) == float("inf") + assert len(dataloader.loaders["b"].loader) == 8 + trainer = Trainer(strategy="ddp", gpus=2, replace_sampler_ddp=replace_sampler_ddp) + dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) + assert len(dataloader.loaders["b"].loader) == 4 if replace_sampler_ddp else 8 + assert get_len(dataloader) == float("inf") From 1c8cb4cc95968bba6f50575264121f877931d2d1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 9 Nov 2021 10:09:54 +0000 Subject: [PATCH 06/13] update test description --- tests/trainer/test_supporters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 280688f2c85fd..967c33f1f55e9 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -391,8 +391,8 @@ def _assert_dataset(loader): def test_combined_data_loader_with_max_size_cycle_and_ddp( cuda_available_mock, device_count_mock, replace_sampler_ddp, tmpdir ): - """This test makes sure distributed sampler has been properly injected in dataloaders when using - CombinedLoader.""" + """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader + with ddp and `max_size_cycle` mode.""" dataloader = CombinedLoader( {"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)}, From fad8bd2957b2f866daf8e2456b6518e8ed9a46ae Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 9 Nov 2021 13:22:00 +0000 Subject: [PATCH 07/13] update --- pytorch_lightning/trainer/supporters.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 5c90fbbbed1ef..1e5373cf2e3ec 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -463,11 +463,8 @@ def _apply_cycle_iterator_length(self) -> None: if self.mode != "max_size_cycle": return - def get_cycle_iterator_len(cycle_iterator: CycleIterator) -> Union[float, int]: - return get_len(cycle_iterator.loader) - all_lengths = apply_to_collection( - self.loaders, CycleIterator, get_cycle_iterator_len, wrong_dtype=(Sequence, Mapping) + self.loaders, CycleIterator, lambda c: get_len(c.loader), wrong_dtype=(Sequence, Mapping) ) length = _nested_calc_num_data(all_lengths, max) From 1cad0fab3f45aaade9f46b2cd2158ccc9477b04d Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 9 Nov 2021 14:54:34 +0000 Subject: [PATCH 08/13] update on comments --- pytorch_lightning/trainer/supporters.py | 12 +++++------- tests/trainer/test_supporters.py | 1 - 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 1e5373cf2e3ec..2492cb372e72a 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -463,16 +463,14 @@ def _apply_cycle_iterator_length(self) -> None: if self.mode != "max_size_cycle": return - all_lengths = apply_to_collection( - self.loaders, CycleIterator, lambda c: get_len(c.loader), wrong_dtype=(Sequence, Mapping) - ) - length = _nested_calc_num_data(all_lengths, max) + all_lengths = apply_to_collection(self.loaders, CycleIterator, lambda c: get_len(c.loader)) - def _apply_fn(cycle_iterator: CycleIterator) -> None: - nonlocal length + def _apply_fn(cycle_iterator: CycleIterator, length) -> None: cycle_iterator.length = length - apply_to_collection(self.loaders, CycleIterator, _apply_fn) + apply_to_collection( + self.loaders, CycleIterator, partial(_apply_fn, length=_nested_calc_num_data(all_lengths, max)) + ) def __iter__(self) -> Any: """Create and return an iterator, `CombinedLoaderIterator`, for the combined loader.""" diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 967c33f1f55e9..c663f0b0f55bd 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -383,7 +383,6 @@ def _assert_dataset(loader): apply_to_collection(dataloader.loaders, DataLoader, _assert_dataset) -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7") @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"}) @mock.patch("torch.cuda.device_count", return_value=2) @mock.patch("torch.cuda.is_available", return_value=True) From 519c984b7228800f465a78791cfa16ec1d97c64f Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 9 Nov 2021 14:55:08 +0000 Subject: [PATCH 09/13] update --- pytorch_lightning/trainer/supporters.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 2492cb372e72a..6e4b7ecaea171 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -468,9 +468,8 @@ def _apply_cycle_iterator_length(self) -> None: def _apply_fn(cycle_iterator: CycleIterator, length) -> None: cycle_iterator.length = length - apply_to_collection( - self.loaders, CycleIterator, partial(_apply_fn, length=_nested_calc_num_data(all_lengths, max)) - ) + length = _nested_calc_num_data(all_lengths, max) + apply_to_collection(self.loaders, CycleIterator, partial(_apply_fn, length=length)) def __iter__(self) -> Any: """Create and return an iterator, `CombinedLoaderIterator`, for the combined loader.""" From 8101bf00c8f8b082cae81b5aab43d7df7fa81ae3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 9 Nov 2021 18:03:10 +0100 Subject: [PATCH 10/13] Accelerator auto --- tests/trainer/test_supporters.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index c663f0b0f55bd..e4598550c24fb 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -383,21 +383,15 @@ def _assert_dataset(loader): apply_to_collection(dataloader.loaders, DataLoader, _assert_dataset) -@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"}) -@mock.patch("torch.cuda.device_count", return_value=2) -@mock.patch("torch.cuda.is_available", return_value=True) @pytest.mark.parametrize("replace_sampler_ddp", [False, True]) -def test_combined_data_loader_with_max_size_cycle_and_ddp( - cuda_available_mock, device_count_mock, replace_sampler_ddp, tmpdir -): +def test_combined_data_loader_with_max_size_cycle_and_ddp(replace_sampler_ddp, tmpdir): """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader with ddp and `max_size_cycle` mode.""" + trainer = Trainer(strategy="ddp", accelerator="auto", devices=2, replace_sampler_ddp=replace_sampler_ddp) dataloader = CombinedLoader( {"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)}, ) - - trainer = Trainer(strategy="ddp", gpus=2, replace_sampler_ddp=replace_sampler_ddp) dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) assert len(dataloader) == 4 if replace_sampler_ddp else 8 @@ -412,17 +406,16 @@ def test_combined_data_loader_with_max_size_cycle_and_ddp( length = max(a_length, 8) assert len(dataloader) == length - trainer = Trainer(strategy="ddp", gpus=2, replace_sampler_ddp=replace_sampler_ddp) dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) assert len(dataloader) == length // 2 if replace_sampler_ddp else length if replace_sampler_ddp: - batches = [batch for batch in dataloader] + last_batch = list(dataloader)[-1] if a_length == 6: - assert batches[-1] == {"a": torch.tensor([0]), "b": torch.tensor([6])} + assert last_batch == {"a": torch.tensor([0]), "b": torch.tensor([6])} elif a_length == 8: - assert batches[-1] == {"a": torch.tensor([6]), "b": torch.tensor([6])} + assert last_batch == {"a": torch.tensor([6]), "b": torch.tensor([6])} elif a_length == 10: - assert batches[-1] == {"a": torch.tensor([8]), "b": torch.tensor([0])} + assert last_batch == {"a": torch.tensor([8]), "b": torch.tensor([0])} class InfiniteDataset(IterableDataset): def __iter__(self): @@ -438,7 +431,6 @@ def __iter__(self): ) assert get_len(dataloader) == float("inf") assert len(dataloader.loaders["b"].loader) == 8 - trainer = Trainer(strategy="ddp", gpus=2, replace_sampler_ddp=replace_sampler_ddp) dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) assert len(dataloader.loaders["b"].loader) == 4 if replace_sampler_ddp else 8 assert get_len(dataloader) == float("inf") From 506cafe8e6b7191c5065644a5c5342096aaefc00 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 9 Nov 2021 18:04:10 +0100 Subject: [PATCH 11/13] Address review --- pytorch_lightning/trainer/supporters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 6e4b7ecaea171..a77e89a698533 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -465,11 +465,11 @@ def _apply_cycle_iterator_length(self) -> None: all_lengths = apply_to_collection(self.loaders, CycleIterator, lambda c: get_len(c.loader)) - def _apply_fn(cycle_iterator: CycleIterator, length) -> None: + def set_len(cycle_iterator: CycleIterator, length: int) -> None: cycle_iterator.length = length length = _nested_calc_num_data(all_lengths, max) - apply_to_collection(self.loaders, CycleIterator, partial(_apply_fn, length=length)) + apply_to_collection(self.loaders, CycleIterator, set_len, length=length) def __iter__(self) -> Any: """Create and return an iterator, `CombinedLoaderIterator`, for the combined loader.""" From 95c19bf53fe6a6e84c449f3ea8feac8ede521605 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 9 Nov 2021 18:06:38 +0100 Subject: [PATCH 12/13] Refactor --- pytorch_lightning/trainer/supporters.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index a77e89a698533..a185b45392c6d 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -463,13 +463,15 @@ def _apply_cycle_iterator_length(self) -> None: if self.mode != "max_size_cycle": return - all_lengths = apply_to_collection(self.loaders, CycleIterator, lambda c: get_len(c.loader)) + def get_len(cycle_iterator: CycleIterator) -> int: + return get_len(cycle_iterator.loader) def set_len(cycle_iterator: CycleIterator, length: int) -> None: cycle_iterator.length = length - length = _nested_calc_num_data(all_lengths, max) - apply_to_collection(self.loaders, CycleIterator, set_len, length=length) + all_lengths = apply_to_collection(self.loaders, CycleIterator, lambda c: get_len(c.loader)) + max_length = _nested_calc_num_data(all_lengths, max) + apply_to_collection(self.loaders, CycleIterator, set_len, length=max_length) def __iter__(self) -> Any: """Create and return an iterator, `CombinedLoaderIterator`, for the combined loader.""" From bd7203a9a14382133faebc930240fb843945ee4c Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 9 Nov 2021 19:29:45 +0000 Subject: [PATCH 13/13] update --- pytorch_lightning/trainer/supporters.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index a185b45392c6d..6e2e51e82bbf1 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -463,9 +463,6 @@ def _apply_cycle_iterator_length(self) -> None: if self.mode != "max_size_cycle": return - def get_len(cycle_iterator: CycleIterator) -> int: - return get_len(cycle_iterator.loader) - def set_len(cycle_iterator: CycleIterator, length: int) -> None: cycle_iterator.length = length