Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
index_path: Optional[str] = None,
force_override_state_dict: bool = False,
transform: Optional[Union[Callable, list[Callable]]] = None,
is_multisample: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how will you know how many sample_count user wants?

Suggested change
is_multisample: bool = False,
sample_count: int = 1,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is better. I'll add this.

) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.

Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(
If `index_path` is a full file path, it will use that directly.
force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict.
transform: Optional transformation function or list of functions to apply to each item in the dataset.
is_multisample: If True, each index access returns multiple samples transformed by the list of functions.
"""
_check_version_and_prompt_upgrade(__version__)

Expand Down Expand Up @@ -209,6 +211,9 @@ def __init__(
raise ValueError(f"Transform should be a callable. Found {t}")
self.transform = transform
self._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache
self.is_multisample = is_multisample
if self.is_multisample and not transform:
raise ValueError("When using `is_multisample=True`, `transform` must be a list of callables.")

@property
def on_demand_bytes(self) -> bool:
Expand Down Expand Up @@ -282,7 +287,8 @@ def _create_shuffler(self, cache: Cache) -> Shuffle:
return FullShuffle(cache, seed, drop_last) if self.shuffle else NoShuffle(cache, seed, drop_last)

def __len__(self) -> int:
return self.get_len(self.num_workers, self.batch_size if self.batch_size else 1)
original_len = self.get_len(self.num_workers, self.batch_size if self.batch_size else 1)
return original_len if not self.is_multisample else original_len * len(self.transform)

def set_batch_size(self, batch_size: int) -> None:
self.batch_size = batch_size
Expand Down Expand Up @@ -323,8 +329,13 @@ def __iter__(self) -> "StreamingDataset":
self.worker_chunks = workers_chunks[worker_rank]
self.worker_intervals = workers_intervals[worker_rank]

# multiply the interval by the multisample factor if multisampling is enabled
self.multisample_factor = len(self.transform) if self.is_multisample else 1

# The max number of samples to return from `__next__` (in worker)
self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals)
self.stop_length = (
sum(interval[2] - interval[1] for interval in self.worker_intervals) * self.multisample_factor
)

# Handle restart
if self._state_dict:
Expand Down Expand Up @@ -407,7 +418,8 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any])

# replay the indexes for the current chunks
interval = self.worker_intervals[self.worker_next_chunk_index]
current_indexes = np.arange(interval[1], interval[2])
# multiply the interval by the multisample factor if multisampling is enabled
current_indexes = np.arange(interval[1] * self.multisample_factor, interval[2] * self.multisample_factor)

# re-shuffle the indexes
current_indexes = self.shuffler(
Expand All @@ -424,6 +436,21 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any])
self.worker_next_chunk_index += 1

def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any:
# Deflate index for multisample case
if self.is_multisample:
if not self.transform:
raise ValueError("When using `is_multisample=True`, `transform` must be a list of callables.")
if not all(callable(fn) for fn in self.transform):
raise ValueError("All elements in `transform` must be callable when using `is_multisample=True`.")
if isinstance(index, int):
sample_idx = index % len(self.transform)
index = index // len(self.transform)
elif isinstance(index, ChunkedIndex):
sample_idx = index.index % len(self.transform)
index.index = index.index // len(self.transform)
else:
raise ValueError("Slices are not supported when using `is_multisample=True`.")

if self.cache is None:
self.worker_env = _WorkerEnv.detect()
self.cache = self._create_cache(worker_env=self.worker_env)
Expand All @@ -437,16 +464,21 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any:
_my_cache_indices = [ChunkedIndex(*self.cache._get_chunk_index_from_index(idx)) for idx in _my_indices]
return [self.cache[chnk_idx] for chnk_idx in _my_cache_indices]
item = self.cache[index]

if hasattr(self, "transform"):
if isinstance(self.transform, list):
for transform_fn in self.transform:
item = transform_fn(item)
if not self.is_multisample:
for transform_fn in self.transform:
item = transform_fn(item)
else:
item = self.transform[sample_idx](item) # apply the specific transform for multisample
else:
item = self.transform(item)

return item

def __next__(self) -> Any:
# print(self.worker_next_chunk_index, self.num_chunks)
# check if we have reached the end of the dataset (i.e., all the chunks have been processed)
if self.global_index >= self.stop_length:
# global_index: total number of samples processed by the current worker across all chunks
Expand Down Expand Up @@ -476,7 +508,8 @@ def __next__(self) -> Any:

# `next_worker_chunks_index` is the index of the chunk that we will be working on now
interval = self.worker_intervals[self.worker_next_chunk_index]
current_indexes = np.arange(interval[1], interval[2])

current_indexes = np.arange(interval[1] * self.multisample_factor, interval[2] * self.multisample_factor)

assert self.shuffler is not None
assert self.num_chunks is not None
Expand Down
111 changes: 111 additions & 0 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,3 +1813,114 @@ def transform(self, x, *args, **kwargs):
# Verify that the transform is applied correctly
for i, item in enumerate(complete_data):
assert item == i * 2, f"Expected {i * 2}, got {item}"


def test_dataset_multisample(tmpdir):
"""Test if the dataset transform is applied correctly."""
# Create a simple dataset
# Create directories for cache and data
cache_dir = os.path.join(tmpdir, "cache_dir")
data_dir = os.path.join(tmpdir, "data_dir")
os.makedirs(cache_dir)
os.makedirs(data_dir)

# Create a dataset with 100 items, 20 items per chunk
cache = Cache(str(data_dir), chunk_size=20)
for i in range(100):
cache[i] = i
cache.done()
cache.merge()

# Define simple transform functions
def transform_fn_sq(x, *args, **kwargs):
"""A simple transform function that doubles the input."""
return x * 2

def transform_fn_add(x):
"""A simple transform function that adds 3 to the input."""
return x + 3

def transform_fn_identity(x):
"""A simple transform function that returns the input as is."""
return x

dataset = StreamingDataset(
data_dir,
cache_dir=str(cache_dir),
shuffle=False,
transform=[transform_fn_sq, transform_fn_add, transform_fn_identity],
is_multisample=True,
)
dataset_length = len(dataset)
assert dataset_length == 300

# ASSERT
# Verify that the transform functions are applied correctly
for i, item in enumerate(dataset):
assert item is not None
if i % 3 == 0:
assert item == (i // len(dataset.transform)) * 2, (
f"Expected {(i // len(dataset.transform)) * 2}, got {item}"
)
elif i % 3 == 1:
assert item == (i // len(dataset.transform)) + 3, (
f"Expected {(i // len(dataset.transform)) + 3}, got {item}"
)
else:
assert item == (i // len(dataset.transform)), f"Expected {(i // len(dataset.transform))}, got {item}"


def test_dataset_multisample_single_transform(tmpdir):
"""Test if the dataset transform is applied correctly."""
# Create a simple dataset
# Create directories for cache and data
cache_dir = os.path.join(tmpdir, "cache_dir")
data_dir = os.path.join(tmpdir, "data_dir")
os.makedirs(cache_dir)
os.makedirs(data_dir)

# Create a dataset with 100 items, 20 items per chunk
cache = Cache(str(data_dir), chunk_size=20)
for i in range(100):
cache[i] = i
cache.done()
cache.merge()

# Define simple transform functions
def transform_fn_sq(x, *args, **kwargs):
"""A simple transform function that doubles the input."""
return x * 2

dataset = StreamingDataset(
data_dir, cache_dir=str(cache_dir), shuffle=False, transform=transform_fn_sq, is_multisample=True
)
dataset_length = len(dataset)
assert dataset_length == 100

# ASSERT
# Verify that the transform function is applied correctly
for i, item in enumerate(dataset):
assert item is not None
assert item == (i * 2), f"Expected {(i * 2)}, got {item}"


def test_dataset_multisample_nonlist_transform_error(tmpdir):
"""Test if the dataset raises an error when is_multisample is True but transform is not a list."""
# Create a simple dataset
# Create directories for cache and data
cache_dir = os.path.join(tmpdir, "cache_dir")
data_dir = os.path.join(tmpdir, "data_dir")
os.makedirs(cache_dir)
os.makedirs(data_dir)

# Create a dataset with 100 items, 20 items per chunk
cache = Cache(str(data_dir), chunk_size=20)
for i in range(100):
cache[i] = i
cache.done()
cache.merge()

# ASSERT
# Verify that ValueError is raised when transform is not given
with pytest.raises(ValueError, match="When using `is_multisample=True`, `transform` must be a list of callables."):
StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=False, is_multisample=True)
Loading