From 5f8ea890b70fe6ce46b4c8294c776c3250737ab5 Mon Sep 17 00:00:00 2001 From: Josh Wills Date: Fri, 8 Aug 2025 10:36:54 -0700 Subject: [PATCH] streaming(reader): cleanup chunk lock files by prefix during delete; add non-local lock cleanup test --- src/litdata/streaming/reader.py | 15 +++--- tests/streaming/test_lock_cleanup.py | 81 ++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 7 deletions(-) create mode 100644 tests/streaming/test_lock_cleanup.py diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index 61a15bd5b..c186de059 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import glob import logging import os import warnings @@ -150,13 +151,13 @@ def _apply_delete(self, chunk_index: int) -> None: if _DEBUG: print(f"Deleted {chunk_filepath} by {self._rank or 0}. Debug: {can_delete_chunk}") - for lock_extension in [".lock", ".cnt.lock"]: - try: - locak_chunk_path = chunk_filepath + lock_extension - if os.path.exists(locak_chunk_path): - os.remove(locak_chunk_path) - except FileNotFoundError: - pass + base_name = os.path.basename(chunk_filepath) + base_prefix = os.path.splitext(base_name)[0] + cache_dir = os.path.dirname(chunk_filepath) + pattern = os.path.join(cache_dir, f"{base_prefix}*.lock") + for lock_path in glob.glob(pattern): + with suppress(FileNotFoundError, PermissionError): + os.remove(lock_path) def stop(self) -> None: """Receive the list of the chunk indices to download for the current epoch.""" diff --git a/tests/streaming/test_lock_cleanup.py b/tests/streaming/test_lock_cleanup.py new file mode 100644 index 000000000..c9deb5e12 --- /dev/null +++ b/tests/streaming/test_lock_cleanup.py @@ -0,0 +1,81 @@ +import os +import shutil +from contextlib import suppress + +import pytest +from filelock import FileLock, Timeout + +from litdata.constants import _ZSTD_AVAILABLE +from litdata.streaming.cache import Cache +from litdata.streaming.config import ChunkedIndex +from litdata.streaming.downloader import LocalDownloader, register_downloader, unregister_downloader +from litdata.streaming.reader import BinaryReader +from litdata.streaming.resolver import Dir + + +class LocalDownloaderNoLockCleanup(LocalDownloader): + """A Local downloader variant that does NOT remove the `.lock` file after download. + + This simulates behavior of non-local downloaders where the lockfile persists on disk + until Reader cleanup runs. Used to verify our centralized lock cleanup. + """ + + def download_file(self, remote_filepath: str, local_filepath: str) -> None: # type: ignore[override] + # Strip the custom scheme used for testing to map to local FS + if remote_filepath.startswith("s3+local://"): + remote_filepath = remote_filepath.replace("s3+local://", "") + if not os.path.exists(remote_filepath): + raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}") + + with ( + suppress(Timeout, FileNotFoundError), + FileLock(local_filepath + ".lock", timeout=0), + ): + if remote_filepath == local_filepath or os.path.exists(local_filepath): + return + temp_file_path = local_filepath + ".tmp" + shutil.copy(remote_filepath, temp_file_path) + os.rename(temp_file_path, local_filepath) + # Intentionally do NOT remove `local_filepath + ".lock"` here + + +@pytest.mark.skipif(not _ZSTD_AVAILABLE, reason="Requires: ['zstd']") +def test_reader_lock_cleanup_with_nonlocal_like_downloader(tmpdir): + cache_dir = os.path.join(tmpdir, "cache_dir") + remote_dir = os.path.join(tmpdir, "remote_dir") + os.makedirs(cache_dir, exist_ok=True) + + # Build a small compressed dataset + cache = Cache(input_dir=Dir(path=cache_dir, url=None), chunk_size=3, compression="zstd") + for i in range(10): + cache[i] = i + cache.done() + cache.merge() + + # Copy to a "remote" directory + shutil.copytree(cache_dir, remote_dir) + + # Use a custom scheme that we register to our test downloader + prefix = "s3+local://" + remote_url = prefix + remote_dir + + # Register the downloader and ensure we unregister afterwards + register_downloader(prefix, LocalDownloaderNoLockCleanup, overwrite=True) + try: + # Fresh cache dir for reading + shutil.rmtree(cache_dir) + os.makedirs(cache_dir, exist_ok=True) + + reader = BinaryReader(cache_dir=cache_dir, remote_input_dir=remote_url, compression="zstd", max_cache_size=1) + + # Iterate across enough samples to trigger multiple chunk downloads and deletions + for i in range(10): + idx = reader._get_chunk_index_from_index(i) + chunk_idx = ChunkedIndex(index=idx[0], chunk_index=idx[1], is_last_index=(i == 9)) + reader.read(chunk_idx) + + # At the end, no chunk-related lock files should remain + leftover_locks = [f for f in os.listdir(cache_dir) if f.endswith(".lock") and f.startswith("chunk-")] + assert leftover_locks == [] + finally: + unregister_downloader(prefix)