Skip to content
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue to ensure all the checkpoint states are saved in a common filepath with `DeepspeedStrategy` ([#12887](https://github.com/PyTorchLightning/pytorch-lightning/pull/12887))


- Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/PyTorchLightning/pytorch-lightning/pull/12653))


- Fixed an issue wrt recursive invocation of DDP configuration in hpu parallel plugin ([#12912](https://github.com/PyTorchLightning/pytorch-lightning/pull/12912))


Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def _request_dataloader(
self.trainer._call_lightning_module_hook("on_" + hook, pl_module=model)
with _replace_dataloader_init_method():
# under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as
# attributes on the instance in case the dataloader needs to be re-instantiated later by Ligtning
# attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning
dataloader = source.dataloader()
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
Expand Down Expand Up @@ -467,6 +467,7 @@ def replace_sampler(dataloader: DataLoader) -> DataLoader:

@staticmethod
def _check_eval_shuffling(dataloader, mode):
# limit this warning only for samplers assigned automatically when shuffle is set
if _is_dataloader_shuffled(dataloader):
rank_zero_warn(
f"Your `{mode.dataloader_prefix}_dataloader`'s sampler has shuffling enabled,"
Expand Down
23 changes: 16 additions & 7 deletions pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union

import torch
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler, SequentialSampler
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler, Sampler, SequentialSampler

import pytorch_lightning as pl
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
Expand Down Expand Up @@ -389,9 +389,18 @@ def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) ->
return dl_kwargs


def _is_dataloader_shuffled(dataloader: DataLoader):
return (
hasattr(dataloader, "sampler")
and not isinstance(dataloader.sampler, SequentialSampler)
and not isinstance(dataloader.dataset, IterableDataset)
)
def _is_dataloader_shuffled(dataloader: object) -> bool:
if hasattr(dataloader, "shuffle"):
# this attribute is not part of PyTorch's DataLoader, but could have been set by
# our `_replace_dataloader_init_method` context manager
return dataloader.shuffle
if isinstance(dataloader.dataset, IterableDataset):
# shuffling is useless with iterable datasets
return False
if not hasattr(dataloader, "sampler"):
# shuffling is enabled via a sampler. No sampler, no shuffling
return False
sampler = dataloader.sampler
if isinstance(sampler, SequentialSampler):
return False
return isinstance(sampler, RandomSampler)
3 changes: 2 additions & 1 deletion tests/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import functools
import os
import re
import traceback
from contextlib import contextmanager
from typing import Optional, Type
Expand Down Expand Up @@ -126,7 +127,7 @@ def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Option
return
else:
for w in record.list:
if w.category is expected_warning and match in w.message.args[0]:
if w.category is expected_warning and re.compile(match).search(w.message.args[0]):
break
else:
return
Expand Down
Loading