Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
dc06c65
Prepare for cloud/local index caching
bhimrazy Aug 1, 2025
b02a3e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 1, 2025
9c04ca0
Merge branch 'main' into feat/store-index-in-raw-dataset
bhimrazy Aug 4, 2025
792be1d
Merge branch 'main' into feat/store-index-in-raw-dataset
bhimrazy Aug 4, 2025
ecfb732
ref: Enhance file indexing with local and remote caching mechanisms
bhimrazy Aug 10, 2025
38cb6ac
ref: Exclude index files from inclusion in file indexing
bhimrazy Aug 10, 2025
fbb087c
Improve index loading documentation and streamline remote cache handl…
bhimrazy Aug 10, 2025
0530a8e
ref: Validate input directory scheme in BaseIndexer and streamline lo…
bhimrazy Aug 10, 2025
f6e1a2a
Validate input directory scheme in FileIndexer and raise error for un…
bhimrazy Aug 10, 2025
a90b910
Add tests for handling unsupported input directory schemes in FileInd…
bhimrazy Aug 10, 2025
ff1d58a
Add tests for building and loading remote index with caching in FileI…
bhimrazy Aug 10, 2025
fc93dad
Add test to ensure index file is excluded during recompute
bhimrazy Aug 10, 2025
44f736b
Enhance test description for index file exclusion during recompute
bhimrazy Aug 10, 2025
2ae9325
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2025
4bbd8cb
Merge branch 'main' into feat/store-index-in-raw-dataset
bhimrazy Aug 10, 2025
36acd45
Add documentation for Smart Index Caching in StreamingRawDataset
bhimrazy Aug 10, 2025
0c9673b
Refine description of Smart Index Caching in StreamingRawDataset for …
bhimrazy Aug 10, 2025
06eb0b1
Add Windows compatibility checks for remote index tests
bhimrazy Aug 10, 2025
ed0e532
Add Windows compatibility check for recompute index test
bhimrazy Aug 10, 2025
a31ee84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2025
ec6d6a6
Apply suggestions
bhimrazy Aug 11, 2025
26263c5
Apply suggestions
bhimrazy Aug 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,20 @@ for item in loader:
pass
```

**Smart Index Caching**

`StreamingRawDataset` automatically caches the file index for instant startup. Initial scan, builds and caches the index, then subsequent runs load instantly.

**Two-Level Cache:**
- **Local:** Stored in your cache directory for instant access
- **Remote:** Automatically saved to cloud storage (e.g., `s3://bucket/files/index.json.zstd`) for reuse

**Force Rebuild:**
```python
# When dataset files have changed
dataset = StreamingRawDataset("s3://bucket/files/", recompute_index=True)
```

</details>

<details>
Expand Down
12 changes: 10 additions & 2 deletions src/litdata/raw/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
indexer: Optional[BaseIndexer] = None,
storage_options: Optional[dict] = None,
cache_files: bool = False,
recompute_index: bool = False,
transform: Optional[Callable[[Union[bytes, list[bytes]]], Any]] = None,
):
"""Initialize StreamingRawDataset.
Expand All @@ -118,8 +119,12 @@ def __init__(
indexer: Custom file indexer (default: FileIndexer).
storage_options: Cloud storage options.
cache_files: Whether to cache files locally (default: False).
recompute_index: Whether to recompute the index (default: False).
If True, forces a re-scan of the input directory and rebuilds the index,
ignoring any cached index files. This is useful when the dataset
structure or files on the remote storage have changed.
transform: A function to apply to each item. It will receive `bytes` for single-file
items or `List[bytes]` for grouped items.
items or `List[bytes]` for grouped items.
"""
self.input_dir = _resolve_dir(input_dir)
self.cache_manager = CacheManager(self.input_dir, cache_dir, storage_options, cache_files)
Expand All @@ -129,7 +134,10 @@ def __init__(

# Discover all files in the input directory.
self.files: list[FileMetadata] = self.indexer.build_or_load_index(
str(self.input_dir.path or self.input_dir.url), self.cache_manager.cache_dir, storage_options
str(self.input_dir.path or self.input_dir.url),
self.cache_manager.cache_dir,
storage_options,
recompute_index,
)
logger.info(f"Discovered {len(self.files)} files.")

Expand Down
173 changes: 143 additions & 30 deletions src/litdata/raw/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import json
import logging
import os
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
Expand All @@ -24,6 +25,7 @@

logger = logging.getLogger(__name__)
_SUPPORTED_PROVIDERS = ("s3", "gs", "azure")
_INDEX_FILENAME = "index.json.zstd"


@dataclass
Expand All @@ -49,38 +51,120 @@ def discover_files(self, input_dir: str, storage_options: Optional[dict[str, Any
"""Discover dataset files and return their metadata."""

def build_or_load_index(
self, input_dir: str, cache_dir: str, storage_options: Optional[dict[str, Any]]
self,
input_dir: str,
cache_dir: str,
storage_options: Optional[dict[str, Any]],
recompute_index: bool = False,
) -> list[FileMetadata]:
"""Build or load a ZSTD-compressed index of file metadata."""
"""Loads or builds a ZSTD-compressed index of dataset file metadata.
This method attempts to load an existing index from local or remote cache, or builds a new one if needed.
Use `recompute_index=True` to force rebuilding the index from the input directory.

Args:
input_dir: Path to the dataset root directory.
cache_dir: Directory for storing the index cache.
storage_options: Optional storage backend options.
recompute_index: If True, always rebuild the index.

Returns:
List of FileMetadata objects for discovered files.

Raises:
ModuleNotFoundError: If required dependencies are missing.
ValueError: If no files are found in the input directory.
"""
if not _ZSTD_AVAILABLE:
raise ModuleNotFoundError(str(_ZSTD_AVAILABLE))

import zstd
if not _FSSPEC_AVAILABLE:
raise ModuleNotFoundError(str(_FSSPEC_AVAILABLE))

index_path = Path(cache_dir) / "index.json.zstd"
parsed_url = urlparse(input_dir)
if parsed_url.scheme and parsed_url.scheme not in _SUPPORTED_PROVIDERS:
raise ValueError(
f"Unsupported input directory scheme: `{parsed_url.scheme}`. "
f"Supported schemes are: {_SUPPORTED_PROVIDERS}"
)

# Try loading cached index if it exists
if index_path.exists():
try:
with open(index_path, "rb") as f:
compressed_data = f.read()
metadata = json.loads(zstd.decompress(compressed_data).decode("utf-8"))
if not recompute_index:
files = self._load_index_from_cache(input_dir, cache_dir, storage_options)
if files:
return files

return [FileMetadata.from_dict(file_data) for file_data in metadata["files"]]
except (FileNotFoundError, json.JSONDecodeError, zstd.ZstdError, KeyError) as e:
logger.warning(f"Failed to load cached index from {index_path}: {e}")
return self._build_and_cache_index(input_dir, cache_dir, storage_options)

# Build fresh index
logger.info(f"Building index for {input_dir} at {index_path}")
def _load_index_from_cache(
self, input_dir: str, cache_dir: str, storage_options: Optional[dict[str, Any]]
) -> Optional[list[FileMetadata]]:
"""Tries to load the index from local or remote cache."""
# 1. Try to load index from local cache.
local_index_path = Path(cache_dir) / _INDEX_FILENAME
if local_index_path.exists():
logger.info(f"Loading index from local cache: {local_index_path}")
files = self._load_index_file(str(local_index_path))
if files:
logger.info(f"Loaded index from local cache: {local_index_path}")
return files

# 2. If not found, try remote cache.
remote_index_path = os.path.join(input_dir, _INDEX_FILENAME)
try:
self._download_from_cloud(remote_index_path, str(local_index_path), storage_options)
files = self._load_index_file(str(local_index_path))
if files:
logger.info(f"Loaded index from remote cache: {remote_index_path}")
return files
except FileNotFoundError:
logger.warning(f"Remote index not found at {remote_index_path}")
except Exception as e:
logger.error(f"Failed to download or load remote index: {e}")

return None

def _build_and_cache_index(
self, input_dir: str, cache_dir: str, storage_options: Optional[dict[str, Any]]
) -> list[FileMetadata]:
"""Builds a new index and caches it locally and remotely."""
local_index_path = Path(cache_dir) / _INDEX_FILENAME
logger.info(f"Building index for {input_dir} at {local_index_path}")
files = self.discover_files(input_dir, storage_options)
if not files:
raise ValueError(f"No files found in {input_dir}")

# Cache the index with ZSTD compression
# TODO: upload the index to cloud storage
self._save_index_file(str(local_index_path), files, input_dir)

# Upload to remote cache
remote_index_path = os.path.join(input_dir, _INDEX_FILENAME)
try:
self._upload_to_cloud(str(local_index_path), remote_index_path, storage_options)
logger.info(f"Uploaded index to remote cache: {remote_index_path}")
except Exception as e:
logger.warning(f"Failed to upload index to remote cache: {e}")

logger.info(f"Built index with {len(files)} files from {input_dir} at {local_index_path}")
return files

def _load_index_file(self, index_path: str) -> Optional[list[FileMetadata]]:
"""Loads and decodes an index file."""
import zstd

try:
with open(index_path, "rb") as f:
compressed_data = f.read()
metadata = json.loads(zstd.decompress(compressed_data).decode("utf-8"))
return [FileMetadata.from_dict(file_data) for file_data in metadata["files"]]
except (FileNotFoundError, json.JSONDecodeError, zstd.ZstdError, KeyError) as e:
logger.warning(f"Failed to load index from local cache at `{index_path}`: {e}. ")
return None

def _save_index_file(self, index_path: str, files: list[FileMetadata], source: str) -> None:
"""Encodes and saves an index file."""
import zstd

try:
metadata = {
"source": input_dir,
"source": source,
"files": [file.to_dict() for file in files],
"created_at": time.time(),
}
Expand All @@ -89,8 +173,35 @@ def build_or_load_index(
except (OSError, zstd.ZstdError) as e:
logger.warning(f"Error caching index to {index_path}: {e}")

logger.info(f"Built index with {len(files)} files from {input_dir} at {index_path}")
return files
def _download_from_cloud(
self,
remote_path: str,
local_path: str,
storage_options: Optional[dict[str, Any]],
) -> None:
"""Downloads a file from cloud storage."""
if not _FSSPEC_AVAILABLE:
raise ModuleNotFoundError(str(_FSSPEC_AVAILABLE))
import fsspec

parsed_url = urlparse(remote_path)
fs = fsspec.filesystem(parsed_url.scheme, **(storage_options or {}))
fs.get(remote_path, local_path)

def _upload_to_cloud(
self,
local_path: str,
remote_path: str,
storage_options: Optional[dict[str, Any]],
) -> None:
"""Uploads a file to cloud storage."""
if not _FSSPEC_AVAILABLE:
raise ModuleNotFoundError(str(_FSSPEC_AVAILABLE))
import fsspec

parsed_url = urlparse(remote_path)
fs = fsspec.filesystem(parsed_url.scheme, **(storage_options or {}))
fs.put(local_path, remote_path)


class FileIndexer(BaseIndexer):
Expand All @@ -107,21 +218,20 @@ def __init__(
def discover_files(self, input_dir: str, storage_options: Optional[dict[str, Any]]) -> list[FileMetadata]:
"""Discover dataset files and return their metadata."""
parsed_url = urlparse(input_dir)
if parsed_url.scheme and parsed_url.scheme not in _SUPPORTED_PROVIDERS:
raise ValueError(
f"Unsupported input directory scheme: `{parsed_url.scheme}`. "
f"Supported schemes are: {_SUPPORTED_PROVIDERS}"
)

if parsed_url.scheme in _SUPPORTED_PROVIDERS: # Cloud storage
return self._discover_cloud_files(input_dir, storage_options)

if not parsed_url.scheme: # Local filesystem
return self._discover_local_files(input_dir)

raise ValueError(
f"Unsupported input directory scheme: {parsed_url.scheme}. Supported schemes are: {_SUPPORTED_PROVIDERS}"
)
# Local filesystem
return self._discover_local_files(input_dir)

def _discover_cloud_files(self, input_dir: str, storage_options: Optional[dict[str, Any]]) -> list[FileMetadata]:
"""Recursively list files in a cloud storage bucket."""
if not _FSSPEC_AVAILABLE:
raise ModuleNotFoundError(str(_FSSPEC_AVAILABLE))
import fsspec

obj = urlparse(input_dir)
Expand Down Expand Up @@ -173,6 +283,9 @@ def _discover_local_files(self, input_dir: str) -> list[FileMetadata]:
return metadatas

def _should_include_file(self, file_path: str) -> bool:
"""Return True if file matches allowed extensions."""
file_ext = Path(file_path).suffix.lower()
"""Return True if file matches allowed extensions and is not an index file."""
path = Path(file_path)
if path.name == _INDEX_FILENAME:
return False
file_ext = path.suffix.lower()
return not self.extensions or file_ext in self.extensions
Loading
Loading