Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,10 +641,15 @@ def __init__(
def __iter__(self) -> Any:
if not self.restore:
if (
not isinstance(self.dataset, ParallelStreamingDataset)
or not self.dataset.is_cycling()
or self.current_epoch == 0
isinstance(self.dataset, ParallelStreamingDataset)
and self.dataset.is_cycling()
and self.dataset.resume
and self.current_epoch != 0
):
# For ParallelStreamingDataset with _length != None we want to cycle the wrapped datasets i.e. we do not
# want to restart at index 0 at every epoch. So we set them in restore state.
self.load_state_dict(self.state_dict())
else:
self._latest_worker_idx = 0
self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1)))
self._worker_idx_iter = iter(self._worker_idx)
Expand Down Expand Up @@ -686,11 +691,6 @@ def __iter__(self) -> Any:
else:
yield batch

# For ParallelStreamingDataset with _length != None we want to cycle the wrapped datasets i.e. we do not want to
# restart at index 0 at every epoch. So we set them in restore state.
if isinstance(self.dataset, ParallelStreamingDataset) and self.dataset.is_cycling():
self.load_state_dict(self.state_dict())

logger.debug(_get_log_msg({"name": "iterating_dataloader", "ph": "E"}))
self.restore = False

Expand Down
19 changes: 14 additions & 5 deletions src/litdata/streaming/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ class ParallelStreamingDataset(_BaseStreamingDatasetWrapper):
datasets.

The parallel dataset can be configured to raise a ``StopIteration`` as soon as any of the datasets is exhausted, or
to cycle through the datasets until a given number of samples are yielded. When cycling, each epoch resumes from
where the previous one left off in the current cycle, i.e. the yielded samples are not the same across epochs.
to cycle through the datasets until a given number of samples are yielded. When cycling and using a
``StreamingDataLoader``, the ``resume`` option can be used to either yield the same ``length`` samples in each
epoch, or to resume the dataset from where it left off in the previous epoch.

New data can be generated on-the-fly from a sample from each dataset by providing a ``transform`` function. This
function can take a single tuple argument containing a sample from each dataset, and optionally a dictionary of
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(
force_override_state_dict: bool = False,
transform: Optional[Transform] = None,
seed: int = 42,
resume: bool = True,
reset_rngs: bool = False,
) -> None:
"""Enable to stream data from multiple StreamingDataset in parallel.
Expand All @@ -100,9 +102,12 @@ def __init__(
argument a tuple containing one sample from each dataset, and optionally a dictionary of random
number generators which are seeded using the current state of the dataset.
seed: Seed for the random number generators provided to ``transform``.
resume: If ``True`` and ``length`` is not ``None``, tells the dataloader to resume the dataset from where it
left off in the previous epoch. If ``False``, the same ``length`` samples are yielded in each epoch.
Ignored if ``length`` is ``None``.
reset_rngs: If ``True``, the random number generators provided to ``transform`` are reset to their initial
state at the beginning of each epoch. Together with ``length=None`` and ``shuffle=False``, this ensures
that the same samples are yielded in each epoch.
state at the beginning of each epoch. Together with ``resume=False``, this allows to produce the same
samples in each epoch.
"""
self._check_datasets(datasets)

Expand Down Expand Up @@ -132,6 +137,7 @@ def __init__(
self._current_epoch = 0
self.num_workers = 1
self.batch_size = 1
self.resume = resume

if length is not None:
for dataset in self._datasets:
Expand All @@ -154,7 +160,7 @@ def is_infinite(self) -> bool:

def set_epoch(self, current_epoch: int) -> None:
self._current_epoch = current_epoch
if self.is_cycling():
if self.is_cycling() and self.resume:
# do not set the epoch as cycling datasets have their own epoch counter
return
for dataset in self._datasets:
Expand Down Expand Up @@ -188,6 +194,9 @@ def get_all_lens(self) -> List[int]:
return [self._get_len(d) for d in self._datasets]

def __iter__(self) -> Iterator[Any]:
if self.is_cycling() and not self.resume:
self.set_epoch(1)

worker_env = _WorkerEnv.detect()

num_samples_yielded = None
Expand Down
17 changes: 1 addition & 16 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
import torch.distributed

from litdata import CombinedStreamingDataset, ParallelStreamingDataset, StreamingDataset
from litdata import CombinedStreamingDataset, StreamingDataset
from litdata.constants import _POLARS_AVAILABLE
from litdata.streaming.cache import Cache
from litdata.streaming.reader import PrepareChunksThread
Expand Down Expand Up @@ -73,21 +73,6 @@ def combined_dataset(prepare_combined_dataset):
return CombinedStreamingDataset(datasets=[dataset_1, dataset_2])


@pytest.fixture
def parallel_dataset(tmp_path_factory, request):
tmpdir = tmp_path_factory.mktemp("data")
datasets = [str(tmpdir / f"dataset_{i}") for i in range(2)]
for dataset, num_items in zip(datasets, [48, 56]):
cache = Cache(input_dir=dataset, chunk_size=10)
for i in range(num_items):
cache[i] = i
cache.done()
cache.merge()
dataset_1 = StreamingDataset(datasets[0], shuffle=True)
dataset_2 = StreamingDataset(datasets[1], shuffle=True)
return ParallelStreamingDataset(datasets=[dataset_1, dataset_2], length=request.param), request.param


@pytest.fixture
def google_mock(monkeypatch):
google = ModuleType("google")
Expand Down
Loading
Loading