From 3e838fffa420de4453c85ccd3b09c747939e42de Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 27 Oct 2024 23:01:45 +0545 Subject: [PATCH 1/7] adds default lightning cache dir --- src/litdata/constants.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/litdata/constants.py b/src/litdata/constants.py index a6a714c72..f6e02fadb 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -23,6 +23,7 @@ _DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B _DEFAULT_FAST_DEV_RUN_ITEMS = 10 _DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".lightning", "chunks") +_DEFAULT_LIGHTNING_CACHE_DIR = os.path.join("/cache", "chunks") # This is required for full pytree serialization / deserialization support _TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0") From cccbde5f983efc4d42b523599db2a8ac031f25b6 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 27 Oct 2024 23:02:02 +0545 Subject: [PATCH 2/7] adds support for cache dir --- src/litdata/streaming/dataset.py | 15 ++++++++++++--- src/litdata/utilities/dataset_utilities.py | 17 ++++++++++++----- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index ea82ce7ae..a494bc45f 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -46,6 +46,7 @@ class StreamingDataset(IterableDataset): def __init__( self, input_dir: Union[str, "Dir"], + cache_dir: Optional[str] = None, item_loader: Optional[BaseItemLoader] = None, shuffle: bool = False, drop_last: Optional[bool] = None, @@ -61,6 +62,8 @@ def __init__( Args: input_dir: Path to the folder where the input data is stored. + cache_dir: Path to the folder where the cache data is stored. If not provided, the cache will be stored + in the default cache directory. item_loader: The logic to load an item from a chunk. shuffle: Whether to shuffle the data. drop_last: If `True`, drops the last items to ensure that @@ -84,12 +87,15 @@ def __init__( raise ValueError("subsample must be a float with value between 0 and 1.") input_dir = _resolve_dir(input_dir) + if cache_dir: + cache_dir = _resolve_dir(cache_dir) self.input_dir = input_dir + self.cache_dir = cache_dir self.subsampled_files: List[str] = [] self.region_of_interest: List[Tuple[int, int]] = [] self.subsampled_files, self.region_of_interest = subsample_streaming_dataset( - self.input_dir, item_loader, subsample, shuffle, seed, storage_options + self.input_dir, self.cache_dir, item_loader, subsample, shuffle, seed, storage_options ) self.item_loader = item_loader @@ -155,7 +161,8 @@ def set_epoch(self, current_epoch: int) -> None: def _create_cache(self, worker_env: _WorkerEnv) -> Cache: if _should_replace_path(self.input_dir.path): cache_path = _try_create_cache_dir( - input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url + input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url, + cache_dir=self.cache_dir.path if self.cache_dir else None, ) if cache_path is not None: self.input_dir.path = cache_path @@ -399,6 +406,7 @@ def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int "current_epoch": self.current_epoch, "input_dir_path": self.input_dir.path, "input_dir_url": self.input_dir.url, + "cache_dir_path": self.cache_dir.path if self.cache_dir else None, "item_loader": self.item_loader.state_dict() if self.item_loader else None, "drop_last": self.drop_last, "seed": self.seed, @@ -438,7 +446,8 @@ def _validate_state_dict(self) -> None: # In this case, validate the cache folder is the same. if _should_replace_path(state["input_dir_path"]): cache_path = _try_create_cache_dir( - input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"] + input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"], + cache_dir=state.get("cache_dir_path"), ) if cache_path != self.input_dir.path: raise ValueError( diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index a23d9e3f2..55d72260d 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -8,7 +8,7 @@ import numpy as np -from litdata.constants import _DEFAULT_CACHE_DIR, _INDEX_FILENAME +from litdata.constants import _DEFAULT_CACHE_DIR, _DEFAULT_LIGHTNING_CACHE_DIR, _INDEX_FILENAME from litdata.streaming.downloader import get_downloader_cls from litdata.streaming.item_loader import BaseItemLoader, TokensLoader from litdata.streaming.resolver import Dir, _resolve_dir @@ -17,6 +17,7 @@ def subsample_streaming_dataset( input_dir: Dir, + cache_dir: Optional[Dir] = None, item_loader: Optional[BaseItemLoader] = None, subsample: float = 1.0, shuffle: bool = False, @@ -39,7 +40,9 @@ def subsample_streaming_dataset( # Make sure input_dir contains cache path and remote url if _should_replace_path(input_dir.path): cache_path = _try_create_cache_dir( - input_dir=input_dir.path if input_dir.path else input_dir.url, storage_options=storage_options + input_dir=input_dir.path if input_dir.path else input_dir.url, + cache_dir=cache_dir.path if cache_dir else None, + storage_options=storage_options, ) if cache_path is not None: input_dir.path = cache_path @@ -137,7 +140,11 @@ def _clear_cache_dir_if_updated(input_dir_hash_filepath: str, updated_at_hash: s shutil.rmtree(input_dir_hash_filepath) -def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Dict] = {}) -> Optional[str]: +def _try_create_cache_dir( + input_dir: Optional[str], + cache_dir: Optional[str] = None, + storage_options: Optional[Dict] = {}, +) -> Optional[str]: resolved_input_dir = _resolve_dir(input_dir) updated_at = _read_updated_at(resolved_input_dir, storage_options) @@ -147,13 +154,13 @@ def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Di dir_url_hash = hashlib.md5((resolved_input_dir.url or "").encode()).hexdigest() # noqa: S324 if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ: - input_dir_hash_filepath = os.path.join(_DEFAULT_CACHE_DIR, dir_url_hash) + input_dir_hash_filepath = os.path.join(cache_dir or _DEFAULT_CACHE_DIR, dir_url_hash) _clear_cache_dir_if_updated(input_dir_hash_filepath, updated_at) cache_dir = os.path.join(input_dir_hash_filepath, updated_at) os.makedirs(cache_dir, exist_ok=True) return cache_dir - input_dir_hash_filepath = os.path.join("/cache", "chunks", dir_url_hash) + input_dir_hash_filepath = os.path.join(cache_dir or _DEFAULT_LIGHTNING_CACHE_DIR, dir_url_hash) _clear_cache_dir_if_updated(input_dir_hash_filepath, updated_at) cache_dir = os.path.join(input_dir_hash_filepath, updated_at) os.makedirs(cache_dir, exist_ok=True) From 3f47e7b5b02592b7a6fa71901fd7892c233bf688 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 27 Oct 2024 23:19:02 +0545 Subject: [PATCH 3/7] adds test_try_create_cache_dir_with_custom_cache_dir --- tests/utilities/test_dataset_utilities.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/utilities/test_dataset_utilities.py b/tests/utilities/test_dataset_utilities.py index fc08e6dfd..03d8d905c 100644 --- a/tests/utilities/test_dataset_utilities.py +++ b/tests/utilities/test_dataset_utilities.py @@ -42,6 +42,26 @@ def test_try_create_cache_dir(): assert len(makedirs_mock.mock_calls) == 2 +def test_try_create_cache_dir_with_custom_cache_dir(tmpdir): + cache_dir = str(tmpdir.join("cache")) + with mock.patch.dict(os.environ, {}, clear=True): + assert os.path.join( + cache_dir, "d41d8cd98f00b204e9800998ecf8427e", "100b8cad7cf2a56f6df78f171f97a1ec" + ) in _try_create_cache_dir("any", cache_dir) + + with ( + mock.patch.dict("os.environ", {"LIGHTNING_CLUSTER_ID": "abc", "LIGHTNING_CLOUD_PROJECT_ID": "123"}), + mock.patch("litdata.streaming.dataset.os.makedirs") as makedirs_mock, + ): + cache_dir_1 = _try_create_cache_dir("", cache_dir) + cache_dir_2 = _try_create_cache_dir("ssdf", cache_dir) + assert cache_dir_1 != cache_dir_2 + assert cache_dir_1 == os.path.join( + cache_dir, "d41d8cd98f00b204e9800998ecf8427e", "d41d8cd98f00b204e9800998ecf8427e" + ) + assert len(makedirs_mock.mock_calls) == 2 + + def test_generate_roi(): my_chunks = [ {"chunk_size": 30}, From 4c7f53cd9343e73170bc3553a1067cb44ece14e0 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 27 Oct 2024 23:26:49 +0545 Subject: [PATCH 4/7] fixed types --- src/litdata/streaming/dataset.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index a494bc45f..3623cbdca 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -46,7 +46,7 @@ class StreamingDataset(IterableDataset): def __init__( self, input_dir: Union[str, "Dir"], - cache_dir: Optional[str] = None, + cache_dir: Optional[Union[str, "Dir"]] = None, item_loader: Optional[BaseItemLoader] = None, shuffle: bool = False, drop_last: Optional[bool] = None, @@ -87,8 +87,7 @@ def __init__( raise ValueError("subsample must be a float with value between 0 and 1.") input_dir = _resolve_dir(input_dir) - if cache_dir: - cache_dir = _resolve_dir(cache_dir) + cache_dir = _resolve_dir(cache_dir) self.input_dir = input_dir self.cache_dir = cache_dir @@ -162,7 +161,7 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache: if _should_replace_path(self.input_dir.path): cache_path = _try_create_cache_dir( input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url, - cache_dir=self.cache_dir.path if self.cache_dir else None, + cache_dir=self.cache_dir.path, ) if cache_path is not None: self.input_dir.path = cache_path @@ -406,7 +405,7 @@ def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int "current_epoch": self.current_epoch, "input_dir_path": self.input_dir.path, "input_dir_url": self.input_dir.url, - "cache_dir_path": self.cache_dir.path if self.cache_dir else None, + "cache_dir_path": self.cache_dir.path, "item_loader": self.item_loader.state_dict() if self.item_loader else None, "drop_last": self.drop_last, "seed": self.seed, From 9dcd783750bf8c3b334600edbb45c629b9185cc3 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 27 Oct 2024 23:30:05 +0545 Subject: [PATCH 5/7] simplified --- src/litdata/utilities/dataset_utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 55d72260d..aa93567c7 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -41,7 +41,7 @@ def subsample_streaming_dataset( if _should_replace_path(input_dir.path): cache_path = _try_create_cache_dir( input_dir=input_dir.path if input_dir.path else input_dir.url, - cache_dir=cache_dir.path if cache_dir else None, + cache_dir=cache_dir.path, storage_options=storage_options, ) if cache_path is not None: From 347f5590295f12c50fc1eed7caa1c5ddea680752 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 27 Oct 2024 23:37:12 +0545 Subject: [PATCH 6/7] reverted change --- src/litdata/utilities/dataset_utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index aa93567c7..55d72260d 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -41,7 +41,7 @@ def subsample_streaming_dataset( if _should_replace_path(input_dir.path): cache_path = _try_create_cache_dir( input_dir=input_dir.path if input_dir.path else input_dir.url, - cache_dir=cache_dir.path, + cache_dir=cache_dir.path if cache_dir else None, storage_options=storage_options, ) if cache_path is not None: From 20be9e1aebe500283e6e3876f2bace67bd5d7b77 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 27 Oct 2024 23:58:41 +0545 Subject: [PATCH 7/7] adds cache_dir_path in test with statedict --- tests/streaming/test_combined.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/streaming/test_combined.py b/tests/streaming/test_combined.py index fb3aa5e42..f7e8a7d92 100644 --- a/tests/streaming/test_combined.py +++ b/tests/streaming/test_combined.py @@ -395,6 +395,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -410,6 +411,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -432,6 +434,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -447,6 +450,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -469,6 +473,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -484,6 +489,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -506,6 +512,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -521,6 +528,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -543,6 +551,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -558,6 +567,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -580,6 +590,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -595,6 +606,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -617,6 +629,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -632,6 +645,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 1, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -657,6 +671,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 2, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -672,6 +687,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 2, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -694,6 +710,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 2, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -709,6 +726,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 2, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -731,6 +749,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 2, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -746,6 +765,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 2, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -768,6 +788,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 2, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -783,6 +804,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 2, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -805,6 +827,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 2, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -820,6 +843,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 2, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -842,6 +866,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 2, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -857,6 +882,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 2, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -879,6 +905,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 2, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42, @@ -894,6 +921,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "current_epoch": 2, "input_dir_path": ANY, "input_dir_url": ANY, + "cache_dir_path": None, "item_loader": None, "drop_last": False, "seed": 42,