Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B
_DEFAULT_FAST_DEV_RUN_ITEMS = 10
_DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".lightning", "chunks")
_DEFAULT_LIGHTNING_CACHE_DIR = os.path.join("/cache", "chunks")

# This is required for full pytree serialization / deserialization support
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
Expand Down
14 changes: 11 additions & 3 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class StreamingDataset(IterableDataset):
def __init__(
self,
input_dir: Union[str, "Dir"],
cache_dir: Optional[Union[str, "Dir"]] = None,
item_loader: Optional[BaseItemLoader] = None,
shuffle: bool = False,
drop_last: Optional[bool] = None,
Expand All @@ -61,6 +62,8 @@ def __init__(

Args:
input_dir: Path to the folder where the input data is stored.
cache_dir: Path to the folder where the cache data is stored. If not provided, the cache will be stored
in the default cache directory.
item_loader: The logic to load an item from a chunk.
shuffle: Whether to shuffle the data.
drop_last: If `True`, drops the last items to ensure that
Expand All @@ -84,12 +87,14 @@ def __init__(
raise ValueError("subsample must be a float with value between 0 and 1.")

input_dir = _resolve_dir(input_dir)
cache_dir = _resolve_dir(cache_dir)

self.input_dir = input_dir
self.cache_dir = cache_dir
self.subsampled_files: List[str] = []
self.region_of_interest: List[Tuple[int, int]] = []
self.subsampled_files, self.region_of_interest = subsample_streaming_dataset(
self.input_dir, item_loader, subsample, shuffle, seed, storage_options
self.input_dir, self.cache_dir, item_loader, subsample, shuffle, seed, storage_options
)

self.item_loader = item_loader
Expand Down Expand Up @@ -155,7 +160,8 @@ def set_epoch(self, current_epoch: int) -> None:
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
if _should_replace_path(self.input_dir.path):
cache_path = _try_create_cache_dir(
input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url
input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url,
cache_dir=self.cache_dir.path,
)
if cache_path is not None:
self.input_dir.path = cache_path
Expand Down Expand Up @@ -399,6 +405,7 @@ def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int
"current_epoch": self.current_epoch,
"input_dir_path": self.input_dir.path,
"input_dir_url": self.input_dir.url,
"cache_dir_path": self.cache_dir.path,
"item_loader": self.item_loader.state_dict() if self.item_loader else None,
"drop_last": self.drop_last,
"seed": self.seed,
Expand Down Expand Up @@ -438,7 +445,8 @@ def _validate_state_dict(self) -> None:
# In this case, validate the cache folder is the same.
if _should_replace_path(state["input_dir_path"]):
cache_path = _try_create_cache_dir(
input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"]
input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"],
cache_dir=state.get("cache_dir_path"),
)
if cache_path != self.input_dir.path:
raise ValueError(
Expand Down
17 changes: 12 additions & 5 deletions src/litdata/utilities/dataset_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np

from litdata.constants import _DEFAULT_CACHE_DIR, _INDEX_FILENAME
from litdata.constants import _DEFAULT_CACHE_DIR, _DEFAULT_LIGHTNING_CACHE_DIR, _INDEX_FILENAME
from litdata.streaming.downloader import get_downloader_cls
from litdata.streaming.item_loader import BaseItemLoader, TokensLoader
from litdata.streaming.resolver import Dir, _resolve_dir
Expand All @@ -17,6 +17,7 @@

def subsample_streaming_dataset(
input_dir: Dir,
cache_dir: Optional[Dir] = None,
item_loader: Optional[BaseItemLoader] = None,
subsample: float = 1.0,
shuffle: bool = False,
Expand All @@ -39,7 +40,9 @@ def subsample_streaming_dataset(
# Make sure input_dir contains cache path and remote url
if _should_replace_path(input_dir.path):
cache_path = _try_create_cache_dir(
input_dir=input_dir.path if input_dir.path else input_dir.url, storage_options=storage_options
input_dir=input_dir.path if input_dir.path else input_dir.url,
cache_dir=cache_dir.path if cache_dir else None,
storage_options=storage_options,
)
if cache_path is not None:
input_dir.path = cache_path
Expand Down Expand Up @@ -137,7 +140,11 @@ def _clear_cache_dir_if_updated(input_dir_hash_filepath: str, updated_at_hash: s
shutil.rmtree(input_dir_hash_filepath)


def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Dict] = {}) -> Optional[str]:
def _try_create_cache_dir(
input_dir: Optional[str],
cache_dir: Optional[str] = None,
storage_options: Optional[Dict] = {},
) -> Optional[str]:
resolved_input_dir = _resolve_dir(input_dir)
updated_at = _read_updated_at(resolved_input_dir, storage_options)

Expand All @@ -147,13 +154,13 @@ def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Di
dir_url_hash = hashlib.md5((resolved_input_dir.url or "").encode()).hexdigest() # noqa: S324

if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ:
input_dir_hash_filepath = os.path.join(_DEFAULT_CACHE_DIR, dir_url_hash)
input_dir_hash_filepath = os.path.join(cache_dir or _DEFAULT_CACHE_DIR, dir_url_hash)
_clear_cache_dir_if_updated(input_dir_hash_filepath, updated_at)
cache_dir = os.path.join(input_dir_hash_filepath, updated_at)
os.makedirs(cache_dir, exist_ok=True)
return cache_dir

input_dir_hash_filepath = os.path.join("/cache", "chunks", dir_url_hash)
input_dir_hash_filepath = os.path.join(cache_dir or _DEFAULT_LIGHTNING_CACHE_DIR, dir_url_hash)
_clear_cache_dir_if_updated(input_dir_hash_filepath, updated_at)
cache_dir = os.path.join(input_dir_hash_filepath, updated_at)
os.makedirs(cache_dir, exist_ok=True)
Expand Down
28 changes: 28 additions & 0 deletions tests/streaming/test_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -410,6 +411,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -432,6 +434,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -447,6 +450,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -469,6 +473,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -484,6 +489,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -506,6 +512,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -521,6 +528,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -543,6 +551,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -558,6 +567,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -580,6 +590,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -595,6 +606,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -617,6 +629,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -632,6 +645,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 1,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -657,6 +671,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -672,6 +687,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -694,6 +710,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -709,6 +726,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -731,6 +749,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -746,6 +765,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -768,6 +788,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -783,6 +804,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -805,6 +827,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -820,6 +843,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -842,6 +866,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -857,6 +882,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -879,6 +905,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand All @@ -894,6 +921,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"current_epoch": 2,
"input_dir_path": ANY,
"input_dir_url": ANY,
"cache_dir_path": None,
"item_loader": None,
"drop_last": False,
"seed": 42,
Expand Down
20 changes: 20 additions & 0 deletions tests/utilities/test_dataset_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,26 @@ def test_try_create_cache_dir():
assert len(makedirs_mock.mock_calls) == 2


def test_try_create_cache_dir_with_custom_cache_dir(tmpdir):
cache_dir = str(tmpdir.join("cache"))
with mock.patch.dict(os.environ, {}, clear=True):
assert os.path.join(
cache_dir, "d41d8cd98f00b204e9800998ecf8427e", "100b8cad7cf2a56f6df78f171f97a1ec"
) in _try_create_cache_dir("any", cache_dir)

with (
mock.patch.dict("os.environ", {"LIGHTNING_CLUSTER_ID": "abc", "LIGHTNING_CLOUD_PROJECT_ID": "123"}),
mock.patch("litdata.streaming.dataset.os.makedirs") as makedirs_mock,
):
cache_dir_1 = _try_create_cache_dir("", cache_dir)
cache_dir_2 = _try_create_cache_dir("ssdf", cache_dir)
assert cache_dir_1 != cache_dir_2
assert cache_dir_1 == os.path.join(
cache_dir, "d41d8cd98f00b204e9800998ecf8427e", "d41d8cd98f00b204e9800998ecf8427e"
)
assert len(makedirs_mock.mock_calls) == 2


def test_generate_roi():
my_chunks = [
{"chunk_size": 30},
Expand Down
Loading