diff --git a/src/litdata/__about__.py b/src/litdata/__about__.py index 031d9f4de..2591bd7d8 100644 --- a/src/litdata/__about__.py +++ b/src/litdata/__about__.py @@ -14,7 +14,7 @@ import time -__version__ = "0.2.35" +__version__ = "0.2.36" __author__ = "Lightning AI et al." __author_email__ = "pytorch@lightning.ai" __license__ = "Apache-2.0" diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index beb6b3d4f..23221bf1f 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -194,10 +194,22 @@ def _remove_target(input_dir: Dir, cache_dir: str, queue_in: Queue) -> None: if os.path.exists(path): os.remove(path) - elif os.path.exists(path) and "s3_connections" not in path: + elif keep_path(path) and os.path.exists(path): os.remove(path) +def keep_path(path: str) -> bool: + paths = [ + "efs_connections", + "efs_folders", + "gcs_connections", + "s3_connections", + "s3_folders", + "snowflake_connections", + ] + return all(p not in path for p in paths) + + def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_dir: Dir) -> None: """Upload optimised chunks from a local to remote dataset directory.""" obj = parse.urlparse(output_dir.url if output_dir.url else output_dir.path) diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index 078c9fa2b..1e4cd71a1 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -86,9 +86,12 @@ def _apply_delete(self, chunk_index: int) -> None: chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] self._item_loader.delete(chunk_index, chunk_filepath) - locak_chunk_path = chunk_filepath + ".lock" - if os.path.exists(locak_chunk_path): - os.remove(locak_chunk_path) + try: + locak_chunk_path = chunk_filepath + ".lock" + if os.path.exists(locak_chunk_path): + os.remove(locak_chunk_path) + except FileNotFoundError: + pass def stop(self) -> None: """Receive the list of the chunk indices to download for the current epoch.""" diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index e7767f8a5..b1a968840 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -180,7 +180,7 @@ def _resolve_s3_folders(dir_path: str) -> Dir: if not data_connection: raise ValueError(f"We didn't find any matching data connection with the provided name `{target_name}`.") - return Dir(path=dir_path, url=data_connection[0].s3_folder.source) + return Dir(path=dir_path, url=os.path.join(data_connection[0].s3_folder.source, *dir_path.split("/")[4:])) def _resolve_datasets(dir_path: str) -> Dir: diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 55d72260d..901642623 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -98,7 +98,13 @@ def _should_replace_path(path: Optional[str]) -> bool: if path is None or path == "": return True - return path.startswith("/teamspace/datasets/") or path.startswith("/teamspace/s3_connections/") + return ( + path.startswith("/teamspace/datasets/") + or path.startswith("/teamspace/s3_connections/") + or path.startswith("/teamspace/s3_folders/") + or path.startswith("/teamspace/gcs_folders/") + or path.startswith("/teamspace/gcs_connections/") + ) def _read_updated_at(input_dir: Optional[Dir], storage_options: Optional[Dict] = {}) -> str: diff --git a/tests/streaming/test_resolver.py b/tests/streaming/test_resolver.py index d049e45c2..3b04d1a5a 100644 --- a/tests/streaming/test_resolver.py +++ b/tests/streaming/test_resolver.py @@ -155,7 +155,7 @@ def test_src_resolver_s3_folders(monkeypatch, lightning_cloud_mock): expected = "s3://imagenet-bucket" assert resolver._resolve_dir("/teamspace/s3_folders/debug_folder").url == expected - + assert resolver._resolve_dir("/teamspace/s3_folders/debug_folder/a/b/c").url == expected + "/a/b/c" auth.clear() diff --git a/tests/utilities/test_dataset_utilities.py b/tests/utilities/test_dataset_utilities.py index 03d8d905c..31dc0b38c 100644 --- a/tests/utilities/test_dataset_utilities.py +++ b/tests/utilities/test_dataset_utilities.py @@ -19,6 +19,9 @@ def test_should_replace_path(): assert not _should_replace_path(".../s3__connections/...") assert _should_replace_path("/teamspace/datasets/...") assert _should_replace_path("/teamspace/s3_connections/...") + assert _should_replace_path("/teamspace/s3_folders/...") + assert _should_replace_path("/teamspace/gcs_folders/...") + assert _should_replace_path("/teamspace/gcs_connections/...") assert not _should_replace_path("something_else")