diff --git a/CHANGELOG.md b/CHANGELOG.md index 8be044ba41fbd..7cdd289e85e63 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -238,6 +238,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed support for `CombinedLoader` while checking for warning raised with eval dataloaders ([#10994](https://github.com/PyTorchLightning/pytorch-lightning/pull/10994)) + + - diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index b4cb610fa138b..e036e7e90254d 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -17,6 +17,7 @@ from typing import Any, Callable, Collection, List, Optional, Tuple, Union from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler +from torch.utils.data.dataset import IterableDataset from torch.utils.data.distributed import DistributedSampler import pytorch_lightning as pl @@ -297,19 +298,17 @@ def _reset_eval_dataloader( if not isinstance(dataloaders, list): dataloaders = [dataloaders] - for loader_i in range(len(dataloaders)): - loader = dataloaders[loader_i] - - if hasattr(loader, "sampler") and not isinstance(loader.sampler, SequentialSampler): - rank_zero_warn( - f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`," - " it is strongly recommended that you turn this off for val/test/predict dataloaders.", - category=PossibleUserWarning, - ) - if any(dl is None for dl in dataloaders): rank_zero_warn("One of given dataloaders is None and it will be skipped.") + for loader in dataloaders: + apply_to_collection( + loader.loaders if isinstance(loader, CombinedLoader) else loader, + DataLoader, + self._check_eval_shuffling, + mode=mode, + ) + # add samplers dataloaders = [self.prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None] @@ -459,3 +458,16 @@ def replace_sampler(dataloader: DataLoader) -> DataLoader: dataloader = apply_to_collection(dataloader, DataLoader, replace_sampler) return dataloader + + @staticmethod + def _check_eval_shuffling(dataloader, mode): + if ( + hasattr(dataloader, "sampler") + and not isinstance(dataloader.sampler, SequentialSampler) + and not isinstance(dataloader.dataset, IterableDataset) + ): + rank_zero_warn( + f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`," + " it is strongly recommended that you turn this off for val/test/predict dataloaders.", + category=PossibleUserWarning, + ) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index d771551da7dc8..76befc049abe4 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -307,10 +307,10 @@ def __len__(self) -> int: class CombinedLoader: - """Combines different dataloaders and allows sampling in parallel. Supported modes are 'min_size', which raises - StopIteration after the shortest loader (the one with the lowest number of batches) is done, and - 'max_size_cycle` which raises StopIteration after the longest loader (the one with most batches) is done, while - cycling through the shorter loaders. + """Combines different dataloaders and allows sampling in parallel. Supported modes are ``"min_size"``, which + raises StopIteration after the shortest loader (the one with the lowest number of batches) is done, and + ``"max_size_cycle"`` which raises StopIteration after the longest loader (the one with most batches) is done, + while cycling through the shorter loaders. Examples: >>> loaders = {'a': torch.utils.data.DataLoader(range(6), batch_size=4), diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index e191f681d209d..74c074aef2110 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -19,9 +19,12 @@ from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler from pytorch_lightning import Trainer +from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.data import _update_dataloader from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.warnings import PossibleUserWarning from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -334,3 +337,27 @@ def test_pre_made_batches(): loader = DataLoader(RandomDataset(32, 10), batch_size=None) trainer = Trainer(fast_dev_run=1) trainer.predict(LoaderTestModel(), loader) + + +@pytest.mark.parametrize( + "val_dl", + [ + DataLoader(dataset=RandomDataset(32, 64), shuffle=True), + CombinedLoader(DataLoader(dataset=RandomDataset(32, 64), shuffle=True)), + CombinedLoader( + [DataLoader(dataset=RandomDataset(32, 64)), DataLoader(dataset=RandomDataset(32, 64), shuffle=True)] + ), + CombinedLoader( + { + "dl1": DataLoader(dataset=RandomDataset(32, 64)), + "dl2": DataLoader(dataset=RandomDataset(32, 64), shuffle=True), + } + ), + ], +) +def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl): + trainer = Trainer() + model = BoringModel() + trainer._data_connector.attach_data(model, val_dataloaders=val_dl) + with pytest.warns(PossibleUserWarning, match="recommended .* turn this off for val/test/predict"): + trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)