Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def __next__(self) -> Any:
index=index,
chunk_index=self.worker_chunks[self.chunk_index - 1],
# We provide the chunks indexes only one the first
chunk_indexes=None if self.has_triggered_download else self.worker_chunks,
chunk_indexes=None if self.has_triggered_download else self.worker_chunks[self.chunk_index - 1 :],
is_last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1,
)
)
Expand Down Expand Up @@ -520,9 +520,12 @@ def _replay_chunks_sampling(

for worker_idx, intervals in workers_intervals.items():
for interval in intervals:
size = interval[-1] - interval[0]
size = interval[2] - interval[1]
if indexes[worker_idx] >= size:
indexes[worker_idx] -= size
chunks_index[worker_idx] += 1
else:
# We've reached the chunk where resuming needs to take place (for this worker)
break

return chunks_index, indexes
68 changes: 64 additions & 4 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
import json
import os
import random
import shutil
import sys
from time import sleep
from unittest import mock

import numpy as np
import pytest
import torch
from litdata import train_test_split
from litdata import optimize, train_test_split
from litdata.constants import _ZSTD_AVAILABLE
from litdata.processing import functions
from litdata.streaming import Cache
Expand Down Expand Up @@ -793,6 +794,57 @@ def test_resumable_dataset_two_workers_2_epochs(tmpdir):
assert not torch.equal(batch_1, batch_2)


def _simple_preprocess(_):
for _ in range(10):
yield torch.randint(0, 100, size=(10,), dtype=torch.int64)


def _get_simulated_s3_dataloader(cache_dir, data_dir):
dataset = EmulateS3StreamingDataset(
input_dir=Dir(cache_dir, data_dir),
item_loader=TokensLoader(block_size=10),
)
return StreamingDataLoader(dataset, batch_size=2, num_workers=1)


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
@mock.patch.dict(os.environ, {}, clear=True)
def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch):
"""This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have
the same size."""
s3_cache_dir = str(tmpdir / "s3cache")
optimize_cache_dir = str(tmpdir / "optimize_cache")
data_dir = str(tmpdir / "optimized")
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", optimize_cache_dir)

optimize(
fn=_simple_preprocess,
inputs=list(range(8)),
output_dir=str(tmpdir / "optimized"),
chunk_size=190,
num_workers=4,
)
assert len(os.listdir(tmpdir / "optimized")) > 0

os.mkdir(s3_cache_dir)
train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir)
batches_to_fetch = 16
batch_to_resume_from = None
for i, batch in enumerate(train_dataloader):
if i == batches_to_fetch:
dataloader_state = train_dataloader.state_dict()
if i == batches_to_fetch + 1:
batch_to_resume_from = batch
break

shutil.rmtree(s3_cache_dir)
os.mkdir(s3_cache_dir)
train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir)
train_dataloader.load_state_dict(dataloader_state)
# The next batch after resuming must match what we should have gotten next in the initial loop
assert torch.equal(next(iter(train_dataloader)), batch_to_resume_from)


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
def test_dataset_valid_state(tmpdir, monkeypatch):
seed_everything(42)
Expand Down Expand Up @@ -915,19 +967,27 @@ def test_replay_sampling():

def test_replay_chunks_sampling():
chunks_replica = range(10)
intervals_replica = [(i, i + 5) for i in range(0, 50, 5)]
intervals_replica = [(i, i, i + 5, i + 5) for i in range(0, 50, 5)]
workers_chunks, workers_intervals = _associate_chunks_to_workers(
_WorkerEnv(2, 0), chunks_replica, intervals_replica
)
assert workers_chunks == {0: [0, 2, 4, 6, 8], 1: [1, 3, 5, 7, 9]}
assert workers_intervals == {
0: [(0, 5), (10, 15), (20, 25), (30, 35), (40, 45)],
1: [(5, 10), (15, 20), (25, 30), (35, 40), (45, 50)],
0: [(0, 0, 5, 5), (10, 10, 15, 15), (20, 20, 25, 25), (30, 30, 35, 35), (40, 40, 45, 45)],
1: [(5, 5, 10, 10), (15, 15, 20, 20), (25, 25, 30, 30), (35, 35, 40, 40), (45, 45, 50, 50)],
}
assert _replay_chunks_sampling(workers_intervals, {0: 16, 1: 11}) == ({0: 3, 1: 2}, {0: 1, 1: 1})
assert _replay_chunks_sampling(workers_intervals, {0: 14, 1: 13}) == ({0: 2, 1: 2}, {0: 4, 1: 3})
assert _replay_chunks_sampling(workers_intervals, {0: 15, 1: 12}) == ({0: 3, 1: 2}, {0: 0, 1: 2})

# Test that replay stops at the right chunk
workers_intervals = {0: [(0, 0, 10, 10), (10, 10, 20, 20), (20, 20, 21, 21), (21, 21, 30, 30)]}
indexes = {0: 15}
# Replay should stop at chunk index 1, because 15 - 10 = 5, which fits into with chunk idx 1
chunk_indexes, indexes = _replay_chunks_sampling(workers_intervals, indexes)
assert chunk_indexes == {0: 1}
assert indexes == {0: 5}


@pytest.mark.parametrize(
"compression",
Expand Down