Skip to content
Merged
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,34 @@ for batch in loader:
```

> Use `StreamingRawDataset` to stream your data as-is. Use `StreamingDataset` for fastest streaming after optimizing your data.


You can also customize how files are grouped by subclassing `StreamingRawDataset` and overriding the `setup` method. This is useful for pairing related files (e.g., image and mask, audio and transcript) or any custom grouping logic.

```python
from litdata.streaming.raw_dataset import StreamingRawDataset, FileMetadata
from torch.utils.data import DataLoader
from typing import Union

class SegmentationRawDataset(StreamingRawDataset):
def setup(self, files: list[FileMetadata]) -> Union[list[FileMetadata], list[list[FileMetadata]]]:
# TODO: Implement your custom grouping logic here.
# For example, group files by prefix, extension, or any rule you need.
# Return a list of groups, where each group is a list of FileMetadata.
# Example:
# return [[image, mask], ...]
pass

# Initialize the custom dataset
dataset = SegmentationRawDataset("s3://bucket/files/")
loader = DataLoader(dataset, batch_size=32)
for item in loader:
# Each item in the batch is a pair: [image_bytes, mask_bytes]
pass
```

</details>

<details>
<summary> ✅ Stream large cloud datasets</summary>
&nbsp;
Expand Down
137 changes: 79 additions & 58 deletions src/litdata/streaming/raw_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import json
import logging
import os
Expand Down Expand Up @@ -210,8 +209,6 @@ def __init__(

self.storage_options = storage_options or {}
self._downloader: Optional[Downloader] = None
self._loop = None
self._closed = False

@property
def downloader(self) -> Downloader:
Expand Down Expand Up @@ -244,23 +241,6 @@ def get_local_path(self, remote_file_path: str) -> str:
local_path.parent.mkdir(parents=True, exist_ok=True)
return str(local_path)

def download_file_sync(self, file_path: str) -> bytes:
"""Download file synchronously and return content."""
# TODO: To add a local cache to avoid redundant downloads if cache_files is True.
# if self.cache_files:
# local_path = self.get_local_path(file_path)
# if os.path.exists(local_path):
# with open(local_path, "rb") as f:
# return f.read()

# Download to BytesIO
file_obj = io.BytesIO()
try:
self.downloader.download_fileobj(file_path, file_obj)
return file_obj.getvalue()
except Exception as e:
raise RuntimeError(f"Error downloading file {file_path}: {e}") from e

async def download_file_async(self, file_path: str) -> bytes:
"""Asynchronously download and return file content."""
if self.cache_files:
Expand All @@ -275,15 +255,13 @@ async def download_file_async(self, file_path: str) -> bytes:


class StreamingRawDataset(Dataset):
"""Streaming dataset for raw files with cloud support, fast indexing, and local caching.
"""Base class for streaming raw datasets.

Supports any folder structure and automatically indexes individual files.
This class provides the core functionality for streaming raw data from a remote or local source,
including file discovery, caching, and asynchronous downloading.

Features:
- `__getitem__` for single-item access
- `__getitems__` for efficient batch downloads
- Automatic local caching with directory structure preservation
- Minimal memory usage with lazy loading
To create a custom dataset, subclass this class and override the `setup` method
to define the structure of your dataset items from the list of all discovered files.
"""

def __init__(
Expand All @@ -293,64 +271,107 @@ def __init__(
indexer: Optional[BaseIndexer] = None,
storage_options: Optional[dict] = None,
cache_files: bool = False,
transform: Optional[Callable[[Any], Any]] = None,
transform: Optional[Callable[[Union[bytes, list[bytes]]], Any]] = None,
):
"""Initialize StreamingRawDataset.

Args:
input_dir: Path to dataset root (e.g. s3://bucket/dataset/)
cache_dir: Directory for caching files (optional)
indexer: Custom file indexer (default: FileIndexer)
storage_options: Cloud storage options
cache_files: Whether to cache files locally (default: False)
transform: A function to apply to each downloaded item.
input_dir: Path to dataset root (e.g., 's3://bucket/dataset/').
cache_dir: Directory for caching files (optional).
indexer: Custom file indexer (default: FileIndexer).
storage_options: Cloud storage options.
cache_files: Whether to cache files locally (default: False).
transform: A function to apply to each item. It will receive `bytes` for single-file
items or `List[bytes]` for grouped items.
"""
# Resolve directories
self.input_dir = _resolve_dir(input_dir)
self.cache_manager = CacheManager(self.input_dir, cache_dir, storage_options, cache_files)

# Configuration
self.indexer = indexer or FileIndexer()
self.storage_options = storage_options or {}
self.transform = transform

# Discover files and build index
self.files = self.indexer.build_or_load_index(
self.cache_manager._input_dir_path, self.cache_manager.cache_dir, storage_options
# Discover all files in the input directory.
self.files: list[FileMetadata] = self.indexer.build_or_load_index(
str(self.input_dir.path or self.input_dir.url), self.cache_manager.cache_dir, storage_options
)
# TODO: Grouping of files as needed by user, e.g., by image, label, etc.
logger.info(f"Discovered {len(self.files)} files.")

logger.info(f"Initialized StreamingRawDataset with {len(self.files)} files")
# Transform the flat list of files into the desired item structure.
self.items: Union[list[FileMetadata], list[list[FileMetadata]]] = self.setup(self.files)
if not isinstance(self.items, list):
raise TypeError(f"The setup method must return a list, but returned {type(self.items)}")
logger.info(f"Dataset setup with {len(self.items)} items.")

def setup(self, files: list[FileMetadata]) -> Union[list[FileMetadata], list[list[FileMetadata]]]:
"""Define the structure of the dataset from the list of discovered files.

Override this method in a subclass to group or filter files into final dataset items.

Args:
files: A list of all `FileMetadata` objects discovered in the `input_dir`.

Returns:
The final structure of the dataset, which can be:
- `List[FileMetadata]`: Each `FileMetadata` object is treated as a single item.
- `List[List[FileMetadata]]`: Each inner list of `FileMetadata` objects is treated as a single item.
"""
return files

@lru_cache(maxsize=1)
def __len__(self) -> int:
"""Return dataset size."""
return len(self.files)
"""Return the number of items in the dataset."""
return len(self.items)

def __getitem__(self, index: int) -> Any:
"""Get single item by index."""
if index < 0 or index >= len(self):
raise IndexError(f"Index {index} out of range")

file_path = self.files[index].path
# TODO: Use common asynchronous download method
data = self.cache_manager.download_file_sync(file_path)
return self.transform(data) if self.transform else data
"""Get a single item by index."""
if not (0 <= index < len(self)):
raise IndexError(f"Index {index} out of range for dataset with length {len(self)}")

item = self.items[index]
if isinstance(item, FileMetadata):
return asyncio.run(self._download_and_process_item(item.path))
if isinstance(item, list):
file_paths = [fm.path for fm in item]
return asyncio.run(self._download_and_process_group(file_paths))
raise TypeError(f"Dataset items must be of type FileMetadata or List[FileMetadata], but found {type(item)}")

def __getitems__(self, indices: list[int]) -> list[Any]:
"""Asynchronously download multiple items by index."""
"""Asynchronously download a batch of items by indices."""
# asyncio.run() handles loop creation, execution, and teardown cleanly.
return asyncio.run(self._download_batch(indices))

async def _download_batch(self, indices: list[int]) -> list[Any]:
"""Asynchronously download and transform items."""
file_paths = [self.files[index].path for index in indices]
coros = [self._process_item(path) for path in file_paths]
"""Asynchronously download and process items."""
batch_items = [self.items[i] for i in indices]
coros = []
for item in batch_items:
if isinstance(item, FileMetadata):
coros.append(self._download_and_process_item(item.path))
elif isinstance(item, list):
file_paths = [fm.path for fm in item]
coros.append(self._download_and_process_group(file_paths))
else:
raise TypeError(
f"Dataset items must be of type FileMetadata or List[FileMetadata], but found {type(item)}"
)
return await asyncio.gather(*coros)

async def _process_item(self, file_path: str) -> Any:
async def _download_and_process_group(self, file_paths: list[str]) -> Any:
"""Download all files in a group, then apply the transform."""
download_coros = [self.cache_manager.download_file_async(path) for path in file_paths]
group_data: list[bytes] = await asyncio.gather(*download_coros)

if self.transform:
# The transform receives a list of bytes, corresponding to the list structure
# of the item defined in setup(). This is true even if the list has only one element.
return await asyncio.to_thread(self.transform, group_data)
return group_data

async def _download_and_process_item(self, file_path: str) -> Any:
"""Download a single file and apply the transform."""
data = await self.cache_manager.download_file_async(file_path)
data: bytes = await self.cache_manager.download_file_async(file_path)
if self.transform:
# The transform receives a single bytes object, corresponding to the
# single FileMetadata object structure of the item.
return await asyncio.to_thread(self.transform, data)
return data
Loading
Loading