diff --git a/README.md b/README.md index 1dbe637c5..5e8f837c5 100644 --- a/README.md +++ b/README.md @@ -1554,6 +1554,100 @@ map(   +## Ultralytics (YOLO) Integration + +
+ ✅ Stream Large Datasets to Ultralytics Models with LitData + +  + +This feature enables **training Ultralytics models (like YOLO)** directly from **LitData’s optimized streaming datasets**. Now you can train on massive datasets (e.g., 500GB+) without downloading everything to disk — just stream from **S3, GCS, local paths, or HTTP(S)** with minimal overhead. + +--- + +### 🔧 How It Works + +#### **Step 1: Optimize Your Dataset (One-time Step)** + +Convert your existing Ultralytics-style dataset into an optimized streaming format: + +```python +from litdata.integrations.ultralytics import optimize_ultralytics_dataset + +optimize_ultralytics_dataset( + "coco128.yaml", # Original dataset config + "s3://some-bucket/optimized-data", # Cloud path or local directory + num_workers=4, # Number of concurrent workers + chunk_bytes="64MB", # Chunk size for streaming +) +``` + +This generates an optimized dataset and creates a new `litdata_coco128.yaml` config to use for training. + +--- + +#### **Step 2: Patch Ultralytics for Streaming** + +Before training, patch Ultralytics internals to enable LitData streaming: + +```python +from litdata.integrations.ultralytics import patch_ultralytics + +patch_ultralytics() +``` + +--- + +#### **Step 3: Train Like Usual — But Now From the Cloud ☁️** + +```python +from litdata.integrations.ultralytics import patch_ultralytics + +patch_ultralytics() + +# ------- + +from ultralytics import YOLO + +patch_ultralytics() + +model = YOLO("yolo11n.pt") +model.train(data="litdata_coco128.yaml", epochs=100, imgsz=640) +``` + +That’s it — Ultralytics now streams your data via LitData under the hood! + +--- + +### ✅ Benefits + +* 🔁 **Stream datasets of any size** — no need to fully download. +* 💾 **Save disk space** — only minimal local caching is used. +* 🧪 **Benchmark-tested** — supports both local and cloud training. +* 🧩 **Plug-and-play with Ultralytics** — zero training code changes. +* ☁️ Supports **S3, GCS, HTTP(S), and local disk** out-of-the-box. +* ✅ **Minimal code changes** to existing Ultralytics training scripts + +--- + +### 📊 Benchmarks (Lightning Studio L4 GPU) + +- On local machine: +Screenshot 2025-07-09 at 10 14 27 AM + +- On lightning studio (L4 GPU machine) +Screenshot 2025-07-11 at 12 50 11 AM + +While the performance gains aren't drastic (due to Ultralytics caching internally), this integration **unlocks all the benefits of streaming** and enables training on large-scale datasets from the cloud. + +Instead of downloading entire datasets (which can be hundreds of GBs), you can now **stream data on-the-fly from S3, GCS, HTTP(S), or even local disk** — making it ideal for training in the cloud with limited storage and more efficient utilization of resources. + +We’re also exploring a **custom LitData dataloader** built from scratch (potentially breaking GIL using Rust or multithreading). If it outperforms `torch.DataLoader`, future benchmarks could reflect significant performance boosts. 💡 + +
+ +  + ---- # Benchmarks diff --git a/requirements/test.txt b/requirements/test.txt index 5a4fdea33..65c0002dc 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -16,3 +16,4 @@ transformers <4.53.0 zstd s5cmd >=0.2.0 soundfile >=0.13.0 # required for torchaudio backend +ultralytics >=8.3.16 diff --git a/src/litdata/constants.py b/src/litdata/constants.py index 35d9797fe..bfc1610c8 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -44,6 +44,7 @@ _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") _AV_AVAILABLE = RequirementCache("av") +_ULTRALYTICS_AVAILABLE = RequirementCache("ultralytics") _DEBUG = bool(int(os.getenv("DEBUG_LITDATA", "0"))) _PRINT_DEBUG_LOGS = bool(int(os.getenv("PRINT_DEBUG_LOGS", "0"))) diff --git a/src/litdata/integrations/__init__.py b/src/litdata/integrations/__init__.py new file mode 100644 index 000000000..27efc0815 --- /dev/null +++ b/src/litdata/integrations/__init__.py @@ -0,0 +1,12 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/litdata/integrations/ultralytics/__init__.py b/src/litdata/integrations/ultralytics/__init__.py new file mode 100644 index 000000000..f9a4fae1b --- /dev/null +++ b/src/litdata/integrations/ultralytics/__init__.py @@ -0,0 +1,16 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from litdata.integrations.ultralytics.optimize import optimize_ultralytics_dataset +from litdata.integrations.ultralytics.patch import patch_ultralytics + +__all__ = ["optimize_ultralytics_dataset", "patch_ultralytics"] diff --git a/src/litdata/integrations/ultralytics/optimize.py b/src/litdata/integrations/ultralytics/optimize.py new file mode 100644 index 000000000..235614956 --- /dev/null +++ b/src/litdata/integrations/ultralytics/optimize.py @@ -0,0 +1,178 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import partial +from pathlib import Path +from typing import Optional, Union + +import yaml + +from litdata.constants import _PIL_AVAILABLE, _ULTRALYTICS_AVAILABLE +from litdata.processing.functions import optimize +from litdata.streaming.resolver import Dir, _resolve_dir + + +def _ultralytics_optimize_fn(img_path: str, img_quality: int) -> Optional[dict]: + """Optimized function for Ultralytics that reads image + label and optionally re-encodes to reduce size.""" + if not img_path.endswith((".jpg", ".jpeg", ".png")): + raise ValueError(f"Unsupported image format: {img_path}. Supported formats are .jpg, .jpeg, and .png.") + + import cv2 + + img_ext = os.path.splitext(img_path)[-1].lower() + + # Read image using OpenCV + img = cv2.imread(img_path, cv2.IMREAD_COLOR) + if img is None: + raise ValueError(f"Failed to read image: {img_path}") + + # JPEG re-encode if image is jpeg or png + if img_ext in [".jpg", ".jpeg", ".png"]: + # Reduce quality to specified value of img_quality + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), img_quality] + success, encoded = cv2.imencode(".jpg", img, encode_param) + if not success: + raise ValueError(f"JPEG encoding failed for: {img_path}") + + # Decode it back to a numpy array (OpenCV default format) + img = cv2.imdecode(encoded, cv2.IMREAD_COLOR) + + # Load the label + label = "" + label_path = img_path.replace("images", "labels").replace(img_ext, ".txt") + if os.path.isfile(label_path): + with open(label_path) as f: + label = f.read().strip() + else: + return None # skip this sample + + return { + "img": img, + "label": label, + } + + +def optimize_ultralytics_dataset( + yaml_path: str, + output_dir: str, + chunk_size: Optional[int] = None, + chunk_bytes: Optional[Union[int, str]] = None, + num_workers: int = 1, + img_quality: int = 90, + verbose: bool = False, +) -> None: + """Optimize an Ultralytics dataset by converting it into chunks and resizing images. + + Args: + yaml_path: Path to the dataset YAML file. + output_dir: Directory where the optimized dataset will be saved. + chunk_size: Number of samples per chunk. If None, no chunking is applied. + chunk_bytes: Maximum size of each chunk in bytes. If None, no size limit is applied. + num_workers: Number of worker processes to use for optimization. Defaults to 1. + img_quality: Quality of the JPEG images after optimization (0-100). Defaults to 90. + verbose: Whether to print progress messages. Defaults to False. + """ + if not _ULTRALYTICS_AVAILABLE: + raise ImportError( + "Ultralytics is not installed. Please install it with `pip install ultralytics` to use this function." + ) + if not _PIL_AVAILABLE: + raise ImportError("PIL is not installed. Please install it with `pip install pillow` to use this function.") + + # check if the YAML file exists and is a file + if not os.path.isfile(yaml_path): + raise FileNotFoundError(f"YAML file not found: {yaml_path}") + + if chunk_bytes is None and chunk_size is None: + raise ValueError("Either chunk_bytes or chunk_size must be specified.") + + if chunk_bytes is not None and chunk_size is not None: + raise ValueError("Only one of chunk_bytes or chunk_size should be specified, not both.") + + from ultralytics.data.utils import check_det_dataset + + # parse the YAML file & make sure data exists, else download it + dataset_config = check_det_dataset(yaml_path) + + output_dir = _resolve_dir(output_dir) + + mode_to_dir = {} + + for mode in ("train", "val", "test"): + if dataset_config[mode] is None: + continue + if not os.path.exists(dataset_config[mode]): + raise FileNotFoundError(f"Dataset directory not found for {mode}: {dataset_config[mode]}") + mode_output_dir = get_output_dir(output_dir, mode) + inputs = list_all_files(dataset_config[mode]) + + optimize( + fn=partial(_ultralytics_optimize_fn, img_quality=img_quality), + inputs=inputs, + output_dir=mode_output_dir.url or mode_output_dir.path or "optimized_data", + chunk_bytes=chunk_bytes, + chunk_size=chunk_size, + num_workers=num_workers, + mode="overwrite", + verbose=verbose, + ) + + mode_to_dir[mode] = mode_output_dir + print(f"Optimized {mode} dataset and saved to {mode_output_dir} ✅") + + # update the YAML file with the new paths + for mode, dir in mode_to_dir.items(): + if mode in dataset_config: + dataset_config[mode] = dir.url if dir.url else dir.path + else: + raise ValueError(f"Mode '{mode}' not found in dataset configuration.") + + # convert path to string if it's a Path object + for key, value in dataset_config.items(): + if isinstance(value, Path): + dataset_config[key] = str(value) + + # save the updated YAML file + output_yaml = Path(yaml_path).with_name("litdata_" + Path(yaml_path).name) + with open(output_yaml, "w") as f: + yaml.dump(dataset_config, f) + + +def get_output_dir(output_dir: Dir, mode: str) -> Dir: + if not isinstance(output_dir, Dir): + raise TypeError(f"Expected output_dir to be of type Dir, got {type(output_dir)} instead.") + url, path = output_dir.url, output_dir.path + if url is not None: + url = url.rstrip("/") + f"/{mode}" + if path is not None: + path = os.path.join(path, f"{mode}") + + return Dir(url=url, path=path) + + +def list_all_files(_path: str) -> list[str]: + path = Path(_path) + + if path.is_dir(): + # Recursively list all files under the directory + return [str(p) for p in path.rglob("*") if p.is_file()] + + if path.is_file() and path.suffix == ".txt": + # Read lines and return cleaned-up paths + base_dir = path.parent # use the parent of the txt file to resolve relative paths + with open(path) as f: + return [str((base_dir / line.strip()).resolve()) for line in f if line.strip()] + + else: + raise ValueError(f"Unsupported path: {path}") diff --git a/src/litdata/integrations/ultralytics/patch.py b/src/litdata/integrations/ultralytics/patch.py new file mode 100644 index 000000000..39f3f0485 --- /dev/null +++ b/src/litdata/integrations/ultralytics/patch.py @@ -0,0 +1,476 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import os +from functools import partial +from typing import Any, Optional, Union + +import numpy as np +import torch + +from litdata.constants import _ULTRALYTICS_AVAILABLE +from litdata.streaming.dataloader import StreamingDataLoader +from litdata.streaming.dataset import StreamingDataset + + +def patch_ultralytics() -> None: + """Patch Ultralytics to use the LitData optimize function.""" + if not _ULTRALYTICS_AVAILABLE: + raise ImportError("Ultralytics is not available. Please install it to use this functionality.") + + from ultralytics.data.utils import check_det_dataset + + check_det_dataset.__code__ = patch_check_det_dataset.__code__ + + from ultralytics.models.yolo.detect.train import DetectionTrainer + + DetectionTrainer.plot_training_samples = patch_detection_plot_training_samples + DetectionTrainer.plot_training_labels = patch_none_function + + from ultralytics.models.yolo.detect.val import DetectionValidator + + DetectionValidator.plot_val_samples = patch_detection_plot_val_samples + DetectionValidator.plot_predictions = patch_detection_plot_predictions + + import ultralytics.data.base as base_module + import ultralytics.data.dataset as child_modules + + base_module.BaseDataset = PatchedUltralyticsBaseDataset + base_module.BaseDataset.set_rectangle = patch_none_function + child_modules.YOLODataset.__bases__ = (PatchedUltralyticsBaseDataset,) + child_modules.YOLODataset.get_labels = patch_get_labels + + from ultralytics.data.build import build_dataloader + + build_dataloader.__code__ = patch_build_dataloader.__code__ + + from ultralytics.data.augment import Compose + + Compose.__call__.__code__ = patch_compose_transform_call.__code__ + + print("✅ Ultralytics successfully patched to use LitData.") + + +if _ULTRALYTICS_AVAILABLE: + from ultralytics.data.base import BaseDataset as UltralyticsBaseDataset + from ultralytics.utils.plotting import plot_images + + class PatchedUltralyticsBaseDataset(UltralyticsBaseDataset): + def __init__(self: Any, img_path: str, classes: Optional[list[int]] = None, *args: Any, **kwargs: Any): + print("patched ultralytics dataset: 🔥") + self.litdata_dataset = img_path + self.classes = classes + super().__init__(img_path, classes=classes, *args, **kwargs) + self.streaming_dataset = StreamingDataset( + img_path, + transform=[ + partial( + ultralytics_detection_transform, + img_size=self.imgsz, + channels=self.channels, + keypoint=self.use_keypoints, + single_cls=self.single_cls, + lit_args=self.data, + ), + partial(self.transform_update_label, classes), + self.update_labels_info, + self.transforms, + ], + ) + self.ni = len(self.streaming_dataset) + self.buffer = list(range(len(self.streaming_dataset))) + + def __len__(self: Any) -> int: + """Return the length of the dataset.""" + return len(self.streaming_dataset) + + def get_image_and_label(self: Any, index: int) -> None: + # Your custom logic to load from .litdata + # e.g. use `self.litdata_dataset[index]` + raise NotImplementedError("Custom logic here") + + def get_img_files(self: Any, img_path: Union[str, list[str]]) -> list[str]: + """Let this method return an empty list to avoid errors.""" + return [] + + def get_labels(self: Any) -> list[dict[str, Any]]: + # this is used to get number of images (ni) in the BaseDataset class + return [] + + def cache_images(self: Any) -> None: + pass + + def cache_images_to_disk(self: Any, i: int) -> None: + pass + + def check_cache_disk(self: Any, safety_margin: float = 0.5) -> bool: + """Check if the cache disk is available.""" + # This method is not used in the streaming dataset, so we can return True + return True + + def check_cache_ram(self: Any, safety_margin: float = 0.5) -> bool: + """Check if the cache RAM is available.""" + # This method is not used in the streaming dataset, so we can return True + return True + + def update_labels(self: Any, *args: Any, **kwargs: Any) -> None: + """Do nothing, we will update labels when item is fetched in transform.""" + pass + + def transform_update_label( + self: Any, include_class: Optional[list[int]], label: dict, *args: Any, **kwargs: Any + ) -> dict: + """Update labels to include only specified classes. + + Args: + self: PatchedUltralyticsBaseDataset instance. + include_class (list[int], optional): list of classes to include. If None, all classes are included. + label (dict): Label to update. + *args: Additional positional arguments (unused). + **kwargs: Additional keyword arguments (unused). + """ + include_class_array = np.array(include_class).reshape(1, -1) + if include_class is not None: + cls = label["cls"] + bboxes = label["bboxes"] + segments = label["segments"] + keypoints = label["keypoints"] + j = (cls == include_class_array).any(1) + label["cls"] = cls[j] + label["bboxes"] = bboxes[j] + if segments: + label["segments"] = [segments[si] for si, idx in enumerate(j) if idx] + if keypoints is not None: + label["keypoints"] = keypoints[j] + if self.single_cls: + label["cls"][:, 0] = 0 + + return label + + def __getitem__(self: Any, index: int) -> dict[str, Any]: + """Return transformed label information for given index.""" + # return self.transforms(self.get_image_and_label(index)) + if not hasattr(self, "streaming_dataset"): + raise ValueError("The dataset must have a 'streaming_dataset' attribute.") + data = self.streaming_dataset[index] + + label = data["label"] + # split label on the basis of `\n` and then split each line on the basis of ` ` + # first element is class, rest are bbox coordinates + if isinstance(label, str): + label = label.split("\n") + label = [line.split(" ") for line in label if line.strip()] + + data = { + "batch_idx": torch.tensor([index], dtype=torch.int32), + "img": data["image"], + "cls": torch.Tensor([int(line[0]) for line in label]), + "bboxes": torch.Tensor([[float(coord) for coord in line[1:]] for line in label]), + "normalized": True, + "segments": [], + "keypoints": None, + "bbox_format": "xywh", + } + else: + raise ValueError("Label must be a string in YOLO format.") + + data = self.transform_update_label( + include_class=self.classes, + label=data, + ) + return self.transforms(data) + + def patch_detection_plot_training_samples(self: Any, batch: dict[str, Any], ni: int) -> None: + """Plot training samples with their annotations. + + Args: + self: DetectionTrainer instance. + batch (dict[str, Any]): dictionary containing batch data. + ni (int): Number of iterations. + """ + plot_images( + labels=batch, + images=batch["img"], + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) + + def patch_detection_plot_val_samples(self: Any, batch: dict[str, Any], ni: int) -> None: + """Plot validation image samples. + + Args: + self: DetectionValidator instance. + batch (dict[str, Any]): Batch containing images and annotations. + ni (int): Batch index. + """ + plot_images( + labels=batch, + images=batch["img"], + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) + + def patch_detection_plot_predictions( + self: Any, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int, max_det: Optional[int] = None + ) -> None: + """Plot predicted bounding boxes on input images and save the result. + + Args: + self: DetectionValidator instance. + batch (dict[str, Any]): Batch containing images and annotations. + preds (list[dict[str, torch.Tensor]]): list of predictions from the model. + ni (int): Batch index. + max_det (Optional[int]): Maximum number of detections to plot. + """ + from ultralytics.utils import ops + + # TODO: optimize this + for i, pred in enumerate(preds): + pred["batch_idx"] = torch.ones_like(pred["conf"]) * i # add batch index to predictions + keys = preds[0].keys() + max_det = max_det or self.args.max_det + batched_preds = {k: torch.cat([x[k][:max_det] for x in preds], dim=0) for k in keys} + # TODO: fix this + batched_preds["bboxes"][:, :4] = ops.xyxy2xywh(batched_preds["bboxes"][:, :4]) # convert to xywh format + plot_images( + images=batch["img"], + labels=batched_preds, + paths=None, + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) + + def patch_check_det_dataset(dataset: str, _: bool = True) -> dict: + if not (isinstance(dataset, str) and dataset.endswith(".yaml") and os.path.isfile(dataset)): + raise ValueError("Dataset must be a string ending with '.yaml' and point to a valid file.") + + import yaml + + if not dataset.startswith("litdata_"): + dataset = "litdata_" + dataset + + if not os.path.isfile(dataset): + raise FileNotFoundError(f"Dataset file not found: {dataset}") + + # read the yaml file + with open(dataset) as file: + return yaml.safe_load(file) + + def patch_build_dataloader( + dataset: Any, batch: int, workers: int, shuffle: bool = True, rank: int = -1, drop_last: bool = False + ) -> StreamingDataLoader: + """Create and return an InfiniteDataLoader or DataLoader for training or validation. + + Args: + dataset (Dataset): Dataset to load data from. + batch (int): Batch size for the dataloader. + workers (int): Number of worker threads for loading data. + shuffle (bool, optional): Whether to shuffle the dataset. + rank (int, optional): Process rank in distributed training. -1 for single-GPU training. + drop_last (bool, optional): Whether to drop the last incomplete batch. + + Returns: + (StreamingDataLoader): A dataloader that can be used for training or validation. + + Examples: + Create a dataloader for training + >>> dataset = YOLODataset(...) + >>> dataloader = build_dataloader(dataset, batch=16, workers=4, shuffle=True) + """ + from litdata.streaming.dataloader import StreamingDataLoader + + print("litdata is rocking⚡️") + if not hasattr(dataset, "streaming_dataset"): + raise ValueError("The dataset must have a 'streaming_dataset' attribute.") + + from ultralytics.data.utils import PIN_MEMORY + + batch = min(batch, len(dataset)) + num_devices = torch.cuda.device_count() # number of CUDA devices + num_workers = min((os.cpu_count() or 1) // max(num_devices, 1), workers) # number of workers + persistent_workers = bool(int(os.getenv("UL_PERSISTENT_WORKERS", 0))) + return StreamingDataLoader( + dataset=dataset.streaming_dataset, + batch_size=batch, + num_workers=num_workers, + persistent_workers=persistent_workers, + pin_memory=PIN_MEMORY, + collate_fn=getattr(dataset, "collate_fn", None), + drop_last=drop_last, + ) + + def patch_detection_plot_training_labels(self: Any) -> None: + """Create a labeled training plot of the YOLO model.""" + pass + + def patch_get_labels(self: Any) -> list[dict[str, Any]]: + # this is used to get number of images (ni) in the BaseDataset class + return [] + + def patch_none_function(*args: Any, **kwargs: Any) -> None: + """A placeholder function that does nothing.""" + pass + + def image_resize( + im: Any, imgsz: int, rect_mode: bool = True, augment: bool = True + ) -> tuple[Any, tuple[int, int], tuple[int, int]]: + """Resize the image to a fixed size. + + Args: + im (Any): Image to resize. + imgsz (int): Target size for resizing. + rect_mode (bool): Whether to use rectangle mode for resizing. + augment (bool): If True, data augmentation is applied. + + Returns: + tuple[Any, tuple[int, int], tuple[int, int]]: Resized image and its original dimensions. + """ + import cv2 + + # Custom logic for resizing the image + h0, w0 = im.shape[:2] # orig hw + if rect_mode: # resize long side to imgsz while maintaining aspect ratio + r = imgsz / max(h0, w0) # ratio + if r != 1: # if sizes are not equal + w, h = (min(math.ceil(w0 * r), imgsz), min(math.ceil(h0 * r), imgsz)) + im = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR) + elif not (h0 == w0 == imgsz): # resize by stretching image to square imgsz + im = cv2.resize(im, (imgsz, imgsz), interpolation=cv2.INTER_LINEAR) + if im.ndim == 2: + im = im[..., None] + + return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized + + def patch_compose_transform_call(self: Any, data: Any) -> Any: + """Apply all transforms to the data, skipping mix transforms.""" + from ultralytics.data.augment import BaseMixTransform + + for t in self.transforms: + if isinstance(t, BaseMixTransform): + continue # Skip mix transforms, they are applied separately + data = t(data) + return data + + +# ------- helper transformations ------- + + +def ultralytics_detection_transform( + data: dict[str, Any], + index: int, + channels: int = 3, + img_size: int = 640, + keypoint: bool = False, + single_cls: bool = False, + lit_args: dict[str, Any] = {}, +) -> dict[str, Any]: + """Transform function for YOLO detection datasets. + + Args: + data (dict[str, Any]): Input data containing image and label. + index (int): Index of the data item. + channels (int): Number of channels in the image. + img_size (int): Target size for resizing the image. + keypoint (bool): Whether to include keypoints in the label. + single_cls (bool): Whether to use single class mode. + lit_args (dict[str, Any]): Additional arguments for the transform. + + Returns: + dict[str, Any]: Transformed data with image and label. + """ + if index is None: + raise ValueError("Index must be provided for YOLO detection transform.") + + label = data["label"] + # split label on the basis of `\n` and then split each line on the basis of ` ` + # first element is class, rest are bbox coordinates + if isinstance(label, str): + img, ori_shape, resized_shape = image_resize(data["img"], imgsz=img_size, rect_mode=True, augment=True) + ratio_pad = ( + resized_shape[0] / ori_shape[0], + resized_shape[1] / ori_shape[1], + ) # for evaluation + lb, segments, keypoint = parse_labels(label, keypoint=keypoint, single_cls=single_cls, lit_args=lit_args) + + data = { + "batch_idx": np.array([index]), + "img": img, + "cls": lb[:, 0:1], # n, 1 + "bboxes": lb[:, 1:], # n, 4 + "segments": segments, + "keypoints": keypoint, + "normalized": True, + "bbox_format": "xywh", + "ori_shape": ori_shape, + "resized_shape": resized_shape, + "ratio_pad": ratio_pad, + "channels": channels, + } + else: + raise ValueError("Label must be a string in YOLO format.") + + return data + + +def parse_labels( + labels: str, keypoint: bool = False, single_cls: bool = False, lit_args: dict[str, Any] = {} +) -> tuple[Any, Any, Any]: + from ultralytics.utils.ops import segments2boxes + + nkpt, ndim = lit_args.get("kpt_shape", (0, 0)) + num_cls: int = len(lit_args["names"]) + + segments: Any = [] + keypoints: Any = None + + lb = [x.split() for x in labels.split("\n") if len(x)] + if any(len(x) > 6 for x in lb) and (not keypoint): # is segment + classes = np.array([x[0] for x in lb], dtype=np.float32) + segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...) + lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) + lb = np.array(lb, dtype=np.float32) + if nl := len(lb): + if keypoint: + assert lb.shape[1] == (5 + nkpt * ndim), ( + f"labels require {(5 + nkpt * ndim)} columns each, but {lb.shape[1]} columns detected" + ) + points = lb[:, 5:].reshape(-1, ndim)[:, :2] + else: + assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected" + points = lb[:, 1:] + # Coordinate points check with 1% tolerance + assert points.max() <= 1.01, f"non-normalized or out of bounds coordinates {points[points > 1.01]}" + assert lb.min() >= -0.01, f"negative class labels {lb[lb < -0.01]}" + + # All labels + if single_cls: + lb[:, 0] = 0 + max_cls = lb[:, 0].max() # max label count + assert max_cls < num_cls, ( + f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. " + f"Possible class labels are 0-{num_cls - 1}" + ) + _, i = np.unique(lb, axis=0, return_index=True) + if len(i) < nl: # duplicate row check + lb = lb[i] # remove duplicates + if segments: + segments = [segments[x] for x in i] + if keypoint: + keypoints = lb[:, 5:].reshape(-1, nkpt, ndim) + if ndim == 2: + kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32) + keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3) + lb = lb[:, :5] + return lb, segments, keypoints diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 7b55172d1..a04bc127d 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -10,7 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect import logging import os from time import time @@ -197,11 +197,13 @@ def __init__( self.storage_options = storage_options self.session_options = session_options self.max_pre_download = max_pre_download + self.transform_fn_accepts_index = {} if transform is not None: transform = transform if isinstance(transform, list) else [transform] for t in transform: if not callable(t): raise ValueError(f"Transform should be a callable. Found {t}") + self.transform_fn_accepts_index[id(t)] = has_argument_named_index(t) self.transform = transform self._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache @@ -444,11 +446,13 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: ) ) if hasattr(self, "transform"): - if isinstance(self.transform, list): - for transform_fn in self.transform: - item = transform_fn(item) - else: - item = self.transform(item) + transforms = self.transform if isinstance(self.transform, list) else [self.transform] + for transform_fn in transforms: + key = id(transform_fn) + if key not in self.transform_fn_accepts_index: + self.transform_fn_accepts_index[key] = has_argument_named_index(transform_fn) + + item = transform_fn(item, index=index) if self.transform_fn_accepts_index[key] else transform_fn(item) return item @@ -739,3 +743,9 @@ def _replay_chunks_sampling( break return chunks_index, indexes + + +def has_argument_named_index(func: Callable) -> bool: + """Returns True if the function has an argument named 'index'.""" + sig = inspect.signature(func) + return "index" in sig.parameters diff --git a/tests/conftest.py b/tests/conftest.py index f306a78ec..f745f5087 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -202,6 +202,37 @@ def huggingface_hub_mock(monkeypatch): return huggingface_hub +@pytest.fixture +def mock_ultralytics(monkeypatch, tmp_path): + fake_ultralytics = ModuleType("ultralytics") + fake_data = ModuleType("ultralytics.data") + fake_utils = ModuleType("ultralytics.data.utils") + + # Create fake dataset paths using tmp_path (safe on all OSes) + train_dir = tmp_path / "train" + val_dir = tmp_path / "val" + train_dir.mkdir() + val_dir.mkdir() + + def check_det_dataset(yaml_path): + return { + "train": str(train_dir), + "val": str(val_dir), + "test": None, + "names": [], + } + + fake_utils.check_det_dataset = check_det_dataset + + # Register in sys.modules + monkeypatch.setitem(sys.modules, "ultralytics", fake_ultralytics) + monkeypatch.setitem(sys.modules, "ultralytics.data", fake_data) + monkeypatch.setitem(sys.modules, "ultralytics.data.utils", fake_utils) + + fake_data.utils = fake_utils + fake_ultralytics.data = fake_data + + @pytest.fixture def huggingface_hub_fs_mock(monkeypatch, write_pq_data, tmp_path): huggingface_hub = ModuleType("huggingface_hub") diff --git a/tests/integrations/ultralytics_support/__init__.py b/tests/integrations/ultralytics_support/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integrations/ultralytics_support/test_optimize.py b/tests/integrations/ultralytics_support/test_optimize.py new file mode 100644 index 000000000..18182d4b4 --- /dev/null +++ b/tests/integrations/ultralytics_support/test_optimize.py @@ -0,0 +1,101 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import mock + +import pytest + +from litdata.integrations.ultralytics.optimize import get_output_dir, list_all_files, optimize_ultralytics_dataset +from litdata.streaming.resolver import Dir + + +@mock.patch("litdata.integrations.ultralytics.optimize._ULTRALYTICS_AVAILABLE", True) +@mock.patch("litdata.integrations.ultralytics.optimize.optimize") +def test_optimize_ultralytics_dataset_mocked(optimize_mock, tmp_path, mock_ultralytics): + os.makedirs(tmp_path / "images/train", exist_ok=True) + os.makedirs(tmp_path / "images/val", exist_ok=True) + for split in ["train", "val"]: + (tmp_path / f"images/{split}/img1.jpg").touch() + (tmp_path / f"labels/{split}/img1.txt").parent.mkdir(parents=True, exist_ok=True) + (tmp_path / f"labels/{split}/img1.txt").write_text("0 0.5 0.5 0.2 0.2") + + yaml_file = tmp_path / "coco8.yaml" + yaml_file.write_text( + "path: {}\ntrain: {}\nval: {}\n".format(tmp_path, tmp_path / "images" / "train", tmp_path / "images" / "val") + ) + + optimize_ultralytics_dataset(str(yaml_file), str(tmp_path / "out"), chunk_size=1, num_workers=1) + + assert optimize_mock.called + assert (tmp_path / "litdata_coco8.yaml").exists() + + +def test_get_output_dir(): + # Case 1: Both url and path provided + d1 = Dir(path="/data/output", url="s3://bucket/output/") + r1 = get_output_dir(d1, "train") + assert r1.path == os.path.join("/data/output", "train") + assert r1.url == "s3://bucket/output/train" + + # Case 2: Only url provided + d2 = Dir(url="s3://bucket/output/") + r2 = get_output_dir(d2, "val") + assert r2.path is None + assert r2.url == "s3://bucket/output/val" + + # Case 3: Only path provided + d3 = Dir(path="/data/output") + r3 = get_output_dir(d3, "test") + assert r3.url is None + assert r3.path == os.path.join("/data/output", "test") + + # Case 4: Neither url nor path provided + d4 = Dir() + r4 = get_output_dir(d4, "debug") + assert r4.url is None + assert r4.path is None + + # Case 5: Invalid type + with pytest.raises(TypeError): + get_output_dir("not_a_dir_obj", "fail") + + +def test_list_all_files_combined(tmp_path): + # --- Case 1: Directory with nested files --- + (tmp_path / "a.txt").write_text("hello") + (tmp_path / "subdir").mkdir() + (tmp_path / "subdir" / "b.txt").write_text("world") + + result = list_all_files(str(tmp_path)) + expected = {str(tmp_path / "a.txt"), str(tmp_path / "subdir" / "b.txt")} + assert set(result) == expected, "Should list all files recursively from directory" + + # --- Case 2: .txt file listing files --- + (tmp_path / "img1.jpg").touch() + (tmp_path / "img2.jpg").touch() + txt_file = tmp_path / "train.txt" + txt_file.write_text("img1.jpg\nimg2.jpg\n") + + result = list_all_files(str(txt_file)) + expected = { + str((tmp_path / "img1.jpg").resolve()), + str((tmp_path / "img2.jpg").resolve()), + } + assert set(result) == expected, ".txt path list should resolve correctly" + + # --- Case 3: Unsupported file path --- + bad_file = tmp_path / "unsupported.md" + bad_file.write_text("invalid") + + with pytest.raises(ValueError, match="Unsupported path"): + list_all_files(str(bad_file)) diff --git a/tests/integrations/ultralytics_support/test_patch.py b/tests/integrations/ultralytics_support/test_patch.py new file mode 100644 index 000000000..f2c6cc02a --- /dev/null +++ b/tests/integrations/ultralytics_support/test_patch.py @@ -0,0 +1,135 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import patch + +import numpy as np +import pytest + +from litdata.integrations.ultralytics.patch import parse_labels, ultralytics_detection_transform + + +def make_lit_args(**kwargs): + return { + "names": ["class0", "class1"], + "kpt_shape": (5, 3), + **kwargs, + } + + +def test_parse_labels(): + # --- Basic test --- + labels = "0 0.1 0.2 0.3 0.4\n1 0.5 0.6 0.7 0.8" + lit_args = make_lit_args() + lb, segments, keypoints = parse_labels(labels, lit_args=lit_args) + assert lb.shape == (2, 5) + assert segments == [] + assert keypoints is None + assert np.all(lb[:, 0] < len(lit_args["names"])) + + # --- Single class --- + labels = "1 0.1 0.2 0.3 0.4" + lb, _, _ = parse_labels(labels, lit_args=lit_args, single_cls=True) + assert np.all(lb[:, 0] == 0) + + # --- With segments --- + segment_label = "0 " + " ".join([str(round(0.01 * i, 3)) for i in range(14)]) # 7 xy pairs + lb, segments, _ = parse_labels(segment_label, lit_args=lit_args) + assert lb.shape == (1, 5) + assert len(segments) == 1 + + # --- With keypoints --- + keypoint_str = "0 " + " ".join(["0.1"] * 19) + lb, _, keypoints = parse_labels(keypoint_str, lit_args=lit_args, keypoint=True) + assert lb.shape == (1, 5) + assert keypoints.shape == (1, 5, 3) + + # --- Duplicate removal --- + dup_labels = "0 0.1 0.2 0.3 0.4\n0 0.1 0.2 0.3 0.4" + lb, _, _ = parse_labels(dup_labels, lit_args=lit_args) + assert lb.shape == (1, 5) + + # --- Out of bounds class --- + bad_class_label = "99 0.1 0.2 0.3 0.4" + with pytest.raises(AssertionError, match="Label class 99 exceeds"): + parse_labels(bad_class_label, lit_args=lit_args) + + # --- Coordinates out of bounds --- + bad_coords_label = "0 1.2 1.2 1.2 1.2" + with pytest.raises(AssertionError, match="non-normalized or out of bounds"): + parse_labels(bad_coords_label, lit_args=lit_args) + + # --- Negative class --- + negative_label = "-1 0.1 0.2 0.3 0.4" + with pytest.raises(AssertionError, match="negative class labels"): + parse_labels(negative_label, lit_args=lit_args) + + # --- Wrong shape --- + wrong_shape_label = "0 0.1 0.2 0.3" # Only 4 elements + with pytest.raises(AssertionError, match="labels require 5 columns"): + parse_labels(wrong_shape_label, lit_args=lit_args) + + +def test_ultralytics_detection_transform(): + dummy_img = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8) + valid_label = "1 0.1 0.2 0.3 0.4\n0 0.5 0.6 0.7 0.8" + invalid_label = 123 # not a string + + dummy_img_resized = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) + ori_shape = dummy_img.shape[:2] + resized_shape = dummy_img_resized.shape[:2] + + lit_args = {"names": ["class0", "class1"], "kpt_shape": (0, 0)} + + with ( + patch( + "litdata.integrations.ultralytics.patch.image_resize", + return_value=(dummy_img_resized, ori_shape, resized_shape), + ), + patch( + "litdata.integrations.ultralytics.patch.parse_labels", + return_value=( + np.array([[1, 0.1, 0.2, 0.3, 0.4], [0, 0.5, 0.6, 0.7, 0.8]], dtype=np.float32), + [], + None, + ), + ), + ): + # ✅ Valid transformation + data = {"img": dummy_img, "label": valid_label} + out = ultralytics_detection_transform(data, index=42, channels=3, img_size=640, lit_args=lit_args) + + assert isinstance(out, dict) + assert out["img"].shape == (640, 640, 3) + assert out["batch_idx"].item() == 42 + assert out["cls"].shape == (2, 1) + assert out["bboxes"].shape == (2, 4) + assert out["segments"] == [] + assert out["keypoints"] is None + assert out["normalized"] is True + assert out["bbox_format"] == "xywh" + assert out["ori_shape"] == ori_shape + assert out["resized_shape"] == resized_shape + assert isinstance(out["ratio_pad"], tuple) + assert out["channels"] == 3 + + # Missing index + with pytest.raises(ValueError, match="Index must be provided"): + ultralytics_detection_transform( + {"img": dummy_img, "label": valid_label}, index=None, channels=3, lit_args=lit_args + ) + + # Invalid label type + with pytest.raises(ValueError, match="Label must be a string"): + ultralytics_detection_transform( + {"img": dummy_img, "label": invalid_label}, index=0, channels=3, lit_args=lit_args + )