From 18327c122975f28b39e111e0207e95042c7e33b2 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Sat, 5 Jul 2025 10:10:42 +0530 Subject: [PATCH 01/41] add verbose option in optimize_fn --- src/litdata/processing/data_processor.py | 66 +++++++++++++++--------- src/litdata/processing/functions.py | 5 +- 2 files changed, 47 insertions(+), 24 deletions(-) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index b01989b00..61227ca1b 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -1089,6 +1089,7 @@ def __init__( start_method: Optional[str] = None, storage_options: Dict[str, Any] = {}, keep_data_ordered: bool = True, + verbose: bool = True, ): """Provides an efficient way to process data across multiple machine into chunks to make training faster. @@ -1115,6 +1116,7 @@ def __init__( inside an interactive shell like Ipython. storage_options: Storage options for the cloud provider. keep_data_ordered: Whether to use a shared queue for the workers or not. + verbose: Whether to print the progress of the workers. Defaults to True. """ # spawn doesn't work in IPython start_method = start_method or ("fork" if in_notebook() else "spawn") @@ -1124,7 +1126,8 @@ def __init__( msg += "Tip: Libraries relying on lock can hang with `fork`. To use `spawn` in notebooks, " msg += "move your code to files and import it within the notebook." - print(msg) + if verbose: + print(msg) multiprocessing.set_start_method(start_method, force=True) @@ -1166,9 +1169,13 @@ def __init__( if self.output_dir: # Ensure the output dir is the same across all nodes self.output_dir = broadcast_object("output_dir", self.output_dir, rank=_get_node_rank()) - print(f"Storing the files under {self.output_dir.path if self.output_dir.path else self.output_dir.url}") + if verbose: + print( + f"Storing the files under {self.output_dir.path if self.output_dir.path else self.output_dir.url}" + ) self.random_seed = random_seed + self.verbose = verbose def run(self, data_recipe: DataRecipe) -> None: """Triggers the data recipe processing over your dataset.""" @@ -1179,7 +1186,8 @@ def run(self, data_recipe: DataRecipe) -> None: self._cleanup_checkpoints() t0 = time() - print(f"Setup started with fast_dev_run={self.fast_dev_run}.") + if self.verbose: + print(f"Setup started with fast_dev_run={self.fast_dev_run}.") # Force random seed to be fixed random.seed(self.random_seed) @@ -1231,7 +1239,8 @@ def run(self, data_recipe: DataRecipe) -> None: if isinstance(user_items, list) else "Using a Queue to process items on demand." ) - print(f"Setup finished in {round(time() - t0, 3)} seconds. {msg}") + if self.verbose: + print(f"Setup finished in {round(time() - t0, 3)} seconds. {msg}") if self.use_checkpoint: if isinstance(user_items, multiprocessing.queues.Queue): @@ -1244,49 +1253,56 @@ def run(self, data_recipe: DataRecipe) -> None: # Checkpoint feature is not supported for generators for now. raise ValueError("Checkpoint feature is not supported for generators, yet.") # get the last checkpoint details - print("Resuming from last saved checkpoint...") + if self.verbose: + print("Resuming from last saved checkpoint...") self._load_checkpoint_config(workers_user_items) assert isinstance(self.checkpoint_next_index, list) if all(self.checkpoint_next_index[i] == 0 for i in range(self.num_workers)): # save the current configuration in the checkpoints.json file - print("No checkpoints found. Saving current configuration...") + if self.verbose: + print("No checkpoints found. Saving current configuration...") self._save_current_config(workers_user_items) else: # load the last checkpoint details assert isinstance(self.checkpoint_next_index, list) workers_user_items = [w[self.checkpoint_next_index[i] :] for i, w in enumerate(workers_user_items)] - print("Checkpoints loaded successfully.") + if self.verbose: + print("Checkpoints loaded successfully.") if self.fast_dev_run and not isinstance(user_items, multiprocessing.queues.Queue): assert isinstance(workers_user_items, list) items_to_keep = self.fast_dev_run if isinstance(self.fast_dev_run, int) else _DEFAULT_FAST_DEV_RUN_ITEMS workers_user_items = [w[:items_to_keep] for w in workers_user_items] - print(f"Fast dev run is enabled. Limiting to {items_to_keep} items per process.") + if self.verbose: + print(f"Fast dev run is enabled. Limiting to {items_to_keep} items per process.") self._cleanup_cache() num_items = sum([len(items) for items in workers_user_items]) if workers_user_items is not None else -1 - if workers_user_items is not None: - print( - f"Starting {self.num_workers} workers with {num_items} items." - f" The progress bar is only updated when a worker finishes." - ) - else: - print(f"Starting {self.num_workers} workers with a Queue to process items on demand.") + if self.verbose: + if workers_user_items is not None: + print( + f"Starting {self.num_workers} workers with {num_items} items." + f" The progress bar is only updated when a worker finishes." + ) + else: + print(f"Starting {self.num_workers} workers with a Queue to process items on demand.") if self.input_dir is None and self.src_resolver is not None and self.input_dir: self.input_dir = self.src_resolver(self.input_dir) - print(f"The remote_dir is `{self.input_dir}`.") + if self.verbose: + print(f"The remote_dir is `{self.input_dir}`.") signal.signal(signal.SIGINT, self._signal_handler) self._create_process_workers(data_recipe, workers_user_items) - print("Workers are ready ! Starting data processing...") + if self.verbose: + print("Workers are ready ! Starting data processing...") current_total = 0 if _TQDM_AVAILABLE: @@ -1306,7 +1322,8 @@ def run(self, data_recipe: DataRecipe) -> None: total_num_items = len(user_items) if isinstance(user_items, list) else -1 while True: - flush_msg_queue(self.msg_queue, pbar if _TQDM_AVAILABLE else None) + if self.verbose: + flush_msg_queue(self.msg_queue, pbar if _TQDM_AVAILABLE else None) # Exit early if all the workers are done. # This means either there were some kinda of errors, or optimize function was very small. @@ -1315,7 +1332,8 @@ def run(self, data_recipe: DataRecipe) -> None: error = self.error_queue.get(timeout=0.01) self._exit_on_error(error) except Empty: - print("All workers are done. Exiting!") + if self.verbose: + print("All workers are done. Exiting!") break try: @@ -1349,13 +1367,15 @@ def run(self, data_recipe: DataRecipe) -> None: with open("status.json", "w") as f: json.dump({"progress": str(100 * current_total * num_nodes / total_num_items) + "%"}, f) - flush_msg_queue(self.msg_queue, pbar if _TQDM_AVAILABLE else None) + if self.verbose: + flush_msg_queue(self.msg_queue, pbar if _TQDM_AVAILABLE else None) if _TQDM_AVAILABLE: pbar.clear() pbar.close() - print("Workers are finished.") + if self.verbose: + print("Workers are finished.") size = len(workers_user_items) if workers_user_items is not None else None result = data_recipe._done(size, self.delete_cached_files, self.output_dir) @@ -1375,8 +1395,8 @@ def run(self, data_recipe: DataRecipe) -> None: num_chunks=result.num_chunks, num_bytes_per_chunk=result.num_bytes_per_chunk, ) - - print("Finished data processing!") + if self.verbose: + print("Finished data processing!") if self.use_checkpoint and isinstance(data_recipe, DataChunkRecipe): # clean up checkpoints self._cleanup_checkpoints() diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index 664159b78..93b0fcb48 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -410,6 +410,7 @@ def optimize( optimize_dns: Optional[bool] = None, storage_options: Dict[str, Any] = {}, keep_data_ordered: bool = True, + verbose: bool = True, ) -> None: """This function converts a dataset into chunks, possibly in a distributed way. @@ -453,6 +454,7 @@ def optimize( workload and reduce idle time when some workers finish early. This may lead to unordered processing of items. If True, each worker processes a statically assigned subset of items in order. + verbose: Whether to print the progress of the optimization. Defaults to True. """ _check_version_and_prompt_upgrade(__version__) @@ -491,7 +493,7 @@ def optimize( "Only https://lightning.ai/ supports multiple nodes or selecting a machine.Create an account to try it out." ) - if not _IS_IN_STUDIO: + if not _IS_IN_STUDIO and verbose: print( "Create an account on https://lightning.ai/ to optimize your data faster " "using multiple nodes and large machines." @@ -563,6 +565,7 @@ def optimize( start_method=start_method, storage_options=storage_options, keep_data_ordered=keep_data_ordered, + verbose=verbose, ) with optimize_dns_context(optimize_dns if optimize_dns is not None else False): From 202b88c788926a5eef500541e89d7c7243f296f6 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Sat, 5 Jul 2025 13:31:27 +0530 Subject: [PATCH 02/41] optimize yolo dataset --- src/litdata/constants.py | 1 + src/litdata/streaming/dataset.py | 1 + src/litdata/streaming/serializers.py | 1 + src/litdata/support/__init__.py | 12 ++ src/litdata/support/ultralytics/__init__.py | 12 ++ src/litdata/support/ultralytics/optimize.py | 150 ++++++++++++++++++++ 6 files changed, 177 insertions(+) create mode 100644 src/litdata/support/__init__.py create mode 100644 src/litdata/support/ultralytics/__init__.py create mode 100644 src/litdata/support/ultralytics/optimize.py diff --git a/src/litdata/constants.py b/src/litdata/constants.py index 63322c0fd..8b4888b84 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -45,6 +45,7 @@ _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") _AV_AVAILABLE = RequirementCache("av") +_ULTRALYTICS_AVAILABLE = RequirementCache("ultralytics") _DEBUG = bool(int(os.getenv("DEBUG_LITDATA", "0"))) _PRINT_DEBUG_LOGS = bool(int(os.getenv("PRINT_DEBUG_LOGS", "0"))) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 2522437be..ad16b2a78 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -418,6 +418,7 @@ def _resume(self, workers_chunks: List[List[int]], workers_intervals: List[Any]) self.worker_next_chunk_index += 1 def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: + print(f"get item: {index=}") if self.cache is None: self.worker_env = _WorkerEnv.detect() self.cache = self._create_cache(worker_env=self.worker_env) diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index abf2c2596..3d545839a 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -364,6 +364,7 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: return pickle.dumps(item), None def deserialize(self, data: bytes) -> Any: + print(f"pickle deserialize: {data=}") return pickle.loads(data) # noqa: S301 def can_serialize(self, _: Any) -> bool: diff --git a/src/litdata/support/__init__.py b/src/litdata/support/__init__.py new file mode 100644 index 000000000..27efc0815 --- /dev/null +++ b/src/litdata/support/__init__.py @@ -0,0 +1,12 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/litdata/support/ultralytics/__init__.py b/src/litdata/support/ultralytics/__init__.py new file mode 100644 index 000000000..27efc0815 --- /dev/null +++ b/src/litdata/support/ultralytics/__init__.py @@ -0,0 +1,12 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/litdata/support/ultralytics/optimize.py b/src/litdata/support/ultralytics/optimize.py new file mode 100644 index 000000000..99214e7f9 --- /dev/null +++ b/src/litdata/support/ultralytics/optimize.py @@ -0,0 +1,150 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path +from typing import Optional, Union + +import yaml + +from litdata.constants import _PIL_AVAILABLE, _ULTRALYTICS_AVAILABLE +from litdata.processing.functions import optimize +from litdata.streaming.resolver import Dir, _resolve_dir + + +def _ultralytics_optimize_fn(img_path: str): + """Internal function that will be passed to the `optimize` function.""" + # from PIL import Image + + # img = Image.open(img_path) + # if not img_path.endswith((".jpg", ".jpeg", ".png")): + # raise ValueError(f"Unsupported image format: {img_path}. Supported formats are .jpg, .jpeg, and .png.") + + img_ext = os.path.splitext(img_path)[-1].lower() # get the file extension + + label = "" + label_path = img_path.replace("images", "labels").replace(img_ext, ".txt") + + # read label file if it exists, else raise an error + if os.path.isfile(label_path): + with open(label_path) as f: + for line in f: + label += line.strip() + "\n" + # line = line.strip().split(" ") + # line_data = [int(line[0])] + [float(x) for x in line[1:]] + # label.append(line_data) + else: + raise FileNotFoundError(f"Label file not found: {label_path}") + + return { + "image": img_path, + "label": label, + } + + +def optimize_ultralytics_dataset( + yaml_path: str, + output_dir: str, + chunk_size: Optional[int] = None, + chunk_bytes: Optional[Union[int, str]] = None, + num_workers: int = 1, + verbose: bool = False, +) -> None: + """Optimize an Ultralytics dataset by converting it into chunks and resizing images. + + Args: + yaml_path: Path to the dataset YAML file. + output_dir: Directory where the optimized dataset will be saved. + chunk_size: Number of samples per chunk. If None, no chunking is applied. + chunk_bytes: Maximum size of each chunk in bytes. If None, no size limit is applied. + num_workers: Number of worker processes to use for optimization. Defaults to 1. + verbose: Whether to print progress messages. Defaults to False. + """ + if not _ULTRALYTICS_AVAILABLE: + raise ImportError( + "Ultralytics is not installed. Please install it with `pip install ultralytics` to use this function." + ) + if not _PIL_AVAILABLE: + raise ImportError("PIL is not installed. Please install it with `pip install pillow` to use this function.") + + # check if the YAML file exists and is a file + if not os.path.isfile(yaml_path): + raise FileNotFoundError(f"YAML file not found: {yaml_path}") + + if chunk_bytes is None and chunk_size is None: + raise ValueError("Either chunk_bytes or chunk_size must be specified.") + + if chunk_bytes is not None and chunk_size is not None: + raise ValueError("Only one of chunk_bytes or chunk_size should be specified, not both.") + + from ultralytics.data.utils import check_det_dataset + + # parse the YAML file & make sure data exists, else download it + dataset_config = check_det_dataset(yaml_path) + print(f"checked dataset structure: {dataset_config=}") + + output_dir = _resolve_dir(output_dir) + + mode_to_dir = {} + + for mode in ("train", "val", "test"): + if dataset_config[mode] is None: + continue + if not os.path.exists(dataset_config[mode]): + raise FileNotFoundError(f"Dataset directory not found for {mode}: {dataset_config[mode]}") + mode_output_dir = get_output_dir(output_dir, mode) + inputs = list_all_files(dataset_config[mode]) + + optimize( + fn=_ultralytics_optimize_fn, + inputs=inputs, + output_dir=mode_output_dir, + chunk_bytes=chunk_bytes, + chunk_size=chunk_size, + num_workers=num_workers, + mode="overwrite", + verbose=verbose, + ) + + mode_to_dir[mode] = mode_output_dir + print(f"Optimized {mode} dataset and saved to {mode_output_dir} ✅") + + # update the YAML file with the new paths + for mode, dir in mode_to_dir.items(): + if mode in dataset_config: + dataset_config[mode] = dir.url if dir.url else dir.path + else: + raise ValueError(f"Mode '{mode}' not found in dataset configuration.") + + # save the updated YAML file + with open("litdata_" + yaml_path, "w") as f: + yaml.dump(dataset_config, f) + + +def get_output_dir(output_dir: Dir, mode: str) -> Dir: + if not isinstance(output_dir, Dir): + raise TypeError(f"Expected output_dir to be of type Dir, got {type(output_dir)} instead.") + print(f"Using output_dir: {output_dir}") + url, path = output_dir.url, output_dir.path + if url is not None: + url = url.rstrip("/") + f"/{mode}" + if path is not None: + path = os.path.join(path, f"{mode}") + + updated_output_dir = Dir(url=url, path=path) + print(f"Updated output_dir for mode '{mode}': {updated_output_dir}") + return updated_output_dir + + +def list_all_files(path: str) -> list[str]: + return [str(p) for p in Path(path).rglob("*") if p.is_file()] From fd01f6102c74140a7aab24d36b0454198a8d4eb8 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Sat, 5 Jul 2025 13:34:54 +0530 Subject: [PATCH 03/41] update --- src/litdata/support/ultralytics/optimize.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/litdata/support/ultralytics/optimize.py b/src/litdata/support/ultralytics/optimize.py index 99214e7f9..a2b6bb630 100644 --- a/src/litdata/support/ultralytics/optimize.py +++ b/src/litdata/support/ultralytics/optimize.py @@ -24,11 +24,11 @@ def _ultralytics_optimize_fn(img_path: str): """Internal function that will be passed to the `optimize` function.""" - # from PIL import Image + from PIL import Image - # img = Image.open(img_path) - # if not img_path.endswith((".jpg", ".jpeg", ".png")): - # raise ValueError(f"Unsupported image format: {img_path}. Supported formats are .jpg, .jpeg, and .png.") + img = Image.open(img_path) + if not img_path.endswith((".jpg", ".jpeg", ".png")): + raise ValueError(f"Unsupported image format: {img_path}. Supported formats are .jpg, .jpeg, and .png.") img_ext = os.path.splitext(img_path)[-1].lower() # get the file extension @@ -47,7 +47,7 @@ def _ultralytics_optimize_fn(img_path: str): raise FileNotFoundError(f"Label file not found: {label_path}") return { - "image": img_path, + "image": img, "label": label, } From 6f9d631dc0139a2907d92216c0987017317d0d70 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Sat, 5 Jul 2025 16:43:31 +0530 Subject: [PATCH 04/41] update --- src/litdata/streaming/dataset.py | 1 - src/litdata/support/ultralytics/optimize.py | 5 +++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index ad16b2a78..2522437be 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -418,7 +418,6 @@ def _resume(self, workers_chunks: List[List[int]], workers_intervals: List[Any]) self.worker_next_chunk_index += 1 def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: - print(f"get item: {index=}") if self.cache is None: self.worker_env = _WorkerEnv.detect() self.cache = self._create_cache(worker_env=self.worker_env) diff --git a/src/litdata/support/ultralytics/optimize.py b/src/litdata/support/ultralytics/optimize.py index a2b6bb630..61c227cc2 100644 --- a/src/litdata/support/ultralytics/optimize.py +++ b/src/litdata/support/ultralytics/optimize.py @@ -126,6 +126,11 @@ def optimize_ultralytics_dataset( else: raise ValueError(f"Mode '{mode}' not found in dataset configuration.") + # convert path to string if it's a Path object + for key, value in dataset_config.items(): + if isinstance(value, Path): + dataset_config[key] = str(value) + # save the updated YAML file with open("litdata_" + yaml_path, "w") as f: yaml.dump(dataset_config, f) From 1560225913ffd7903c434e57f30d956665753c79 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Sat, 5 Jul 2025 16:54:40 +0530 Subject: [PATCH 05/41] update --- src/litdata/streaming/serializers.py | 1 - src/litdata/support/ultralytics/optimize.py | 7 ++----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index 3d545839a..abf2c2596 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -364,7 +364,6 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: return pickle.dumps(item), None def deserialize(self, data: bytes) -> Any: - print(f"pickle deserialize: {data=}") return pickle.loads(data) # noqa: S301 def can_serialize(self, _: Any) -> bool: diff --git a/src/litdata/support/ultralytics/optimize.py b/src/litdata/support/ultralytics/optimize.py index 61c227cc2..59bfbabf3 100644 --- a/src/litdata/support/ultralytics/optimize.py +++ b/src/litdata/support/ultralytics/optimize.py @@ -38,11 +38,8 @@ def _ultralytics_optimize_fn(img_path: str): # read label file if it exists, else raise an error if os.path.isfile(label_path): with open(label_path) as f: - for line in f: - label += line.strip() + "\n" - # line = line.strip().split(" ") - # line_data = [int(line[0])] + [float(x) for x in line[1:]] - # label.append(line_data) + # don't convert to lists, as labels might've different lengths and hence config won't be same for all images + label = f.read().strip() # read the entire file content as a single string else: raise FileNotFoundError(f"Label file not found: {label_path}") From 2ab318439e71f883cd0194fb2867f19a05e97efb Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Sun, 6 Jul 2025 10:06:26 +0530 Subject: [PATCH 06/41] patching works. verified for check_det_dataset of ultralytics --- src/litdata/support/ultralytics/__init__.py | 3 ++ src/litdata/support/ultralytics/patch.py | 44 +++++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 src/litdata/support/ultralytics/patch.py diff --git a/src/litdata/support/ultralytics/__init__.py b/src/litdata/support/ultralytics/__init__.py index 27efc0815..59b7e02f0 100644 --- a/src/litdata/support/ultralytics/__init__.py +++ b/src/litdata/support/ultralytics/__init__.py @@ -10,3 +10,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from litdata.support.ultralytics.patch import patch_ultralytics + +__all__ = ["patch_ultralytics"] diff --git a/src/litdata/support/ultralytics/patch.py b/src/litdata/support/ultralytics/patch.py new file mode 100644 index 000000000..7ac6204fa --- /dev/null +++ b/src/litdata/support/ultralytics/patch.py @@ -0,0 +1,44 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Dict + +from litdata.constants import _ULTRALYTICS_AVAILABLE + + +def patch_ultralytics(): + """Patch Ultralytics to use the LitData optimize function.""" + if not _ULTRALYTICS_AVAILABLE: + raise ImportError("Ultralytics is not available. Please install it to use this functionality.") + + import sys + + if "ultralytics" in sys.modules: + raise RuntimeError("patch_ultralytics() must be called before importing 'ultralytics'") + + from ultralytics.data.utils import check_det_dataset + + check_det_dataset.__code__ = patch_check_det_dataset.__code__ + + +def patch_check_det_dataset(dataset: str, _: bool = True) -> Dict: + if not (isinstance(dataset, str) and dataset.endswith(".yaml") and os.path.isfile(dataset)): + raise ValueError("Dataset must be a string ending with '.yaml' and point to a valid file.") + + import yaml + + # read the yaml file + with open(dataset) as file: + data = yaml.safe_load(file) + print(f"patch successful for {dataset}") + return data From 678f048e43c57ea70a2a231dade459a6abd84ab9 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 7 Jul 2025 08:40:52 +0530 Subject: [PATCH 07/41] ready to patch ultralytics now --- src/litdata/support/ultralytics/optimize.py | 6 +++--- src/litdata/support/ultralytics/patch.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/litdata/support/ultralytics/optimize.py b/src/litdata/support/ultralytics/optimize.py index 59bfbabf3..51663051f 100644 --- a/src/litdata/support/ultralytics/optimize.py +++ b/src/litdata/support/ultralytics/optimize.py @@ -13,7 +13,7 @@ import os from pathlib import Path -from typing import Optional, Union +from typing import Dict, Optional, Union import yaml @@ -22,7 +22,7 @@ from litdata.streaming.resolver import Dir, _resolve_dir -def _ultralytics_optimize_fn(img_path: str): +def _ultralytics_optimize_fn(img_path: str) -> Dict: """Internal function that will be passed to the `optimize` function.""" from PIL import Image @@ -105,7 +105,7 @@ def optimize_ultralytics_dataset( optimize( fn=_ultralytics_optimize_fn, inputs=inputs, - output_dir=mode_output_dir, + output_dir=mode_output_dir.url or mode_output_dir.path or "optimized_data", chunk_bytes=chunk_bytes, chunk_size=chunk_size, num_workers=num_workers, diff --git a/src/litdata/support/ultralytics/patch.py b/src/litdata/support/ultralytics/patch.py index 7ac6204fa..ec97702ba 100644 --- a/src/litdata/support/ultralytics/patch.py +++ b/src/litdata/support/ultralytics/patch.py @@ -16,7 +16,7 @@ from litdata.constants import _ULTRALYTICS_AVAILABLE -def patch_ultralytics(): +def patch_ultralytics() -> None: """Patch Ultralytics to use the LitData optimize function.""" if not _ULTRALYTICS_AVAILABLE: raise ImportError("Ultralytics is not available. Please install it to use this functionality.") From da75d4268706a1ea9d00e86fa1a909ca3b5393fe Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 7 Jul 2025 16:41:04 +0530 Subject: [PATCH 08/41] getting closer --- src/litdata/streaming/dataset.py | 30 ++- src/litdata/support/ultralytics/optimize.py | 7 +- src/litdata/support/ultralytics/patch.py | 279 +++++++++++++++++++- 3 files changed, 301 insertions(+), 15 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 2522437be..c6be00b1a 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -10,7 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect import logging import os from time import time @@ -62,7 +62,7 @@ def __init__( max_pre_download: int = 2, index_path: Optional[str] = None, force_override_state_dict: bool = False, - transform: Optional[Callable] = None, + transform: Optional[Union[Callable, List[Callable]]] = None, ) -> None: """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. @@ -89,7 +89,7 @@ def __init__( If `index_path` is a directory, the function will look for `index.json` within it. If `index_path` is a full file path, it will use that directly. force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict. - transform: Optional transformation function to apply to each item in the dataset. + transform: Optional transformation function or list of functions to apply to each item in the dataset. """ _check_version_and_prompt_upgrade(__version__) @@ -198,8 +198,10 @@ def __init__( self.session_options = session_options self.max_pre_download = max_pre_download if transform is not None: - if not callable(transform): - raise ValueError(f"Transform should be a callable. Found {transform}") + transform = transform if isinstance(transform, list) else [transform] + for t in transform: + if not callable(t): + raise ValueError(f"Transform should be a callable. Found {t}") self.transform = transform self._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache @@ -441,7 +443,23 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: {"name": f"getitem_dataset_for_chunk_index_{index.chunk_index}_and_index_{index.index}", "ph": "E"} ) ) - return self.transform(item) if hasattr(self, "transform") else item + if hasattr(self, "transform"): + if isinstance(self.transform, list): + for transform_fn in self.transform: + sig = inspect.signature(transform_fn) + if any( + p.kind in (p.KEYWORD_ONLY, p.VAR_KEYWORD) or p.name == "index" for p in sig.parameters.values() + ): + item = transform_fn(item, index=index.index) + else: + item = transform_fn(item) + else: + sig = inspect.signature(self.transform) + if any(p.kind in (p.KEYWORD_ONLY, p.VAR_KEYWORD) or p.name == "index" for p in sig.parameters.values()): + item = self.transform(item, index=index.index) + else: + item = self.transform(item) + return item def __next__(self) -> Any: # check if we have reached the end of the dataset (i.e., all the chunks have been processed) diff --git a/src/litdata/support/ultralytics/optimize.py b/src/litdata/support/ultralytics/optimize.py index 51663051f..04e3f847f 100644 --- a/src/litdata/support/ultralytics/optimize.py +++ b/src/litdata/support/ultralytics/optimize.py @@ -41,6 +41,7 @@ def _ultralytics_optimize_fn(img_path: str) -> Dict: # don't convert to lists, as labels might've different lengths and hence config won't be same for all images label = f.read().strip() # read the entire file content as a single string else: + return None raise FileNotFoundError(f"Label file not found: {label_path}") return { @@ -88,7 +89,6 @@ def optimize_ultralytics_dataset( # parse the YAML file & make sure data exists, else download it dataset_config = check_det_dataset(yaml_path) - print(f"checked dataset structure: {dataset_config=}") output_dir = _resolve_dir(output_dir) @@ -136,16 +136,13 @@ def optimize_ultralytics_dataset( def get_output_dir(output_dir: Dir, mode: str) -> Dir: if not isinstance(output_dir, Dir): raise TypeError(f"Expected output_dir to be of type Dir, got {type(output_dir)} instead.") - print(f"Using output_dir: {output_dir}") url, path = output_dir.url, output_dir.path if url is not None: url = url.rstrip("/") + f"/{mode}" if path is not None: path = os.path.join(path, f"{mode}") - updated_output_dir = Dir(url=url, path=path) - print(f"Updated output_dir for mode '{mode}': {updated_output_dir}") - return updated_output_dir + return Dir(url=url, path=path) def list_all_files(path: str) -> list[str]: diff --git a/src/litdata/support/ultralytics/patch.py b/src/litdata/support/ultralytics/patch.py index ec97702ba..8503482d0 100644 --- a/src/litdata/support/ultralytics/patch.py +++ b/src/litdata/support/ultralytics/patch.py @@ -11,9 +11,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Dict +from functools import partial +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch from litdata.constants import _ULTRALYTICS_AVAILABLE +from litdata.streaming.dataloader import StreamingDataLoader +from litdata.streaming.dataset import StreamingDataset def patch_ultralytics() -> None: @@ -21,15 +27,37 @@ def patch_ultralytics() -> None: if not _ULTRALYTICS_AVAILABLE: raise ImportError("Ultralytics is not available. Please install it to use this functionality.") - import sys + # import sys - if "ultralytics" in sys.modules: - raise RuntimeError("patch_ultralytics() must be called before importing 'ultralytics'") + # if "ultralytics" in sys.modules: + # raise RuntimeError("patch_ultralytics() must be called before importing 'ultralytics'") + # Patch detection dataset loading from ultralytics.data.utils import check_det_dataset check_det_dataset.__code__ = patch_check_det_dataset.__code__ + # Patch training visualizer (optional, but useful) + from ultralytics.models.yolo.detect.train import DetectionTrainer + + DetectionTrainer.plot_training_samples = patch_detection_plot_training_samples + DetectionTrainer.plot_training_labels = patch_none_function + + # Patch BaseDataset globally + import ultralytics.data.base as base_module + import ultralytics.data.dataset as child_modules + + base_module.BaseDataset = PatchedUltralyticsBaseDataset + base_module.BaseDataset.set_rectangle = patch_none_function + child_modules.YOLODataset.__bases__ = (PatchedUltralyticsBaseDataset,) + child_modules.YOLODataset.get_labels = patch_get_labels + + from ultralytics.data.build import build_dataloader + + build_dataloader.__code__ = patch_build_dataloader.__code__ + + print("✅ Ultralytics successfully patched to use LitData.") + def patch_check_det_dataset(dataset: str, _: bool = True) -> Dict: if not (isinstance(dataset, str) and dataset.endswith(".yaml") and os.path.isfile(dataset)): @@ -42,3 +70,246 @@ def patch_check_det_dataset(dataset: str, _: bool = True) -> Dict: data = yaml.safe_load(file) print(f"patch successful for {dataset}") return data + + +def patch_build_dataloader( + dataset: Any, batch: int, workers: int, shuffle: bool = True, rank: int = -1, drop_last: bool = False +) -> StreamingDataLoader: + """Create and return an InfiniteDataLoader or DataLoader for training or validation. + + Args: + dataset (Dataset): Dataset to load data from. + batch (int): Batch size for the dataloader. + workers (int): Number of worker threads for loading data. + shuffle (bool, optional): Whether to shuffle the dataset. + rank (int, optional): Process rank in distributed training. -1 for single-GPU training. + drop_last (bool, optional): Whether to drop the last incomplete batch. + + Returns: + (StreamingDataLoader): A dataloader that can be used for training or validation. + + Examples: + Create a dataloader for training + >>> dataset = YOLODataset(...) + >>> dataloader = build_dataloader(dataset, batch=16, workers=4, shuffle=True) + """ + from litdata.streaming.dataloader import StreamingDataLoader + + print("litdata is rocking⚡️") + if not hasattr(dataset, "streaming_dataset"): + raise ValueError("The dataset must have a 'streaming_dataset' attribute.") + + from ultralytics.data.utils import PIN_MEMORY + + batch = min(batch, len(dataset)) + num_devices = torch.cuda.device_count() # number of CUDA devices + num_workers = min(os.cpu_count() // max(num_devices, 1), workers) # number of workers + return StreamingDataLoader( + dataset=dataset.streaming_dataset, + batch_size=batch, + num_workers=num_workers, + pin_memory=PIN_MEMORY, + collate_fn=getattr(dataset, "collate_fn", None), + drop_last=drop_last, + ) + + +class TransformedStreamingDataset(StreamingDataset): + def transform(self, x, *args, **kwargs): + """Apply transformations to the data. + + Args: + x: Data to transform. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + Transformed data. + """ + ... + + +if _ULTRALYTICS_AVAILABLE: + from ultralytics.data.base import BaseDataset as UltralyticsBaseDataset + from ultralytics.utils.plotting import plot_images + + class PatchedUltralyticsBaseDataset(UltralyticsBaseDataset): + def __init__(self, img_path: str, classes: Optional[List[int]] = None, *args, **kwargs): + print("patched ultralytics dataset: 🔥") + self.litdata_dataset = img_path + self.classes = classes + super().__init__(img_path, classes=classes, *args, **kwargs) + self.streaming_dataset = TransformedStreamingDataset( + img_path, + transform=[ + ultralytics_detection_transform, + partial(self.transform_update_label, classes), + ], + ) + self.ni = len(self.streaming_dataset) + self.buffer = list(range(len(self.streaming_dataset))) + + def __len__(self): + """Return the length of the dataset.""" + return len(self.streaming_dataset) + + def get_image_and_label(self, index): + # Your custom logic to load from .litdata + # e.g. use `self.litdata_dataset[index]` + raise NotImplementedError("Custom logic here") + + def get_img_files(self, img_path: Union[str, List[str]]) -> List[str]: + """Let this method return an empty list to avoid errors.""" + return [] + + def get_labels(self) -> List[Dict[str, Any]]: + # this is used to get number of images (ni) in the BaseDataset class + return [] + + def cache_images(self) -> None: + pass + + def cache_images_to_disk(self, i: int) -> None: + pass + + def check_cache_disk(self, safety_margin: float = 0.5) -> bool: + """Check if the cache disk is available.""" + # This method is not used in the streaming dataset, so we can return True + return True + + def check_cache_ram(self, safety_margin: float = 0.5) -> bool: + """Check if the cache RAM is available.""" + # This method is not used in the streaming dataset, so we can return True + return True + + def update_labels(self, *args, **kwargs): + """Do nothing, we will update labels when item is fetched in transform.""" + pass + + def transform_update_label(self, include_class: Optional[List[int]], label: Dict, *args, **kwargs) -> Dict: + """Update labels to include only specified classes. + + Args: + self: PatchedUltralyticsBaseDataset instance. + include_class (List[int], optional): List of classes to include. If None, all classes are included. + label (Dict): Label to update. + *args: Additional positional arguments (unused). + **kwargs: Additional keyword arguments (unused). + """ + include_class_array = np.array(include_class).reshape(1, -1) + if include_class is not None: + cls = label["cls"] + bboxes = label["bboxes"] + segments = label["segments"] + keypoints = label["keypoints"] + j = (cls == include_class_array).any(1) + label["cls"] = cls[j] + label["bboxes"] = bboxes[j] + if segments: + label["segments"] = [segments[si] for si, idx in enumerate(j) if idx] + if keypoints is not None: + label["keypoints"] = keypoints[j] + if self.single_cls: + label["cls"][:, 0] = 0 + + return label + + def __getitem__(self, index: int) -> Dict[str, Any]: + """Return transformed label information for given index.""" + # return self.transforms(self.get_image_and_label(index)) + if not hasattr(self, "streaming_dataset"): + raise ValueError("The dataset must have a 'streaming_dataset' attribute.") + data = self.streaming_dataset[index] + + label = data["label"] + # split label on the basis of `\n` and then split each line on the basis of ` ` + # first element is class, rest are bbox coordinates + if isinstance(label, str): + label = label.split("\n") + label = [line.split(" ") for line in label if line.strip()] + + data = { + "batch_idx": torch.Tensor([index], dtype=torch.int32), # ← add this! + "img": data["image"], + "cls": torch.Tensor([int(line[0]) for line in label]), + "bboxes": torch.Tensor([[float(coord) for coord in line[1:]] for line in label]), + "normalized": True, + "segments": [], + "keypoints": None, + "bbox_format": "xywh", + } + else: + raise ValueError("Label must be a string in YOLO format.") + + data = self.transform_update_label( + include_class=self.classes, + label=data, + ) + print("-" * 100) + return self.transforms(data) + + def patch_detection_plot_training_samples(self, batch: Dict[str, Any], ni: int) -> None: + """Plot training samples with their annotations. + + Args: + self: DetectionTrainer instance. + batch (Dict[str, Any]): Dictionary containing batch data. + ni (int): Number of iterations. + """ + plot_images( + labels=batch, + images=batch["img"], + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) + + def patch_detection_plot_training_labels(self) -> None: + """Create a labeled training plot of the YOLO model.""" + pass + + def patch_get_labels(self) -> List[Dict[str, Any]]: + # this is used to get number of images (ni) in the BaseDataset class + return [] + + def patch_none_function(*args, **kwargs): + """A placeholder function that does nothing.""" + pass + +# ------- helper transformations ------- + + +def ultralytics_detection_transform(data: Dict[str, Any], index: int) -> Dict[str, Any]: + """Transform function for YOLO detection datasets. + + Args: + data (Dict[str, Any]): Input data containing image and label. + index (int): Index of the data item. + + Returns: + Dict[str, Any]: Transformed data with image and label. + """ + if index is None: + raise ValueError("Index must be provided for YOLO detection transform.") + label = data["label"] + # split label on the basis of `\n` and then split each line on the basis of ` ` + # first element is class, rest are bbox coordinates + if isinstance(label, str): + label = label.split("\n") + label = [line.split(" ") for line in label if line.strip()] + print(f"label={label}") + + data = { + "batch_idx": torch.Tensor([index]), # ← add this! + "img": data["image"], + "cls": torch.Tensor([int(line[0]) for line in label]), + "bboxes": torch.Tensor([[float(coord) for coord in line[1:]] for line in label]), + "normalized": True, + "segments": [], + "keypoints": None, + "bbox_format": "xywh", + } + print(f"{data=}") + else: + raise ValueError("Label must be a string in YOLO format.") + + return data From 8667e526b19d0aea435b4deb0e3784150739e3dc Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 7 Jul 2025 16:45:14 +0530 Subject: [PATCH 09/41] update --- src/litdata/support/ultralytics/optimize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/support/ultralytics/optimize.py b/src/litdata/support/ultralytics/optimize.py index 04e3f847f..2a6e8f02f 100644 --- a/src/litdata/support/ultralytics/optimize.py +++ b/src/litdata/support/ultralytics/optimize.py @@ -22,7 +22,7 @@ from litdata.streaming.resolver import Dir, _resolve_dir -def _ultralytics_optimize_fn(img_path: str) -> Dict: +def _ultralytics_optimize_fn(img_path: str) -> Optional[Dict]: """Internal function that will be passed to the `optimize` function.""" from PIL import Image From b2f5677b982c297e9827eb0ff9b8ba3d950381fe Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Tue, 8 Jul 2025 14:34:31 +0530 Subject: [PATCH 10/41] yolo model train end to end --- src/litdata/streaming/dataset.py | 31 +++-- src/litdata/support/ultralytics/optimize.py | 12 +- src/litdata/support/ultralytics/patch.py | 141 ++++++++++++++++++-- 3 files changed, 155 insertions(+), 29 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index c6be00b1a..2a4afd3d2 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -63,6 +63,9 @@ def __init__( index_path: Optional[str] = None, force_override_state_dict: bool = False, transform: Optional[Union[Callable, List[Callable]]] = None, + transform_kwargs: Optional[Dict[str, Any]] = None, + *args: Any, + **kwargs: Any, ) -> None: """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. @@ -90,6 +93,9 @@ def __init__( If `index_path` is a full file path, it will use that directly. force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict. transform: Optional transformation function or list of functions to apply to each item in the dataset. + transform_kwargs: Keyword arguments for the transformation function. + args: Additional positional arguments. + kwargs: Additional keyword arguments. """ _check_version_and_prompt_upgrade(__version__) @@ -203,6 +209,7 @@ def __init__( if not callable(t): raise ValueError(f"Transform should be a callable. Found {t}") self.transform = transform + self.transform_kwargs = transform_kwargs or {} self._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache @property @@ -444,21 +451,21 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: ) ) if hasattr(self, "transform"): + self.transform_kwargs["index"] = index.index if isinstance(self.transform, list): for transform_fn in self.transform: - sig = inspect.signature(transform_fn) - if any( - p.kind in (p.KEYWORD_ONLY, p.VAR_KEYWORD) or p.name == "index" for p in sig.parameters.values() - ): - item = transform_fn(item, index=index.index) - else: - item = transform_fn(item) + signature = inspect.signature(transform_fn) + accepts_kwargs = any( + param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values() + ) + item = transform_fn(item, **self.transform_kwargs) if accepts_kwargs else transform_fn(item) else: - sig = inspect.signature(self.transform) - if any(p.kind in (p.KEYWORD_ONLY, p.VAR_KEYWORD) or p.name == "index" for p in sig.parameters.values()): - item = self.transform(item, index=index.index) - else: - item = self.transform(item) + # check if transform function accepts kwargs + signature = inspect.signature(transform_fn) + accepts_kwargs = any( + param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values() + ) + item = self.transform(item, **self.transform_kwargs) if accepts_kwargs else self.transform(item) return item def __next__(self) -> Any: diff --git a/src/litdata/support/ultralytics/optimize.py b/src/litdata/support/ultralytics/optimize.py index 2a6e8f02f..4737888fd 100644 --- a/src/litdata/support/ultralytics/optimize.py +++ b/src/litdata/support/ultralytics/optimize.py @@ -24,9 +24,13 @@ def _ultralytics_optimize_fn(img_path: str) -> Optional[Dict]: """Internal function that will be passed to the `optimize` function.""" - from PIL import Image + # from PIL import Image + # from torchvision.io import read_image + import cv2 - img = Image.open(img_path) + # img = Image.open(img_path) + # img = read_image(img_path) + img = cv2.imread(img_path) if not img_path.endswith((".jpg", ".jpeg", ".png")): raise ValueError(f"Unsupported image format: {img_path}. Supported formats are .jpg, .jpeg, and .png.") @@ -45,7 +49,7 @@ def _ultralytics_optimize_fn(img_path: str) -> Optional[Dict]: raise FileNotFoundError(f"Label file not found: {label_path}") return { - "image": img, + "img": img, "label": label, } @@ -127,7 +131,7 @@ def optimize_ultralytics_dataset( for key, value in dataset_config.items(): if isinstance(value, Path): dataset_config[key] = str(value) - + dataset_config[""] # save the updated YAML file with open("litdata_" + yaml_path, "w") as f: yaml.dump(dataset_config, f) diff --git a/src/litdata/support/ultralytics/patch.py b/src/litdata/support/ultralytics/patch.py index 8503482d0..4b8d91eba 100644 --- a/src/litdata/support/ultralytics/patch.py +++ b/src/litdata/support/ultralytics/patch.py @@ -10,9 +10,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math import os from functools import partial -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -43,6 +44,11 @@ def patch_ultralytics() -> None: DetectionTrainer.plot_training_samples = patch_detection_plot_training_samples DetectionTrainer.plot_training_labels = patch_none_function + from ultralytics.models.yolo.detect.val import DetectionValidator + + DetectionValidator.plot_val_samples = patch_detection_plot_val_samples + DetectionValidator.plot_predictions = patch_detection_plot_predictions + # Patch BaseDataset globally import ultralytics.data.base as base_module import ultralytics.data.dataset as child_modules @@ -56,6 +62,10 @@ def patch_ultralytics() -> None: build_dataloader.__code__ = patch_build_dataloader.__code__ + from ultralytics.data.augment import Compose + + Compose.__call__.__code__ = patch_compose_transform_call.__code__ + print("✅ Ultralytics successfully patched to use LitData.") @@ -65,11 +75,15 @@ def patch_check_det_dataset(dataset: str, _: bool = True) -> Dict: import yaml + if not dataset.startswith("litdata_"): + dataset = "litdata_" + dataset + + if not os.path.isfile(dataset): + raise FileNotFoundError(f"Dataset file not found: {dataset}") + # read the yaml file with open(dataset) as file: - data = yaml.safe_load(file) - print(f"patch successful for {dataset}") - return data + return yaml.safe_load(file) def patch_build_dataloader( @@ -144,7 +158,10 @@ def __init__(self, img_path: str, classes: Optional[List[int]] = None, *args, ** transform=[ ultralytics_detection_transform, partial(self.transform_update_label, classes), + self.update_labels_info, + self.transforms, ], + transform_kwargs={"img_size": self.imgsz, "channels": self.channels}, ) self.ni = len(self.streaming_dataset) self.buffer = list(range(len(self.streaming_dataset))) @@ -245,7 +262,6 @@ def __getitem__(self, index: int) -> Dict[str, Any]: include_class=self.classes, label=data, ) - print("-" * 100) return self.transforms(data) def patch_detection_plot_training_samples(self, batch: Dict[str, Any], ni: int) -> None: @@ -263,6 +279,53 @@ def patch_detection_plot_training_samples(self, batch: Dict[str, Any], ni: int) on_plot=self.on_plot, ) + def patch_detection_plot_val_samples(self: Any, batch: Dict[str, Any], ni: int) -> None: + """Plot validation image samples. + + Args: + self: DetectionValidator instance. + batch (Dict[str, Any]): Batch containing images and annotations. + ni (int): Batch index. + """ + plot_images( + labels=batch, + images=batch["img"], + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) + + def patch_detection_plot_predictions( + self: Any, batch: Dict[str, Any], preds: List[Dict[str, torch.Tensor]], ni: int, max_det: Optional[int] = None + ) -> None: + """Plot predicted bounding boxes on input images and save the result. + + Args: + self: DetectionValidator instance. + batch (Dict[str, Any]): Batch containing images and annotations. + preds (List[Dict[str, torch.Tensor]]): List of predictions from the model. + ni (int): Batch index. + max_det (Optional[int]): Maximum number of detections to plot. + """ + from ultralytics.utils import ops + + # TODO: optimize this + for i, pred in enumerate(preds): + pred["batch_idx"] = torch.ones_like(pred["conf"]) * i # add batch index to predictions + keys = preds[0].keys() + max_det = max_det or self.args.max_det + batched_preds = {k: torch.cat([x[k][:max_det] for x in preds], dim=0) for k in keys} + # TODO: fix this + batched_preds["bboxes"][:, :4] = ops.xyxy2xywh(batched_preds["bboxes"][:, :4]) # convert to xywh format + plot_images( + images=batch["img"], + labels=batched_preds, + paths=None, + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred + def patch_detection_plot_training_labels(self) -> None: """Create a labeled training plot of the YOLO model.""" pass @@ -275,40 +338,92 @@ def patch_none_function(*args, **kwargs): """A placeholder function that does nothing.""" pass + def image_resize( + im: Any, imgsz: int, rect_mode: bool = True, augment: bool = True + ) -> Tuple[Any, Tuple[int, int], Tuple[int, int]]: + """Resize the image to a fixed size. + + Args: + im (Any): Image to resize. + imgsz (int): Target size for resizing. + rect_mode (bool): Whether to use rectangle mode for resizing. + augment (bool): If True, data augmentation is applied. + + Returns: + Tuple[Any, Tuple[int, int], Tuple[int, int]]: Resized image and its original dimensions. + """ + import cv2 + + # Custom logic for resizing the image + h0, w0 = im.shape[:2] # orig hw + if rect_mode: # resize long side to imgsz while maintaining aspect ratio + r = imgsz / max(h0, w0) # ratio + if r != 1: # if sizes are not equal + w, h = (min(math.ceil(w0 * r), imgsz), min(math.ceil(h0 * r), imgsz)) + im = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR) + elif not (h0 == w0 == imgsz): # resize by stretching image to square imgsz + im = cv2.resize(im, (imgsz, imgsz), interpolation=cv2.INTER_LINEAR) + if im.ndim == 2: + im = im[..., None] + + return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized + + def patch_compose_transform_call(self: Any, data: Any) -> Any: + """Apply all transforms to the data, skipping mix transforms.""" + from ultralytics.data.augment import BaseMixTransform + + for t in self.transforms: + if isinstance(t, BaseMixTransform): + continue # Skip mix transforms, they are applied separately + data = t(data) + return data + # ------- helper transformations ------- -def ultralytics_detection_transform(data: Dict[str, Any], index: int) -> Dict[str, Any]: +def ultralytics_detection_transform(data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: """Transform function for YOLO detection datasets. Args: data (Dict[str, Any]): Input data containing image and label. - index (int): Index of the data item. + kwargs (Dict[str, Any]): Additional keyword arguments, including the index of the data item. Returns: Dict[str, Any]: Transformed data with image and label. """ + index = kwargs.get("index") + channels = kwargs.get("channels", 3) # default to 3 channels (RGB) if index is None: raise ValueError("Index must be provided for YOLO detection transform.") + label = data["label"] # split label on the basis of `\n` and then split each line on the basis of ` ` # first element is class, rest are bbox coordinates if isinstance(label, str): label = label.split("\n") label = [line.split(" ") for line in label if line.strip()] - print(f"label={label}") + img, ori_shape, resized_shape = image_resize( + data["img"], imgsz=kwargs.get("img_size", 640), rect_mode=True, augment=True + ) + ratio_pad = ( + resized_shape[0] / ori_shape[0], + resized_shape[1] / ori_shape[1], + ) # for evaluation data = { - "batch_idx": torch.Tensor([index]), # ← add this! - "img": data["image"], - "cls": torch.Tensor([int(line[0]) for line in label]), - "bboxes": torch.Tensor([[float(coord) for coord in line[1:]] for line in label]), + "batch_idx": np.array([index]), # ← add this! + "img": img, + "cls": np.array([[int(line[0])] for line in label]), + "bboxes": np.array([[float(coord) for coord in line[1:]] for line in label]), "normalized": True, "segments": [], "keypoints": None, "bbox_format": "xywh", + "ori_shape": ori_shape, + "resized_shape": resized_shape, + "ratio_pad": ratio_pad, + "channels": channels, } - print(f"{data=}") else: raise ValueError("Label must be a string in YOLO format.") From 39c7bf3efe60de514881e798a8276409f0ba5a64 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Wed, 9 Jul 2025 10:48:46 +0530 Subject: [PATCH 11/41] update --- src/litdata/streaming/dataset.py | 2 +- src/litdata/support/ultralytics/patch.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 2a4afd3d2..10b5eae09 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -461,7 +461,7 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: item = transform_fn(item, **self.transform_kwargs) if accepts_kwargs else transform_fn(item) else: # check if transform function accepts kwargs - signature = inspect.signature(transform_fn) + signature = inspect.signature(self.transform) accepts_kwargs = any( param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values() ) diff --git a/src/litdata/support/ultralytics/patch.py b/src/litdata/support/ultralytics/patch.py index 4b8d91eba..cc4e13d1b 100644 --- a/src/litdata/support/ultralytics/patch.py +++ b/src/litdata/support/ultralytics/patch.py @@ -118,10 +118,13 @@ def patch_build_dataloader( batch = min(batch, len(dataset)) num_devices = torch.cuda.device_count() # number of CUDA devices num_workers = min(os.cpu_count() // max(num_devices, 1), workers) # number of workers + num_workers = int(os.getenv("UL_NUM_WORKERS", num_workers)) # get from environment variable if set + persistent_workers = bool(int(os.getenv("UL_PERSISTENT_WORKERS", 0))) return StreamingDataLoader( dataset=dataset.streaming_dataset, batch_size=batch, num_workers=num_workers, + persistent_workers=persistent_workers, pin_memory=PIN_MEMORY, collate_fn=getattr(dataset, "collate_fn", None), drop_last=drop_last, From 3159a07c9b23db0b905c11f41906a1d03e854e61 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Wed, 9 Jul 2025 10:57:27 +0530 Subject: [PATCH 12/41] fix mypy errors --- src/litdata/support/ultralytics/patch.py | 56 +++++++++--------------- 1 file changed, 21 insertions(+), 35 deletions(-) diff --git a/src/litdata/support/ultralytics/patch.py b/src/litdata/support/ultralytics/patch.py index cc4e13d1b..ca1eaea9f 100644 --- a/src/litdata/support/ultralytics/patch.py +++ b/src/litdata/support/ultralytics/patch.py @@ -117,8 +117,7 @@ def patch_build_dataloader( batch = min(batch, len(dataset)) num_devices = torch.cuda.device_count() # number of CUDA devices - num_workers = min(os.cpu_count() // max(num_devices, 1), workers) # number of workers - num_workers = int(os.getenv("UL_NUM_WORKERS", num_workers)) # get from environment variable if set + num_workers = min((os.cpu_count() or 1) // max(num_devices, 1), workers) # number of workers persistent_workers = bool(int(os.getenv("UL_PERSISTENT_WORKERS", 0))) return StreamingDataLoader( dataset=dataset.streaming_dataset, @@ -131,32 +130,17 @@ def patch_build_dataloader( ) -class TransformedStreamingDataset(StreamingDataset): - def transform(self, x, *args, **kwargs): - """Apply transformations to the data. - - Args: - x: Data to transform. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. - - Returns: - Transformed data. - """ - ... - - if _ULTRALYTICS_AVAILABLE: from ultralytics.data.base import BaseDataset as UltralyticsBaseDataset from ultralytics.utils.plotting import plot_images class PatchedUltralyticsBaseDataset(UltralyticsBaseDataset): - def __init__(self, img_path: str, classes: Optional[List[int]] = None, *args, **kwargs): + def __init__(self: Any, img_path: str, classes: Optional[List[int]] = None, *args: Any, **kwargs: Any): print("patched ultralytics dataset: 🔥") self.litdata_dataset = img_path self.classes = classes super().__init__(img_path, classes=classes, *args, **kwargs) - self.streaming_dataset = TransformedStreamingDataset( + self.streaming_dataset = StreamingDataset( img_path, transform=[ ultralytics_detection_transform, @@ -169,44 +153,46 @@ def __init__(self, img_path: str, classes: Optional[List[int]] = None, *args, ** self.ni = len(self.streaming_dataset) self.buffer = list(range(len(self.streaming_dataset))) - def __len__(self): + def __len__(self: Any) -> int: """Return the length of the dataset.""" return len(self.streaming_dataset) - def get_image_and_label(self, index): + def get_image_and_label(self: Any, index: int) -> None: # Your custom logic to load from .litdata # e.g. use `self.litdata_dataset[index]` raise NotImplementedError("Custom logic here") - def get_img_files(self, img_path: Union[str, List[str]]) -> List[str]: + def get_img_files(self: Any, img_path: Union[str, List[str]]) -> List[str]: """Let this method return an empty list to avoid errors.""" return [] - def get_labels(self) -> List[Dict[str, Any]]: + def get_labels(self: Any) -> List[Dict[str, Any]]: # this is used to get number of images (ni) in the BaseDataset class return [] - def cache_images(self) -> None: + def cache_images(self: Any) -> None: pass - def cache_images_to_disk(self, i: int) -> None: + def cache_images_to_disk(self: Any, i: int) -> None: pass - def check_cache_disk(self, safety_margin: float = 0.5) -> bool: + def check_cache_disk(self: Any, safety_margin: float = 0.5) -> bool: """Check if the cache disk is available.""" # This method is not used in the streaming dataset, so we can return True return True - def check_cache_ram(self, safety_margin: float = 0.5) -> bool: + def check_cache_ram(self: Any, safety_margin: float = 0.5) -> bool: """Check if the cache RAM is available.""" # This method is not used in the streaming dataset, so we can return True return True - def update_labels(self, *args, **kwargs): + def update_labels(self: Any, *args: Any, **kwargs: Any) -> None: """Do nothing, we will update labels when item is fetched in transform.""" pass - def transform_update_label(self, include_class: Optional[List[int]], label: Dict, *args, **kwargs) -> Dict: + def transform_update_label( + self: Any, include_class: Optional[List[int]], label: Dict, *args: Any, **kwargs: Any + ) -> Dict: """Update labels to include only specified classes. Args: @@ -234,7 +220,7 @@ def transform_update_label(self, include_class: Optional[List[int]], label: Dict return label - def __getitem__(self, index: int) -> Dict[str, Any]: + def __getitem__(self: Any, index: int) -> Dict[str, Any]: """Return transformed label information for given index.""" # return self.transforms(self.get_image_and_label(index)) if not hasattr(self, "streaming_dataset"): @@ -249,7 +235,7 @@ def __getitem__(self, index: int) -> Dict[str, Any]: label = [line.split(" ") for line in label if line.strip()] data = { - "batch_idx": torch.Tensor([index], dtype=torch.int32), # ← add this! + "batch_idx": torch.tensor([index], dtype=torch.int32), # ← add this! "img": data["image"], "cls": torch.Tensor([int(line[0]) for line in label]), "bboxes": torch.Tensor([[float(coord) for coord in line[1:]] for line in label]), @@ -267,7 +253,7 @@ def __getitem__(self, index: int) -> Dict[str, Any]: ) return self.transforms(data) - def patch_detection_plot_training_samples(self, batch: Dict[str, Any], ni: int) -> None: + def patch_detection_plot_training_samples(self: Any, batch: Dict[str, Any], ni: int) -> None: """Plot training samples with their annotations. Args: @@ -329,15 +315,15 @@ def patch_detection_plot_predictions( on_plot=self.on_plot, ) # pred - def patch_detection_plot_training_labels(self) -> None: + def patch_detection_plot_training_labels(self: Any) -> None: """Create a labeled training plot of the YOLO model.""" pass - def patch_get_labels(self) -> List[Dict[str, Any]]: + def patch_get_labels(self: Any) -> List[Dict[str, Any]]: # this is used to get number of images (ni) in the BaseDataset class return [] - def patch_none_function(*args, **kwargs): + def patch_none_function(*args: Any, **kwargs: Any) -> None: """A placeholder function that does nothing.""" pass From 08b0683ce43574e859b2efe58298074f27eb48d6 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Wed, 9 Jul 2025 10:59:29 +0530 Subject: [PATCH 13/41] update --- src/litdata/support/ultralytics/patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/support/ultralytics/patch.py b/src/litdata/support/ultralytics/patch.py index ca1eaea9f..49167e41e 100644 --- a/src/litdata/support/ultralytics/patch.py +++ b/src/litdata/support/ultralytics/patch.py @@ -235,7 +235,7 @@ def __getitem__(self: Any, index: int) -> Dict[str, Any]: label = [line.split(" ") for line in label if line.strip()] data = { - "batch_idx": torch.tensor([index], dtype=torch.int32), # ← add this! + "batch_idx": torch.tensor([index], dtype=torch.int32), "img": data["image"], "cls": torch.Tensor([int(line[0]) for line in label]), "bboxes": torch.Tensor([[float(coord) for coord in line[1:]] for line in label]), From ee8d1795d54f4e8d4181e4bec70204aacd87e012 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Wed, 9 Jul 2025 19:31:39 +0530 Subject: [PATCH 14/41] update --- src/litdata/support/ultralytics/optimize.py | 40 ++++++++++++--------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/src/litdata/support/ultralytics/optimize.py b/src/litdata/support/ultralytics/optimize.py index 4737888fd..373d7d6be 100644 --- a/src/litdata/support/ultralytics/optimize.py +++ b/src/litdata/support/ultralytics/optimize.py @@ -23,30 +23,38 @@ def _ultralytics_optimize_fn(img_path: str) -> Optional[Dict]: - """Internal function that will be passed to the `optimize` function.""" - # from PIL import Image - # from torchvision.io import read_image - import cv2 - - # img = Image.open(img_path) - # img = read_image(img_path) - img = cv2.imread(img_path) + """Optimized function for Ultralytics that reads image + label and optionally re-encodes to reduce size.""" if not img_path.endswith((".jpg", ".jpeg", ".png")): raise ValueError(f"Unsupported image format: {img_path}. Supported formats are .jpg, .jpeg, and .png.") - img_ext = os.path.splitext(img_path)[-1].lower() # get the file extension + import cv2 + + img_ext = os.path.splitext(img_path)[-1].lower() + + # Read image using OpenCV + img = cv2.imread(img_path, cv2.IMREAD_COLOR) + if img is None: + raise ValueError(f"Failed to read image: {img_path}") + # JPEG re-encode if image is jpeg or png + if img_ext in [".jpg", ".jpeg", ".png"]: + # Reduce quality to 90% + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 90] + success, encoded = cv2.imencode(".jpg", img, encode_param) + if not success: + raise ValueError(f"JPEG encoding failed for: {img_path}") + + # Decode it back to a numpy array (OpenCV default format) + img = cv2.imdecode(encoded, cv2.IMREAD_COLOR) + + # Load the label label = "" label_path = img_path.replace("images", "labels").replace(img_ext, ".txt") - - # read label file if it exists, else raise an error if os.path.isfile(label_path): with open(label_path) as f: - # don't convert to lists, as labels might've different lengths and hence config won't be same for all images - label = f.read().strip() # read the entire file content as a single string + label = f.read().strip() else: - return None - raise FileNotFoundError(f"Label file not found: {label_path}") + return None # skip this sample return { "img": img, @@ -131,7 +139,7 @@ def optimize_ultralytics_dataset( for key, value in dataset_config.items(): if isinstance(value, Path): dataset_config[key] = str(value) - dataset_config[""] + # save the updated YAML file with open("litdata_" + yaml_path, "w") as f: yaml.dump(dataset_config, f) From 90d828ec136d3200d02fb79a816a5b548119fc70 Mon Sep 17 00:00:00 2001 From: Deependu Date: Thu, 10 Jul 2025 09:11:14 +0000 Subject: [PATCH 15/41] despacito --- src/litdata/support/ultralytics/optimize.py | 17 ++++- src/litdata/support/ultralytics/patch.py | 71 +++++++++++++++++++-- 2 files changed, 79 insertions(+), 9 deletions(-) diff --git a/src/litdata/support/ultralytics/optimize.py b/src/litdata/support/ultralytics/optimize.py index 373d7d6be..97932f04c 100644 --- a/src/litdata/support/ultralytics/optimize.py +++ b/src/litdata/support/ultralytics/optimize.py @@ -157,5 +157,18 @@ def get_output_dir(output_dir: Dir, mode: str) -> Dir: return Dir(url=url, path=path) -def list_all_files(path: str) -> list[str]: - return [str(p) for p in Path(path).rglob("*") if p.is_file()] +def list_all_files(_path: str) -> list[str]: + path = Path(_path) + + if path.is_dir(): + # Recursively list all files under the directory + return [str(p) for p in path.rglob("*") if p.is_file()] + + if path.is_file() and path.suffix == ".txt": + # Read lines and return cleaned-up paths + base_dir = path.parent # use the parent of the txt file to resolve relative paths + with open(path) as f: + return [str((base_dir / line.strip()).resolve()) for line in f if line.strip()] + + else: + raise ValueError(f"Unsupported path: {path}") diff --git a/src/litdata/support/ultralytics/patch.py b/src/litdata/support/ultralytics/patch.py index 49167e41e..c041b6550 100644 --- a/src/litdata/support/ultralytics/patch.py +++ b/src/litdata/support/ultralytics/patch.py @@ -148,7 +148,15 @@ def __init__(self: Any, img_path: str, classes: Optional[List[int]] = None, *arg self.update_labels_info, self.transforms, ], - transform_kwargs={"img_size": self.imgsz, "channels": self.channels}, + transform_kwargs={ + "img_size": self.imgsz, + "channels": self.channels, + "segment": self.use_segments, + "use_keypoints": self.use_keypoints, + "use_obb": self.use_obb, + "lit_args": self.data, + "single_cls": self.single_cls, + }, ) self.ni = len(self.streaming_dataset) self.buffer = list(range(len(self.streaming_dataset))) @@ -389,8 +397,6 @@ def ultralytics_detection_transform(data: Dict[str, Any], **kwargs: Any) -> Dict # split label on the basis of `\n` and then split each line on the basis of ` ` # first element is class, rest are bbox coordinates if isinstance(label, str): - label = label.split("\n") - label = [line.split(" ") for line in label if line.strip()] img, ori_shape, resized_shape = image_resize( data["img"], imgsz=kwargs.get("img_size", 640), rect_mode=True, augment=True ) @@ -398,15 +404,16 @@ def ultralytics_detection_transform(data: Dict[str, Any], **kwargs: Any) -> Dict resized_shape[0] / ori_shape[0], resized_shape[1] / ori_shape[1], ) # for evaluation + lb, segments, keypoint = parse_labels(label, **kwargs) data = { "batch_idx": np.array([index]), # ← add this! "img": img, - "cls": np.array([[int(line[0])] for line in label]), - "bboxes": np.array([[float(coord) for coord in line[1:]] for line in label]), + "cls": lb[:, 0:1], # n, 1 + "bboxes": lb[:, 1:], # n, 4 + "segments": segments, + "keypoints": keypoint, "normalized": True, - "segments": [], - "keypoints": None, "bbox_format": "xywh", "ori_shape": ori_shape, "resized_shape": resized_shape, @@ -417,3 +424,53 @@ def ultralytics_detection_transform(data: Dict[str, Any], **kwargs: Any) -> Dict raise ValueError("Label must be a string in YOLO format.") return data + + +def parse_labels(labels: str, **kwargs: Any): + from ultralytics.utils.ops import segments2boxes + + keypoint = kwargs.get("keypoint", False) + single_cls = kwargs.get("single_cls", False) + data = kwargs["lit_args"] + nkpt, ndim = data.get("kpt_shape", (0, 0)) + num_cls: int = len(data["names"]) + + segments, keypoints = [], None + + lb = [x.split() for x in labels.split("\n") if len(x)] + if any(len(x) > 6 for x in lb) and (not keypoint): # is segment + classes = np.array([x[0] for x in lb], dtype=np.float32) + segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...) + lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) + lb = np.array(lb, dtype=np.float32) + if nl := len(lb): + if keypoint: + assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each" + points = lb[:, 5:].reshape(-1, ndim)[:, :2] + else: + assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected" + points = lb[:, 1:] + # Coordinate points check with 1% tolerance + assert points.max() <= 1.01, f"non-normalized or out of bounds coordinates {points[points > 1.01]}" + assert lb.min() >= -0.01, f"negative class labels {lb[lb < -0.01]}" + + # All labels + if single_cls: + lb[:, 0] = 0 + max_cls = lb[:, 0].max() # max label count + assert max_cls < num_cls, ( + f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. " + f"Possible class labels are 0-{num_cls - 1}" + ) + _, i = np.unique(lb, axis=0, return_index=True) + if len(i) < nl: # duplicate row check + lb = lb[i] # remove duplicates + if segments: + segments = [segments[x] for x in i] + if keypoint: + keypoints = lb[:, 5:].reshape(-1, nkpt, ndim) + if ndim == 2: + kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32) + keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3) + lb = lb[:, :5] + return lb, segments, keypoints From 5ce76d6b0cb48b123c40c8330eb1c58d21eb2d9c Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Thu, 10 Jul 2025 15:09:05 +0530 Subject: [PATCH 16/41] update --- src/litdata/{support => integrations}/__init__.py | 0 .../{support => integrations}/ultralytics/__init__.py | 5 +++-- .../{support => integrations}/ultralytics/optimize.py | 0 src/litdata/{support => integrations}/ultralytics/patch.py | 0 4 files changed, 3 insertions(+), 2 deletions(-) rename src/litdata/{support => integrations}/__init__.py (100%) rename src/litdata/{support => integrations}/ultralytics/__init__.py (72%) rename src/litdata/{support => integrations}/ultralytics/optimize.py (100%) rename src/litdata/{support => integrations}/ultralytics/patch.py (100%) diff --git a/src/litdata/support/__init__.py b/src/litdata/integrations/__init__.py similarity index 100% rename from src/litdata/support/__init__.py rename to src/litdata/integrations/__init__.py diff --git a/src/litdata/support/ultralytics/__init__.py b/src/litdata/integrations/ultralytics/__init__.py similarity index 72% rename from src/litdata/support/ultralytics/__init__.py rename to src/litdata/integrations/ultralytics/__init__.py index 59b7e02f0..f9a4fae1b 100644 --- a/src/litdata/support/ultralytics/__init__.py +++ b/src/litdata/integrations/ultralytics/__init__.py @@ -10,6 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from litdata.support.ultralytics.patch import patch_ultralytics +from litdata.integrations.ultralytics.optimize import optimize_ultralytics_dataset +from litdata.integrations.ultralytics.patch import patch_ultralytics -__all__ = ["patch_ultralytics"] +__all__ = ["optimize_ultralytics_dataset", "patch_ultralytics"] diff --git a/src/litdata/support/ultralytics/optimize.py b/src/litdata/integrations/ultralytics/optimize.py similarity index 100% rename from src/litdata/support/ultralytics/optimize.py rename to src/litdata/integrations/ultralytics/optimize.py diff --git a/src/litdata/support/ultralytics/patch.py b/src/litdata/integrations/ultralytics/patch.py similarity index 100% rename from src/litdata/support/ultralytics/patch.py rename to src/litdata/integrations/ultralytics/patch.py From cbe7ef44e721156a5d8a678391c347ec3716e3dd Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Thu, 10 Jul 2025 23:42:11 +0530 Subject: [PATCH 17/41] write tests --- .../integrations/ultralytics/optimize.py | 3 +- src/litdata/integrations/ultralytics/patch.py | 2 +- tests/conftest.py | 31 ++++++ .../ultralytics_support/__init__.py | 0 .../ultralytics_support/test_optimize.py | 101 ++++++++++++++++++ tests/streaming/test_dataset.py | 53 +++++++++ 6 files changed, 188 insertions(+), 2 deletions(-) create mode 100644 tests/integrations/ultralytics_support/__init__.py create mode 100644 tests/integrations/ultralytics_support/test_optimize.py diff --git a/src/litdata/integrations/ultralytics/optimize.py b/src/litdata/integrations/ultralytics/optimize.py index 97932f04c..e991ca5a4 100644 --- a/src/litdata/integrations/ultralytics/optimize.py +++ b/src/litdata/integrations/ultralytics/optimize.py @@ -141,7 +141,8 @@ def optimize_ultralytics_dataset( dataset_config[key] = str(value) # save the updated YAML file - with open("litdata_" + yaml_path, "w") as f: + output_yaml = Path(yaml_path).with_name("litdata_" + Path(yaml_path).name) + with open(output_yaml, "w") as f: yaml.dump(dataset_config, f) diff --git a/src/litdata/integrations/ultralytics/patch.py b/src/litdata/integrations/ultralytics/patch.py index c041b6550..da214937f 100644 --- a/src/litdata/integrations/ultralytics/patch.py +++ b/src/litdata/integrations/ultralytics/patch.py @@ -426,7 +426,7 @@ def ultralytics_detection_transform(data: Dict[str, Any], **kwargs: Any) -> Dict return data -def parse_labels(labels: str, **kwargs: Any): +def parse_labels(labels: str, **kwargs: Any) -> Tuple[np.ndarray, List[np.ndarray], Optional[np.ndarray]]: from ultralytics.utils.ops import segments2boxes keypoint = kwargs.get("keypoint", False) diff --git a/tests/conftest.py b/tests/conftest.py index f306a78ec..f745f5087 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -202,6 +202,37 @@ def huggingface_hub_mock(monkeypatch): return huggingface_hub +@pytest.fixture +def mock_ultralytics(monkeypatch, tmp_path): + fake_ultralytics = ModuleType("ultralytics") + fake_data = ModuleType("ultralytics.data") + fake_utils = ModuleType("ultralytics.data.utils") + + # Create fake dataset paths using tmp_path (safe on all OSes) + train_dir = tmp_path / "train" + val_dir = tmp_path / "val" + train_dir.mkdir() + val_dir.mkdir() + + def check_det_dataset(yaml_path): + return { + "train": str(train_dir), + "val": str(val_dir), + "test": None, + "names": [], + } + + fake_utils.check_det_dataset = check_det_dataset + + # Register in sys.modules + monkeypatch.setitem(sys.modules, "ultralytics", fake_ultralytics) + monkeypatch.setitem(sys.modules, "ultralytics.data", fake_data) + monkeypatch.setitem(sys.modules, "ultralytics.data.utils", fake_utils) + + fake_data.utils = fake_utils + fake_ultralytics.data = fake_data + + @pytest.fixture def huggingface_hub_fs_mock(monkeypatch, write_pq_data, tmp_path): huggingface_hub = ModuleType("huggingface_hub") diff --git a/tests/integrations/ultralytics_support/__init__.py b/tests/integrations/ultralytics_support/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integrations/ultralytics_support/test_optimize.py b/tests/integrations/ultralytics_support/test_optimize.py new file mode 100644 index 000000000..18182d4b4 --- /dev/null +++ b/tests/integrations/ultralytics_support/test_optimize.py @@ -0,0 +1,101 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import mock + +import pytest + +from litdata.integrations.ultralytics.optimize import get_output_dir, list_all_files, optimize_ultralytics_dataset +from litdata.streaming.resolver import Dir + + +@mock.patch("litdata.integrations.ultralytics.optimize._ULTRALYTICS_AVAILABLE", True) +@mock.patch("litdata.integrations.ultralytics.optimize.optimize") +def test_optimize_ultralytics_dataset_mocked(optimize_mock, tmp_path, mock_ultralytics): + os.makedirs(tmp_path / "images/train", exist_ok=True) + os.makedirs(tmp_path / "images/val", exist_ok=True) + for split in ["train", "val"]: + (tmp_path / f"images/{split}/img1.jpg").touch() + (tmp_path / f"labels/{split}/img1.txt").parent.mkdir(parents=True, exist_ok=True) + (tmp_path / f"labels/{split}/img1.txt").write_text("0 0.5 0.5 0.2 0.2") + + yaml_file = tmp_path / "coco8.yaml" + yaml_file.write_text( + "path: {}\ntrain: {}\nval: {}\n".format(tmp_path, tmp_path / "images" / "train", tmp_path / "images" / "val") + ) + + optimize_ultralytics_dataset(str(yaml_file), str(tmp_path / "out"), chunk_size=1, num_workers=1) + + assert optimize_mock.called + assert (tmp_path / "litdata_coco8.yaml").exists() + + +def test_get_output_dir(): + # Case 1: Both url and path provided + d1 = Dir(path="/data/output", url="s3://bucket/output/") + r1 = get_output_dir(d1, "train") + assert r1.path == os.path.join("/data/output", "train") + assert r1.url == "s3://bucket/output/train" + + # Case 2: Only url provided + d2 = Dir(url="s3://bucket/output/") + r2 = get_output_dir(d2, "val") + assert r2.path is None + assert r2.url == "s3://bucket/output/val" + + # Case 3: Only path provided + d3 = Dir(path="/data/output") + r3 = get_output_dir(d3, "test") + assert r3.url is None + assert r3.path == os.path.join("/data/output", "test") + + # Case 4: Neither url nor path provided + d4 = Dir() + r4 = get_output_dir(d4, "debug") + assert r4.url is None + assert r4.path is None + + # Case 5: Invalid type + with pytest.raises(TypeError): + get_output_dir("not_a_dir_obj", "fail") + + +def test_list_all_files_combined(tmp_path): + # --- Case 1: Directory with nested files --- + (tmp_path / "a.txt").write_text("hello") + (tmp_path / "subdir").mkdir() + (tmp_path / "subdir" / "b.txt").write_text("world") + + result = list_all_files(str(tmp_path)) + expected = {str(tmp_path / "a.txt"), str(tmp_path / "subdir" / "b.txt")} + assert set(result) == expected, "Should list all files recursively from directory" + + # --- Case 2: .txt file listing files --- + (tmp_path / "img1.jpg").touch() + (tmp_path / "img2.jpg").touch() + txt_file = tmp_path / "train.txt" + txt_file.write_text("img1.jpg\nimg2.jpg\n") + + result = list_all_files(str(txt_file)) + expected = { + str((tmp_path / "img1.jpg").resolve()), + str((tmp_path / "img2.jpg").resolve()), + } + assert set(result) == expected, ".txt path list should resolve correctly" + + # --- Case 3: Unsupported file path --- + bad_file = tmp_path / "unsupported.md" + bad_file.write_text("invalid") + + with pytest.raises(ValueError, match="Unsupported path"): + list_all_files(str(bad_file)) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 13bdd90c3..dc7e1d746 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -1695,6 +1695,59 @@ def transform_fn(x, *args, **kwargs): assert item == i * 2, f"Expected {i * 2}, got {item}" +@pytest.mark.parametrize("shuffle", [True, False]) +def test_dataset_multiple_transform(tmpdir, shuffle): + """Test if the dataset transform is applied correctly.""" + # Create a simple dataset + # Create directories for cache and data + cache_dir = os.path.join(tmpdir, "cache_dir") + data_dir = os.path.join(tmpdir, "data_dir") + os.makedirs(cache_dir) + os.makedirs(data_dir) + + # Create a dataset with 100 items, 20 items per chunk + cache = Cache(str(data_dir), chunk_size=20) + for i in range(100): + cache[i] = i + cache.done() + cache.merge() + + # Define two simple transform function + def transform_fn_1(x, *args, **kwargs): + """A simple transform function that doubles the input.""" + return x * 2 + + def transform_fn_2(x, *args, **kwargs): + """A simple transform function that adds one to the input.""" + extra_num = kwargs.get("extra_num", 0) + return x + extra_num + + dataset = StreamingDataset( + data_dir, + cache_dir=str(cache_dir), + shuffle=shuffle, + transform=[transform_fn_1, transform_fn_2], + transform_kwargs={"extra_num": 100}, + ) + dataset_length = len(dataset) + assert dataset_length == 100 + + # ACT + # Stream through the entire dataset and store the results + complete_data = [] + for data in dataset: + assert data is not None + complete_data.append(data) + + if shuffle: + complete_data.sort() + + # ASSERT + # Verify that the transform is applied correctly + for i, item in enumerate(complete_data): + assert item == i * 2 + 100, f"Expected {i * 2 + 100}, got {item}" + + @pytest.mark.parametrize("shuffle", [True, False]) def test_dataset_transform_inheritance(tmpdir, shuffle): """Test if the dataset transform is applied correctly.""" From c5e177dcb387e4ce5fe472e91c6580f321ad0720 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Fri, 11 Jul 2025 00:11:30 +0530 Subject: [PATCH 18/41] update --- src/litdata/integrations/ultralytics/patch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/litdata/integrations/ultralytics/patch.py b/src/litdata/integrations/ultralytics/patch.py index da214937f..fbf7a9d3b 100644 --- a/src/litdata/integrations/ultralytics/patch.py +++ b/src/litdata/integrations/ultralytics/patch.py @@ -17,6 +17,7 @@ import numpy as np import torch +from numpy.typing import NDArray from litdata.constants import _ULTRALYTICS_AVAILABLE from litdata.streaming.dataloader import StreamingDataLoader @@ -426,7 +427,9 @@ def ultralytics_detection_transform(data: Dict[str, Any], **kwargs: Any) -> Dict return data -def parse_labels(labels: str, **kwargs: Any) -> Tuple[np.ndarray, List[np.ndarray], Optional[np.ndarray]]: +def parse_labels( + labels: str, **kwargs: Any +) -> Tuple[NDArray[np.float32], List[NDArray[np.float32]], Optional[NDArray[np.float32]]]: from ultralytics.utils.ops import segments2boxes keypoint = kwargs.get("keypoint", False) From 9c8842a1d2a768062babc63d85c6765fc7ea3060 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Fri, 11 Jul 2025 00:19:00 +0530 Subject: [PATCH 19/41] update --- src/litdata/integrations/ultralytics/patch.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/litdata/integrations/ultralytics/patch.py b/src/litdata/integrations/ultralytics/patch.py index fbf7a9d3b..70cdf528c 100644 --- a/src/litdata/integrations/ultralytics/patch.py +++ b/src/litdata/integrations/ultralytics/patch.py @@ -427,9 +427,7 @@ def ultralytics_detection_transform(data: Dict[str, Any], **kwargs: Any) -> Dict return data -def parse_labels( - labels: str, **kwargs: Any -) -> Tuple[NDArray[np.float32], List[NDArray[np.float32]], Optional[NDArray[np.float32]]]: +def parse_labels(labels: str, **kwargs: Any) -> Tuple[NDArray[np.float32], List[NDArray[np.float32]], Any]: from ultralytics.utils.ops import segments2boxes keypoint = kwargs.get("keypoint", False) From 74f88f397ef9cb7190fc43d65351a1530b87d915 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Fri, 11 Jul 2025 00:24:46 +0530 Subject: [PATCH 20/41] update --- src/litdata/integrations/ultralytics/patch.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/litdata/integrations/ultralytics/patch.py b/src/litdata/integrations/ultralytics/patch.py index 70cdf528c..74edaa975 100644 --- a/src/litdata/integrations/ultralytics/patch.py +++ b/src/litdata/integrations/ultralytics/patch.py @@ -17,7 +17,6 @@ import numpy as np import torch -from numpy.typing import NDArray from litdata.constants import _ULTRALYTICS_AVAILABLE from litdata.streaming.dataloader import StreamingDataLoader @@ -427,7 +426,7 @@ def ultralytics_detection_transform(data: Dict[str, Any], **kwargs: Any) -> Dict return data -def parse_labels(labels: str, **kwargs: Any) -> Tuple[NDArray[np.float32], List[NDArray[np.float32]], Any]: +def parse_labels(labels: str, **kwargs: Any) -> Tuple[np.ndarray, Union[np.ndarray, List], Optional[np.ndarray]]: from ultralytics.utils.ops import segments2boxes keypoint = kwargs.get("keypoint", False) From 2c00abf4c5770f2bb4413cbeed67bd132435efe5 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Fri, 11 Jul 2025 00:30:25 +0530 Subject: [PATCH 21/41] update --- src/litdata/integrations/ultralytics/patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/integrations/ultralytics/patch.py b/src/litdata/integrations/ultralytics/patch.py index 74edaa975..93b1b6cda 100644 --- a/src/litdata/integrations/ultralytics/patch.py +++ b/src/litdata/integrations/ultralytics/patch.py @@ -426,7 +426,7 @@ def ultralytics_detection_transform(data: Dict[str, Any], **kwargs: Any) -> Dict return data -def parse_labels(labels: str, **kwargs: Any) -> Tuple[np.ndarray, Union[np.ndarray, List], Optional[np.ndarray]]: +def parse_labels(labels: str, **kwargs: Any) -> Tuple[np.ndarray, Union[np.ndarray, List[Any]], Optional[Any]]: from ultralytics.utils.ops import segments2boxes keypoint = kwargs.get("keypoint", False) From 5d8a7041f52e417ed8af346626a023d4b167993c Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Fri, 11 Jul 2025 00:33:57 +0530 Subject: [PATCH 22/41] update --- src/litdata/integrations/ultralytics/patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/integrations/ultralytics/patch.py b/src/litdata/integrations/ultralytics/patch.py index 93b1b6cda..eaaa8a3fb 100644 --- a/src/litdata/integrations/ultralytics/patch.py +++ b/src/litdata/integrations/ultralytics/patch.py @@ -426,7 +426,7 @@ def ultralytics_detection_transform(data: Dict[str, Any], **kwargs: Any) -> Dict return data -def parse_labels(labels: str, **kwargs: Any) -> Tuple[np.ndarray, Union[np.ndarray, List[Any]], Optional[Any]]: +def parse_labels(labels: str, **kwargs: Any) -> Tuple[Any, Any, Any]: from ultralytics.utils.ops import segments2boxes keypoint = kwargs.get("keypoint", False) From 867d8d00ad324868e2cda7d9afb493681b962bbd Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Fri, 11 Jul 2025 00:40:54 +0530 Subject: [PATCH 23/41] update --- src/litdata/integrations/ultralytics/patch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/litdata/integrations/ultralytics/patch.py b/src/litdata/integrations/ultralytics/patch.py index eaaa8a3fb..c68dd9f03 100644 --- a/src/litdata/integrations/ultralytics/patch.py +++ b/src/litdata/integrations/ultralytics/patch.py @@ -435,7 +435,8 @@ def parse_labels(labels: str, **kwargs: Any) -> Tuple[Any, Any, Any]: nkpt, ndim = data.get("kpt_shape", (0, 0)) num_cls: int = len(data["names"]) - segments, keypoints = [], None + segments: Any = [] + keypoints: Any = None lb = [x.split() for x in labels.split("\n") if len(x)] if any(len(x) > 6 for x in lb) and (not keypoint): # is segment From 74aa0ae7668e2b61861adf30bc8c81fc4c93c117 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Fri, 11 Jul 2025 01:17:06 +0530 Subject: [PATCH 24/41] test-cov --- tests/processing/test_functions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/processing/test_functions.py b/tests/processing/test_functions.py index 111b1ccea..1007da67b 100644 --- a/tests/processing/test_functions.py +++ b/tests/processing/test_functions.py @@ -124,7 +124,8 @@ def random_image(index): @pytest.mark.skipif(sys.platform == "win32", reason="too slow") -def test_optimize_append_overwrite(tmpdir): +@pytest.mark.parametrize("verbose", [False, True]) +def test_optimize_append_overwrite(tmpdir, verbose): output_dir = str(tmpdir / "output_dir") optimize( @@ -133,6 +134,7 @@ def test_optimize_append_overwrite(tmpdir): num_workers=1, output_dir=output_dir, chunk_bytes="64MB", + verbose=verbose, ) ds = StreamingDataset(output_dir) From 5f85ccfba3fbb65d38f9b3b57906b3d90f5d6c1c Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Fri, 11 Jul 2025 01:24:01 +0530 Subject: [PATCH 25/41] update --- src/litdata/integrations/ultralytics/patch.py | 129 +++++++++--------- 1 file changed, 61 insertions(+), 68 deletions(-) diff --git a/src/litdata/integrations/ultralytics/patch.py b/src/litdata/integrations/ultralytics/patch.py index c68dd9f03..bb6980f5f 100644 --- a/src/litdata/integrations/ultralytics/patch.py +++ b/src/litdata/integrations/ultralytics/patch.py @@ -28,12 +28,6 @@ def patch_ultralytics() -> None: if not _ULTRALYTICS_AVAILABLE: raise ImportError("Ultralytics is not available. Please install it to use this functionality.") - # import sys - - # if "ultralytics" in sys.modules: - # raise RuntimeError("patch_ultralytics() must be called before importing 'ultralytics'") - - # Patch detection dataset loading from ultralytics.data.utils import check_det_dataset check_det_dataset.__code__ = patch_check_det_dataset.__code__ @@ -69,67 +63,6 @@ def patch_ultralytics() -> None: print("✅ Ultralytics successfully patched to use LitData.") -def patch_check_det_dataset(dataset: str, _: bool = True) -> Dict: - if not (isinstance(dataset, str) and dataset.endswith(".yaml") and os.path.isfile(dataset)): - raise ValueError("Dataset must be a string ending with '.yaml' and point to a valid file.") - - import yaml - - if not dataset.startswith("litdata_"): - dataset = "litdata_" + dataset - - if not os.path.isfile(dataset): - raise FileNotFoundError(f"Dataset file not found: {dataset}") - - # read the yaml file - with open(dataset) as file: - return yaml.safe_load(file) - - -def patch_build_dataloader( - dataset: Any, batch: int, workers: int, shuffle: bool = True, rank: int = -1, drop_last: bool = False -) -> StreamingDataLoader: - """Create and return an InfiniteDataLoader or DataLoader for training or validation. - - Args: - dataset (Dataset): Dataset to load data from. - batch (int): Batch size for the dataloader. - workers (int): Number of worker threads for loading data. - shuffle (bool, optional): Whether to shuffle the dataset. - rank (int, optional): Process rank in distributed training. -1 for single-GPU training. - drop_last (bool, optional): Whether to drop the last incomplete batch. - - Returns: - (StreamingDataLoader): A dataloader that can be used for training or validation. - - Examples: - Create a dataloader for training - >>> dataset = YOLODataset(...) - >>> dataloader = build_dataloader(dataset, batch=16, workers=4, shuffle=True) - """ - from litdata.streaming.dataloader import StreamingDataLoader - - print("litdata is rocking⚡️") - if not hasattr(dataset, "streaming_dataset"): - raise ValueError("The dataset must have a 'streaming_dataset' attribute.") - - from ultralytics.data.utils import PIN_MEMORY - - batch = min(batch, len(dataset)) - num_devices = torch.cuda.device_count() # number of CUDA devices - num_workers = min((os.cpu_count() or 1) // max(num_devices, 1), workers) # number of workers - persistent_workers = bool(int(os.getenv("UL_PERSISTENT_WORKERS", 0))) - return StreamingDataLoader( - dataset=dataset.streaming_dataset, - batch_size=batch, - num_workers=num_workers, - persistent_workers=persistent_workers, - pin_memory=PIN_MEMORY, - collate_fn=getattr(dataset, "collate_fn", None), - drop_last=drop_last, - ) - - if _ULTRALYTICS_AVAILABLE: from ultralytics.data.base import BaseDataset as UltralyticsBaseDataset from ultralytics.utils.plotting import plot_images @@ -321,7 +254,66 @@ def patch_detection_plot_predictions( fname=self.save_dir / f"val_batch{ni}_pred.jpg", names=self.names, on_plot=self.on_plot, - ) # pred + ) + + def patch_check_det_dataset(dataset: str, _: bool = True) -> Dict: + if not (isinstance(dataset, str) and dataset.endswith(".yaml") and os.path.isfile(dataset)): + raise ValueError("Dataset must be a string ending with '.yaml' and point to a valid file.") + + import yaml + + if not dataset.startswith("litdata_"): + dataset = "litdata_" + dataset + + if not os.path.isfile(dataset): + raise FileNotFoundError(f"Dataset file not found: {dataset}") + + # read the yaml file + with open(dataset) as file: + return yaml.safe_load(file) + + def patch_build_dataloader( + dataset: Any, batch: int, workers: int, shuffle: bool = True, rank: int = -1, drop_last: bool = False + ) -> StreamingDataLoader: + """Create and return an InfiniteDataLoader or DataLoader for training or validation. + + Args: + dataset (Dataset): Dataset to load data from. + batch (int): Batch size for the dataloader. + workers (int): Number of worker threads for loading data. + shuffle (bool, optional): Whether to shuffle the dataset. + rank (int, optional): Process rank in distributed training. -1 for single-GPU training. + drop_last (bool, optional): Whether to drop the last incomplete batch. + + Returns: + (StreamingDataLoader): A dataloader that can be used for training or validation. + + Examples: + Create a dataloader for training + >>> dataset = YOLODataset(...) + >>> dataloader = build_dataloader(dataset, batch=16, workers=4, shuffle=True) + """ + from litdata.streaming.dataloader import StreamingDataLoader + + print("litdata is rocking⚡️") + if not hasattr(dataset, "streaming_dataset"): + raise ValueError("The dataset must have a 'streaming_dataset' attribute.") + + from ultralytics.data.utils import PIN_MEMORY + + batch = min(batch, len(dataset)) + num_devices = torch.cuda.device_count() # number of CUDA devices + num_workers = min((os.cpu_count() or 1) // max(num_devices, 1), workers) # number of workers + persistent_workers = bool(int(os.getenv("UL_PERSISTENT_WORKERS", 0))) + return StreamingDataLoader( + dataset=dataset.streaming_dataset, + batch_size=batch, + num_workers=num_workers, + persistent_workers=persistent_workers, + pin_memory=PIN_MEMORY, + collate_fn=getattr(dataset, "collate_fn", None), + drop_last=drop_last, + ) def patch_detection_plot_training_labels(self: Any) -> None: """Create a labeled training plot of the YOLO model.""" @@ -375,6 +367,7 @@ def patch_compose_transform_call(self: Any, data: Any) -> Any: data = t(data) return data + # ------- helper transformations ------- From 3b2f398fe79f1f3d1cc316cf972fe5a0affd830e Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Fri, 11 Jul 2025 01:29:54 +0530 Subject: [PATCH 26/41] add readme --- README.md | 78 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/README.md b/README.md index 2be60537b..34b55919d 100644 --- a/README.md +++ b/README.md @@ -1516,6 +1516,84 @@ if __name__ == "__main__":   +
+ 🚀 Stream Large Datasets to Ultralytics Models with LitData + +  + +This feature enables **training Ultralytics models (like YOLO)** directly from **LitData’s optimized streaming datasets**. Now you can train on massive datasets (e.g., 500GB+) without downloading everything to disk — just stream from **S3, GCS, local paths, or HTTP(S)** with minimal overhead. + +--- + +### 🔧 How It Works + +#### **Step 1: Optimize Your Dataset (One-time Step)** + +Convert your existing Ultralytics-style dataset into an optimized streaming format: + +```python +from litdata.integrations.ultralytics import optimize_ultralytics_dataset + +optimize_ultralytics_dataset( + "coco128.yaml", # Original dataset config + "s3://some-bucket/optimized-data", # Cloud path or local directory + num_workers=4, # Number of concurrent workers + chunk_bytes="64MB", # Chunk size for streaming +) +``` + +This generates an optimized dataset and creates a new `litdata_coco128.yaml` config to use for training. + +--- + +#### **Step 2: Patch Ultralytics for Streaming** + +Before training, patch Ultralytics internals to enable LitData streaming: + +```python +from litdata.integrations.ultralytics import patch_ultralytics + +patch_ultralytics() +``` + +--- + +#### **Step 3: Train Like Usual — But Now From the Cloud ☁️** + +```python +from ultralytics import YOLO + +patch_ultralytics() + +model = YOLO("yolo11n.pt") +model.train(data="litdata_coco128.yaml", epochs=100, imgsz=640) +``` + +That’s it — Ultralytics now streams your data via LitData under the hood! + +--- + +### ✅ Benefits + +* 🔁 **Stream datasets of any size** — no need to fully download. +* 💾 **Save disk space** — only minimal local caching is used. +* 🧪 **Benchmark-tested** — supports both local and cloud training. +* 🧩 **Plug-and-play with Ultralytics** — zero training code changes. +* ☁️ Supports **S3, GCS, HTTP(S), and local disk** out-of-the-box. + +--- + +### 📊 Benchmarks (Lightning Studio L4 GPU) + + +While the performance gains aren't drastic (due to Ultralytics caching internally), this integration **unlocks all the benefits of streaming** and enables training on large-scale datasets from the cloud. + +We’re also exploring a **custom LitData dataloader** built from scratch (potentially breaking GIL using Rust or multithreading). If it outperforms `torch.DataLoader`, future benchmarks could reflect significant performance boosts. 💡 + +
+ +  + ## Features for transforming datasets From 8d06209e5159515b8d9b59b8f05bd3de73062b5a Mon Sep 17 00:00:00 2001 From: Deependu Date: Fri, 11 Jul 2025 01:36:17 +0530 Subject: [PATCH 27/41] Update README.md --- README.md | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 34b55919d..259e3622c 100644 --- a/README.md +++ b/README.md @@ -202,7 +202,7 @@ ld.map( ## Features for optimizing and streaming datasets for model training
- ✅ Stream large cloud datasets + ✅ ge cloud datasets   Use data stored on the cloud without needing to download it all to your computer, saving time and space. @@ -1514,10 +1514,8 @@ if __name__ == "__main__":
-  -
- 🚀 Stream Large Datasets to Ultralytics Models with LitData + ✅ Stream Large Datasets to Ultralytics Models with LitData   @@ -1561,6 +1559,12 @@ patch_ultralytics() #### **Step 3: Train Like Usual — But Now From the Cloud ☁️** ```python +from litdata.integrations.ultralytics import patch_ultralytics + +patch_ultralytics() + +# ------- + from ultralytics import YOLO patch_ultralytics() @@ -1585,6 +1589,11 @@ That’s it — Ultralytics now streams your data via LitData under the hood! ### 📊 Benchmarks (Lightning Studio L4 GPU) +- On local machine: +Screenshot 2025-07-09 at 10 14 27 AM + +- On lightning studio (L4 GPU machine) +Screenshot 2025-07-11 at 12 50 11 AM While the performance gains aren't drastic (due to Ultralytics caching internally), this integration **unlocks all the benefits of streaming** and enables training on large-scale datasets from the cloud. From e1eefc51e1698d5e65e232d0b8f32015466d18bd Mon Sep 17 00:00:00 2001 From: Deependu Date: Fri, 11 Jul 2025 01:37:56 +0530 Subject: [PATCH 28/41] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 259e3622c..6974e43d9 100644 --- a/README.md +++ b/README.md @@ -202,7 +202,7 @@ ld.map( ## Features for optimizing and streaming datasets for model training
- ✅ ge cloud datasets + ✅ Stream large cloud datasets   Use data stored on the cloud without needing to download it all to your computer, saving time and space. From b000474847fc1127ab83242c16dffe879c815a1a Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Fri, 11 Jul 2025 01:39:01 +0530 Subject: [PATCH 29/41] remove redundant comment --- src/litdata/integrations/ultralytics/patch.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/litdata/integrations/ultralytics/patch.py b/src/litdata/integrations/ultralytics/patch.py index bb6980f5f..bcdfa3c33 100644 --- a/src/litdata/integrations/ultralytics/patch.py +++ b/src/litdata/integrations/ultralytics/patch.py @@ -32,7 +32,6 @@ def patch_ultralytics() -> None: check_det_dataset.__code__ = patch_check_det_dataset.__code__ - # Patch training visualizer (optional, but useful) from ultralytics.models.yolo.detect.train import DetectionTrainer DetectionTrainer.plot_training_samples = patch_detection_plot_training_samples @@ -43,7 +42,6 @@ def patch_ultralytics() -> None: DetectionValidator.plot_val_samples = patch_detection_plot_val_samples DetectionValidator.plot_predictions = patch_detection_plot_predictions - # Patch BaseDataset globally import ultralytics.data.base as base_module import ultralytics.data.dataset as child_modules From 0f2a4c4c2ac52433f3fc2cfba2f74251cd24ab5a Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Fri, 11 Jul 2025 09:22:03 +0530 Subject: [PATCH 30/41] test-cov --- requirements/test.txt | 1 + src/litdata/integrations/ultralytics/patch.py | 4 +- .../ultralytics_support/test_patch.py | 130 ++++++++++++++++++ 3 files changed, 134 insertions(+), 1 deletion(-) create mode 100644 tests/integrations/ultralytics_support/test_patch.py diff --git a/requirements/test.txt b/requirements/test.txt index 5a4fdea33..65c0002dc 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -16,3 +16,4 @@ transformers <4.53.0 zstd s5cmd >=0.2.0 soundfile >=0.13.0 # required for torchaudio backend +ultralytics >=8.3.16 diff --git a/src/litdata/integrations/ultralytics/patch.py b/src/litdata/integrations/ultralytics/patch.py index bcdfa3c33..07fc155c0 100644 --- a/src/litdata/integrations/ultralytics/patch.py +++ b/src/litdata/integrations/ultralytics/patch.py @@ -437,7 +437,9 @@ def parse_labels(labels: str, **kwargs: Any) -> Tuple[Any, Any, Any]: lb = np.array(lb, dtype=np.float32) if nl := len(lb): if keypoint: - assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each" + assert lb.shape[1] == (5 + nkpt * ndim), ( + f"labels require {(5 + nkpt * ndim)} columns each, but {lb.shape[1]} columns detected" + ) points = lb[:, 5:].reshape(-1, ndim)[:, :2] else: assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected" diff --git a/tests/integrations/ultralytics_support/test_patch.py b/tests/integrations/ultralytics_support/test_patch.py new file mode 100644 index 000000000..0746a8c4f --- /dev/null +++ b/tests/integrations/ultralytics_support/test_patch.py @@ -0,0 +1,130 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import patch + +import numpy as np +import pytest + +from litdata.integrations.ultralytics.patch import parse_labels, ultralytics_detection_transform + + +def make_lit_args(**kwargs): + return { + "names": ["class0", "class1"], + "kpt_shape": (5, 3), + **kwargs, + } + + +def test_parse_labels(): + # --- Basic test --- + labels = "0 0.1 0.2 0.3 0.4\n1 0.5 0.6 0.7 0.8" + lit_args = make_lit_args() + lb, segments, keypoints = parse_labels(labels, lit_args=lit_args) + assert lb.shape == (2, 5) + assert segments == [] + assert keypoints is None + assert np.all(lb[:, 0] < len(lit_args["names"])) + + # --- Single class --- + labels = "1 0.1 0.2 0.3 0.4" + lb, _, _ = parse_labels(labels, lit_args=lit_args, single_cls=True) + assert np.all(lb[:, 0] == 0) + + # --- With segments --- + segment_label = "0 " + " ".join([str(round(0.01 * i, 3)) for i in range(14)]) # 7 xy pairs + lb, segments, _ = parse_labels(segment_label, lit_args=lit_args) + assert lb.shape == (1, 5) + assert len(segments) == 1 + + # --- With keypoints --- + keypoint_str = "0 " + " ".join(["0.1"] * 19) + lb, _, keypoints = parse_labels(keypoint_str, lit_args=lit_args, keypoint=True) + assert lb.shape == (1, 5) + assert keypoints.shape == (1, 5, 3) + + # --- Duplicate removal --- + dup_labels = "0 0.1 0.2 0.3 0.4\n0 0.1 0.2 0.3 0.4" + lb, _, _ = parse_labels(dup_labels, lit_args=lit_args) + assert lb.shape == (1, 5) + + # --- Out of bounds class --- + bad_class_label = "99 0.1 0.2 0.3 0.4" + with pytest.raises(AssertionError, match="Label class 99 exceeds"): + parse_labels(bad_class_label, lit_args=lit_args) + + # --- Coordinates out of bounds --- + bad_coords_label = "0 1.2 1.2 1.2 1.2" + with pytest.raises(AssertionError, match="non-normalized or out of bounds"): + parse_labels(bad_coords_label, lit_args=lit_args) + + # --- Negative class --- + negative_label = "-1 0.1 0.2 0.3 0.4" + with pytest.raises(AssertionError, match="negative class labels"): + parse_labels(negative_label, lit_args=lit_args) + + # --- Wrong shape --- + wrong_shape_label = "0 0.1 0.2 0.3" # Only 4 elements + with pytest.raises(AssertionError, match="labels require 5 columns"): + parse_labels(wrong_shape_label, lit_args=lit_args) + + +def test_ultralytics_detection_transform(): + dummy_img = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8) + valid_label = "1 0.1 0.2 0.3 0.4\n0 0.5 0.6 0.7 0.8" + invalid_label = 123 # not a string + + dummy_img_resized = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) + ori_shape = dummy_img.shape[:2] + resized_shape = dummy_img_resized.shape[:2] + + lit_args = {"names": ["class0", "class1"], "kpt_shape": (0, 0)} + + with patch( + "litdata.integrations.ultralytics.patch.image_resize", + return_value=(dummy_img_resized, ori_shape, resized_shape), + ), patch( + "litdata.integrations.ultralytics.patch.parse_labels", + return_value=( + np.array([[1, 0.1, 0.2, 0.3, 0.4], [0, 0.5, 0.6, 0.7, 0.8]], dtype=np.float32), + [], + None, + ), + ): + # ✅ Valid transformation + data = {"img": dummy_img, "label": valid_label} + out = ultralytics_detection_transform(data, index=42, channels=3, img_size=640, lit_args=lit_args) + + assert isinstance(out, dict) + assert out["img"].shape == (640, 640, 3) + assert out["batch_idx"].item() == 42 + assert out["cls"].shape == (2, 1) + assert out["bboxes"].shape == (2, 4) + assert out["segments"] == [] + assert out["keypoints"] is None + assert out["normalized"] is True + assert out["bbox_format"] == "xywh" + assert out["ori_shape"] == ori_shape + assert out["resized_shape"] == resized_shape + assert isinstance(out["ratio_pad"], tuple) + assert out["channels"] == 3 + + # Missing index + with pytest.raises(ValueError, match="Index must be provided"): + ultralytics_detection_transform({"img": dummy_img, "label": valid_label}, channels=3, lit_args=lit_args) + + # Invalid label type + with pytest.raises(ValueError, match="Label must be a string"): + ultralytics_detection_transform( + {"img": dummy_img, "label": invalid_label}, index=0, channels=3, lit_args=lit_args + ) From 8bb10b4ff3178e507c2f5b4bb90e96aadf1b2e0f Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Sat, 12 Jul 2025 11:45:38 +0530 Subject: [PATCH 31/41] update readme --- README.md | 80 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 42 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 6974e43d9..07e103324 100644 --- a/README.md +++ b/README.md @@ -1514,6 +1514,45 @@ if __name__ == "__main__":
+## Features for transforming datasets + +
+ ✅ Parallelize data transformations (map) +  + +Apply the same change to different parts of the dataset at once to save time and effort. + +The `map` operator can be used to apply a function over a list of inputs. + +Here is an example where the `map` operator is used to apply a `resize_image` function over a folder of large images. + +```python +from litdata import map +from PIL import Image + +# Note: Inputs could also refer to files on s3 directly. +input_dir = "my_large_images" +inputs = [os.path.join(input_dir, f) for f in os.listdir(input_dir)] + +# The resize image takes one of the input (image_path) and the output directory. +# Files written to output_dir are persisted. +def resize_image(image_path, output_dir): + output_image_path = os.path.join(output_dir, os.path.basename(image_path)) + Image.open(image_path).resize((224, 224)).save(output_image_path) + +map( + fn=resize_image, + inputs=inputs, + output_dir="s3://my-bucket/my_resized_images", +) +``` + +
+ +  + +## Ultralytics (YOLO) Integration +
✅ Stream Large Datasets to Ultralytics Models with LitData @@ -1584,6 +1623,7 @@ That’s it — Ultralytics now streams your data via LitData under the hood! * 🧪 **Benchmark-tested** — supports both local and cloud training. * 🧩 **Plug-and-play with Ultralytics** — zero training code changes. * ☁️ Supports **S3, GCS, HTTP(S), and local disk** out-of-the-box. +* ✅ **Minimal code changes** to existing Ultralytics training scripts --- @@ -1597,45 +1637,9 @@ That’s it — Ultralytics now streams your data via LitData under the hood! While the performance gains aren't drastic (due to Ultralytics caching internally), this integration **unlocks all the benefits of streaming** and enables training on large-scale datasets from the cloud. -We’re also exploring a **custom LitData dataloader** built from scratch (potentially breaking GIL using Rust or multithreading). If it outperforms `torch.DataLoader`, future benchmarks could reflect significant performance boosts. 💡 +Instead of downloading entire datasets (which can be hundreds of GBs), you can now **stream data on-the-fly from S3, GCS, HTTP(S), or even local disk** — making it ideal for training in the cloud with limited storage and more efficient utilization of resources. -
- -  - - -## Features for transforming datasets - -
- ✅ Parallelize data transformations (map) -  - -Apply the same change to different parts of the dataset at once to save time and effort. - -The `map` operator can be used to apply a function over a list of inputs. - -Here is an example where the `map` operator is used to apply a `resize_image` function over a folder of large images. - -```python -from litdata import map -from PIL import Image - -# Note: Inputs could also refer to files on s3 directly. -input_dir = "my_large_images" -inputs = [os.path.join(input_dir, f) for f in os.listdir(input_dir)] - -# The resize image takes one of the input (image_path) and the output directory. -# Files written to output_dir are persisted. -def resize_image(image_path, output_dir): - output_image_path = os.path.join(output_dir, os.path.basename(image_path)) - Image.open(image_path).resize((224, 224)).save(output_image_path) - -map( - fn=resize_image, - inputs=inputs, - output_dir="s3://my-bucket/my_resized_images", -) -``` +We’re also exploring a **custom LitData dataloader** built from scratch (potentially breaking GIL using Rust or multithreading). If it outperforms `torch.DataLoader`, future benchmarks could reflect significant performance boosts. 💡
From 6113878556effa0fc0e1b961582a73037a80088c Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Sat, 12 Jul 2025 11:48:26 +0530 Subject: [PATCH 32/41] update --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 07e103324..97ca4044b 100644 --- a/README.md +++ b/README.md @@ -1514,6 +1514,9 @@ if __name__ == "__main__":
+  + + ## Features for transforming datasets
From 36c46f145a3cab08e8d82fc4c2530be750a84335 Mon Sep 17 00:00:00 2001 From: Deependu Date: Sat, 12 Jul 2025 11:53:05 +0530 Subject: [PATCH 33/41] Update src/litdata/streaming/dataset.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/litdata/streaming/dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 10b5eae09..e3e317484 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -451,21 +451,22 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: ) ) if hasattr(self, "transform"): - self.transform_kwargs["index"] = index.index + local_transform_kwargs = self.transform_kwargs.copy() + local_transform_kwargs["index"] = index.index if isinstance(self.transform, list): for transform_fn in self.transform: signature = inspect.signature(transform_fn) accepts_kwargs = any( param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values() ) - item = transform_fn(item, **self.transform_kwargs) if accepts_kwargs else transform_fn(item) + item = transform_fn(item, **local_transform_kwargs) if accepts_kwargs else transform_fn(item) else: # check if transform function accepts kwargs signature = inspect.signature(self.transform) accepts_kwargs = any( param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values() ) - item = self.transform(item, **self.transform_kwargs) if accepts_kwargs else self.transform(item) + item = self.transform(item, **local_transform_kwargs) if accepts_kwargs else self.transform(item) return item def __next__(self) -> Any: From be4aab7e823d78df51d3d90d0b8bd2d4cb53943d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Jul 2025 05:50:14 +0000 Subject: [PATCH 34/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/dataset.py | 1 - .../ultralytics_support/test_patch.py | 21 +++++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 0ec7aebff..94ec3eba4 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -10,7 +10,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import os from time import time diff --git a/tests/integrations/ultralytics_support/test_patch.py b/tests/integrations/ultralytics_support/test_patch.py index 0746a8c4f..83224d12d 100644 --- a/tests/integrations/ultralytics_support/test_patch.py +++ b/tests/integrations/ultralytics_support/test_patch.py @@ -90,15 +90,18 @@ def test_ultralytics_detection_transform(): lit_args = {"names": ["class0", "class1"], "kpt_shape": (0, 0)} - with patch( - "litdata.integrations.ultralytics.patch.image_resize", - return_value=(dummy_img_resized, ori_shape, resized_shape), - ), patch( - "litdata.integrations.ultralytics.patch.parse_labels", - return_value=( - np.array([[1, 0.1, 0.2, 0.3, 0.4], [0, 0.5, 0.6, 0.7, 0.8]], dtype=np.float32), - [], - None, + with ( + patch( + "litdata.integrations.ultralytics.patch.image_resize", + return_value=(dummy_img_resized, ori_shape, resized_shape), + ), + patch( + "litdata.integrations.ultralytics.patch.parse_labels", + return_value=( + np.array([[1, 0.1, 0.2, 0.3, 0.4], [0, 0.5, 0.6, 0.7, 0.8]], dtype=np.float32), + [], + None, + ), ), ): # ✅ Valid transformation From 5d315a0b3f5f7f1a55eef5f7b3a525ae28e4fc22 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Tue, 15 Jul 2025 11:20:56 +0530 Subject: [PATCH 35/41] update --- tests/processing/test_functions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/processing/test_functions.py b/tests/processing/test_functions.py index 0b20e230f..dc675f628 100644 --- a/tests/processing/test_functions.py +++ b/tests/processing/test_functions.py @@ -124,8 +124,7 @@ def random_image(index): @pytest.mark.skipif(sys.platform == "win32", reason="too slow") -@pytest.mark.parametrize("verbose", [False, True]) -def test_optimize_append_overwrite(tmpdir, verbose): +def test_optimize_append_overwrite(tmpdir): output_dir = str(tmpdir / "output_dir") optimize( @@ -134,7 +133,6 @@ def test_optimize_append_overwrite(tmpdir, verbose): num_workers=1, output_dir=output_dir, chunk_bytes="64MB", - verbose=verbose, ) ds = StreamingDataset(output_dir) From b16d4aa6d9d05406acbe4e8427f8ad4d88ca0a58 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Tue, 15 Jul 2025 13:22:14 +0530 Subject: [PATCH 36/41] update pr --- src/litdata/integrations/ultralytics/patch.py | 100 ++++++++++-------- src/litdata/streaming/dataset.py | 22 ++-- 2 files changed, 69 insertions(+), 53 deletions(-) diff --git a/src/litdata/integrations/ultralytics/patch.py b/src/litdata/integrations/ultralytics/patch.py index 07fc155c0..39f3f0485 100644 --- a/src/litdata/integrations/ultralytics/patch.py +++ b/src/litdata/integrations/ultralytics/patch.py @@ -13,7 +13,7 @@ import math import os from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import numpy as np import torch @@ -66,7 +66,7 @@ def patch_ultralytics() -> None: from ultralytics.utils.plotting import plot_images class PatchedUltralyticsBaseDataset(UltralyticsBaseDataset): - def __init__(self: Any, img_path: str, classes: Optional[List[int]] = None, *args: Any, **kwargs: Any): + def __init__(self: Any, img_path: str, classes: Optional[list[int]] = None, *args: Any, **kwargs: Any): print("patched ultralytics dataset: 🔥") self.litdata_dataset = img_path self.classes = classes @@ -74,20 +74,18 @@ def __init__(self: Any, img_path: str, classes: Optional[List[int]] = None, *arg self.streaming_dataset = StreamingDataset( img_path, transform=[ - ultralytics_detection_transform, + partial( + ultralytics_detection_transform, + img_size=self.imgsz, + channels=self.channels, + keypoint=self.use_keypoints, + single_cls=self.single_cls, + lit_args=self.data, + ), partial(self.transform_update_label, classes), self.update_labels_info, self.transforms, ], - transform_kwargs={ - "img_size": self.imgsz, - "channels": self.channels, - "segment": self.use_segments, - "use_keypoints": self.use_keypoints, - "use_obb": self.use_obb, - "lit_args": self.data, - "single_cls": self.single_cls, - }, ) self.ni = len(self.streaming_dataset) self.buffer = list(range(len(self.streaming_dataset))) @@ -101,11 +99,11 @@ def get_image_and_label(self: Any, index: int) -> None: # e.g. use `self.litdata_dataset[index]` raise NotImplementedError("Custom logic here") - def get_img_files(self: Any, img_path: Union[str, List[str]]) -> List[str]: + def get_img_files(self: Any, img_path: Union[str, list[str]]) -> list[str]: """Let this method return an empty list to avoid errors.""" return [] - def get_labels(self: Any) -> List[Dict[str, Any]]: + def get_labels(self: Any) -> list[dict[str, Any]]: # this is used to get number of images (ni) in the BaseDataset class return [] @@ -130,14 +128,14 @@ def update_labels(self: Any, *args: Any, **kwargs: Any) -> None: pass def transform_update_label( - self: Any, include_class: Optional[List[int]], label: Dict, *args: Any, **kwargs: Any - ) -> Dict: + self: Any, include_class: Optional[list[int]], label: dict, *args: Any, **kwargs: Any + ) -> dict: """Update labels to include only specified classes. Args: self: PatchedUltralyticsBaseDataset instance. - include_class (List[int], optional): List of classes to include. If None, all classes are included. - label (Dict): Label to update. + include_class (list[int], optional): list of classes to include. If None, all classes are included. + label (dict): Label to update. *args: Additional positional arguments (unused). **kwargs: Additional keyword arguments (unused). """ @@ -159,7 +157,7 @@ def transform_update_label( return label - def __getitem__(self: Any, index: int) -> Dict[str, Any]: + def __getitem__(self: Any, index: int) -> dict[str, Any]: """Return transformed label information for given index.""" # return self.transforms(self.get_image_and_label(index)) if not hasattr(self, "streaming_dataset"): @@ -192,12 +190,12 @@ def __getitem__(self: Any, index: int) -> Dict[str, Any]: ) return self.transforms(data) - def patch_detection_plot_training_samples(self: Any, batch: Dict[str, Any], ni: int) -> None: + def patch_detection_plot_training_samples(self: Any, batch: dict[str, Any], ni: int) -> None: """Plot training samples with their annotations. Args: self: DetectionTrainer instance. - batch (Dict[str, Any]): Dictionary containing batch data. + batch (dict[str, Any]): dictionary containing batch data. ni (int): Number of iterations. """ plot_images( @@ -207,12 +205,12 @@ def patch_detection_plot_training_samples(self: Any, batch: Dict[str, Any], ni: on_plot=self.on_plot, ) - def patch_detection_plot_val_samples(self: Any, batch: Dict[str, Any], ni: int) -> None: + def patch_detection_plot_val_samples(self: Any, batch: dict[str, Any], ni: int) -> None: """Plot validation image samples. Args: self: DetectionValidator instance. - batch (Dict[str, Any]): Batch containing images and annotations. + batch (dict[str, Any]): Batch containing images and annotations. ni (int): Batch index. """ plot_images( @@ -224,14 +222,14 @@ def patch_detection_plot_val_samples(self: Any, batch: Dict[str, Any], ni: int) ) def patch_detection_plot_predictions( - self: Any, batch: Dict[str, Any], preds: List[Dict[str, torch.Tensor]], ni: int, max_det: Optional[int] = None + self: Any, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int, max_det: Optional[int] = None ) -> None: """Plot predicted bounding boxes on input images and save the result. Args: self: DetectionValidator instance. - batch (Dict[str, Any]): Batch containing images and annotations. - preds (List[Dict[str, torch.Tensor]]): List of predictions from the model. + batch (dict[str, Any]): Batch containing images and annotations. + preds (list[dict[str, torch.Tensor]]): list of predictions from the model. ni (int): Batch index. max_det (Optional[int]): Maximum number of detections to plot. """ @@ -254,7 +252,7 @@ def patch_detection_plot_predictions( on_plot=self.on_plot, ) - def patch_check_det_dataset(dataset: str, _: bool = True) -> Dict: + def patch_check_det_dataset(dataset: str, _: bool = True) -> dict: if not (isinstance(dataset, str) and dataset.endswith(".yaml") and os.path.isfile(dataset)): raise ValueError("Dataset must be a string ending with '.yaml' and point to a valid file.") @@ -317,7 +315,7 @@ def patch_detection_plot_training_labels(self: Any) -> None: """Create a labeled training plot of the YOLO model.""" pass - def patch_get_labels(self: Any) -> List[Dict[str, Any]]: + def patch_get_labels(self: Any) -> list[dict[str, Any]]: # this is used to get number of images (ni) in the BaseDataset class return [] @@ -327,7 +325,7 @@ def patch_none_function(*args: Any, **kwargs: Any) -> None: def image_resize( im: Any, imgsz: int, rect_mode: bool = True, augment: bool = True - ) -> Tuple[Any, Tuple[int, int], Tuple[int, int]]: + ) -> tuple[Any, tuple[int, int], tuple[int, int]]: """Resize the image to a fixed size. Args: @@ -337,7 +335,7 @@ def image_resize( augment (bool): If True, data augmentation is applied. Returns: - Tuple[Any, Tuple[int, int], Tuple[int, int]]: Resized image and its original dimensions. + tuple[Any, tuple[int, int], tuple[int, int]]: Resized image and its original dimensions. """ import cv2 @@ -369,18 +367,29 @@ def patch_compose_transform_call(self: Any, data: Any) -> Any: # ------- helper transformations ------- -def ultralytics_detection_transform(data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: +def ultralytics_detection_transform( + data: dict[str, Any], + index: int, + channels: int = 3, + img_size: int = 640, + keypoint: bool = False, + single_cls: bool = False, + lit_args: dict[str, Any] = {}, +) -> dict[str, Any]: """Transform function for YOLO detection datasets. Args: - data (Dict[str, Any]): Input data containing image and label. - kwargs (Dict[str, Any]): Additional keyword arguments, including the index of the data item. + data (dict[str, Any]): Input data containing image and label. + index (int): Index of the data item. + channels (int): Number of channels in the image. + img_size (int): Target size for resizing the image. + keypoint (bool): Whether to include keypoints in the label. + single_cls (bool): Whether to use single class mode. + lit_args (dict[str, Any]): Additional arguments for the transform. Returns: - Dict[str, Any]: Transformed data with image and label. + dict[str, Any]: Transformed data with image and label. """ - index = kwargs.get("index") - channels = kwargs.get("channels", 3) # default to 3 channels (RGB) if index is None: raise ValueError("Index must be provided for YOLO detection transform.") @@ -388,17 +397,15 @@ def ultralytics_detection_transform(data: Dict[str, Any], **kwargs: Any) -> Dict # split label on the basis of `\n` and then split each line on the basis of ` ` # first element is class, rest are bbox coordinates if isinstance(label, str): - img, ori_shape, resized_shape = image_resize( - data["img"], imgsz=kwargs.get("img_size", 640), rect_mode=True, augment=True - ) + img, ori_shape, resized_shape = image_resize(data["img"], imgsz=img_size, rect_mode=True, augment=True) ratio_pad = ( resized_shape[0] / ori_shape[0], resized_shape[1] / ori_shape[1], ) # for evaluation - lb, segments, keypoint = parse_labels(label, **kwargs) + lb, segments, keypoint = parse_labels(label, keypoint=keypoint, single_cls=single_cls, lit_args=lit_args) data = { - "batch_idx": np.array([index]), # ← add this! + "batch_idx": np.array([index]), "img": img, "cls": lb[:, 0:1], # n, 1 "bboxes": lb[:, 1:], # n, 4 @@ -417,14 +424,13 @@ def ultralytics_detection_transform(data: Dict[str, Any], **kwargs: Any) -> Dict return data -def parse_labels(labels: str, **kwargs: Any) -> Tuple[Any, Any, Any]: +def parse_labels( + labels: str, keypoint: bool = False, single_cls: bool = False, lit_args: dict[str, Any] = {} +) -> tuple[Any, Any, Any]: from ultralytics.utils.ops import segments2boxes - keypoint = kwargs.get("keypoint", False) - single_cls = kwargs.get("single_cls", False) - data = kwargs["lit_args"] - nkpt, ndim = data.get("kpt_shape", (0, 0)) - num_cls: int = len(data["names"]) + nkpt, ndim = lit_args.get("kpt_shape", (0, 0)) + num_cls: int = len(lit_args["names"]) segments: Any = [] keypoints: Any = None diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 94ec3eba4..f85db3f41 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -10,6 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging import os from time import time @@ -198,11 +199,12 @@ def __init__( self.max_pre_download = max_pre_download if transform is not None: transform = transform if isinstance(transform, list) else [transform] + self.transform_fn_accepts_index = {} for t in transform: if not callable(t): raise ValueError(f"Transform should be a callable. Found {t}") + self.transform_fn_accepts_index[id(t)] = has_argument_named_index(t) self.transform = transform - self.transform_kwargs = transform_kwargs or {} self._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache @property @@ -444,11 +446,13 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: ) ) if hasattr(self, "transform"): - if isinstance(self.transform, list): - for transform_fn in self.transform: - item = transform_fn(item) - else: - item = self.transform(item) + transforms = self.transform if isinstance(self.transform, list) else [self.transform] + for transform_fn in transforms: + key = id(transform_fn) + if key not in self.transform_fn_accepts_index: + self.transform_fn_accepts_index[key] = has_argument_named_index(transform_fn) + + item = transform_fn(item, index=index) if self.transform_fn_accepts_index[key] else transform_fn(item) return item @@ -739,3 +743,9 @@ def _replay_chunks_sampling( break return chunks_index, indexes + + +def has_argument_named_index(func): + """Returns True if the function has an argument named 'index'.""" + sig = inspect.signature(func) + return "index" in sig.parameters From 9ca3dbd5d1c3e6c266d6127944fc6f1b7bbff9a2 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Tue, 15 Jul 2025 14:06:45 +0530 Subject: [PATCH 37/41] update --- src/litdata/integrations/ultralytics/optimize.py | 4 ++-- tests/integrations/ultralytics_support/test_patch.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/litdata/integrations/ultralytics/optimize.py b/src/litdata/integrations/ultralytics/optimize.py index e991ca5a4..1b8c9e124 100644 --- a/src/litdata/integrations/ultralytics/optimize.py +++ b/src/litdata/integrations/ultralytics/optimize.py @@ -13,7 +13,7 @@ import os from pathlib import Path -from typing import Dict, Optional, Union +from typing import Optional, Union import yaml @@ -22,7 +22,7 @@ from litdata.streaming.resolver import Dir, _resolve_dir -def _ultralytics_optimize_fn(img_path: str) -> Optional[Dict]: +def _ultralytics_optimize_fn(img_path: str) -> Optional[dict]: """Optimized function for Ultralytics that reads image + label and optionally re-encodes to reduce size.""" if not img_path.endswith((".jpg", ".jpeg", ".png")): raise ValueError(f"Unsupported image format: {img_path}. Supported formats are .jpg, .jpeg, and .png.") diff --git a/tests/integrations/ultralytics_support/test_patch.py b/tests/integrations/ultralytics_support/test_patch.py index 83224d12d..f2c6cc02a 100644 --- a/tests/integrations/ultralytics_support/test_patch.py +++ b/tests/integrations/ultralytics_support/test_patch.py @@ -124,7 +124,9 @@ def test_ultralytics_detection_transform(): # Missing index with pytest.raises(ValueError, match="Index must be provided"): - ultralytics_detection_transform({"img": dummy_img, "label": valid_label}, channels=3, lit_args=lit_args) + ultralytics_detection_transform( + {"img": dummy_img, "label": valid_label}, index=None, channels=3, lit_args=lit_args + ) # Invalid label type with pytest.raises(ValueError, match="Label must be a string"): From d9a319f3b94fb9d1e92572d34f5574c0b140e9f3 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Tue, 15 Jul 2025 18:12:42 +0530 Subject: [PATCH 38/41] update --- src/litdata/streaming/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index f85db3f41..33ee4312e 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -197,9 +197,9 @@ def __init__( self.storage_options = storage_options self.session_options = session_options self.max_pre_download = max_pre_download + self.transform_fn_accepts_index = {} if transform is not None: transform = transform if isinstance(transform, list) else [transform] - self.transform_fn_accepts_index = {} for t in transform: if not callable(t): raise ValueError(f"Transform should be a callable. Found {t}") From bc760b2e82aa9f3f4940f5ce37b48f98e9b7ea71 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Tue, 15 Jul 2025 20:07:32 +0530 Subject: [PATCH 39/41] update --- src/litdata/streaming/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 33ee4312e..926713f3a 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -745,7 +745,7 @@ def _replay_chunks_sampling( return chunks_index, indexes -def has_argument_named_index(func): +def has_argument_named_index(func) -> bool: """Returns True if the function has an argument named 'index'.""" sig = inspect.signature(func) return "index" in sig.parameters From 928d6a35e1ef6be265dcc7b7df0b271c06b1a131 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Wed, 16 Jul 2025 01:06:30 +0530 Subject: [PATCH 40/41] update --- src/litdata/streaming/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 926713f3a..a04bc127d 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -745,7 +745,7 @@ def _replay_chunks_sampling( return chunks_index, indexes -def has_argument_named_index(func) -> bool: +def has_argument_named_index(func: Callable) -> bool: """Returns True if the function has an argument named 'index'.""" sig = inspect.signature(func) return "index" in sig.parameters From eeb25fd9f2eed3704f3558e650cfb655d25efd23 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Thu, 17 Jul 2025 15:35:35 +0530 Subject: [PATCH 41/41] Refactor image optimization function to accept customizable image quality parameter --- src/litdata/integrations/ultralytics/optimize.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/litdata/integrations/ultralytics/optimize.py b/src/litdata/integrations/ultralytics/optimize.py index 1b8c9e124..235614956 100644 --- a/src/litdata/integrations/ultralytics/optimize.py +++ b/src/litdata/integrations/ultralytics/optimize.py @@ -12,6 +12,7 @@ # limitations under the License. import os +from functools import partial from pathlib import Path from typing import Optional, Union @@ -22,7 +23,7 @@ from litdata.streaming.resolver import Dir, _resolve_dir -def _ultralytics_optimize_fn(img_path: str) -> Optional[dict]: +def _ultralytics_optimize_fn(img_path: str, img_quality: int) -> Optional[dict]: """Optimized function for Ultralytics that reads image + label and optionally re-encodes to reduce size.""" if not img_path.endswith((".jpg", ".jpeg", ".png")): raise ValueError(f"Unsupported image format: {img_path}. Supported formats are .jpg, .jpeg, and .png.") @@ -38,8 +39,8 @@ def _ultralytics_optimize_fn(img_path: str) -> Optional[dict]: # JPEG re-encode if image is jpeg or png if img_ext in [".jpg", ".jpeg", ".png"]: - # Reduce quality to 90% - encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 90] + # Reduce quality to specified value of img_quality + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), img_quality] success, encoded = cv2.imencode(".jpg", img, encode_param) if not success: raise ValueError(f"JPEG encoding failed for: {img_path}") @@ -68,6 +69,7 @@ def optimize_ultralytics_dataset( chunk_size: Optional[int] = None, chunk_bytes: Optional[Union[int, str]] = None, num_workers: int = 1, + img_quality: int = 90, verbose: bool = False, ) -> None: """Optimize an Ultralytics dataset by converting it into chunks and resizing images. @@ -78,6 +80,7 @@ def optimize_ultralytics_dataset( chunk_size: Number of samples per chunk. If None, no chunking is applied. chunk_bytes: Maximum size of each chunk in bytes. If None, no size limit is applied. num_workers: Number of worker processes to use for optimization. Defaults to 1. + img_quality: Quality of the JPEG images after optimization (0-100). Defaults to 90. verbose: Whether to print progress messages. Defaults to False. """ if not _ULTRALYTICS_AVAILABLE: @@ -115,7 +118,7 @@ def optimize_ultralytics_dataset( inputs = list_all_files(dataset_config[mode]) optimize( - fn=_ultralytics_optimize_fn, + fn=partial(_ultralytics_optimize_fn, img_quality=img_quality), inputs=inputs, output_dir=mode_output_dir.url or mode_output_dir.path or "optimized_data", chunk_bytes=chunk_bytes,