From 1a14d250ecba8f2e3d57f2eb9f3e908d61b333a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 29 Mar 2023 01:43:02 +0200 Subject: [PATCH] Skip length checks for non-sized iterables --- src/lightning/pytorch/utilities/data.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index a21c6d074839e..c94eea7398861 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -17,7 +17,6 @@ import torch from lightning_utilities.core.apply_func import is_dataclass_instance -from lightning_utilities.core.rank_zero import rank_prefixed_message from torch import Tensor from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler, Sampler, SequentialSampler from typing_extensions import TypeGuard @@ -93,13 +92,12 @@ def has_len_all_ranks( strategy: "pl.strategies.Strategy", allow_zero_length_dataloader_with_multiple_devices: bool = False, ) -> TypeGuard[Sized]: - """Checks if a given object has ``__len__`` method implemented on all aranks.""" + """Checks if a given object has ``__len__`` method implemented on all ranks.""" local_length = sized_len(dataloader) - has_len = True if local_length is None: - # if one rank does not define a length, the reduction after would fail, default to 0 - local_length = 0 - has_len = False + # __len__ is not defined, skip these checks + return False + total_length = strategy.reduce(torch.tensor(local_length, device=strategy.root_device), reduce_op="sum") if total_length == 0: rank_zero_warn( @@ -108,10 +106,6 @@ def has_len_all_ranks( ) if total_length > 0 and local_length == 0: dataloader_cls_name = type(dataloader).__name__ - if not has_len: - raise RuntimeError( - rank_prefixed_message(f"The `{dataloader_cls_name}` does not define a length.", strategy.global_rank) - ) if not allow_zero_length_dataloader_with_multiple_devices: raise RuntimeError( f"`{dataloader_cls_name}` within local rank has zero length." @@ -121,16 +115,15 @@ def has_len_all_ranks( f"Total length of `{dataloader_cls_name}` across ranks is zero, but local rank has zero" " length. Please be cautious of uneven batch length." ) - has_len = False - if has_len and has_iterable_dataset(dataloader): + if has_iterable_dataset(dataloader): rank_zero_warn( "Your `IterableDataset` has `__len__` defined." " In combination with multi-process data loading (when num_workers > 1)," " `__len__` could be inaccurate if each worker is not configured independently" " to avoid having duplicate data." ) - return has_len + return True def _update_dataloader(