Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions src/lightning/pytorch/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import torch
from lightning_utilities.core.apply_func import is_dataclass_instance
from lightning_utilities.core.rank_zero import rank_prefixed_message
from torch import Tensor
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler, Sampler, SequentialSampler
from typing_extensions import TypeGuard
Expand Down Expand Up @@ -93,13 +92,12 @@ def has_len_all_ranks(
strategy: "pl.strategies.Strategy",
allow_zero_length_dataloader_with_multiple_devices: bool = False,
) -> TypeGuard[Sized]:
"""Checks if a given object has ``__len__`` method implemented on all aranks."""
"""Checks if a given object has ``__len__`` method implemented on all ranks."""
local_length = sized_len(dataloader)
has_len = True
if local_length is None:
# if one rank does not define a length, the reduction after would fail, default to 0
local_length = 0
has_len = False
# __len__ is not defined, skip these checks
return False

total_length = strategy.reduce(torch.tensor(local_length, device=strategy.root_device), reduce_op="sum")
if total_length == 0:
rank_zero_warn(
Expand All @@ -108,10 +106,6 @@ def has_len_all_ranks(
)
if total_length > 0 and local_length == 0:
dataloader_cls_name = type(dataloader).__name__
if not has_len:
raise RuntimeError(
rank_prefixed_message(f"The `{dataloader_cls_name}` does not define a length.", strategy.global_rank)
)
if not allow_zero_length_dataloader_with_multiple_devices:
raise RuntimeError(
f"`{dataloader_cls_name}` within local rank has zero length."
Expand All @@ -121,16 +115,15 @@ def has_len_all_ranks(
f"Total length of `{dataloader_cls_name}` across ranks is zero, but local rank has zero"
" length. Please be cautious of uneven batch length."
)
has_len = False

if has_len and has_iterable_dataset(dataloader):
if has_iterable_dataset(dataloader):
rank_zero_warn(
"Your `IterableDataset` has `__len__` defined."
" In combination with multi-process data loading (when num_workers > 1),"
" `__len__` could be inaccurate if each worker is not configured independently"
" to avoid having duplicate data."
)
return has_len
return True


def _update_dataloader(
Expand Down