diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index d83808e73..fc6271f49 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -641,10 +641,15 @@ def __init__( def __iter__(self) -> Any: if not self.restore: if ( - not isinstance(self.dataset, ParallelStreamingDataset) - or not self.dataset.is_cycling() - or self.current_epoch == 0 + isinstance(self.dataset, ParallelStreamingDataset) + and self.dataset.is_cycling() + and self.dataset.resume + and self.current_epoch != 0 ): + # For ParallelStreamingDataset with _length != None we want to cycle the wrapped datasets i.e. we do not + # want to restart at index 0 at every epoch. So we set them in restore state. + self.load_state_dict(self.state_dict()) + else: self._latest_worker_idx = 0 self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1))) self._worker_idx_iter = iter(self._worker_idx) @@ -686,11 +691,6 @@ def __iter__(self) -> Any: else: yield batch - # For ParallelStreamingDataset with _length != None we want to cycle the wrapped datasets i.e. we do not want to - # restart at index 0 at every epoch. So we set them in restore state. - if isinstance(self.dataset, ParallelStreamingDataset) and self.dataset.is_cycling(): - self.load_state_dict(self.state_dict()) - logger.debug(_get_log_msg({"name": "iterating_dataloader", "ph": "E"})) self.restore = False diff --git a/src/litdata/streaming/parallel.py b/src/litdata/streaming/parallel.py index 9d3e6fb2f..a3231f7ad 100644 --- a/src/litdata/streaming/parallel.py +++ b/src/litdata/streaming/parallel.py @@ -49,8 +49,9 @@ class ParallelStreamingDataset(_BaseStreamingDatasetWrapper): datasets. The parallel dataset can be configured to raise a ``StopIteration`` as soon as any of the datasets is exhausted, or - to cycle through the datasets until a given number of samples are yielded. When cycling, each epoch resumes from - where the previous one left off in the current cycle, i.e. the yielded samples are not the same across epochs. + to cycle through the datasets until a given number of samples are yielded. When cycling and using a + ``StreamingDataLoader``, the ``resume`` option can be used to either yield the same ``length`` samples in each + epoch, or to resume the dataset from where it left off in the previous epoch. New data can be generated on-the-fly from a sample from each dataset by providing a ``transform`` function. This function can take a single tuple argument containing a sample from each dataset, and optionally a dictionary of @@ -86,6 +87,7 @@ def __init__( force_override_state_dict: bool = False, transform: Optional[Transform] = None, seed: int = 42, + resume: bool = True, reset_rngs: bool = False, ) -> None: """Enable to stream data from multiple StreamingDataset in parallel. @@ -100,9 +102,12 @@ def __init__( argument a tuple containing one sample from each dataset, and optionally a dictionary of random number generators which are seeded using the current state of the dataset. seed: Seed for the random number generators provided to ``transform``. + resume: If ``True`` and ``length`` is not ``None``, tells the dataloader to resume the dataset from where it + left off in the previous epoch. If ``False``, the same ``length`` samples are yielded in each epoch. + Ignored if ``length`` is ``None``. reset_rngs: If ``True``, the random number generators provided to ``transform`` are reset to their initial - state at the beginning of each epoch. Together with ``length=None`` and ``shuffle=False``, this ensures - that the same samples are yielded in each epoch. + state at the beginning of each epoch. Together with ``resume=False``, this allows to produce the same + samples in each epoch. """ self._check_datasets(datasets) @@ -132,6 +137,7 @@ def __init__( self._current_epoch = 0 self.num_workers = 1 self.batch_size = 1 + self.resume = resume if length is not None: for dataset in self._datasets: @@ -154,7 +160,7 @@ def is_infinite(self) -> bool: def set_epoch(self, current_epoch: int) -> None: self._current_epoch = current_epoch - if self.is_cycling(): + if self.is_cycling() and self.resume: # do not set the epoch as cycling datasets have their own epoch counter return for dataset in self._datasets: @@ -188,6 +194,9 @@ def get_all_lens(self) -> List[int]: return [self._get_len(d) for d in self._datasets] def __iter__(self) -> Iterator[Any]: + if self.is_cycling() and not self.resume: + self.set_epoch(1) + worker_env = _WorkerEnv.detect() num_samples_yielded = None diff --git a/tests/conftest.py b/tests/conftest.py index 7390a29b7..f306a78ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ import pytest import torch.distributed -from litdata import CombinedStreamingDataset, ParallelStreamingDataset, StreamingDataset +from litdata import CombinedStreamingDataset, StreamingDataset from litdata.constants import _POLARS_AVAILABLE from litdata.streaming.cache import Cache from litdata.streaming.reader import PrepareChunksThread @@ -73,21 +73,6 @@ def combined_dataset(prepare_combined_dataset): return CombinedStreamingDataset(datasets=[dataset_1, dataset_2]) -@pytest.fixture -def parallel_dataset(tmp_path_factory, request): - tmpdir = tmp_path_factory.mktemp("data") - datasets = [str(tmpdir / f"dataset_{i}") for i in range(2)] - for dataset, num_items in zip(datasets, [48, 56]): - cache = Cache(input_dir=dataset, chunk_size=10) - for i in range(num_items): - cache[i] = i - cache.done() - cache.merge() - dataset_1 = StreamingDataset(datasets[0], shuffle=True) - dataset_2 = StreamingDataset(datasets[1], shuffle=True) - return ParallelStreamingDataset(datasets=[dataset_1, dataset_2], length=request.param), request.param - - @pytest.fixture def google_mock(monkeypatch): google = ModuleType("google") diff --git a/tests/streaming/test_parallel.py b/tests/streaming/test_parallel.py index 5ad874a62..f77066a13 100644 --- a/tests/streaming/test_parallel.py +++ b/tests/streaming/test_parallel.py @@ -1,7 +1,7 @@ import functools -import os import sys from copy import deepcopy +from dataclasses import dataclass from unittest.mock import ANY, MagicMock import pytest @@ -10,7 +10,7 @@ from litdata.streaming.cache import Cache from litdata.streaming.dataloader import StreamingDataLoader -from litdata.streaming.dataset import Dir, StreamingDataset +from litdata.streaming.dataset import StreamingDataset from litdata.streaming.parallel import ParallelStreamingDataset @@ -321,7 +321,7 @@ def test_parallel_dataset_with_dataloader_and_one_worker(batch_size, length, exp "1": {"num_samples_yielded": num_samples_yielded[1], "num_workers": 1, "batch_size": batch_size}, }, "current_epoch": 1, - "latest_worker_idx": 0 if length is None else 1, + "latest_worker_idx": 0, "num_samples_yielded": {0: num_samples_yielded}, "num_cycles": {0: num_cycles}, } @@ -422,10 +422,27 @@ def test_dataloader_shuffle(tmp_path, shuffle): assert shuffle ^ all(torch.equal(x, y) for x, y in zip(epoch_1_batches[:3], epoch_2_batches[-3:])) -@pytest.mark.parametrize("parallel_dataset", [None, 3, float("inf")], indirect=True) -def test_parallel_dataset_dataloader_states_without_any_iterations(parallel_dataset): - parallel_dataset, _ = parallel_dataset - dataloader = StreamingDataLoader(parallel_dataset, batch_size=4) +def prepare_parallel_dataset_and_dataloder( + tmp_path_factory, parlen, len1=48, len2=56, num_workers=0, batch_size=4, shuffle=True, resume=True +): + tmpdir = tmp_path_factory.mktemp("data") + datasets = [str(tmpdir / f"dataset_{i}") for i in range(2)] + for dataset, num_items in zip(datasets, [len1, len2]): + cache = Cache(input_dir=dataset, chunk_size=10) + for i in range(num_items): + cache[i] = i + cache.done() + cache.merge() + dset1 = StreamingDataset(datasets[0], shuffle=shuffle) + dset2 = StreamingDataset(datasets[1], shuffle=shuffle) + pardset = ParallelStreamingDataset(datasets=[dset1, dset2], length=parlen, resume=resume) + dloader = StreamingDataLoader(pardset, num_workers=num_workers, batch_size=batch_size) + return dset1, dset2, pardset, dloader + + +@pytest.mark.parametrize("length", [None, 3, float("inf")]) +def test_parallel_dataset_dataloader_states_without_any_iterations(tmp_path_factory, length): + _, _, _, dataloader = prepare_parallel_dataset_and_dataloder(tmp_path_factory, length) assert not dataloader.restore dataloader.load_state_dict(dataloader.state_dict()) assert not dataloader.restore @@ -434,16 +451,19 @@ def test_parallel_dataset_dataloader_states_without_any_iterations(parallel_data @pytest.mark.timeout(120) +@pytest.mark.parametrize("length", [None, 24]) @pytest.mark.parametrize("num_workers", [0, 2]) -@pytest.mark.parametrize("parallel_dataset", [None, 24], indirect=True) +@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") -def test_parallel_dataset_dataloader_states_complete_iterations(parallel_dataset, num_workers): +def test_parallel_dataset_dataloader_states_complete_iterations(tmp_path_factory, length, num_workers, batch_size): print(f"Testing with num_workers={num_workers}") - parallel_dataset, length = parallel_dataset - batch_size = 2 - - dataloader = StreamingDataLoader(parallel_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) + _, _, parallel_dataset, dataloader = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, + length, + batch_size=batch_size, + num_workers=num_workers, + ) assert len(dataloader) == -(-len(parallel_dataset) // batch_size) @@ -494,18 +514,19 @@ def test_parallel_dataset_dataloader_states_complete_iterations(parallel_dataset @pytest.mark.timeout(300) +@pytest.mark.parametrize("length", [None, 20, 48]) @pytest.mark.parametrize("num_workers", [0, 2]) +@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("break_at", [3, 7]) -@pytest.mark.parametrize("parallel_dataset", [None, 20, 48], indirect=True) @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") -def test_parallel_dataset_dataloader_states_partial_iterations(parallel_dataset, num_workers, break_at): +def test_parallel_dataset_dataloader_states_partial_iterations( + tmp_path_factory, length, num_workers, batch_size, break_at +): print(f"Testing with num_workers={num_workers}, break_at={break_at}") - parallel_dataset, _ = parallel_dataset - batch_size = 2 - - # Verify dataloader state after partial last iteration - dataloader = StreamingDataLoader(parallel_dataset, batch_size=batch_size, num_workers=num_workers) + _, _, parallel_dataset, dataloader = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, length, batch_size=batch_size, num_workers=num_workers, shuffle=True + ) total_batches = len(dataloader) assert total_batches == -(-len(parallel_dataset) // batch_size) @@ -542,48 +563,162 @@ def test_parallel_dataset_dataloader_states_partial_iterations(parallel_dataset, assert samples_yielded == len(parallel_dataset), "All samples should be yielded in the second epoch." -@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") -def test_parallel_dataset_with_dataloader_2_epochs_none_length(tmp_path): - data_dir_1 = str(tmp_path / "data_1") - data_dir_2 = str(tmp_path / "data_2") - cache_dir_1 = str(tmp_path / "cache_dir_1") - cache_dir_2 = str(tmp_path / "cache_dir_2") - - os.makedirs(data_dir_1) - os.makedirs(data_dir_2) - os.makedirs(cache_dir_1) - os.makedirs(cache_dir_2) - - cache = Cache(input_dir=data_dir_1, chunk_size=2) - - for i in range(12): - cache[i] = i - - cache.done() - cache.merge() - - cache = Cache(input_dir=data_dir_2, chunk_size=2) - - for i in range(14): - cache[i] = -i +@dataclass +class ExpectedStates: + num_samples_yielded: list[dict[int, list[int]]] + num_cycles: list[dict[int, list[int]]] + latest_worker_idx: list[int] + current_epoch: list[int] + dset_1_current_epoch: list[int] + dset_2_current_epoch: list[int] + dset_1_samples_yielded: list[int] + dset_2_samples_yielded: list[int] + + +TEST_MAPS = [ + { + "length": None, + "len1": 12, + "len2": 14, + "batch_size": 2, + "num_workers": 3, + "expected_states_1": ExpectedStates( + num_samples_yielded=[ + {0: [2, 2]}, + {0: [2, 2], 1: [2, 2]}, + {0: [2, 2], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [4, 4]}, + ], + num_cycles=[ + {0: [0, 0]}, + {0: [0, 0], 1: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + ], + latest_worker_idx=[0, 1, 2, 0, 1, 2], + current_epoch=[1, 1, 1, 1, 1, 1], + dset_1_current_epoch=[1, 1, 1, 1, 1, 1], + dset_2_current_epoch=[1, 1, 1, 1, 1, 1], + dset_1_samples_yielded=[2, 4, 6, 8, 10, 12], + dset_2_samples_yielded=[2, 4, 6, 8, 10, 12], + ), + "expected_states_2": ExpectedStates( + num_samples_yielded=[ + {0: [2, 2]}, + {0: [2, 2], 1: [2, 2]}, + {0: [2, 2], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [4, 4]}, + ], + num_cycles=[ + {0: [0, 0]}, + {0: [0, 0], 1: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + ], + latest_worker_idx=[0, 1, 2, 0, 1, 2], + current_epoch=[2, 2, 2, 2, 2, 2], + dset_1_current_epoch=[2, 2, 2, 2, 2, 2], + dset_2_current_epoch=[2, 2, 2, 2, 2, 2], + dset_1_samples_yielded=[2, 4, 6, 8, 10, 12], + dset_2_samples_yielded=[2, 4, 6, 8, 10, 12], + ), + }, + { + "length": 12, + "len1": 18, + "len2": 20, + "batch_size": 2, + "num_workers": 3, + "expected_states_1": ExpectedStates( + num_samples_yielded=[ + {0: [2, 2]}, + {0: [2, 2], 1: [2, 2]}, + {0: [2, 2], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [2, 2], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [2, 2]}, + {0: [4, 4], 1: [4, 4], 2: [4, 4]}, + ], + num_cycles=[ + {0: [0, 0]}, + {0: [0, 0], 1: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + ], + latest_worker_idx=[0, 1, 2, 0, 1, 2], + current_epoch=[1, 1, 1, 1, 1, 1], + dset_1_current_epoch=[1, 1, 1, 1, 1, 1], + dset_2_current_epoch=[1, 1, 1, 1, 1, 1], + dset_1_samples_yielded=[2, 4, 6, 8, 10, 12], + dset_2_samples_yielded=[2, 4, 6, 8, 10, 12], + ), + "expected_states_2": ExpectedStates( + num_samples_yielded=[ + {0: [6, 6], 1: [4, 4], 2: [4, 4]}, + {0: [6, 6], 1: [6, 6], 2: [4, 4]}, + {0: [6, 6], 1: [6, 6], 2: [6, 6]}, + {0: [2, 8], 1: [6, 6], 2: [6, 6]}, + {0: [2, 8], 1: [2, 2], 2: [6, 6]}, + {0: [2, 8], 1: [2, 2], 2: [2, 2]}, + ], + num_cycles=[ + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [0, 0], 1: [0, 0], 2: [0, 0]}, + {0: [1, 0], 1: [0, 0], 2: [0, 0]}, + {0: [1, 0], 1: [1, 1], 2: [0, 0]}, + {0: [1, 0], 1: [1, 1], 2: [1, 1]}, + ], + latest_worker_idx=[0, 1, 2, 0, 1, 2], + current_epoch=[2, 2, 2, 2, 2, 2], + dset_1_current_epoch=[1, 1, 1, 2, 2, 2], + dset_2_current_epoch=[1, 1, 1, 1, 2, 2], + dset_1_samples_yielded=[14, 16, 18, 2, 4, 6], + dset_2_samples_yielded=[14, 16, 18, 20, 2, 4], + ), + }, +] - cache.done() - cache.merge() - dataset1 = StreamingDataset(input_dir=Dir(cache_dir_1, data_dir_1), shuffle=True) - dataset2 = StreamingDataset(input_dir=Dir(cache_dir_2, data_dir_2), shuffle=True) - dataset = ParallelStreamingDataset(datasets=[dataset1, dataset2], length=None) - dataloader = StreamingDataLoader(dataset, num_workers=3, batch_size=2) +@pytest.mark.parametrize(TEST_MAPS[0].keys(), [x.values() for x in TEST_MAPS]) +@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") +def test_parallel_dataset_with_dataloader_2_epochs( + tmp_path_factory, + length, + len1, + len2, + batch_size, + num_workers, + expected_states_1: ExpectedStates, + expected_states_2: ExpectedStates, +): + dataset1, dataset2, _, dataloader = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, + length, + len1, + len2, + batch_size=batch_size, + num_workers=num_workers, + ) assert dataset1.current_epoch == 1 assert dataset2.current_epoch == 1 - expected_dataset_state = { + expected_dset_state = { "dataset": { "0": { "num_samples_yielded": 0, - "num_workers": 3, - "batch_size": 2, + "num_workers": num_workers, + "batch_size": batch_size, "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, @@ -598,8 +733,8 @@ def test_parallel_dataset_with_dataloader_2_epochs_none_length(tmp_path): }, "1": { "num_samples_yielded": 0, - "num_workers": 3, - "batch_size": 2, + "num_workers": num_workers, + "batch_size": batch_size, "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, @@ -618,28 +753,6 @@ def test_parallel_dataset_with_dataloader_2_epochs_none_length(tmp_path): "num_samples_yielded": {}, "num_cycles": {}, } - expected_num_samples_yielded = [ - {0: [2, 2]}, - {0: [2, 2], 1: [2, 2]}, - {0: [2, 2], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [4, 4]}, - ] - expected_num_cycles = [ - {0: [0, 0]}, - {0: [0, 0], 1: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - ] - expected_current_epoch = [1, 1, 1, 1, 1, 1] - dataset_1_current_epoch = [1, 1, 1, 1, 1, 1] - dataset_2_current_epoch = [1, 1, 1, 1, 1, 1] - expected_latest_worker_idx = [0, 1, 2, 0, 1, 2] - expected_dataset0_samples_yielded = [2, 4, 6, 8, 10, 12] - expected_dataset1_samples_yielded = [2, 4, 6, 8, 10, 12] batches_1 = [] @@ -647,16 +760,16 @@ def test_parallel_dataset_with_dataloader_2_epochs_none_length(tmp_path): batches_1.append(batch) curr_state_dict = dataloader.state_dict() - expected_dataset_state["num_samples_yielded"] = expected_num_samples_yielded[idx] - expected_dataset_state["num_cycles"] = expected_num_cycles[idx] - expected_dataset_state["current_epoch"] = expected_current_epoch[idx] - expected_dataset_state["latest_worker_idx"] = expected_latest_worker_idx[idx] - expected_dataset_state["dataset"]["0"]["num_samples_yielded"] = expected_dataset0_samples_yielded[idx] - expected_dataset_state["dataset"]["1"]["num_samples_yielded"] = expected_dataset1_samples_yielded[idx] - expected_dataset_state["dataset"]["0"]["current_epoch"] = dataset_1_current_epoch[idx] - expected_dataset_state["dataset"]["1"]["current_epoch"] = dataset_2_current_epoch[idx] + expected_dset_state["num_samples_yielded"] = expected_states_1.num_samples_yielded[idx] + expected_dset_state["num_cycles"] = expected_states_1.num_cycles[idx] + expected_dset_state["latest_worker_idx"] = expected_states_1.latest_worker_idx[idx] + expected_dset_state["current_epoch"] = expected_states_1.current_epoch[idx] + expected_dset_state["dataset"]["0"]["current_epoch"] = expected_states_1.dset_1_current_epoch[idx] + expected_dset_state["dataset"]["1"]["current_epoch"] = expected_states_1.dset_2_current_epoch[idx] + expected_dset_state["dataset"]["0"]["num_samples_yielded"] = expected_states_1.dset_1_samples_yielded[idx] + expected_dset_state["dataset"]["1"]["num_samples_yielded"] = expected_states_1.dset_2_samples_yielded[idx] - assert curr_state_dict == expected_dataset_state + assert curr_state_dict == expected_dset_state assert dataset1.current_epoch == 1 assert dataset2.current_epoch == 1 @@ -664,45 +777,24 @@ def test_parallel_dataset_with_dataloader_2_epochs_none_length(tmp_path): saved_dataloader_state_dict = None batches_2 = [] + save_at = 2 - expected_num_samples_yielded = [ - {0: [2, 2]}, - {0: [2, 2], 1: [2, 2]}, - {0: [2, 2], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [4, 4]}, - ] - expected_num_cycles = [ - {0: [0, 0]}, - {0: [0, 0], 1: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - ] - expected_current_epoch = [2, 2, 2, 2, 2, 2] - dataset_1_current_epoch = [2, 2, 2, 2, 2, 2] - dataset_2_current_epoch = [2, 2, 2, 2, 2, 2] - expected_latest_worker_idx = [0, 1, 2, 0, 1, 2] - expected_dataset0_samples_yielded = [2, 4, 6, 8, 10, 12] - expected_dataset1_samples_yielded = [2, 4, 6, 8, 10, 12] for idx, batch in enumerate(dataloader): batches_2.append(batch) curr_state_dict = dataloader.state_dict() - expected_dataset_state["num_samples_yielded"] = expected_num_samples_yielded[idx] - expected_dataset_state["num_cycles"] = expected_num_cycles[idx] - expected_dataset_state["current_epoch"] = expected_current_epoch[idx] - expected_dataset_state["latest_worker_idx"] = expected_latest_worker_idx[idx] - expected_dataset_state["dataset"]["0"]["num_samples_yielded"] = expected_dataset0_samples_yielded[idx] - expected_dataset_state["dataset"]["1"]["num_samples_yielded"] = expected_dataset1_samples_yielded[idx] - expected_dataset_state["dataset"]["0"]["current_epoch"] = dataset_1_current_epoch[idx] - expected_dataset_state["dataset"]["1"]["current_epoch"] = dataset_2_current_epoch[idx] + expected_dset_state["num_samples_yielded"] = expected_states_2.num_samples_yielded[idx] + expected_dset_state["num_cycles"] = expected_states_2.num_cycles[idx] + expected_dset_state["latest_worker_idx"] = expected_states_2.latest_worker_idx[idx] + expected_dset_state["current_epoch"] = expected_states_2.current_epoch[idx] + expected_dset_state["dataset"]["0"]["current_epoch"] = expected_states_2.dset_1_current_epoch[idx] + expected_dset_state["dataset"]["1"]["current_epoch"] = expected_states_2.dset_2_current_epoch[idx] + expected_dset_state["dataset"]["0"]["num_samples_yielded"] = expected_states_2.dset_1_samples_yielded[idx] + expected_dset_state["dataset"]["1"]["num_samples_yielded"] = expected_states_2.dset_2_samples_yielded[idx] - assert curr_state_dict == expected_dataset_state + assert curr_state_dict == expected_dset_state - if idx == 2: + if idx == save_at: saved_dataloader_state_dict = deepcopy(curr_state_dict) assert dataset1.current_epoch == 2 @@ -717,200 +809,166 @@ def test_parallel_dataset_with_dataloader_2_epochs_none_length(tmp_path): assert dataloader.restore batches_23 = [] - states_23 = [] - for batch in dataloader: + for idx, batch in enumerate(dataloader): batches_23.append(batch) - states_23.append(dataloader.state_dict()) - - assert len(batches_2[3:]) == len(batches_23) - assert all(torch.equal(x1, x2) for b1, b2 in zip(batches_2[3:], batches_23) for x1, x2 in zip(b1, b2)) - assert states_23[0]["current_epoch"] == 2 - - assert not dataloader.restore - -@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") -def test_parallel_dataset_with_dataloader_2_epochs_int_length(tmp_path): - data_dir_1 = str(tmp_path / "data_1") - data_dir_2 = str(tmp_path / "data_2") - cache_dir_1 = str(tmp_path / "cache_dir_1") - cache_dir_2 = str(tmp_path / "cache_dir_2") - - os.makedirs(data_dir_1) - os.makedirs(data_dir_2) - os.makedirs(cache_dir_1) - os.makedirs(cache_dir_2) - - cache = Cache(input_dir=data_dir_1, chunk_size=2) - - for i in range(18): - cache[i] = i + curr_state_dict = dataloader.state_dict() - cache.done() - cache.merge() + idx = idx + save_at + 1 + expected_dset_state["num_samples_yielded"] = expected_states_2.num_samples_yielded[idx] + expected_dset_state["num_cycles"] = expected_states_2.num_cycles[idx] + expected_dset_state["latest_worker_idx"] = expected_states_2.latest_worker_idx[idx] + expected_dset_state["current_epoch"] = expected_states_2.current_epoch[idx] + expected_dset_state["dataset"]["0"]["current_epoch"] = expected_states_2.dset_1_current_epoch[idx] + expected_dset_state["dataset"]["1"]["current_epoch"] = expected_states_2.dset_2_current_epoch[idx] + expected_dset_state["dataset"]["0"]["num_samples_yielded"] = expected_states_2.dset_1_samples_yielded[idx] + expected_dset_state["dataset"]["1"]["num_samples_yielded"] = expected_states_2.dset_2_samples_yielded[idx] - cache = Cache(input_dir=data_dir_2, chunk_size=2) + assert curr_state_dict == expected_dset_state - for i in range(20): - cache[i] = -i + assert len(batches_2[save_at + 1 :]) == len(batches_23) + assert all(torch.equal(x1, x2) for b1, b2 in zip(batches_2[save_at + 1 :], batches_23) for x1, x2 in zip(b1, b2)) - cache.done() - cache.merge() - - dataset1 = StreamingDataset(input_dir=Dir(cache_dir_1, data_dir_1), shuffle=True) - dataset2 = StreamingDataset(input_dir=Dir(cache_dir_2, data_dir_2), shuffle=True) - dataset = ParallelStreamingDataset(datasets=[dataset1, dataset2], length=12) - dataloader = StreamingDataLoader(dataset, num_workers=3, batch_size=2) + assert not dataloader.restore - assert dataset1.current_epoch == 1 - assert dataset2.current_epoch == 1 - expected_dataset_state = { - "dataset": { - "0": { - "num_samples_yielded": 0, - "num_workers": 3, - "batch_size": 2, - "current_epoch": 1, - "input_dir_path": ANY, - "input_dir_url": ANY, - "cache_dir_path": None, - "item_loader": None, - "drop_last": False, - "seed": 42, - "world_size": 1, - "shuffle": True, - "subsampled_files": ANY, - "region_of_interest": ANY, - }, - "1": { - "num_samples_yielded": 0, - "num_workers": 3, - "batch_size": 2, - "current_epoch": 1, - "input_dir_path": ANY, - "input_dir_url": ANY, - "cache_dir_path": None, - "item_loader": None, - "drop_last": False, - "seed": 42, - "world_size": 1, - "shuffle": True, - "subsampled_files": ANY, - "region_of_interest": ANY, - }, - }, - "current_epoch": 1, - "latest_worker_idx": 0, - "num_samples_yielded": {}, - "num_cycles": {}, - } - expected_num_samples_yielded = [ - {0: [2, 2]}, - {0: [2, 2], 1: [2, 2]}, - {0: [2, 2], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [2, 2], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [2, 2]}, - {0: [4, 4], 1: [4, 4], 2: [4, 4]}, - ] - expected_num_cycles = [ - {0: [0, 0]}, - {0: [0, 0], 1: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, +@pytest.mark.parametrize("length", [None, 16, float("inf")]) +@pytest.mark.parametrize("resume", [False, True]) +@pytest.mark.parametrize("shuffle", [False, True]) +@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") +def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, resume, shuffle): + _, _, pardset, dloader = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, parlen=length, len1=10, len2=10, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume + ) + assert pardset.is_cycling() or length is None + break_at = 3 + expected_1 = [ + [torch.tensor([0]), torch.tensor([0])], + [torch.tensor([5]), torch.tensor([5])], + [torch.tensor([1]), torch.tensor([1])], + [torch.tensor([6]), torch.tensor([6])], ] - expected_current_epoch = [1, 1, 1, 1, 1, 1] - dataset_1_current_epoch = [1, 1, 1, 1, 1, 1] - dataset_2_current_epoch = [1, 1, 1, 1, 1, 1] - expected_latest_worker_idx = [0, 1, 2, 0, 1, 2] - expected_dataset0_samples_yielded = [2, 4, 6, 8, 10, 12] - expected_dataset1_samples_yielded = [2, 4, 6, 8, 10, 12] - batches_1 = [] - - for idx, batch in enumerate(dataloader): + for i, batch in enumerate(dloader): + if not shuffle: + assert all(torch.equal(x, y) for x, y in zip(batch, expected_1[i])) batches_1.append(batch) - curr_state_dict = dataloader.state_dict() - - expected_dataset_state["num_samples_yielded"] = expected_num_samples_yielded[idx] - expected_dataset_state["num_cycles"] = expected_num_cycles[idx] - expected_dataset_state["current_epoch"] = expected_current_epoch[idx] - expected_dataset_state["latest_worker_idx"] = expected_latest_worker_idx[idx] - expected_dataset_state["dataset"]["0"]["num_samples_yielded"] = expected_dataset0_samples_yielded[idx] - expected_dataset_state["dataset"]["1"]["num_samples_yielded"] = expected_dataset1_samples_yielded[idx] - expected_dataset_state["dataset"]["0"]["current_epoch"] = dataset_1_current_epoch[idx] - expected_dataset_state["dataset"]["1"]["current_epoch"] = dataset_2_current_epoch[idx] - - assert curr_state_dict == expected_dataset_state - - assert dataset1.current_epoch == 1 - assert dataset2.current_epoch == 1 - - saved_dataloader_state_dict = None + if i == break_at: + break + expected_2 = [ + [torch.tensor([2]), torch.tensor([2])], + [torch.tensor([7]), torch.tensor([7])], + [torch.tensor([3]), torch.tensor([3])], + [torch.tensor([8]), torch.tensor([8])], + ] + for i, batch in enumerate(dloader): + if not shuffle: + assert all( + torch.equal(x, y) + for x, y in zip(batch, (expected_2 if resume and length is not None else expected_1)[i]) + ) + elif not resume and length is not None: + assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i])) + if i == break_at: + break - batches_2 = [] - expected_num_samples_yielded = [ - {0: [6, 6], 1: [4, 4], 2: [4, 4]}, - {0: [6, 6], 1: [6, 6], 2: [4, 4]}, - {0: [6, 6], 1: [6, 6], 2: [6, 6]}, - {0: [2, 8], 1: [6, 6], 2: [6, 6]}, - {0: [2, 8], 1: [2, 2], 2: [6, 6]}, - {0: [2, 8], 1: [2, 2], 2: [2, 2]}, +@pytest.mark.parametrize("length", [None, 6]) +@pytest.mark.parametrize("resume", [False, True]) +@pytest.mark.parametrize("shuffle", [False, True]) +@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") +def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, resume, shuffle): + _, _, pardset, dloader = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, parlen=length, len1=4, len2=4, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume + ) + assert pardset.is_cycling() or length is None + expected_1 = [ + [torch.tensor([0]), torch.tensor([0])], + [torch.tensor([2]), torch.tensor([2])], + [torch.tensor([1]), torch.tensor([1])], + [torch.tensor([3]), torch.tensor([3])], + [torch.tensor([0]), torch.tensor([0])], + [torch.tensor([2]), torch.tensor([2])], ] - expected_num_cycles = [ - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [0, 0], 1: [0, 0], 2: [0, 0]}, - {0: [1, 0], 1: [0, 0], 2: [0, 0]}, - {0: [1, 0], 1: [1, 1], 2: [0, 0]}, - {0: [1, 0], 1: [1, 1], 2: [1, 1]}, + batches_1 = [] + for i, batch in enumerate(dloader): + if not shuffle: + assert all(torch.equal(x, y) for x, y in zip(batch, expected_1[i])) + batches_1.append(batch) + expected_2 = [ + [torch.tensor([1]), torch.tensor([1])], + [torch.tensor([3]), torch.tensor([3])], + [torch.tensor([0]), torch.tensor([0])], + [torch.tensor([2]), torch.tensor([2])], + [torch.tensor([1]), torch.tensor([1])], + [torch.tensor([3]), torch.tensor([3])], ] - expected_current_epoch = [2, 2, 2, 2, 2, 2] - dataset_1_current_epoch = [1, 1, 1, 2, 2, 2] - dataset_2_current_epoch = [1, 1, 1, 1, 2, 2] - expected_latest_worker_idx = [0, 1, 2, 0, 1, 2, 0] - expected_dataset0_samples_yielded = [14, 16, 18, 2, 4, 6] - expected_dataset1_samples_yielded = [14, 16, 18, 20, 2, 4] - for idx, batch in enumerate(dataloader): - batches_2.append(batch) - curr_state_dict = dataloader.state_dict() - - expected_dataset_state["num_samples_yielded"] = expected_num_samples_yielded[idx] - expected_dataset_state["num_cycles"] = expected_num_cycles[idx] - expected_dataset_state["current_epoch"] = expected_current_epoch[idx] - expected_dataset_state["latest_worker_idx"] = expected_latest_worker_idx[idx] - expected_dataset_state["dataset"]["0"]["num_samples_yielded"] = expected_dataset0_samples_yielded[idx] - expected_dataset_state["dataset"]["1"]["num_samples_yielded"] = expected_dataset1_samples_yielded[idx] - expected_dataset_state["dataset"]["0"]["current_epoch"] = dataset_1_current_epoch[idx] - expected_dataset_state["dataset"]["1"]["current_epoch"] = dataset_2_current_epoch[idx] - - assert curr_state_dict == expected_dataset_state - - if idx == 2: - saved_dataloader_state_dict = deepcopy(curr_state_dict) - - assert dataset1.current_epoch == 2 - assert dataset2.current_epoch == 2 - - assert len(batches_1) == len(batches_2) - assert any(not torch.equal(x1, x2) for b1, b2 in zip(batches_1, batches_2) for x1, x2 in zip(b1, b2)) - - assert saved_dataloader_state_dict is not None - dataloader.load_state_dict(saved_dataloader_state_dict) - - assert dataloader.restore - - batches_23 = [] - states_23 = [] - for batch in dataloader: - batches_23.append(batch) - states_23.append(dataloader.state_dict()) + for i, batch in enumerate(dloader): + if not shuffle: + assert all( + torch.equal(x, y) + for x, y in zip(batch, (expected_2 if resume and length is not None else expected_1)[i]) + ) + elif not resume and length is not None: + assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i])) + + +@pytest.mark.parametrize("length", [None, 18]) +@pytest.mark.parametrize("resume", [False, True]) +@pytest.mark.parametrize("shuffle", [False, True]) +@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") +def test_parallel_dataset_partial_iteration_resume_without_dataloader(tmp_path_factory, length, resume, shuffle): + _, _, pardset, _ = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, parlen=length, len1=10, len2=10, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume + ) + assert pardset.is_cycling() or length is None + break_at = 3 + expected = [ + [0, 0], + [1, 1], + [2, 2], + [3, 3], + ] + samples = [] + for i, sample in enumerate(pardset): + if not shuffle: + assert all(x == y for x, y in zip(sample, expected[i])) + samples.append(sample) + if i == break_at: + break + for i, sample in enumerate(pardset): + if not shuffle: + assert all(x == y for x, y in zip(sample, expected[i])) + elif not resume and length is not None: + assert all(x == y for x, y in zip(sample, samples[i])) + if i == break_at: + break - assert len(batches_2[3:]) == len(batches_23) - assert all(torch.equal(x1, x2) for b1, b2 in zip(batches_2[3:], batches_23) for x1, x2 in zip(b1, b2)) - assert states_23[0]["current_epoch"] == 2 - assert not dataloader.restore +@pytest.mark.parametrize("length", [None, 6]) +@pytest.mark.parametrize("resume", [False, True]) +@pytest.mark.parametrize("shuffle", [False, True]) +@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") +def test_parallel_dataset_complete_iteration_resume_without_dataloader(tmp_path_factory, length, resume, shuffle): + _, _, pardset, _ = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, parlen=length, len1=4, len2=4, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume + ) + assert pardset.is_cycling() or length is None + expected = [ + [0, 0], + [1, 1], + [2, 2], + [3, 3], + [0, 0], + [1, 1], + ] + samples = [] + for i, sample in enumerate(pardset): + if not shuffle: + assert all(x == y for x, y in zip(sample, expected[i])) + samples.append(sample) + for i, sample in enumerate(pardset): + if not shuffle: + assert all(x == y for x, y in zip(sample, expected[i])) + elif not resume and length is not None: + assert all(x == y for x, y in zip(sample, samples[i]))