Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,7 @@ def __iter__(self) -> Any:
# For ParallelStreamingDataset with _length != None we want to cycle the wrapped datasets i.e. we do not
# want to restart at index 0 at every epoch. So we set them in restore state.
self.load_state_dict(self.state_dict())
self.restore = False
else:
self._latest_worker_idx = 0
self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1)))
Expand Down
37 changes: 33 additions & 4 deletions tests/streaming/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,9 +871,25 @@ def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, res
assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i]))
if i == break_at:
break
expected_3 = [
[torch.tensor([4]), torch.tensor([4])],
[torch.tensor([9]), torch.tensor([9])],
[torch.tensor([0]), torch.tensor([0])],
[torch.tensor([5]), torch.tensor([5])],
]
for i, batch in enumerate(dloader):
if not shuffle:
assert all(
torch.equal(x, y)
for x, y in zip(batch, (expected_3 if resume and length is not None else expected_1)[i])
)
elif not resume and length is not None:
assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i]))
if i == break_at:
break


@pytest.mark.parametrize("length", [None, 6])
@pytest.mark.parametrize("length", [None, 5])
@pytest.mark.parametrize("resume", [False, True])
@pytest.mark.parametrize("shuffle", [False, True])
@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI")
Expand All @@ -888,26 +904,39 @@ def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, re
[torch.tensor([1]), torch.tensor([1])],
[torch.tensor([3]), torch.tensor([3])],
[torch.tensor([0]), torch.tensor([0])],
[torch.tensor([2]), torch.tensor([2])],
]
batches_1 = []
for i, batch in enumerate(dloader):
if not shuffle:
assert all(torch.equal(x, y) for x, y in zip(batch, expected_1[i]))
batches_1.append(batch)
expected_2 = [
[torch.tensor([1]), torch.tensor([1])],
[torch.tensor([2]), torch.tensor([2])],
[torch.tensor([0]), torch.tensor([0])],
[torch.tensor([3]), torch.tensor([3])],
[torch.tensor([1]), torch.tensor([1])],
]
for i, batch in enumerate(dloader):
if not shuffle:
assert all(
torch.equal(x, y)
for x, y in zip(batch, (expected_2 if resume and length is not None else expected_1)[i])
)
elif not resume and length is not None:
assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i]))
expected_3 = [
[torch.tensor([1]), torch.tensor([1])],
[torch.tensor([3]), torch.tensor([3])],
[torch.tensor([0]), torch.tensor([0])],
[torch.tensor([2]), torch.tensor([2])],
[torch.tensor([1]), torch.tensor([1])],
[torch.tensor([3]), torch.tensor([3])],
]
for i, batch in enumerate(dloader):
if not shuffle:
assert all(
torch.equal(x, y)
for x, y in zip(batch, (expected_2 if resume and length is not None else expected_1)[i])
for x, y in zip(batch, (expected_3 if resume and length is not None else expected_1)[i])
)
elif not resume and length is not None:
assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i]))
Expand Down
Loading