From 7371a5a14bc87543a254511b945e342e51a620ee Mon Sep 17 00:00:00 2001 From: Deependu Date: Mon, 14 Jul 2025 08:38:38 +0000 Subject: [PATCH 01/10] feat: add support to pass transform fn kwargs --- src/litdata/streaming/dataset.py | 40 +++++++++++++--- src/litdata/utilities/dataset_utilities.py | 9 +++- tests/streaming/test_dataset.py | 53 ++++++++++++++++++++++ tests/utilities/test_dataset_utilities.py | 18 ++++++++ 4 files changed, 112 insertions(+), 8 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 2522437be..7dcccc522 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -10,7 +10,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import os from time import time @@ -29,7 +28,12 @@ from litdata.streaming.sampler import ChunkedIndex from litdata.streaming.serializers import Serializer from litdata.streaming.shuffle import FullShuffle, NoShuffle, Shuffle -from litdata.utilities.dataset_utilities import _should_replace_path, _try_create_cache_dir, subsample_streaming_dataset +from litdata.utilities.dataset_utilities import ( + _should_replace_path, + _try_create_cache_dir, + fn_accepts_kwargs, + subsample_streaming_dataset, +) from litdata.utilities.encryption import Encryption from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv from litdata.utilities.format import _convert_bytes_to_int @@ -62,7 +66,10 @@ def __init__( max_pre_download: int = 2, index_path: Optional[str] = None, force_override_state_dict: bool = False, - transform: Optional[Callable] = None, + transform: Optional[Union[Callable, List[Callable]]] = None, + transform_kwargs: Optional[Dict[str, Any]] = None, + *args: Any, + **kwargs: Any, ) -> None: """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. @@ -89,7 +96,10 @@ def __init__( If `index_path` is a directory, the function will look for `index.json` within it. If `index_path` is a full file path, it will use that directly. force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict. - transform: Optional transformation function to apply to each item in the dataset. + transform: Optional transformation function or list of functions to apply to each item in the dataset. + transform_kwargs: Keyword arguments for the transformation function. + args: Additional positional arguments. + kwargs: Additional keyword arguments. """ _check_version_and_prompt_upgrade(__version__) @@ -198,9 +208,12 @@ def __init__( self.session_options = session_options self.max_pre_download = max_pre_download if transform is not None: - if not callable(transform): - raise ValueError(f"Transform should be a callable. Found {transform}") + transform = transform if isinstance(transform, list) else [transform] + for t in transform: + if not callable(t): + raise ValueError(f"Transform should be a callable. Found {t}") self.transform = transform + self.transform_kwargs = transform_kwargs or {} self._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache @property @@ -441,7 +454,20 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: {"name": f"getitem_dataset_for_chunk_index_{index.chunk_index}_and_index_{index.index}", "ph": "E"} ) ) - return self.transform(item) if hasattr(self, "transform") else item + if hasattr(self, "transform"): + local_transform_kwargs = self.transform_kwargs.copy() + local_transform_kwargs["index"] = index.index + + # check if transform function accepts kwargs + if isinstance(self.transform, list): + for transform_fn in self.transform: + accepts_kwargs = fn_accepts_kwargs(transform_fn) + item = transform_fn(item, **local_transform_kwargs) if accepts_kwargs else transform_fn(item) + else: + accepts_kwargs = fn_accepts_kwargs(self.transform) + item = self.transform(item, **local_transform_kwargs) if accepts_kwargs else self.transform(item) + + return item def __next__(self) -> Any: # check if we have reached the end of the dataset (i.e., all the chunks have been processed) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 5bbb5b250..dc29bdb35 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -1,11 +1,12 @@ import hashlib +import inspect import json import math import os import shutil import tempfile import time -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np @@ -341,3 +342,9 @@ def copy_index_to_cache_index_filepath(index_path: str, cache_index_filepath: st raise FileNotFoundError(f"Index file not found: {index_path}") # Copy the file to cache_index_filepath shutil.copy(index_path, cache_index_filepath) + + +def fn_accepts_kwargs(_fn: Callable) -> bool: + """Check if a function accepts keyword arguments.""" + signature = inspect.signature(_fn) + return any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 13bdd90c3..166cfe341 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -1695,6 +1695,59 @@ def transform_fn(x, *args, **kwargs): assert item == i * 2, f"Expected {i * 2}, got {item}" +@pytest.mark.parametrize("shuffle", [True, False]) +def test_dataset_multiple_transform(tmpdir, shuffle): + """Test if the dataset transform is applied correctly.""" + # Create a simple dataset + # Create directories for cache and data + cache_dir = os.path.join(tmpdir, "cache_dir") + data_dir = os.path.join(tmpdir, "data_dir") + os.makedirs(cache_dir) + os.makedirs(data_dir) + + # Create a dataset with 100 items, 20 items per chunk + cache = Cache(str(data_dir), chunk_size=20) + for i in range(100): + cache[i] = i + cache.done() + cache.merge() + + # Define two simple transform function + def transform_fn_1(x): + """A simple transform function that doubles the input.""" + return x * 2 + + def transform_fn_2(x, **kwargs): + """A simple transform function that adds one to the input.""" + extra_num = kwargs.get("extra_num", 0) + return x + extra_num + + dataset = StreamingDataset( + data_dir, + cache_dir=str(cache_dir), + shuffle=shuffle, + transform=[transform_fn_1, transform_fn_2], + transform_kwargs={"extra_num": 100}, + ) + dataset_length = len(dataset) + assert dataset_length == 100 + + # ACT + # Stream through the entire dataset and store the results + complete_data = [] + for data in dataset: + assert data is not None + complete_data.append(data) + + if shuffle: + complete_data.sort() + + # ASSERT + # Verify that the transform is applied correctly + for i, item in enumerate(complete_data): + assert item == i * 2 + 100, f"Expected {i * 2 + 100}, got {item}" + + @pytest.mark.parametrize("shuffle", [True, False]) def test_dataset_transform_inheritance(tmpdir, shuffle): """Test if the dataset transform is applied correctly.""" diff --git a/tests/utilities/test_dataset_utilities.py b/tests/utilities/test_dataset_utilities.py index 60dc4b39b..55cc9b667 100644 --- a/tests/utilities/test_dataset_utilities.py +++ b/tests/utilities/test_dataset_utilities.py @@ -11,6 +11,7 @@ _should_replace_path, _try_create_cache_dir, adapt_mds_shards_to_chunks, + fn_accepts_kwargs, generate_roi, get_default_cache_dir, load_index_file, @@ -109,3 +110,20 @@ def test_get_default_cache_dir(): importlib.reload(litdata.constants) importlib.reload(litdata.utilities.dataset_utilities) assert litdata.utilities.dataset_utilities.get_default_cache_dir() == "/custom/cache/dir" + + +def test_fn_accepts_kwargs(): + """Check if a function accepts keyword arguments.""" + + def func_with_kwargs(a, b=1, *args, **kwargs): + return a + b + + def func_without_kwargs(a, b=1): + return a + b + + assert fn_accepts_kwargs(func_with_kwargs) + assert not fn_accepts_kwargs(func_without_kwargs) + + # Test with a lambda function + lambda_func = lambda x, y=2: x + y + assert not fn_accepts_kwargs(lambda_func) From 117f0392366ca6973aca5c92874e20ba135a6cf1 Mon Sep 17 00:00:00 2001 From: Deependu Date: Mon, 14 Jul 2025 09:04:33 +0000 Subject: [PATCH 02/10] update readme --- README.md | 70 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2be60537b..b73ef40f7 100644 --- a/README.md +++ b/README.md @@ -935,7 +935,7 @@ if __name__ == "__main__": Transform datasets on-the-fly while streaming them, allowing for efficient data processing without the need to store intermediate results. -- You can use the `transform` argument in `StreamingDataset` to apply a transformation function to each sample as it is streamed. +- You can use the `transform` argument in `StreamingDataset` to apply a `transformation function` or a list of `transformation function` to each sample as it is streamed. ```python # Define a simple transform function @@ -983,6 +983,74 @@ class StreamingDatasetWithTransform(StreamingDataset): dataset = StreamingDatasetWithTransform(data_dir, cache_dir=str(cache_dir), shuffle=shuffle) ``` +#### Passing keyword arguments to your transform function(s) + +You can now pass custom keyword arguments to your transform function(s) using the `transform_kwargs` argument in `StreamingDataset`. + +This allows for more flexible and dynamic preprocessing, especially useful `when transforms depend on external configuration, state, or shared dataset attributes`. + +```python +def transform_fn_1(x, **kwargs): + """Apply a custom transform using additional arguments.""" + some_val = kwargs["key1"] + index = kwargs["index"] # Automatically included + return transform_logic_1(index, some_val, x) + +def transform_fn_2(x, **kwargs): + """Apply a second transform using shared config.""" + some_val = kwargs["key2"] + return transform_logic_2(some_val, x) + +dataset = StreamingDataset( + data_dir, + cache_dir=str(cache_dir), + shuffle=shuffle, + transform=[transform_fn_1, transform_fn_2], + transform_kwargs={"key1": "value1", "key2": "value2"} +) +``` + +> â„šī¸ **Note:** `transform_kwargs` will always include the `index` key automatically. + +##### 💡 Why this is useful + +Traditionally, transform functions are self-contained. But in real-world pipelines, transforms often depend on shared context like: + +* Dataset-specific configurations (`img_size`, `augmentation`, `task_type`) +* Class methods (e.g., parsing logic, tokenizers, decoders) +* External tools (like label mappers or precomputed metadata) + +Instead of wrapping these in closures or using global variables, you can now **pass them cleanly via `transform_kwargs`**. + +##### đŸ“Ļ Real-world use case: Integrating Ultralytics datasets + +Ultralytics datasets have their own logic for parsing labels and transforming images. You can wrap that logic inside a `StreamingDataset` by: + +1. Subclassing `StreamingDataset` +2. Extracting needed attributes from a preconfigured Ultralytics dataset +3. Passing them to transform functions using `transform_kwargs` + +#### Example + +```python +class StreamingUltralyticsDataset(StreamingDataset): + def __init__(self, yolo_dataset, *args, **kwargs): + self.yolo_dataset = yolo_dataset # Ultralytics dataset instance + + super().__init__( + *args, + transform=[ultralytics.parse_labels, ultralytics.postprocess], + transform_kwargs={ + "names": yolo_dataset.names, + "is_coco": yolo_dataset.is_coco_format, + "imgsz": yolo_dataset.imgsz, + }, + **kwargs + ) +``` + +> 🔁 Each transform receives all `transform_kwargs` (including `index`), making it easy to pass dynamic state without breaking encapsulation. +
From 663e71531920c4a889d3c29a78e03b76a9d18830 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Jul 2025 09:10:05 +0000 Subject: [PATCH 03/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/utilities/dataset_utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 59075ff51..5de6d1cf6 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -6,7 +6,7 @@ import shutil import tempfile import time -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Optional import numpy as np From 89b9f3952a7ff4bd38e3b6991cbff4fbf6f44c32 Mon Sep 17 00:00:00 2001 From: Deependu Date: Mon, 14 Jul 2025 09:12:12 +0000 Subject: [PATCH 04/10] fix pre-commit --- src/litdata/streaming/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index e87b12a82..f38bc5021 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -66,8 +66,8 @@ def __init__( max_pre_download: int = 2, index_path: Optional[str] = None, force_override_state_dict: bool = False, - transform: Optional[Union[Callable, List[Callable]]] = None, - transform_kwargs: Optional[Dict[str, Any]] = None, + transform: Optional[Union[Callable, list[Callable]]] = None, + transform_kwargs: Optional[dict[str, Any]] = None, *args: Any, **kwargs: Any, ) -> None: From 13aa9fd1b6cb0c25b5d3cb87e5d6064089f58580 Mon Sep 17 00:00:00 2001 From: Deependu Date: Mon, 14 Jul 2025 09:13:20 +0000 Subject: [PATCH 05/10] update --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b73ef40f7..fb169f46a 100644 --- a/README.md +++ b/README.md @@ -935,7 +935,7 @@ if __name__ == "__main__": Transform datasets on-the-fly while streaming them, allowing for efficient data processing without the need to store intermediate results. -- You can use the `transform` argument in `StreamingDataset` to apply a `transformation function` or a list of `transformation function` to each sample as it is streamed. +- You can use the `transform` argument in `StreamingDataset` to apply a `transformation function` or `a list of transformation function` to each sample as it is streamed. ```python # Define a simple transform function From 7617713d1f2bf6932124573861c6e0e7653d1a7b Mon Sep 17 00:00:00 2001 From: Deependu Date: Mon, 14 Jul 2025 09:15:04 +0000 Subject: [PATCH 06/10] update --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index fb169f46a..274511582 100644 --- a/README.md +++ b/README.md @@ -983,9 +983,9 @@ class StreamingDatasetWithTransform(StreamingDataset): dataset = StreamingDatasetWithTransform(data_dir, cache_dir=str(cache_dir), shuffle=shuffle) ``` -#### Passing keyword arguments to your transform function(s) +#### 🚀 Passing keyword arguments to your transform function(s) -You can now pass custom keyword arguments to your transform function(s) using the `transform_kwargs` argument in `StreamingDataset`. +You can pass custom keyword arguments to your transform function(s) using the `transform_kwargs` argument in `StreamingDataset`. This allows for more flexible and dynamic preprocessing, especially useful `when transforms depend on external configuration, state, or shared dataset attributes`. From eaba7a59808f995d2a40f1cda0e10b44ddc726d0 Mon Sep 17 00:00:00 2001 From: Deependu Date: Mon, 14 Jul 2025 09:20:13 +0000 Subject: [PATCH 07/10] update --- README.md | 2 +- src/litdata/streaming/dataset.py | 6 +++--- src/litdata/utilities/dataset_utilities.py | 2 +- tests/utilities/test_dataset_utilities.py | 10 +++++----- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 274511582..067a86e3a 100644 --- a/README.md +++ b/README.md @@ -935,7 +935,7 @@ if __name__ == "__main__": Transform datasets on-the-fly while streaming them, allowing for efficient data processing without the need to store intermediate results. -- You can use the `transform` argument in `StreamingDataset` to apply a `transformation function` or `a list of transformation function` to each sample as it is streamed. +- You can use the `transform` argument in `StreamingDataset` to apply a `transformation function` or `a list of transformation functions` to each sample as it is streamed. ```python # Define a simple transform function diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index f38bc5021..d594a5d26 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -31,7 +31,7 @@ from litdata.utilities.dataset_utilities import ( _should_replace_path, _try_create_cache_dir, - fn_accepts_kwargs, + function_accepts_kwargs, subsample_streaming_dataset, ) from litdata.utilities.encryption import Encryption @@ -461,10 +461,10 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: # check if transform function accepts kwargs if isinstance(self.transform, list): for transform_fn in self.transform: - accepts_kwargs = fn_accepts_kwargs(transform_fn) + accepts_kwargs = function_accepts_kwargs(transform_fn) item = transform_fn(item, **local_transform_kwargs) if accepts_kwargs else transform_fn(item) else: - accepts_kwargs = fn_accepts_kwargs(self.transform) + accepts_kwargs = function_accepts_kwargs(self.transform) item = self.transform(item, **local_transform_kwargs) if accepts_kwargs else self.transform(item) return item diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 5de6d1cf6..9c9906f64 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -344,7 +344,7 @@ def copy_index_to_cache_index_filepath(index_path: str, cache_index_filepath: st shutil.copy(index_path, cache_index_filepath) -def fn_accepts_kwargs(_fn: Callable) -> bool: +def function_accepts_kwargs(_fn: Callable) -> bool: """Check if a function accepts keyword arguments.""" signature = inspect.signature(_fn) return any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()) diff --git a/tests/utilities/test_dataset_utilities.py b/tests/utilities/test_dataset_utilities.py index 55cc9b667..495496ba9 100644 --- a/tests/utilities/test_dataset_utilities.py +++ b/tests/utilities/test_dataset_utilities.py @@ -11,7 +11,7 @@ _should_replace_path, _try_create_cache_dir, adapt_mds_shards_to_chunks, - fn_accepts_kwargs, + function_accepts_kwargs, generate_roi, get_default_cache_dir, load_index_file, @@ -112,7 +112,7 @@ def test_get_default_cache_dir(): assert litdata.utilities.dataset_utilities.get_default_cache_dir() == "/custom/cache/dir" -def test_fn_accepts_kwargs(): +def test_function_accepts_kwargs(): """Check if a function accepts keyword arguments.""" def func_with_kwargs(a, b=1, *args, **kwargs): @@ -121,9 +121,9 @@ def func_with_kwargs(a, b=1, *args, **kwargs): def func_without_kwargs(a, b=1): return a + b - assert fn_accepts_kwargs(func_with_kwargs) - assert not fn_accepts_kwargs(func_without_kwargs) + assert function_accepts_kwargs(func_with_kwargs) + assert not function_accepts_kwargs(func_without_kwargs) # Test with a lambda function lambda_func = lambda x, y=2: x + y - assert not fn_accepts_kwargs(lambda_func) + assert not function_accepts_kwargs(lambda_func) From 94ba197eb92c79172eb06bcfd9951bd9b130c5d8 Mon Sep 17 00:00:00 2001 From: Deependu Date: Mon, 14 Jul 2025 12:13:40 +0000 Subject: [PATCH 08/10] remove transform-kwargs --- src/litdata/streaming/dataset.py | 17 ++--------------- src/litdata/utilities/dataset_utilities.py | 9 +-------- tests/streaming/test_dataset.py | 7 +++---- tests/utilities/test_dataset_utilities.py | 18 ------------------ 4 files changed, 6 insertions(+), 45 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index d594a5d26..db30f8ac3 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -31,7 +31,6 @@ from litdata.utilities.dataset_utilities import ( _should_replace_path, _try_create_cache_dir, - function_accepts_kwargs, subsample_streaming_dataset, ) from litdata.utilities.encryption import Encryption @@ -67,9 +66,6 @@ def __init__( index_path: Optional[str] = None, force_override_state_dict: bool = False, transform: Optional[Union[Callable, list[Callable]]] = None, - transform_kwargs: Optional[dict[str, Any]] = None, - *args: Any, - **kwargs: Any, ) -> None: """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. @@ -97,9 +93,6 @@ def __init__( If `index_path` is a full file path, it will use that directly. force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict. transform: Optional transformation function or list of functions to apply to each item in the dataset. - transform_kwargs: Keyword arguments for the transformation function. - args: Additional positional arguments. - kwargs: Additional keyword arguments. """ _check_version_and_prompt_upgrade(__version__) @@ -213,7 +206,6 @@ def __init__( if not callable(t): raise ValueError(f"Transform should be a callable. Found {t}") self.transform = transform - self.transform_kwargs = transform_kwargs or {} self._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache @property @@ -455,17 +447,12 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: ) ) if hasattr(self, "transform"): - local_transform_kwargs = self.transform_kwargs.copy() - local_transform_kwargs["index"] = index.index - # check if transform function accepts kwargs if isinstance(self.transform, list): for transform_fn in self.transform: - accepts_kwargs = function_accepts_kwargs(transform_fn) - item = transform_fn(item, **local_transform_kwargs) if accepts_kwargs else transform_fn(item) + item = transform_fn(item) else: - accepts_kwargs = function_accepts_kwargs(self.transform) - item = self.transform(item, **local_transform_kwargs) if accepts_kwargs else self.transform(item) + item = self.transform(item) return item diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 9c9906f64..6e5bc651f 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -1,12 +1,11 @@ import hashlib -import inspect import json import math import os import shutil import tempfile import time -from typing import Any, Callable, Optional +from typing import Any, Optional import numpy as np @@ -342,9 +341,3 @@ def copy_index_to_cache_index_filepath(index_path: str, cache_index_filepath: st raise FileNotFoundError(f"Index file not found: {index_path}") # Copy the file to cache_index_filepath shutil.copy(index_path, cache_index_filepath) - - -def function_accepts_kwargs(_fn: Callable) -> bool: - """Check if a function accepts keyword arguments.""" - signature = inspect.signature(_fn) - return any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index c0cbd449c..027b73f9a 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -17,6 +17,7 @@ import random import shutil import sys +from functools import partial from time import sleep from typing import Any, Optional from unittest import mock @@ -1717,17 +1718,15 @@ def transform_fn_1(x): """A simple transform function that doubles the input.""" return x * 2 - def transform_fn_2(x, **kwargs): + def transform_fn_2(x, extra_num): """A simple transform function that adds one to the input.""" - extra_num = kwargs.get("extra_num", 0) return x + extra_num dataset = StreamingDataset( data_dir, cache_dir=str(cache_dir), shuffle=shuffle, - transform=[transform_fn_1, transform_fn_2], - transform_kwargs={"extra_num": 100}, + transform=[transform_fn_1, partial(transform_fn_2, extra_num=100)], ) dataset_length = len(dataset) assert dataset_length == 100 diff --git a/tests/utilities/test_dataset_utilities.py b/tests/utilities/test_dataset_utilities.py index 495496ba9..60dc4b39b 100644 --- a/tests/utilities/test_dataset_utilities.py +++ b/tests/utilities/test_dataset_utilities.py @@ -11,7 +11,6 @@ _should_replace_path, _try_create_cache_dir, adapt_mds_shards_to_chunks, - function_accepts_kwargs, generate_roi, get_default_cache_dir, load_index_file, @@ -110,20 +109,3 @@ def test_get_default_cache_dir(): importlib.reload(litdata.constants) importlib.reload(litdata.utilities.dataset_utilities) assert litdata.utilities.dataset_utilities.get_default_cache_dir() == "/custom/cache/dir" - - -def test_function_accepts_kwargs(): - """Check if a function accepts keyword arguments.""" - - def func_with_kwargs(a, b=1, *args, **kwargs): - return a + b - - def func_without_kwargs(a, b=1): - return a + b - - assert function_accepts_kwargs(func_with_kwargs) - assert not function_accepts_kwargs(func_without_kwargs) - - # Test with a lambda function - lambda_func = lambda x, y=2: x + y - assert not function_accepts_kwargs(lambda_func) From ee17a19d01f67c94e536d6d2372cd2c42cd92ca2 Mon Sep 17 00:00:00 2001 From: Deependu Date: Mon, 14 Jul 2025 12:15:30 +0000 Subject: [PATCH 09/10] update --- README.md | 67 -------------------------------- src/litdata/streaming/dataset.py | 2 +- 2 files changed, 1 insertion(+), 68 deletions(-) diff --git a/README.md b/README.md index 067a86e3a..2b7131fa7 100644 --- a/README.md +++ b/README.md @@ -981,76 +981,9 @@ class StreamingDatasetWithTransform(StreamingDataset): dataset = StreamingDatasetWithTransform(data_dir, cache_dir=str(cache_dir), shuffle=shuffle) -``` - -#### 🚀 Passing keyword arguments to your transform function(s) - -You can pass custom keyword arguments to your transform function(s) using the `transform_kwargs` argument in `StreamingDataset`. - -This allows for more flexible and dynamic preprocessing, especially useful `when transforms depend on external configuration, state, or shared dataset attributes`. - -```python -def transform_fn_1(x, **kwargs): - """Apply a custom transform using additional arguments.""" - some_val = kwargs["key1"] - index = kwargs["index"] # Automatically included - return transform_logic_1(index, some_val, x) - -def transform_fn_2(x, **kwargs): - """Apply a second transform using shared config.""" - some_val = kwargs["key2"] - return transform_logic_2(some_val, x) - -dataset = StreamingDataset( - data_dir, - cache_dir=str(cache_dir), - shuffle=shuffle, - transform=[transform_fn_1, transform_fn_2], - transform_kwargs={"key1": "value1", "key2": "value2"} -) -``` -> â„šī¸ **Note:** `transform_kwargs` will always include the `index` key automatically. -##### 💡 Why this is useful - -Traditionally, transform functions are self-contained. But in real-world pipelines, transforms often depend on shared context like: - -* Dataset-specific configurations (`img_size`, `augmentation`, `task_type`) -* Class methods (e.g., parsing logic, tokenizers, decoders) -* External tools (like label mappers or precomputed metadata) - -Instead of wrapping these in closures or using global variables, you can now **pass them cleanly via `transform_kwargs`**. - -##### đŸ“Ļ Real-world use case: Integrating Ultralytics datasets - -Ultralytics datasets have their own logic for parsing labels and transforming images. You can wrap that logic inside a `StreamingDataset` by: - -1. Subclassing `StreamingDataset` -2. Extracting needed attributes from a preconfigured Ultralytics dataset -3. Passing them to transform functions using `transform_kwargs` - -#### Example - -```python -class StreamingUltralyticsDataset(StreamingDataset): - def __init__(self, yolo_dataset, *args, **kwargs): - self.yolo_dataset = yolo_dataset # Ultralytics dataset instance - - super().__init__( - *args, - transform=[ultralytics.parse_labels, ultralytics.postprocess], - transform_kwargs={ - "names": yolo_dataset.names, - "is_coco": yolo_dataset.is_coco_format, - "imgsz": yolo_dataset.imgsz, - }, - **kwargs - ) ``` - -> 🔁 Each transform receives all `transform_kwargs` (including `index`), making it easy to pass dynamic state without breaking encapsulation. -
diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index db30f8ac3..3987ce8ba 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -10,6 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging import os from time import time @@ -447,7 +448,6 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: ) ) if hasattr(self, "transform"): - # check if transform function accepts kwargs if isinstance(self.transform, list): for transform_fn in self.transform: item = transform_fn(item) From c79dbd69434ce7b116bf5bdbdc4c6296a0480847 Mon Sep 17 00:00:00 2001 From: Deependu Date: Mon, 14 Jul 2025 12:17:32 +0000 Subject: [PATCH 10/10] updatd --- README.md | 5 ++--- src/litdata/streaming/dataset.py | 6 +----- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 2b7131fa7..1dbe637c5 100644 --- a/README.md +++ b/README.md @@ -953,7 +953,7 @@ def transform_fn(x, *args, **kwargs): return torch_transform(x) # Apply the transform to the input image # Create dataset with appropriate configuration -dataset = StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=shuffle, transform=transform_fn) +dataset = StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=shuffle, transform=[transform_fn]) ``` Or, you can create a subclass of `StreamingDataset` and override its `transform` method to apply custom transformations to each sample. @@ -981,9 +981,8 @@ class StreamingDatasetWithTransform(StreamingDataset): dataset = StreamingDatasetWithTransform(data_dir, cache_dir=str(cache_dir), shuffle=shuffle) - - ``` +
diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 3987ce8ba..7b55172d1 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -29,11 +29,7 @@ from litdata.streaming.sampler import ChunkedIndex from litdata.streaming.serializers import Serializer from litdata.streaming.shuffle import FullShuffle, NoShuffle, Shuffle -from litdata.utilities.dataset_utilities import ( - _should_replace_path, - _try_create_cache_dir, - subsample_streaming_dataset, -) +from litdata.utilities.dataset_utilities import _should_replace_path, _try_create_cache_dir, subsample_streaming_dataset from litdata.utilities.encryption import Encryption from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv from litdata.utilities.format import _convert_bytes_to_int