Skip to content

Commit 47a8c7a

Browse files
authored
feat(litdata/raw): Implement remote and local index caching for StreamingRawDataset (#666)
1 parent e572b84 commit 47a8c7a

File tree

4 files changed

+299
-33
lines changed

4 files changed

+299
-33
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,20 @@ for item in loader:
290290
pass
291291
```
292292

293+
**Smart Index Caching**
294+
295+
`StreamingRawDataset` automatically caches the file index for instant startup. Initial scan, builds and caches the index, then subsequent runs load instantly.
296+
297+
**Two-Level Cache:**
298+
- **Local:** Stored in your cache directory for instant access
299+
- **Remote:** Automatically saved to cloud storage (e.g., `s3://bucket/files/index.json.zstd`) for reuse
300+
301+
**Force Rebuild:**
302+
```python
303+
# When dataset files have changed
304+
dataset = StreamingRawDataset("s3://bucket/files/", recompute_index=True)
305+
```
306+
293307
</details>
294308

295309
<details>

src/litdata/raw/dataset.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
indexer: Optional[BaseIndexer] = None,
109109
storage_options: Optional[dict] = None,
110110
cache_files: bool = False,
111+
recompute_index: bool = False,
111112
transform: Optional[Callable[[Union[bytes, list[bytes]]], Any]] = None,
112113
):
113114
"""Initialize StreamingRawDataset.
@@ -118,8 +119,12 @@ def __init__(
118119
indexer: Custom file indexer (default: FileIndexer).
119120
storage_options: Cloud storage options.
120121
cache_files: Whether to cache files locally (default: False).
122+
recompute_index: Whether to recompute the index (default: False).
123+
If True, forces a re-scan of the input directory and rebuilds the index,
124+
ignoring any cached index files. This is useful when the dataset
125+
structure or files on the remote storage have changed.
121126
transform: A function to apply to each item. It will receive `bytes` for single-file
122-
items or `List[bytes]` for grouped items.
127+
items or `List[bytes]` for grouped items.
123128
"""
124129
self.input_dir = _resolve_dir(input_dir)
125130
self.cache_manager = CacheManager(self.input_dir, cache_dir, storage_options, cache_files)
@@ -129,7 +134,10 @@ def __init__(
129134

130135
# Discover all files in the input directory.
131136
self.files: list[FileMetadata] = self.indexer.build_or_load_index(
132-
str(self.input_dir.path or self.input_dir.url), self.cache_manager.cache_dir, storage_options
137+
str(self.input_dir.path or self.input_dir.url),
138+
self.cache_manager.cache_dir,
139+
storage_options,
140+
recompute_index,
133141
)
134142
logger.info(f"Discovered {len(self.files)} files.")
135143

src/litdata/raw/indexer.py

Lines changed: 143 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import json
1515
import logging
16+
import os
1617
import time
1718
from abc import ABC, abstractmethod
1819
from dataclasses import dataclass
@@ -24,6 +25,7 @@
2425

2526
logger = logging.getLogger(__name__)
2627
_SUPPORTED_PROVIDERS = ("s3", "gs", "azure")
28+
_INDEX_FILENAME = "index.json.zstd"
2729

2830

2931
@dataclass
@@ -49,38 +51,120 @@ def discover_files(self, input_dir: str, storage_options: Optional[dict[str, Any
4951
"""Discover dataset files and return their metadata."""
5052

5153
def build_or_load_index(
52-
self, input_dir: str, cache_dir: str, storage_options: Optional[dict[str, Any]]
54+
self,
55+
input_dir: str,
56+
cache_dir: str,
57+
storage_options: Optional[dict[str, Any]],
58+
recompute_index: bool = False,
5359
) -> list[FileMetadata]:
54-
"""Build or load a ZSTD-compressed index of file metadata."""
60+
"""Loads or builds a ZSTD-compressed index of dataset file metadata.
61+
This method attempts to load an existing index from local or remote cache, or builds a new one if needed.
62+
Use `recompute_index=True` to force rebuilding the index from the input directory.
63+
64+
Args:
65+
input_dir: Path to the dataset root directory.
66+
cache_dir: Directory for storing the index cache.
67+
storage_options: Optional storage backend options.
68+
recompute_index: If True, always rebuild the index.
69+
70+
Returns:
71+
List of FileMetadata objects for discovered files.
72+
73+
Raises:
74+
ModuleNotFoundError: If required dependencies are missing.
75+
ValueError: If no files are found in the input directory.
76+
"""
5577
if not _ZSTD_AVAILABLE:
5678
raise ModuleNotFoundError(str(_ZSTD_AVAILABLE))
5779

58-
import zstd
80+
if not _FSSPEC_AVAILABLE:
81+
raise ModuleNotFoundError(str(_FSSPEC_AVAILABLE))
5982

60-
index_path = Path(cache_dir) / "index.json.zstd"
83+
parsed_url = urlparse(input_dir)
84+
if parsed_url.scheme and parsed_url.scheme not in _SUPPORTED_PROVIDERS:
85+
raise ValueError(
86+
f"Unsupported input directory scheme: `{parsed_url.scheme}`. "
87+
f"Supported schemes are: {_SUPPORTED_PROVIDERS}"
88+
)
6189

62-
# Try loading cached index if it exists
63-
if index_path.exists():
64-
try:
65-
with open(index_path, "rb") as f:
66-
compressed_data = f.read()
67-
metadata = json.loads(zstd.decompress(compressed_data).decode("utf-8"))
90+
if not recompute_index:
91+
files = self._load_index_from_cache(input_dir, cache_dir, storage_options)
92+
if files:
93+
return files
6894

69-
return [FileMetadata.from_dict(file_data) for file_data in metadata["files"]]
70-
except (FileNotFoundError, json.JSONDecodeError, zstd.ZstdError, KeyError) as e:
71-
logger.warning(f"Failed to load cached index from {index_path}: {e}")
95+
return self._build_and_cache_index(input_dir, cache_dir, storage_options)
7296

73-
# Build fresh index
74-
logger.info(f"Building index for {input_dir} at {index_path}")
97+
def _load_index_from_cache(
98+
self, input_dir: str, cache_dir: str, storage_options: Optional[dict[str, Any]]
99+
) -> Optional[list[FileMetadata]]:
100+
"""Tries to load the index from local or remote cache."""
101+
# 1. Try to load index from local cache.
102+
local_index_path = Path(cache_dir) / _INDEX_FILENAME
103+
if local_index_path.exists():
104+
logger.info(f"Loading index from local cache: {local_index_path}")
105+
files = self._load_index_file(str(local_index_path))
106+
if files:
107+
logger.info(f"Loaded index from local cache: {local_index_path}")
108+
return files
109+
110+
# 2. If not found, try remote cache.
111+
remote_index_path = os.path.join(input_dir, _INDEX_FILENAME)
112+
try:
113+
self._download_from_cloud(remote_index_path, str(local_index_path), storage_options)
114+
files = self._load_index_file(str(local_index_path))
115+
if files:
116+
logger.info(f"Loaded index from remote cache: {remote_index_path}")
117+
return files
118+
except FileNotFoundError:
119+
logger.warning(f"Remote index not found at {remote_index_path}")
120+
except Exception as e:
121+
logger.error(f"Failed to download or load remote index: {e}")
122+
123+
return None
124+
125+
def _build_and_cache_index(
126+
self, input_dir: str, cache_dir: str, storage_options: Optional[dict[str, Any]]
127+
) -> list[FileMetadata]:
128+
"""Builds a new index and caches it locally and remotely."""
129+
local_index_path = Path(cache_dir) / _INDEX_FILENAME
130+
logger.info(f"Building index for {input_dir} at {local_index_path}")
75131
files = self.discover_files(input_dir, storage_options)
76132
if not files:
77133
raise ValueError(f"No files found in {input_dir}")
78134

79-
# Cache the index with ZSTD compression
80-
# TODO: upload the index to cloud storage
135+
self._save_index_file(str(local_index_path), files, input_dir)
136+
137+
# Upload to remote cache
138+
remote_index_path = os.path.join(input_dir, _INDEX_FILENAME)
139+
try:
140+
self._upload_to_cloud(str(local_index_path), remote_index_path, storage_options)
141+
logger.info(f"Uploaded index to remote cache: {remote_index_path}")
142+
except Exception as e:
143+
logger.warning(f"Failed to upload index to remote cache: {e}")
144+
145+
logger.info(f"Built index with {len(files)} files from {input_dir} at {local_index_path}")
146+
return files
147+
148+
def _load_index_file(self, index_path: str) -> Optional[list[FileMetadata]]:
149+
"""Loads and decodes an index file."""
150+
import zstd
151+
152+
try:
153+
with open(index_path, "rb") as f:
154+
compressed_data = f.read()
155+
metadata = json.loads(zstd.decompress(compressed_data).decode("utf-8"))
156+
return [FileMetadata.from_dict(file_data) for file_data in metadata["files"]]
157+
except (FileNotFoundError, json.JSONDecodeError, zstd.ZstdError, KeyError) as e:
158+
logger.warning(f"Failed to load index from local cache at `{index_path}`: {e}. ")
159+
return None
160+
161+
def _save_index_file(self, index_path: str, files: list[FileMetadata], source: str) -> None:
162+
"""Encodes and saves an index file."""
163+
import zstd
164+
81165
try:
82166
metadata = {
83-
"source": input_dir,
167+
"source": source,
84168
"files": [file.to_dict() for file in files],
85169
"created_at": time.time(),
86170
}
@@ -89,8 +173,35 @@ def build_or_load_index(
89173
except (OSError, zstd.ZstdError) as e:
90174
logger.warning(f"Error caching index to {index_path}: {e}")
91175

92-
logger.info(f"Built index with {len(files)} files from {input_dir} at {index_path}")
93-
return files
176+
def _download_from_cloud(
177+
self,
178+
remote_path: str,
179+
local_path: str,
180+
storage_options: Optional[dict[str, Any]],
181+
) -> None:
182+
"""Downloads a file from cloud storage."""
183+
if not _FSSPEC_AVAILABLE:
184+
raise ModuleNotFoundError(str(_FSSPEC_AVAILABLE))
185+
import fsspec
186+
187+
parsed_url = urlparse(remote_path)
188+
fs = fsspec.filesystem(parsed_url.scheme, **(storage_options or {}))
189+
fs.get(remote_path, local_path)
190+
191+
def _upload_to_cloud(
192+
self,
193+
local_path: str,
194+
remote_path: str,
195+
storage_options: Optional[dict[str, Any]],
196+
) -> None:
197+
"""Uploads a file to cloud storage."""
198+
if not _FSSPEC_AVAILABLE:
199+
raise ModuleNotFoundError(str(_FSSPEC_AVAILABLE))
200+
import fsspec
201+
202+
parsed_url = urlparse(remote_path)
203+
fs = fsspec.filesystem(parsed_url.scheme, **(storage_options or {}))
204+
fs.put(local_path, remote_path)
94205

95206

96207
class FileIndexer(BaseIndexer):
@@ -107,21 +218,20 @@ def __init__(
107218
def discover_files(self, input_dir: str, storage_options: Optional[dict[str, Any]]) -> list[FileMetadata]:
108219
"""Discover dataset files and return their metadata."""
109220
parsed_url = urlparse(input_dir)
221+
if parsed_url.scheme and parsed_url.scheme not in _SUPPORTED_PROVIDERS:
222+
raise ValueError(
223+
f"Unsupported input directory scheme: `{parsed_url.scheme}`. "
224+
f"Supported schemes are: {_SUPPORTED_PROVIDERS}"
225+
)
110226

111227
if parsed_url.scheme in _SUPPORTED_PROVIDERS: # Cloud storage
112228
return self._discover_cloud_files(input_dir, storage_options)
113229

114-
if not parsed_url.scheme: # Local filesystem
115-
return self._discover_local_files(input_dir)
116-
117-
raise ValueError(
118-
f"Unsupported input directory scheme: {parsed_url.scheme}. Supported schemes are: {_SUPPORTED_PROVIDERS}"
119-
)
230+
# Local filesystem
231+
return self._discover_local_files(input_dir)
120232

121233
def _discover_cloud_files(self, input_dir: str, storage_options: Optional[dict[str, Any]]) -> list[FileMetadata]:
122234
"""Recursively list files in a cloud storage bucket."""
123-
if not _FSSPEC_AVAILABLE:
124-
raise ModuleNotFoundError(str(_FSSPEC_AVAILABLE))
125235
import fsspec
126236

127237
obj = urlparse(input_dir)
@@ -173,6 +283,9 @@ def _discover_local_files(self, input_dir: str) -> list[FileMetadata]:
173283
return metadatas
174284

175285
def _should_include_file(self, file_path: str) -> bool:
176-
"""Return True if file matches allowed extensions."""
177-
file_ext = Path(file_path).suffix.lower()
286+
"""Return True if file matches allowed extensions and is not an index file."""
287+
path = Path(file_path)
288+
if path.name == _INDEX_FILENAME:
289+
return False
290+
file_ext = path.suffix.lower()
178291
return not self.extensions or file_ext in self.extensions

0 commit comments

Comments
 (0)