Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ if __name__ == "__main__":

Transform datasets on-the-fly while streaming them, allowing for efficient data processing without the need to store intermediate results.

- You can use the `transform` argument in `StreamingDataset` to apply a transformation function to each sample as it is streamed.
- You can use the `transform` argument in `StreamingDataset` to apply a `transformation function` or `a list of transformation functions` to each sample as it is streamed.

```python
# Define a simple transform function
Expand All @@ -953,7 +953,7 @@ def transform_fn(x, *args, **kwargs):
return torch_transform(x) # Apply the transform to the input image

# Create dataset with appropriate configuration
dataset = StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=shuffle, transform=transform_fn)
dataset = StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=shuffle, transform=[transform_fn])
```

Or, you can create a subclass of `StreamingDataset` and override its `transform` method to apply custom transformations to each sample.
Expand Down
19 changes: 14 additions & 5 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
max_pre_download: int = 2,
index_path: Optional[str] = None,
force_override_state_dict: bool = False,
transform: Optional[Callable] = None,
transform: Optional[Union[Callable, list[Callable]]] = None,
) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.

Expand All @@ -89,7 +89,7 @@ def __init__(
If `index_path` is a directory, the function will look for `index.json` within it.
If `index_path` is a full file path, it will use that directly.
force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict.
transform: Optional transformation function to apply to each item in the dataset.
transform: Optional transformation function or list of functions to apply to each item in the dataset.
"""
_check_version_and_prompt_upgrade(__version__)

Expand Down Expand Up @@ -198,8 +198,10 @@ def __init__(
self.session_options = session_options
self.max_pre_download = max_pre_download
if transform is not None:
if not callable(transform):
raise ValueError(f"Transform should be a callable. Found {transform}")
transform = transform if isinstance(transform, list) else [transform]
for t in transform:
if not callable(t):
raise ValueError(f"Transform should be a callable. Found {t}")
self.transform = transform
self._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache

Expand Down Expand Up @@ -441,7 +443,14 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any:
{"name": f"getitem_dataset_for_chunk_index_{index.chunk_index}_and_index_{index.index}", "ph": "E"}
)
)
return self.transform(item) if hasattr(self, "transform") else item
if hasattr(self, "transform"):
if isinstance(self.transform, list):
for transform_fn in self.transform:
item = transform_fn(item)
else:
item = self.transform(item)

return item

def __next__(self) -> Any:
# check if we have reached the end of the dataset (i.e., all the chunks have been processed)
Expand Down
52 changes: 52 additions & 0 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import random
import shutil
import sys
from functools import partial
from time import sleep
from typing import Any, Optional
from unittest import mock
Expand Down Expand Up @@ -1695,6 +1696,57 @@ def transform_fn(x, *args, **kwargs):
assert item == i * 2, f"Expected {i * 2}, got {item}"


@pytest.mark.parametrize("shuffle", [True, False])
def test_dataset_multiple_transform(tmpdir, shuffle):
"""Test if the dataset transform is applied correctly."""
# Create a simple dataset
# Create directories for cache and data
cache_dir = os.path.join(tmpdir, "cache_dir")
data_dir = os.path.join(tmpdir, "data_dir")
os.makedirs(cache_dir)
os.makedirs(data_dir)

# Create a dataset with 100 items, 20 items per chunk
cache = Cache(str(data_dir), chunk_size=20)
for i in range(100):
cache[i] = i
cache.done()
cache.merge()

# Define two simple transform function
def transform_fn_1(x):
"""A simple transform function that doubles the input."""
return x * 2

def transform_fn_2(x, extra_num):
"""A simple transform function that adds one to the input."""
return x + extra_num

dataset = StreamingDataset(
data_dir,
cache_dir=str(cache_dir),
shuffle=shuffle,
transform=[transform_fn_1, partial(transform_fn_2, extra_num=100)],
)
dataset_length = len(dataset)
assert dataset_length == 100

# ACT
# Stream through the entire dataset and store the results
complete_data = []
for data in dataset:
assert data is not None
complete_data.append(data)

if shuffle:
complete_data.sort()

# ASSERT
# Verify that the transform is applied correctly
for i, item in enumerate(complete_data):
assert item == i * 2 + 100, f"Expected {i * 2 + 100}, got {item}"


@pytest.mark.parametrize("shuffle", [True, False])
def test_dataset_transform_inheritance(tmpdir, shuffle):
"""Test if the dataset transform is applied correctly."""
Expand Down
Loading