diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index fa47b346..44cda068 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -649,6 +649,7 @@ def __iter__(self) -> Any: # For ParallelStreamingDataset with _length != None we want to cycle the wrapped datasets i.e. we do not # want to restart at index 0 at every epoch. So we set them in restore state. self.load_state_dict(self.state_dict()) + self.restore = False else: self._latest_worker_idx = 0 self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1))) diff --git a/tests/streaming/test_parallel.py b/tests/streaming/test_parallel.py index f77066a1..380e4792 100644 --- a/tests/streaming/test_parallel.py +++ b/tests/streaming/test_parallel.py @@ -871,9 +871,25 @@ def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, res assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i])) if i == break_at: break + expected_3 = [ + [torch.tensor([4]), torch.tensor([4])], + [torch.tensor([9]), torch.tensor([9])], + [torch.tensor([0]), torch.tensor([0])], + [torch.tensor([5]), torch.tensor([5])], + ] + for i, batch in enumerate(dloader): + if not shuffle: + assert all( + torch.equal(x, y) + for x, y in zip(batch, (expected_3 if resume and length is not None else expected_1)[i]) + ) + elif not resume and length is not None: + assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i])) + if i == break_at: + break -@pytest.mark.parametrize("length", [None, 6]) +@pytest.mark.parametrize("length", [None, 5]) @pytest.mark.parametrize("resume", [False, True]) @pytest.mark.parametrize("shuffle", [False, True]) @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") @@ -888,7 +904,6 @@ def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, re [torch.tensor([1]), torch.tensor([1])], [torch.tensor([3]), torch.tensor([3])], [torch.tensor([0]), torch.tensor([0])], - [torch.tensor([2]), torch.tensor([2])], ] batches_1 = [] for i, batch in enumerate(dloader): @@ -896,18 +911,32 @@ def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, re assert all(torch.equal(x, y) for x, y in zip(batch, expected_1[i])) batches_1.append(batch) expected_2 = [ + [torch.tensor([1]), torch.tensor([1])], + [torch.tensor([2]), torch.tensor([2])], + [torch.tensor([0]), torch.tensor([0])], + [torch.tensor([3]), torch.tensor([3])], + [torch.tensor([1]), torch.tensor([1])], + ] + for i, batch in enumerate(dloader): + if not shuffle: + assert all( + torch.equal(x, y) + for x, y in zip(batch, (expected_2 if resume and length is not None else expected_1)[i]) + ) + elif not resume and length is not None: + assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i])) + expected_3 = [ [torch.tensor([1]), torch.tensor([1])], [torch.tensor([3]), torch.tensor([3])], [torch.tensor([0]), torch.tensor([0])], [torch.tensor([2]), torch.tensor([2])], [torch.tensor([1]), torch.tensor([1])], - [torch.tensor([3]), torch.tensor([3])], ] for i, batch in enumerate(dloader): if not shuffle: assert all( torch.equal(x, y) - for x, y in zip(batch, (expected_2 if resume and length is not None else expected_1)[i]) + for x, y in zip(batch, (expected_3 if resume and length is not None else expected_1)[i]) ) elif not resume and length is not None: assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i]))