From efa307c10214730dad3ea97ce6bd09839c49063c Mon Sep 17 00:00:00 2001 From: philgzl Date: Fri, 4 Jul 2025 15:44:42 +0200 Subject: [PATCH 1/4] Add resume option to ParallelStreamingDataset --- src/litdata/streaming/dataloader.py | 16 +- src/litdata/streaming/parallel.py | 14 +- tests/conftest.py | 17 +- tests/streaming/test_parallel.py | 551 +++++++++++++--------------- 4 files changed, 269 insertions(+), 329 deletions(-) diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index d83808e73..fc6271f49 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -641,10 +641,15 @@ def __init__( def __iter__(self) -> Any: if not self.restore: if ( - not isinstance(self.dataset, ParallelStreamingDataset) - or not self.dataset.is_cycling() - or self.current_epoch == 0 + isinstance(self.dataset, ParallelStreamingDataset) + and self.dataset.is_cycling() + and self.dataset.resume + and self.current_epoch != 0 ): + # For ParallelStreamingDataset with _length != None we want to cycle the wrapped datasets i.e. we do not + # want to restart at index 0 at every epoch. So we set them in restore state. + self.load_state_dict(self.state_dict()) + else: self._latest_worker_idx = 0 self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1))) self._worker_idx_iter = iter(self._worker_idx) @@ -686,11 +691,6 @@ def __iter__(self) -> Any: else: yield batch - # For ParallelStreamingDataset with _length != None we want to cycle the wrapped datasets i.e. we do not want to - # restart at index 0 at every epoch. So we set them in restore state. - if isinstance(self.dataset, ParallelStreamingDataset) and self.dataset.is_cycling(): - self.load_state_dict(self.state_dict()) - logger.debug(_get_log_msg({"name": "iterating_dataloader", "ph": "E"})) self.restore = False diff --git a/src/litdata/streaming/parallel.py b/src/litdata/streaming/parallel.py index 9d3e6fb2f..9c36d2c12 100644 --- a/src/litdata/streaming/parallel.py +++ b/src/litdata/streaming/parallel.py @@ -49,8 +49,9 @@ class ParallelStreamingDataset(_BaseStreamingDatasetWrapper): datasets. The parallel dataset can be configured to raise a ``StopIteration`` as soon as any of the datasets is exhausted, or - to cycle through the datasets until a given number of samples are yielded. When cycling, each epoch resumes from - where the previous one left off in the current cycle, i.e. the yielded samples are not the same across epochs. + to cycle through the datasets until a given number of samples are yielded. When cycling and using a + ``StreamingDataLoader``, the ``resume`` option can be used to either yield the same ``length`` samples in each + epoch, or to resume the dataset from where it left off in the previous epoch. New data can be generated on-the-fly from a sample from each dataset by providing a ``transform`` function. This function can take a single tuple argument containing a sample from each dataset, and optionally a dictionary of @@ -86,6 +87,7 @@ def __init__( force_override_state_dict: bool = False, transform: Optional[Transform] = None, seed: int = 42, + resume: bool = True, reset_rngs: bool = False, ) -> None: """Enable to stream data from multiple StreamingDataset in parallel. @@ -100,9 +102,12 @@ def __init__( argument a tuple containing one sample from each dataset, and optionally a dictionary of random number generators which are seeded using the current state of the dataset. seed: Seed for the random number generators provided to ``transform``. + resume: If ``True`` and ``length`` is not ``None``, tells the dataloader to resume the dataset from where it + left off in the previous epoch. If ``False``, the same ``length`` samples are yielded in each epoch. + Ignored if ``length`` is ``None`` or if the dataset is used without a ``StreamingDataLoader``. reset_rngs: If ``True``, the random number generators provided to ``transform`` are reset to their initial - state at the beginning of each epoch. Together with ``length=None`` and ``shuffle=False``, this ensures - that the same samples are yielded in each epoch. + state at the beginning of each epoch. Together with ``resume=False``, this allows to produce the same + samples in each epoch. """ self._check_datasets(datasets) @@ -132,6 +137,7 @@ def __init__( self._current_epoch = 0 self.num_workers = 1 self.batch_size = 1 + self.resume = resume if length is not None: for dataset in self._datasets: diff --git a/tests/conftest.py b/tests/conftest.py index 7390a29b7..f306a78ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ import pytest import torch.distributed -from litdata import CombinedStreamingDataset, ParallelStreamingDataset, StreamingDataset +from litdata import CombinedStreamingDataset, StreamingDataset from litdata.constants import _POLARS_AVAILABLE from litdata.streaming.cache import Cache from litdata.streaming.reader import PrepareChunksThread @@ -73,21 +73,6 @@ def combined_dataset(prepare_combined_dataset): return CombinedStreamingDataset(datasets=[dataset_1, dataset_2]) -@pytest.fixture -def parallel_dataset(tmp_path_factory, request): - tmpdir = tmp_path_factory.mktemp("data") - datasets = [str(tmpdir / f"dataset_{i}") for i in range(2)] - for dataset, num_items in zip(datasets, [48, 56]): - cache = Cache(input_dir=dataset, chunk_size=10) - for i in range(num_items): - cache[i] = i - cache.done() - cache.merge() - dataset_1 = StreamingDataset(datasets[0], shuffle=True) - dataset_2 = StreamingDataset(datasets[1], shuffle=True) - return ParallelStreamingDataset(datasets=[dataset_1, dataset_2], length=request.param), request.param - - @pytest.fixture def google_mock(monkeypatch): google = ModuleType("google") diff --git a/tests/streaming/test_parallel.py b/tests/streaming/test_parallel.py index 5ad874a62..83e48f73f 100644 --- a/tests/streaming/test_parallel.py +++ b/tests/streaming/test_parallel.py @@ -1,5 +1,4 @@ import functools -import os import sys from copy import deepcopy from unittest.mock import ANY, MagicMock @@ -10,7 +9,7 @@ from litdata.streaming.cache import Cache from litdata.streaming.dataloader import StreamingDataLoader -from litdata.streaming.dataset import Dir, StreamingDataset +from litdata.streaming.dataset import StreamingDataset from litdata.streaming.parallel import ParallelStreamingDataset @@ -321,7 +320,7 @@ def test_parallel_dataset_with_dataloader_and_one_worker(batch_size, length, exp "1": {"num_samples_yielded": num_samples_yielded[1], "num_workers": 1, "batch_size": batch_size}, }, "current_epoch": 1, - "latest_worker_idx": 0 if length is None else 1, + "latest_worker_idx": 0, "num_samples_yielded": {0: num_samples_yielded}, "num_cycles": {0: num_cycles}, } @@ -422,10 +421,27 @@ def test_dataloader_shuffle(tmp_path, shuffle): assert shuffle ^ all(torch.equal(x, y) for x, y in zip(epoch_1_batches[:3], epoch_2_batches[-3:])) -@pytest.mark.parametrize("parallel_dataset", [None, 3, float("inf")], indirect=True) -def test_parallel_dataset_dataloader_states_without_any_iterations(parallel_dataset): - parallel_dataset, _ = parallel_dataset - dataloader = StreamingDataLoader(parallel_dataset, batch_size=4) +def prepare_parallel_dataset_and_dataloder( + tmp_path_factory, parlen, len1=48, len2=56, num_workers=0, batch_size=4, shuffle=True, resume=True +): + tmpdir = tmp_path_factory.mktemp("data") + datasets = [str(tmpdir / f"dataset_{i}") for i in range(2)] + for dataset, num_items in zip(datasets, [len1, len2]): + cache = Cache(input_dir=dataset, chunk_size=10) + for i in range(num_items): + cache[i] = i + cache.done() + cache.merge() + dset1 = StreamingDataset(datasets[0], shuffle=shuffle) + dset2 = StreamingDataset(datasets[1], shuffle=shuffle) + pardset = ParallelStreamingDataset(datasets=[dset1, dset2], length=parlen, resume=resume) + dloader = StreamingDataLoader(pardset, num_workers=num_workers, batch_size=batch_size) + return dset1, dset2, pardset, dloader + + +@pytest.mark.parametrize("length", [None, 3, float("inf")]) +def test_parallel_dataset_dataloader_states_without_any_iterations(tmp_path_factory, length): + _, _, _, dataloader = prepare_parallel_dataset_and_dataloder(tmp_path_factory, length) assert not dataloader.restore dataloader.load_state_dict(dataloader.state_dict()) assert not dataloader.restore @@ -434,16 +450,19 @@ def test_parallel_dataset_dataloader_states_without_any_iterations(parallel_data @pytest.mark.timeout(120) +@pytest.mark.parametrize("length", [None, 24]) @pytest.mark.parametrize("num_workers", [0, 2]) -@pytest.mark.parametrize("parallel_dataset", [None, 24], indirect=True) +@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") -def test_parallel_dataset_dataloader_states_complete_iterations(parallel_dataset, num_workers): +def test_parallel_dataset_dataloader_states_complete_iterations(tmp_path_factory, length, num_workers, batch_size): print(f"Testing with num_workers={num_workers}") - parallel_dataset, length = parallel_dataset - batch_size = 2 - - dataloader = StreamingDataLoader(parallel_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) + _, _, parallel_dataset, dataloader = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, + length, + batch_size=batch_size, + num_workers=num_workers, + ) assert len(dataloader) == -(-len(parallel_dataset) // batch_size) @@ -494,18 +513,19 @@ def test_parallel_dataset_dataloader_states_complete_iterations(parallel_dataset @pytest.mark.timeout(300) +@pytest.mark.parametrize("length", [None, 20, 48]) @pytest.mark.parametrize("num_workers", [0, 2]) +@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("break_at", [3, 7]) -@pytest.mark.parametrize("parallel_dataset", [None, 20, 48], indirect=True) @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") -def test_parallel_dataset_dataloader_states_partial_iterations(parallel_dataset, num_workers, break_at): +def test_parallel_dataset_dataloader_states_partial_iterations( + tmp_path_factory, length, num_workers, batch_size, break_at +): print(f"Testing with num_workers={num_workers}, break_at={break_at}") - parallel_dataset, _ = parallel_dataset - batch_size = 2 - - # Verify dataloader state after partial last iteration - dataloader = StreamingDataLoader(parallel_dataset, batch_size=batch_size, num_workers=num_workers) + _, _, parallel_dataset, dataloader = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, length, batch_size=batch_size, num_workers=num_workers, shuffle=True + ) total_batches = len(dataloader) assert total_batches == -(-len(parallel_dataset) // batch_size) @@ -542,38 +562,146 @@ def test_parallel_dataset_dataloader_states_partial_iterations(parallel_dataset, assert samples_yielded == len(parallel_dataset), "All samples should be yielded in the second epoch." +TEST_MAPS = [ + { + "length": None, + "len1": 12, + "len2": 14, + "batch_size": 2, + "num_workers": 3, + "epoch_1_expected_num_samples_yielded": [ + {0: [2, 2]}, + {0: [2, 2], 1: [2, 2]}, + {0: [2, 2], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [4, 4]}, + ], + "epoch_1_expected_num_cycles": [ + {0: [0, 0]}, + {0: [0, 0], 1: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + ], + "epoch_1_expected_current_epoch": [1, 1, 1, 1, 1, 1], + "epoch_1_expected_dataset_1_current_epoch": [1, 1, 1, 1, 1, 1], + "epoch_1_expected_dataset_2_current_epoch": [1, 1, 1, 1, 1, 1], + "epoch_1_expected_latest_worker_idx": [0, 1, 2, 0, 1, 2], + "epoch_1_expected_dataset0_samples_yielded": [2, 4, 6, 8, 10, 12], + "epoch_1_expected_dataset1_samples_yielded": [2, 4, 6, 8, 10, 12], + "epoch_2_expected_num_samples_yielded": [ + {0: [2, 2]}, + {0: [2, 2], 1: [2, 2]}, + {0: [2, 2], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [4, 4]}, + ], + "epoch_2_expected_num_cycles": [ + {0: [0, 0]}, + {0: [0, 0], 1: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + ], + "epoch_2_expected_current_epoch": [2, 2, 2, 2, 2, 2], + "epoch_2_expected_dataset_1_current_epoch": [2, 2, 2, 2, 2, 2], + "epoch_2_expected_dataset_2_current_epoch": [2, 2, 2, 2, 2, 2], + "epoch_2_expected_latest_worker_idx": [0, 1, 2, 0, 1, 2], + "epoch_2_expected_dataset0_samples_yielded": [2, 4, 6, 8, 10, 12], + "epoch_2_expected_dataset1_samples_yielded": [2, 4, 6, 8, 10, 12], + }, + { + "length": 12, + "len1": 18, + "len2": 20, + "batch_size": 2, + "num_workers": 3, + "epoch_1_expected_num_samples_yielded": [ + {0: [2, 2]}, + {0: [2, 2], 1: [2, 2]}, + {0: [2, 2], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [4, 4]}, + ], + "epoch_1_expected_num_cycles": [ + {0: [0, 0]}, + {0: [0, 0], 1: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + ], + "epoch_1_expected_current_epoch": [1, 1, 1, 1, 1, 1], + "epoch_1_expected_dataset_1_current_epoch": [1, 1, 1, 1, 1, 1], + "epoch_1_expected_dataset_2_current_epoch": [1, 1, 1, 1, 1, 1], + "epoch_1_expected_latest_worker_idx": [0, 1, 2, 0, 1, 2], + "epoch_1_expected_dataset0_samples_yielded": [2, 4, 6, 8, 10, 12], + "epoch_1_expected_dataset1_samples_yielded": [2, 4, 6, 8, 10, 12], + "epoch_2_expected_num_samples_yielded": [ + {0: [6, 6], 1: [4, 4], 2: [4, 4]}, + {0: [6, 6], 1: [6, 6], 2: [4, 4]}, + {0: [6, 6], 1: [6, 6], 2: [6, 6]}, + {0: [2, 8], 1: [6, 6], 2: [6, 6]}, + {0: [2, 8], 1: [2, 2], 2: [6, 6]}, + {0: [2, 8], 1: [2, 2], 2: [2, 2]}, + ], + "epoch_2_expected_num_cycles": [ + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [1, 0], 1: [0, 0], 2: [0, 0]}, + {0: [1, 0], 1: [1, 1], 2: [0, 0]}, + {0: [1, 0], 1: [1, 1], 2: [1, 1]}, + ], + "epoch_2_expected_current_epoch": [2, 2, 2, 2, 2, 2], + "epoch_2_expected_dataset_1_current_epoch": [1, 1, 1, 2, 2, 2], + "epoch_2_expected_dataset_2_current_epoch": [1, 1, 1, 1, 2, 2], + "epoch_2_expected_latest_worker_idx": [0, 1, 2, 0, 1, 2], + "epoch_2_expected_dataset0_samples_yielded": [14, 16, 18, 2, 4, 6], + "epoch_2_expected_dataset1_samples_yielded": [14, 16, 18, 20, 2, 4], + }, +] + + +@pytest.mark.parametrize(TEST_MAPS[0].keys(), [x.values() for x in TEST_MAPS]) @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") -def test_parallel_dataset_with_dataloader_2_epochs_none_length(tmp_path): - data_dir_1 = str(tmp_path / "data_1") - data_dir_2 = str(tmp_path / "data_2") - cache_dir_1 = str(tmp_path / "cache_dir_1") - cache_dir_2 = str(tmp_path / "cache_dir_2") - - os.makedirs(data_dir_1) - os.makedirs(data_dir_2) - os.makedirs(cache_dir_1) - os.makedirs(cache_dir_2) - - cache = Cache(input_dir=data_dir_1, chunk_size=2) - - for i in range(12): - cache[i] = i - - cache.done() - cache.merge() - - cache = Cache(input_dir=data_dir_2, chunk_size=2) - - for i in range(14): - cache[i] = -i - - cache.done() - cache.merge() - - dataset1 = StreamingDataset(input_dir=Dir(cache_dir_1, data_dir_1), shuffle=True) - dataset2 = StreamingDataset(input_dir=Dir(cache_dir_2, data_dir_2), shuffle=True) - dataset = ParallelStreamingDataset(datasets=[dataset1, dataset2], length=None) - dataloader = StreamingDataLoader(dataset, num_workers=3, batch_size=2) +def test_parallel_dataset_with_dataloader_2_epochs( + tmp_path_factory, + length, + len1, + len2, + batch_size, + num_workers, + epoch_1_expected_num_samples_yielded, + epoch_1_expected_num_cycles, + epoch_1_expected_current_epoch, + epoch_1_expected_dataset_1_current_epoch, + epoch_1_expected_dataset_2_current_epoch, + epoch_1_expected_latest_worker_idx, + epoch_1_expected_dataset0_samples_yielded, + epoch_1_expected_dataset1_samples_yielded, + epoch_2_expected_num_samples_yielded, + epoch_2_expected_num_cycles, + epoch_2_expected_current_epoch, + epoch_2_expected_dataset_1_current_epoch, + epoch_2_expected_dataset_2_current_epoch, + epoch_2_expected_latest_worker_idx, + epoch_2_expected_dataset0_samples_yielded, + epoch_2_expected_dataset1_samples_yielded, +): + dataset1, dataset2, _, dataloader = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, + length, + len1, + len2, + batch_size=batch_size, + num_workers=num_workers, + ) assert dataset1.current_epoch == 1 assert dataset2.current_epoch == 1 @@ -582,8 +710,8 @@ def test_parallel_dataset_with_dataloader_2_epochs_none_length(tmp_path): "dataset": { "0": { "num_samples_yielded": 0, - "num_workers": 3, - "batch_size": 2, + "num_workers": num_workers, + "batch_size": batch_size, "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, @@ -598,8 +726,8 @@ def test_parallel_dataset_with_dataloader_2_epochs_none_length(tmp_path): }, "1": { "num_samples_yielded": 0, - "num_workers": 3, - "batch_size": 2, + "num_workers": num_workers, + "batch_size": batch_size, "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, @@ -618,28 +746,6 @@ def test_parallel_dataset_with_dataloader_2_epochs_none_length(tmp_path): "num_samples_yielded": {}, "num_cycles": {}, } - expected_num_samples_yielded = [ - {0: [2, 2]}, - {0: [2, 2], 1: [2, 2]}, - {0: [2, 2], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [4, 4]}, - ] - expected_num_cycles = [ - {0: [0, 0]}, - {0: [0, 0], 1: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - ] - expected_current_epoch = [1, 1, 1, 1, 1, 1] - dataset_1_current_epoch = [1, 1, 1, 1, 1, 1] - dataset_2_current_epoch = [1, 1, 1, 1, 1, 1] - expected_latest_worker_idx = [0, 1, 2, 0, 1, 2] - expected_dataset0_samples_yielded = [2, 4, 6, 8, 10, 12] - expected_dataset1_samples_yielded = [2, 4, 6, 8, 10, 12] batches_1 = [] @@ -647,14 +753,14 @@ def test_parallel_dataset_with_dataloader_2_epochs_none_length(tmp_path): batches_1.append(batch) curr_state_dict = dataloader.state_dict() - expected_dataset_state["num_samples_yielded"] = expected_num_samples_yielded[idx] - expected_dataset_state["num_cycles"] = expected_num_cycles[idx] - expected_dataset_state["current_epoch"] = expected_current_epoch[idx] - expected_dataset_state["latest_worker_idx"] = expected_latest_worker_idx[idx] - expected_dataset_state["dataset"]["0"]["num_samples_yielded"] = expected_dataset0_samples_yielded[idx] - expected_dataset_state["dataset"]["1"]["num_samples_yielded"] = expected_dataset1_samples_yielded[idx] - expected_dataset_state["dataset"]["0"]["current_epoch"] = dataset_1_current_epoch[idx] - expected_dataset_state["dataset"]["1"]["current_epoch"] = dataset_2_current_epoch[idx] + expected_dataset_state["num_samples_yielded"] = epoch_1_expected_num_samples_yielded[idx] + expected_dataset_state["num_cycles"] = epoch_1_expected_num_cycles[idx] + expected_dataset_state["current_epoch"] = epoch_1_expected_current_epoch[idx] + expected_dataset_state["latest_worker_idx"] = epoch_1_expected_latest_worker_idx[idx] + expected_dataset_state["dataset"]["0"]["num_samples_yielded"] = epoch_1_expected_dataset0_samples_yielded[idx] + expected_dataset_state["dataset"]["1"]["num_samples_yielded"] = epoch_1_expected_dataset1_samples_yielded[idx] + expected_dataset_state["dataset"]["0"]["current_epoch"] = epoch_1_expected_dataset_1_current_epoch[idx] + expected_dataset_state["dataset"]["1"]["current_epoch"] = epoch_1_expected_dataset_2_current_epoch[idx] assert curr_state_dict == expected_dataset_state @@ -664,45 +770,24 @@ def test_parallel_dataset_with_dataloader_2_epochs_none_length(tmp_path): saved_dataloader_state_dict = None batches_2 = [] + save_at = 2 - expected_num_samples_yielded = [ - {0: [2, 2]}, - {0: [2, 2], 1: [2, 2]}, - {0: [2, 2], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [4, 4]}, - ] - expected_num_cycles = [ - {0: [0, 0]}, - {0: [0, 0], 1: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - ] - expected_current_epoch = [2, 2, 2, 2, 2, 2] - dataset_1_current_epoch = [2, 2, 2, 2, 2, 2] - dataset_2_current_epoch = [2, 2, 2, 2, 2, 2] - expected_latest_worker_idx = [0, 1, 2, 0, 1, 2] - expected_dataset0_samples_yielded = [2, 4, 6, 8, 10, 12] - expected_dataset1_samples_yielded = [2, 4, 6, 8, 10, 12] for idx, batch in enumerate(dataloader): batches_2.append(batch) curr_state_dict = dataloader.state_dict() - expected_dataset_state["num_samples_yielded"] = expected_num_samples_yielded[idx] - expected_dataset_state["num_cycles"] = expected_num_cycles[idx] - expected_dataset_state["current_epoch"] = expected_current_epoch[idx] - expected_dataset_state["latest_worker_idx"] = expected_latest_worker_idx[idx] - expected_dataset_state["dataset"]["0"]["num_samples_yielded"] = expected_dataset0_samples_yielded[idx] - expected_dataset_state["dataset"]["1"]["num_samples_yielded"] = expected_dataset1_samples_yielded[idx] - expected_dataset_state["dataset"]["0"]["current_epoch"] = dataset_1_current_epoch[idx] - expected_dataset_state["dataset"]["1"]["current_epoch"] = dataset_2_current_epoch[idx] + expected_dataset_state["num_samples_yielded"] = epoch_2_expected_num_samples_yielded[idx] + expected_dataset_state["num_cycles"] = epoch_2_expected_num_cycles[idx] + expected_dataset_state["current_epoch"] = epoch_2_expected_current_epoch[idx] + expected_dataset_state["latest_worker_idx"] = epoch_2_expected_latest_worker_idx[idx] + expected_dataset_state["dataset"]["0"]["num_samples_yielded"] = epoch_2_expected_dataset0_samples_yielded[idx] + expected_dataset_state["dataset"]["1"]["num_samples_yielded"] = epoch_2_expected_dataset1_samples_yielded[idx] + expected_dataset_state["dataset"]["0"]["current_epoch"] = epoch_2_expected_dataset_1_current_epoch[idx] + expected_dataset_state["dataset"]["1"]["current_epoch"] = epoch_2_expected_dataset_2_current_epoch[idx] assert curr_state_dict == expected_dataset_state - if idx == 2: + if idx == save_at: saved_dataloader_state_dict = deepcopy(curr_state_dict) assert dataset1.current_epoch == 2 @@ -717,200 +802,64 @@ def test_parallel_dataset_with_dataloader_2_epochs_none_length(tmp_path): assert dataloader.restore batches_23 = [] - states_23 = [] - for batch in dataloader: + for idx, batch in enumerate(dataloader): batches_23.append(batch) - states_23.append(dataloader.state_dict()) - - assert len(batches_2[3:]) == len(batches_23) - assert all(torch.equal(x1, x2) for b1, b2 in zip(batches_2[3:], batches_23) for x1, x2 in zip(b1, b2)) - assert states_23[0]["current_epoch"] == 2 - - assert not dataloader.restore - - -@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") -def test_parallel_dataset_with_dataloader_2_epochs_int_length(tmp_path): - data_dir_1 = str(tmp_path / "data_1") - data_dir_2 = str(tmp_path / "data_2") - cache_dir_1 = str(tmp_path / "cache_dir_1") - cache_dir_2 = str(tmp_path / "cache_dir_2") - os.makedirs(data_dir_1) - os.makedirs(data_dir_2) - os.makedirs(cache_dir_1) - os.makedirs(cache_dir_2) - - cache = Cache(input_dir=data_dir_1, chunk_size=2) - - for i in range(18): - cache[i] = i - - cache.done() - cache.merge() - - cache = Cache(input_dir=data_dir_2, chunk_size=2) - - for i in range(20): - cache[i] = -i - - cache.done() - cache.merge() - - dataset1 = StreamingDataset(input_dir=Dir(cache_dir_1, data_dir_1), shuffle=True) - dataset2 = StreamingDataset(input_dir=Dir(cache_dir_2, data_dir_2), shuffle=True) - dataset = ParallelStreamingDataset(datasets=[dataset1, dataset2], length=12) - dataloader = StreamingDataLoader(dataset, num_workers=3, batch_size=2) - - assert dataset1.current_epoch == 1 - assert dataset2.current_epoch == 1 - - expected_dataset_state = { - "dataset": { - "0": { - "num_samples_yielded": 0, - "num_workers": 3, - "batch_size": 2, - "current_epoch": 1, - "input_dir_path": ANY, - "input_dir_url": ANY, - "cache_dir_path": None, - "item_loader": None, - "drop_last": False, - "seed": 42, - "world_size": 1, - "shuffle": True, - "subsampled_files": ANY, - "region_of_interest": ANY, - }, - "1": { - "num_samples_yielded": 0, - "num_workers": 3, - "batch_size": 2, - "current_epoch": 1, - "input_dir_path": ANY, - "input_dir_url": ANY, - "cache_dir_path": None, - "item_loader": None, - "drop_last": False, - "seed": 42, - "world_size": 1, - "shuffle": True, - "subsampled_files": ANY, - "region_of_interest": ANY, - }, - }, - "current_epoch": 1, - "latest_worker_idx": 0, - "num_samples_yielded": {}, - "num_cycles": {}, - } - expected_num_samples_yielded = [ - {0: [2, 2]}, - {0: [2, 2], 1: [2, 2]}, - {0: [2, 2], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [4, 4]}, - ] - expected_num_cycles = [ - {0: [0, 0]}, - {0: [0, 0], 1: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - ] - expected_current_epoch = [1, 1, 1, 1, 1, 1] - dataset_1_current_epoch = [1, 1, 1, 1, 1, 1] - dataset_2_current_epoch = [1, 1, 1, 1, 1, 1] - expected_latest_worker_idx = [0, 1, 2, 0, 1, 2] - expected_dataset0_samples_yielded = [2, 4, 6, 8, 10, 12] - expected_dataset1_samples_yielded = [2, 4, 6, 8, 10, 12] - - batches_1 = [] - - for idx, batch in enumerate(dataloader): - batches_1.append(batch) curr_state_dict = dataloader.state_dict() - expected_dataset_state["num_samples_yielded"] = expected_num_samples_yielded[idx] - expected_dataset_state["num_cycles"] = expected_num_cycles[idx] - expected_dataset_state["current_epoch"] = expected_current_epoch[idx] - expected_dataset_state["latest_worker_idx"] = expected_latest_worker_idx[idx] - expected_dataset_state["dataset"]["0"]["num_samples_yielded"] = expected_dataset0_samples_yielded[idx] - expected_dataset_state["dataset"]["1"]["num_samples_yielded"] = expected_dataset1_samples_yielded[idx] - expected_dataset_state["dataset"]["0"]["current_epoch"] = dataset_1_current_epoch[idx] - expected_dataset_state["dataset"]["1"]["current_epoch"] = dataset_2_current_epoch[idx] + idx = idx + save_at + 1 + expected_dataset_state["num_samples_yielded"] = epoch_2_expected_num_samples_yielded[idx] + expected_dataset_state["num_cycles"] = epoch_2_expected_num_cycles[idx] + expected_dataset_state["current_epoch"] = epoch_2_expected_current_epoch[idx] + expected_dataset_state["latest_worker_idx"] = epoch_2_expected_latest_worker_idx[idx] + expected_dataset_state["dataset"]["0"]["num_samples_yielded"] = epoch_2_expected_dataset0_samples_yielded[idx] + expected_dataset_state["dataset"]["1"]["num_samples_yielded"] = epoch_2_expected_dataset1_samples_yielded[idx] + expected_dataset_state["dataset"]["0"]["current_epoch"] = epoch_2_expected_dataset_1_current_epoch[idx] + expected_dataset_state["dataset"]["1"]["current_epoch"] = epoch_2_expected_dataset_2_current_epoch[idx] assert curr_state_dict == expected_dataset_state - assert dataset1.current_epoch == 1 - assert dataset2.current_epoch == 1 + assert len(batches_2[save_at + 1 :]) == len(batches_23) + assert all(torch.equal(x1, x2) for b1, b2 in zip(batches_2[save_at + 1 :], batches_23) for x1, x2 in zip(b1, b2)) - saved_dataloader_state_dict = None + assert not dataloader.restore - batches_2 = [] - expected_num_samples_yielded = [ - {0: [6, 6], 1: [4, 4], 2: [4, 4]}, - {0: [6, 6], 1: [6, 6], 2: [4, 4]}, - {0: [6, 6], 1: [6, 6], 2: [6, 6]}, - {0: [2, 8], 1: [6, 6], 2: [6, 6]}, - {0: [2, 8], 1: [2, 2], 2: [6, 6]}, - {0: [2, 8], 1: [2, 2], 2: [2, 2]}, +@pytest.mark.parametrize("length", [None, 16, float("inf")]) +@pytest.mark.parametrize("resume", [False, True]) +@pytest.mark.parametrize("shuffle", [False, True]) +def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, resume, shuffle): + _, _, pardset, dloader = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, parlen=length, len1=10, len2=10, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume + ) + assert pardset.is_cycling() or length is None + break_at = 3 + expected_1 = [ + [torch.tensor([0]), torch.tensor([0])], + [torch.tensor([5]), torch.tensor([5])], + [torch.tensor([1]), torch.tensor([1])], + [torch.tensor([6]), torch.tensor([6])], ] - expected_num_cycles = [ - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [1, 0], 1: [0, 0], 2: [0, 0]}, - {0: [1, 0], 1: [1, 1], 2: [0, 0]}, - {0: [1, 0], 1: [1, 1], 2: [1, 1]}, + batches_1 = [] + for i, batch in enumerate(dloader): + if not shuffle: + assert all(torch.equal(x, y) for x, y in zip(batch, expected_1[i])) + batches_1.append(batch) + if i == break_at: + break + expected_2 = [ + [torch.tensor([2]), torch.tensor([2])], + [torch.tensor([7]), torch.tensor([7])], + [torch.tensor([3]), torch.tensor([3])], + [torch.tensor([8]), torch.tensor([8])], ] - expected_current_epoch = [2, 2, 2, 2, 2, 2] - dataset_1_current_epoch = [1, 1, 1, 2, 2, 2] - dataset_2_current_epoch = [1, 1, 1, 1, 2, 2] - expected_latest_worker_idx = [0, 1, 2, 0, 1, 2, 0] - expected_dataset0_samples_yielded = [14, 16, 18, 2, 4, 6] - expected_dataset1_samples_yielded = [14, 16, 18, 20, 2, 4] - for idx, batch in enumerate(dataloader): - batches_2.append(batch) - curr_state_dict = dataloader.state_dict() - - expected_dataset_state["num_samples_yielded"] = expected_num_samples_yielded[idx] - expected_dataset_state["num_cycles"] = expected_num_cycles[idx] - expected_dataset_state["current_epoch"] = expected_current_epoch[idx] - expected_dataset_state["latest_worker_idx"] = expected_latest_worker_idx[idx] - expected_dataset_state["dataset"]["0"]["num_samples_yielded"] = expected_dataset0_samples_yielded[idx] - expected_dataset_state["dataset"]["1"]["num_samples_yielded"] = expected_dataset1_samples_yielded[idx] - expected_dataset_state["dataset"]["0"]["current_epoch"] = dataset_1_current_epoch[idx] - expected_dataset_state["dataset"]["1"]["current_epoch"] = dataset_2_current_epoch[idx] - - assert curr_state_dict == expected_dataset_state - - if idx == 2: - saved_dataloader_state_dict = deepcopy(curr_state_dict) - - assert dataset1.current_epoch == 2 - assert dataset2.current_epoch == 2 - - assert len(batches_1) == len(batches_2) - assert any(not torch.equal(x1, x2) for b1, b2 in zip(batches_1, batches_2) for x1, x2 in zip(b1, b2)) - - assert saved_dataloader_state_dict is not None - dataloader.load_state_dict(saved_dataloader_state_dict) - - assert dataloader.restore - - batches_23 = [] - states_23 = [] - for batch in dataloader: - batches_23.append(batch) - states_23.append(dataloader.state_dict()) - - assert len(batches_2[3:]) == len(batches_23) - assert all(torch.equal(x1, x2) for b1, b2 in zip(batches_2[3:], batches_23) for x1, x2 in zip(b1, b2)) - assert states_23[0]["current_epoch"] == 2 - - assert not dataloader.restore + for i, batch in enumerate(dloader): + if not shuffle: + assert all( + torch.equal(x, y) + for x, y in zip(batch, (expected_2 if resume and length is not None else expected_1)[i]) + ) + elif not resume and length is not None: + assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i])) + if i == break_at: + break From b5119e8530aedb8edebda7e787e7477af3648ce1 Mon Sep 17 00:00:00 2001 From: philgzl Date: Mon, 7 Jul 2025 17:13:42 +0200 Subject: [PATCH 2/4] Add dataclass in test for readability --- tests/streaming/test_parallel.py | 273 ++++++++++++++++--------------- 1 file changed, 140 insertions(+), 133 deletions(-) diff --git a/tests/streaming/test_parallel.py b/tests/streaming/test_parallel.py index 83e48f73f..c5cf76788 100644 --- a/tests/streaming/test_parallel.py +++ b/tests/streaming/test_parallel.py @@ -1,6 +1,7 @@ import functools import sys from copy import deepcopy +from dataclasses import dataclass from unittest.mock import ANY, MagicMock import pytest @@ -562,6 +563,18 @@ def test_parallel_dataset_dataloader_states_partial_iterations( assert samples_yielded == len(parallel_dataset), "All samples should be yielded in the second epoch." +@dataclass +class ExpectedStates: + num_samples_yielded: list[dict[int, list[int]]] + num_cycles: list[dict[int, list[int]]] + latest_worker_idx: list[int] + current_epoch: list[int] + dset_1_current_epoch: list[int] + dset_2_current_epoch: list[int] + dset_1_samples_yielded: list[int] + dset_2_samples_yielded: list[int] + + TEST_MAPS = [ { "length": None, @@ -569,50 +582,54 @@ def test_parallel_dataset_dataloader_states_partial_iterations( "len2": 14, "batch_size": 2, "num_workers": 3, - "epoch_1_expected_num_samples_yielded": [ - {0: [2, 2]}, - {0: [2, 2], 1: [2, 2]}, - {0: [2, 2], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [4, 4]}, - ], - "epoch_1_expected_num_cycles": [ - {0: [0, 0]}, - {0: [0, 0], 1: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - ], - "epoch_1_expected_current_epoch": [1, 1, 1, 1, 1, 1], - "epoch_1_expected_dataset_1_current_epoch": [1, 1, 1, 1, 1, 1], - "epoch_1_expected_dataset_2_current_epoch": [1, 1, 1, 1, 1, 1], - "epoch_1_expected_latest_worker_idx": [0, 1, 2, 0, 1, 2], - "epoch_1_expected_dataset0_samples_yielded": [2, 4, 6, 8, 10, 12], - "epoch_1_expected_dataset1_samples_yielded": [2, 4, 6, 8, 10, 12], - "epoch_2_expected_num_samples_yielded": [ - {0: [2, 2]}, - {0: [2, 2], 1: [2, 2]}, - {0: [2, 2], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [4, 4]}, - ], - "epoch_2_expected_num_cycles": [ - {0: [0, 0]}, - {0: [0, 0], 1: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - ], - "epoch_2_expected_current_epoch": [2, 2, 2, 2, 2, 2], - "epoch_2_expected_dataset_1_current_epoch": [2, 2, 2, 2, 2, 2], - "epoch_2_expected_dataset_2_current_epoch": [2, 2, 2, 2, 2, 2], - "epoch_2_expected_latest_worker_idx": [0, 1, 2, 0, 1, 2], - "epoch_2_expected_dataset0_samples_yielded": [2, 4, 6, 8, 10, 12], - "epoch_2_expected_dataset1_samples_yielded": [2, 4, 6, 8, 10, 12], + "expected_states_1": ExpectedStates( + num_samples_yielded=[ + {0: [2, 2]}, + {0: [2, 2], 1: [2, 2]}, + {0: [2, 2], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [4, 4]}, + ], + num_cycles=[ + {0: [0, 0]}, + {0: [0, 0], 1: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + ], + latest_worker_idx=[0, 1, 2, 0, 1, 2], + current_epoch=[1, 1, 1, 1, 1, 1], + dset_1_current_epoch=[1, 1, 1, 1, 1, 1], + dset_2_current_epoch=[1, 1, 1, 1, 1, 1], + dset_1_samples_yielded=[2, 4, 6, 8, 10, 12], + dset_2_samples_yielded=[2, 4, 6, 8, 10, 12], + ), + "expected_states_2": ExpectedStates( + num_samples_yielded=[ + {0: [2, 2]}, + {0: [2, 2], 1: [2, 2]}, + {0: [2, 2], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [4, 4]}, + ], + num_cycles=[ + {0: [0, 0]}, + {0: [0, 0], 1: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + ], + latest_worker_idx=[0, 1, 2, 0, 1, 2], + current_epoch=[2, 2, 2, 2, 2, 2], + dset_1_current_epoch=[2, 2, 2, 2, 2, 2], + dset_2_current_epoch=[2, 2, 2, 2, 2, 2], + dset_1_samples_yielded=[2, 4, 6, 8, 10, 12], + dset_2_samples_yielded=[2, 4, 6, 8, 10, 12], + ), }, { "length": 12, @@ -620,50 +637,54 @@ def test_parallel_dataset_dataloader_states_partial_iterations( "len2": 20, "batch_size": 2, "num_workers": 3, - "epoch_1_expected_num_samples_yielded": [ - {0: [2, 2]}, - {0: [2, 2], 1: [2, 2]}, - {0: [2, 2], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [4, 4]}, - ], - "epoch_1_expected_num_cycles": [ - {0: [0, 0]}, - {0: [0, 0], 1: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - ], - "epoch_1_expected_current_epoch": [1, 1, 1, 1, 1, 1], - "epoch_1_expected_dataset_1_current_epoch": [1, 1, 1, 1, 1, 1], - "epoch_1_expected_dataset_2_current_epoch": [1, 1, 1, 1, 1, 1], - "epoch_1_expected_latest_worker_idx": [0, 1, 2, 0, 1, 2], - "epoch_1_expected_dataset0_samples_yielded": [2, 4, 6, 8, 10, 12], - "epoch_1_expected_dataset1_samples_yielded": [2, 4, 6, 8, 10, 12], - "epoch_2_expected_num_samples_yielded": [ - {0: [6, 6], 1: [4, 4], 2: [4, 4]}, - {0: [6, 6], 1: [6, 6], 2: [4, 4]}, - {0: [6, 6], 1: [6, 6], 2: [6, 6]}, - {0: [2, 8], 1: [6, 6], 2: [6, 6]}, - {0: [2, 8], 1: [2, 2], 2: [6, 6]}, - {0: [2, 8], 1: [2, 2], 2: [2, 2]}, - ], - "epoch_2_expected_num_cycles": [ - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [1, 0], 1: [0, 0], 2: [0, 0]}, - {0: [1, 0], 1: [1, 1], 2: [0, 0]}, - {0: [1, 0], 1: [1, 1], 2: [1, 1]}, - ], - "epoch_2_expected_current_epoch": [2, 2, 2, 2, 2, 2], - "epoch_2_expected_dataset_1_current_epoch": [1, 1, 1, 2, 2, 2], - "epoch_2_expected_dataset_2_current_epoch": [1, 1, 1, 1, 2, 2], - "epoch_2_expected_latest_worker_idx": [0, 1, 2, 0, 1, 2], - "epoch_2_expected_dataset0_samples_yielded": [14, 16, 18, 2, 4, 6], - "epoch_2_expected_dataset1_samples_yielded": [14, 16, 18, 20, 2, 4], + "expected_states_1": ExpectedStates( + num_samples_yielded=[ + {0: [2, 2]}, + {0: [2, 2], 1: [2, 2]}, + {0: [2, 2], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [4, 4]}, + ], + num_cycles=[ + {0: [0, 0]}, + {0: [0, 0], 1: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + ], + latest_worker_idx=[0, 1, 2, 0, 1, 2], + current_epoch=[1, 1, 1, 1, 1, 1], + dset_1_current_epoch=[1, 1, 1, 1, 1, 1], + dset_2_current_epoch=[1, 1, 1, 1, 1, 1], + dset_1_samples_yielded=[2, 4, 6, 8, 10, 12], + dset_2_samples_yielded=[2, 4, 6, 8, 10, 12], + ), + "expected_states_2": ExpectedStates( + num_samples_yielded=[ + {0: [6, 6], 1: [4, 4], 2: [4, 4]}, + {0: [6, 6], 1: [6, 6], 2: [4, 4]}, + {0: [6, 6], 1: [6, 6], 2: [6, 6]}, + {0: [2, 8], 1: [6, 6], 2: [6, 6]}, + {0: [2, 8], 1: [2, 2], 2: [6, 6]}, + {0: [2, 8], 1: [2, 2], 2: [2, 2]}, + ], + num_cycles=[ + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [1, 0], 1: [0, 0], 2: [0, 0]}, + {0: [1, 0], 1: [1, 1], 2: [0, 0]}, + {0: [1, 0], 1: [1, 1], 2: [1, 1]}, + ], + latest_worker_idx=[0, 1, 2, 0, 1, 2], + current_epoch=[2, 2, 2, 2, 2, 2], + dset_1_current_epoch=[1, 1, 1, 2, 2, 2], + dset_2_current_epoch=[1, 1, 1, 1, 2, 2], + dset_1_samples_yielded=[14, 16, 18, 2, 4, 6], + dset_2_samples_yielded=[14, 16, 18, 20, 2, 4], + ), }, ] @@ -677,22 +698,8 @@ def test_parallel_dataset_with_dataloader_2_epochs( len2, batch_size, num_workers, - epoch_1_expected_num_samples_yielded, - epoch_1_expected_num_cycles, - epoch_1_expected_current_epoch, - epoch_1_expected_dataset_1_current_epoch, - epoch_1_expected_dataset_2_current_epoch, - epoch_1_expected_latest_worker_idx, - epoch_1_expected_dataset0_samples_yielded, - epoch_1_expected_dataset1_samples_yielded, - epoch_2_expected_num_samples_yielded, - epoch_2_expected_num_cycles, - epoch_2_expected_current_epoch, - epoch_2_expected_dataset_1_current_epoch, - epoch_2_expected_dataset_2_current_epoch, - epoch_2_expected_latest_worker_idx, - epoch_2_expected_dataset0_samples_yielded, - epoch_2_expected_dataset1_samples_yielded, + expected_states_1: ExpectedStates, + expected_states_2: ExpectedStates, ): dataset1, dataset2, _, dataloader = prepare_parallel_dataset_and_dataloder( tmp_path_factory, @@ -706,7 +713,7 @@ def test_parallel_dataset_with_dataloader_2_epochs( assert dataset1.current_epoch == 1 assert dataset2.current_epoch == 1 - expected_dataset_state = { + expected_dset_state = { "dataset": { "0": { "num_samples_yielded": 0, @@ -753,16 +760,16 @@ def test_parallel_dataset_with_dataloader_2_epochs( batches_1.append(batch) curr_state_dict = dataloader.state_dict() - expected_dataset_state["num_samples_yielded"] = epoch_1_expected_num_samples_yielded[idx] - expected_dataset_state["num_cycles"] = epoch_1_expected_num_cycles[idx] - expected_dataset_state["current_epoch"] = epoch_1_expected_current_epoch[idx] - expected_dataset_state["latest_worker_idx"] = epoch_1_expected_latest_worker_idx[idx] - expected_dataset_state["dataset"]["0"]["num_samples_yielded"] = epoch_1_expected_dataset0_samples_yielded[idx] - expected_dataset_state["dataset"]["1"]["num_samples_yielded"] = epoch_1_expected_dataset1_samples_yielded[idx] - expected_dataset_state["dataset"]["0"]["current_epoch"] = epoch_1_expected_dataset_1_current_epoch[idx] - expected_dataset_state["dataset"]["1"]["current_epoch"] = epoch_1_expected_dataset_2_current_epoch[idx] + expected_dset_state["num_samples_yielded"] = expected_states_1.num_samples_yielded[idx] + expected_dset_state["num_cycles"] = expected_states_1.num_cycles[idx] + expected_dset_state["latest_worker_idx"] = expected_states_1.latest_worker_idx[idx] + expected_dset_state["current_epoch"] = expected_states_1.current_epoch[idx] + expected_dset_state["dataset"]["0"]["current_epoch"] = expected_states_1.dset_1_current_epoch[idx] + expected_dset_state["dataset"]["1"]["current_epoch"] = expected_states_1.dset_2_current_epoch[idx] + expected_dset_state["dataset"]["0"]["num_samples_yielded"] = expected_states_1.dset_1_samples_yielded[idx] + expected_dset_state["dataset"]["1"]["num_samples_yielded"] = expected_states_1.dset_2_samples_yielded[idx] - assert curr_state_dict == expected_dataset_state + assert curr_state_dict == expected_dset_state assert dataset1.current_epoch == 1 assert dataset2.current_epoch == 1 @@ -776,16 +783,16 @@ def test_parallel_dataset_with_dataloader_2_epochs( batches_2.append(batch) curr_state_dict = dataloader.state_dict() - expected_dataset_state["num_samples_yielded"] = epoch_2_expected_num_samples_yielded[idx] - expected_dataset_state["num_cycles"] = epoch_2_expected_num_cycles[idx] - expected_dataset_state["current_epoch"] = epoch_2_expected_current_epoch[idx] - expected_dataset_state["latest_worker_idx"] = epoch_2_expected_latest_worker_idx[idx] - expected_dataset_state["dataset"]["0"]["num_samples_yielded"] = epoch_2_expected_dataset0_samples_yielded[idx] - expected_dataset_state["dataset"]["1"]["num_samples_yielded"] = epoch_2_expected_dataset1_samples_yielded[idx] - expected_dataset_state["dataset"]["0"]["current_epoch"] = epoch_2_expected_dataset_1_current_epoch[idx] - expected_dataset_state["dataset"]["1"]["current_epoch"] = epoch_2_expected_dataset_2_current_epoch[idx] + expected_dset_state["num_samples_yielded"] = expected_states_2.num_samples_yielded[idx] + expected_dset_state["num_cycles"] = expected_states_2.num_cycles[idx] + expected_dset_state["latest_worker_idx"] = expected_states_2.latest_worker_idx[idx] + expected_dset_state["current_epoch"] = expected_states_2.current_epoch[idx] + expected_dset_state["dataset"]["0"]["current_epoch"] = expected_states_2.dset_1_current_epoch[idx] + expected_dset_state["dataset"]["1"]["current_epoch"] = expected_states_2.dset_2_current_epoch[idx] + expected_dset_state["dataset"]["0"]["num_samples_yielded"] = expected_states_2.dset_1_samples_yielded[idx] + expected_dset_state["dataset"]["1"]["num_samples_yielded"] = expected_states_2.dset_2_samples_yielded[idx] - assert curr_state_dict == expected_dataset_state + assert curr_state_dict == expected_dset_state if idx == save_at: saved_dataloader_state_dict = deepcopy(curr_state_dict) @@ -808,16 +815,16 @@ def test_parallel_dataset_with_dataloader_2_epochs( curr_state_dict = dataloader.state_dict() idx = idx + save_at + 1 - expected_dataset_state["num_samples_yielded"] = epoch_2_expected_num_samples_yielded[idx] - expected_dataset_state["num_cycles"] = epoch_2_expected_num_cycles[idx] - expected_dataset_state["current_epoch"] = epoch_2_expected_current_epoch[idx] - expected_dataset_state["latest_worker_idx"] = epoch_2_expected_latest_worker_idx[idx] - expected_dataset_state["dataset"]["0"]["num_samples_yielded"] = epoch_2_expected_dataset0_samples_yielded[idx] - expected_dataset_state["dataset"]["1"]["num_samples_yielded"] = epoch_2_expected_dataset1_samples_yielded[idx] - expected_dataset_state["dataset"]["0"]["current_epoch"] = epoch_2_expected_dataset_1_current_epoch[idx] - expected_dataset_state["dataset"]["1"]["current_epoch"] = epoch_2_expected_dataset_2_current_epoch[idx] - - assert curr_state_dict == expected_dataset_state + expected_dset_state["num_samples_yielded"] = expected_states_2.num_samples_yielded[idx] + expected_dset_state["num_cycles"] = expected_states_2.num_cycles[idx] + expected_dset_state["latest_worker_idx"] = expected_states_2.latest_worker_idx[idx] + expected_dset_state["current_epoch"] = expected_states_2.current_epoch[idx] + expected_dset_state["dataset"]["0"]["current_epoch"] = expected_states_2.dset_1_current_epoch[idx] + expected_dset_state["dataset"]["1"]["current_epoch"] = expected_states_2.dset_2_current_epoch[idx] + expected_dset_state["dataset"]["0"]["num_samples_yielded"] = expected_states_2.dset_1_samples_yielded[idx] + expected_dset_state["dataset"]["1"]["num_samples_yielded"] = expected_states_2.dset_2_samples_yielded[idx] + + assert curr_state_dict == expected_dset_state assert len(batches_2[save_at + 1 :]) == len(batches_23) assert all(torch.equal(x1, x2) for b1, b2 in zip(batches_2[save_at + 1 :], batches_23) for x1, x2 in zip(b1, b2)) From 7bf45d4459e53264b36d63fdef46fe4355987411 Mon Sep 17 00:00:00 2001 From: philgzl Date: Mon, 7 Jul 2025 21:01:50 +0200 Subject: [PATCH 3/4] Add failing test for complete iteration with resume=False and fix it --- src/litdata/streaming/parallel.py | 9 ++++++- tests/streaming/test_parallel.py | 39 +++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/src/litdata/streaming/parallel.py b/src/litdata/streaming/parallel.py index 9c36d2c12..16e328717 100644 --- a/src/litdata/streaming/parallel.py +++ b/src/litdata/streaming/parallel.py @@ -104,7 +104,7 @@ def __init__( seed: Seed for the random number generators provided to ``transform``. resume: If ``True`` and ``length`` is not ``None``, tells the dataloader to resume the dataset from where it left off in the previous epoch. If ``False``, the same ``length`` samples are yielded in each epoch. - Ignored if ``length`` is ``None`` or if the dataset is used without a ``StreamingDataLoader``. + Ignored if ``length`` is ``None``. reset_rngs: If ``True``, the random number generators provided to ``transform`` are reset to their initial state at the beginning of each epoch. Together with ``resume=False``, this allows to produce the same samples in each epoch. @@ -301,6 +301,13 @@ def state_dict( for dataset_idx, dataset in enumerate(self._datasets) } + def reset_state_dict(self) -> None: + """Reset the state of the dataset.""" + super().reset_state_dict() + if self.is_cycling() and not self.resume: + for dataset in self._datasets: + dataset.set_epoch(0) + class _ParallelDatasetIterator(Iterator): def __init__( diff --git a/tests/streaming/test_parallel.py b/tests/streaming/test_parallel.py index c5cf76788..a40cce9f2 100644 --- a/tests/streaming/test_parallel.py +++ b/tests/streaming/test_parallel.py @@ -870,3 +870,42 @@ def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, res assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i])) if i == break_at: break + + +@pytest.mark.parametrize("length", [None, 6]) +@pytest.mark.parametrize("resume", [False, True]) +@pytest.mark.parametrize("shuffle", [False, True]) +def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, resume, shuffle): + _, _, pardset, dloader = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, parlen=length, len1=4, len2=4, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume + ) + assert pardset.is_cycling() or length is None + expected_1 = [ + [torch.tensor([0]), torch.tensor([0])], + [torch.tensor([2]), torch.tensor([2])], + [torch.tensor([1]), torch.tensor([1])], + [torch.tensor([3]), torch.tensor([3])], + [torch.tensor([0]), torch.tensor([0])], + [torch.tensor([2]), torch.tensor([2])], + ] + batches_1 = [] + for i, batch in enumerate(dloader): + if not shuffle: + assert all(torch.equal(x, y) for x, y in zip(batch, expected_1[i])) + batches_1.append(batch) + expected_2 = [ + [torch.tensor([1]), torch.tensor([1])], + [torch.tensor([3]), torch.tensor([3])], + [torch.tensor([0]), torch.tensor([0])], + [torch.tensor([2]), torch.tensor([2])], + [torch.tensor([1]), torch.tensor([1])], + [torch.tensor([3]), torch.tensor([3])], + ] + for i, batch in enumerate(dloader): + if not shuffle: + assert all( + torch.equal(x, y) + for x, y in zip(batch, (expected_2 if resume and length is not None else expected_1)[i]) + ) + elif not resume and length is not None: + assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i])) From 136043939f82e7fc7c0127ccb4e2628952e5408a Mon Sep 17 00:00:00 2001 From: philgzl Date: Tue, 8 Jul 2025 08:34:46 +0200 Subject: [PATCH 4/4] Add yet another failing test and fix it --- src/litdata/streaming/parallel.py | 12 ++---- tests/streaming/test_parallel.py | 63 +++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/src/litdata/streaming/parallel.py b/src/litdata/streaming/parallel.py index 16e328717..a3231f7ad 100644 --- a/src/litdata/streaming/parallel.py +++ b/src/litdata/streaming/parallel.py @@ -160,7 +160,7 @@ def is_infinite(self) -> bool: def set_epoch(self, current_epoch: int) -> None: self._current_epoch = current_epoch - if self.is_cycling(): + if self.is_cycling() and self.resume: # do not set the epoch as cycling datasets have their own epoch counter return for dataset in self._datasets: @@ -194,6 +194,9 @@ def get_all_lens(self) -> List[int]: return [self._get_len(d) for d in self._datasets] def __iter__(self) -> Iterator[Any]: + if self.is_cycling() and not self.resume: + self.set_epoch(1) + worker_env = _WorkerEnv.detect() num_samples_yielded = None @@ -301,13 +304,6 @@ def state_dict( for dataset_idx, dataset in enumerate(self._datasets) } - def reset_state_dict(self) -> None: - """Reset the state of the dataset.""" - super().reset_state_dict() - if self.is_cycling() and not self.resume: - for dataset in self._datasets: - dataset.set_epoch(0) - class _ParallelDatasetIterator(Iterator): def __init__( diff --git a/tests/streaming/test_parallel.py b/tests/streaming/test_parallel.py index a40cce9f2..f77066a13 100644 --- a/tests/streaming/test_parallel.py +++ b/tests/streaming/test_parallel.py @@ -835,6 +835,7 @@ def test_parallel_dataset_with_dataloader_2_epochs( @pytest.mark.parametrize("length", [None, 16, float("inf")]) @pytest.mark.parametrize("resume", [False, True]) @pytest.mark.parametrize("shuffle", [False, True]) +@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, resume, shuffle): _, _, pardset, dloader = prepare_parallel_dataset_and_dataloder( tmp_path_factory, parlen=length, len1=10, len2=10, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume @@ -875,6 +876,7 @@ def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, res @pytest.mark.parametrize("length", [None, 6]) @pytest.mark.parametrize("resume", [False, True]) @pytest.mark.parametrize("shuffle", [False, True]) +@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, resume, shuffle): _, _, pardset, dloader = prepare_parallel_dataset_and_dataloder( tmp_path_factory, parlen=length, len1=4, len2=4, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume @@ -909,3 +911,64 @@ def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, re ) elif not resume and length is not None: assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i])) + + +@pytest.mark.parametrize("length", [None, 18]) +@pytest.mark.parametrize("resume", [False, True]) +@pytest.mark.parametrize("shuffle", [False, True]) +@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") +def test_parallel_dataset_partial_iteration_resume_without_dataloader(tmp_path_factory, length, resume, shuffle): + _, _, pardset, _ = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, parlen=length, len1=10, len2=10, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume + ) + assert pardset.is_cycling() or length is None + break_at = 3 + expected = [ + [0, 0], + [1, 1], + [2, 2], + [3, 3], + ] + samples = [] + for i, sample in enumerate(pardset): + if not shuffle: + assert all(x == y for x, y in zip(sample, expected[i])) + samples.append(sample) + if i == break_at: + break + for i, sample in enumerate(pardset): + if not shuffle: + assert all(x == y for x, y in zip(sample, expected[i])) + elif not resume and length is not None: + assert all(x == y for x, y in zip(sample, samples[i])) + if i == break_at: + break + + +@pytest.mark.parametrize("length", [None, 6]) +@pytest.mark.parametrize("resume", [False, True]) +@pytest.mark.parametrize("shuffle", [False, True]) +@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") +def test_parallel_dataset_complete_iteration_resume_without_dataloader(tmp_path_factory, length, resume, shuffle): + _, _, pardset, _ = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, parlen=length, len1=4, len2=4, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume + ) + assert pardset.is_cycling() or length is None + expected = [ + [0, 0], + [1, 1], + [2, 2], + [3, 3], + [0, 0], + [1, 1], + ] + samples = [] + for i, sample in enumerate(pardset): + if not shuffle: + assert all(x == y for x, y in zip(sample, expected[i])) + samples.append(sample) + for i, sample in enumerate(pardset): + if not shuffle: + assert all(x == y for x, y in zip(sample, expected[i])) + elif not resume and length is not None: + assert all(x == y for x, y in zip(sample, samples[i]))