From 320ce9971f45451538cfc49909412421341deaee Mon Sep 17 00:00:00 2001 From: Rob Levy Date: Tue, 23 Sep 2025 23:03:13 +0000 Subject: [PATCH 01/15] Fix: Force delete prior to force download --- src/litdata/streaming/reader.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index c186de059..f45d5efed 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -130,17 +130,18 @@ def _decrement_local_lock(self, chunk_index: int) -> int: return curr_count return 0 - def _apply_delete(self, chunk_index: int) -> None: + def _apply_delete(self, chunk_index: int, skip_lock: bool = False) -> None: """Inform the item loader of the chunk to delete.""" # TODO: Fix the can_delete method can_delete_chunk = self._config.can_delete(chunk_index) chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] - remaining_locks = self._remaining_locks(chunk_filepath) - if remaining_locks > 0: # Can't delete this, something has it - if _DEBUG: - print(f"Skip delete {chunk_filepath} by {self._rank or 0}, current lock count: {remaining_locks}") - return + if not skip_lock: + remaining_locks = self._remaining_locks(chunk_filepath) + if remaining_locks > 0: # Can't delete this, something has it + if _DEBUG: + print(f"Skip delete {chunk_filepath} by {self._rank or 0}, current lock count: {remaining_locks}") + return if _DEBUG: with open(chunk_filepath + ".tmb", "w+") as tombstone_file: @@ -203,6 +204,8 @@ def _pre_load_chunk(self, chunk_index: int) -> None: def _force_download(self) -> None: chunk_index = _get_from_queue(self._force_download_queue) if chunk_index is not None: + # force apply deletion before redownload + self._apply_delete(chunk_index, skip_lock=True) if _DEBUG: chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] print(f"Requested force download for {chunk_filepath} by {self._rank}") From 6d794ffb8f4e2a8264fa7724a7c619b958dce658 Mon Sep 17 00:00:00 2001 From: Rob Levy Date: Wed, 24 Sep 2025 02:34:10 +0000 Subject: [PATCH 02/15] fix --- src/litdata/streaming/config.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index 14483e0ef..d8555002d 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -194,6 +194,10 @@ def try_decompress(self, local_chunkpath: str) -> None: exists = os.path.exists(local_chunkpath) and os.stat(local_chunkpath).st_size >= chunk_bytes while not exists: sleep(0.1) + # Return if the actual file exists + if os.path.exists(target_local_chunkpath): + return + # find the local compressed file exists = os.path.exists(local_chunkpath) and os.stat(local_chunkpath).st_size >= chunk_bytes if (time() - start_time) > _MAX_WAIT_TIME: From 816adc54c9f475500e98be995137abb68b4485e0 Mon Sep 17 00:00:00 2001 From: Rob Levy Date: Fri, 3 Oct 2025 11:42:44 +0000 Subject: [PATCH 03/15] update --- src/litdata/debugger.py | 178 +++++++++++++++++---------- src/litdata/streaming/compression.py | 4 +- src/litdata/streaming/config.py | 28 +---- src/litdata/streaming/dataloader.py | 2 +- src/litdata/streaming/dataset.py | 26 ++-- src/litdata/streaming/downloader.py | 86 ++++--------- src/litdata/streaming/item_loader.py | 40 +++--- src/litdata/streaming/reader.py | 53 +++++--- src/litdata/streaming/sampler.py | 1 + src/litdata/streaming/serializers.py | 66 +++++----- src/litdata/utilities/shuffle.py | 164 ++++++++++++++---------- 11 files changed, 343 insertions(+), 305 deletions(-) diff --git a/src/litdata/debugger.py b/src/litdata/debugger.py index 03fe465e8..023f59597 100644 --- a/src/litdata/debugger.py +++ b/src/litdata/debugger.py @@ -1,6 +1,6 @@ # Copyright The Lightning AI team. # Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. +# You may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 @@ -12,93 +12,141 @@ # limitations under the License. import logging -import os -import sys +import re +import threading +import time +from litdata.utilities.env import _DistributedEnv, _WorkerEnv, _is_in_dataloader_worker from functools import lru_cache +import os -from litdata.constants import _PRINT_DEBUG_LOGS -from litdata.utilities.env import _DistributedEnv, _WorkerEnv - -# Create the root logger for the library -root_logger = logging.getLogger("litdata") - +class TimedFlushFileHandler(logging.FileHandler): + """FileHandler that flushes every N seconds in a background thread.""" + def __init__(self, filename, mode='a', flush_interval=2): + super().__init__(filename, mode) + self.flush_interval = flush_interval + self._stop_event = threading.Event() + t = threading.Thread(target=self._flusher, daemon=True, name="TimedFlushFileHandler._flusher") + t.start() + + def _flusher(self): + while not self._stop_event.is_set(): + time.sleep(self.flush_interval) + self.flush() + + def close(self): + self._stop_event.set() + self.flush() + super().close() + +class EnvConfigFilter(logging.Filter): + """A logging filter that reads its configuration from environment variables.""" + def __init__(self): + super().__init__() + self.name_re = re.compile(r"name:\s*([^;]+);") + + def _get_name_from_msg(self, msg): + match = self.name_re.search(msg) + return match.group(1).strip() if match else None + + def filter(self, record): + """Determine if a log record should be processed by checking env vars.""" + is_iterating_dataset_enabled = os.getenv("LITDATA_LOG_ITERATING_DATASET", "True").lower() == "true" + is_getitem_enabled = os.getenv("LITDATA_LOG_GETITEM", "True").lower() == "true" + is_item_loader_enabled = os.getenv("LITDATA_LOG_ITEM_LOADER", "True").lower() == "true" + + log_name = self._get_name_from_msg(record.getMessage()) + + if log_name: + if not is_iterating_dataset_enabled and log_name.startswith("iterating_dataset"): + return False + if not is_getitem_enabled and log_name.startswith("getitem_dataset_for_chunk_index"): + return False + if not is_item_loader_enabled and log_name.startswith("item_loader"): + return False + + return True def get_logger_level(level: str) -> int: - """Get the log level from the level string.""" level = level.upper() if level in logging._nameToLevel: return logging._nameToLevel[level] - raise ValueError(f"Invalid log level: {level}. Valid levels: {list(logging._nameToLevel.keys())}.") - + raise ValueError(f"Invalid log level: {level}") class LitDataLogger: - def __init__(self, name: str): + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, name="litdata", flush_interval=2): + if hasattr(self, "logger"): + return # Already initialized + self.logger = logging.getLogger(name) + self.logger.propagate = False self.log_file, self.log_level = self.get_log_file_and_level() - self.setup_logger() + self.flush_interval = flush_interval + self._setup_logger() @staticmethod - def get_log_file_and_level() -> tuple[str, int]: + def get_log_file_and_level(): log_file = os.getenv("LITDATA_LOG_FILE", "litdata_debug.log") log_lvl = os.getenv("LITDATA_LOG_LEVEL", "DEBUG") + return log_file, get_logger_level(log_lvl) - log_lvl = get_logger_level(log_lvl) - - return log_file, log_lvl - - def setup_logger(self) -> None: - """Configures logging by adding handlers and formatting.""" - if len(self.logger.handlers) > 0: # Avoid duplicate handlers + def _setup_logger(self): + if self.logger.handlers: return - self.logger.setLevel(self.log_level) - - # Console handler - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(self.log_level) - - # File handler - file_handler = logging.FileHandler(self.log_file) - file_handler.setLevel(self.log_level) - - # Log format formatter = logging.Formatter( - "ts:%(created)s; logger_name:%(name)s; level:%(levelname)s; PID:%(process)d; TID:%(thread)d; %(message)s" + "ts:%(created)s;" + "PID:%(process)d; TID:%(thread)d; %(message)s" ) - # ENV - f"{WORLD_SIZE, GLOBAL_RANK, NNODES, LOCAL_RANK, NODE_RANK}" - console_handler.setFormatter(formatter) - file_handler.setFormatter(formatter) - - # Attach handlers - if _PRINT_DEBUG_LOGS: - self.logger.addHandler(console_handler) - self.logger.addHandler(file_handler) - - -def enable_tracer() -> None: + handler = TimedFlushFileHandler(self.log_file, flush_interval=self.flush_interval) + handler.setFormatter(formatter) + handler.setLevel(self.log_level) + self.logger.addHandler(handler) + + self.logger.filters = [f for f in self.logger.filters if not isinstance(f, EnvConfigFilter)] + self.logger.addFilter(EnvConfigFilter()) + + def get_logger(self): + return self.logger + +def enable_tracer(flush_interval: int = 5, item_loader=True, iterating_dataset=True, getitem_dataset_for_chunk_index=True) -> logging.Logger: + """ + Convenience function to enable and configure litdata logging. + This function SETS the environment variables that control the logging behavior. + """ os.environ["LITDATA_LOG_FILE"] = "litdata_debug.log" - LitDataLogger("litdata") + os.environ["LITDATA_LOG_ITEM_LOADER"] = str(item_loader) + os.environ["LITDATA_LOG_ITERATING_DATASET"] = str(iterating_dataset) + os.environ["LITDATA_LOG_GETITEM"] = str(getitem_dataset_for_chunk_index) + master_logger = LitDataLogger(flush_interval=flush_interval).get_logger() + return master_logger def _get_log_msg(data: dict) -> str: log_msg = "" - if "name" not in data or "ph" not in data: raise ValueError(f"Missing required keys in data dictionary. Required keys: 'name', 'ph'. Received: {data}") - env_info_data = env_info() data.update(env_info_data) - for key, value in data.items(): log_msg += f"{key}: {value};" return log_msg - -@lru_cache(maxsize=1) def env_info() -> dict: - dist_env = _DistributedEnv.detect() - worker_env = _WorkerEnv.detect() # will all threads read the same value if decorate this function with `@cache` + if _is_in_dataloader_worker(): + return _cached_env_info() + dist_env = _DistributedEnv.detect() + worker_env = _WorkerEnv.detect() return { "dist_world_size": dist_env.world_size, "dist_global_rank": dist_env.global_rank, @@ -107,17 +155,19 @@ def env_info() -> dict: "worker_rank": worker_env.rank, } +@lru_cache(maxsize=1) +def _cached_env_info() -> dict: + dist_env = _DistributedEnv.detect() + worker_env = _WorkerEnv.detect() + return { + "dist_world_size": dist_env.world_size, + "dist_global_rank": dist_env.global_rank, + "dist_num_nodes": dist_env.num_nodes, + "worker_world_size": worker_env.world_size, + "worker_rank": worker_env.rank, + } -# -> Chrome tracing colors -# url: https://chromium.googlesource.com/external/trace-viewer/+/bf55211014397cf0ebcd9e7090de1c4f84fc3ac0/tracing/tracing/ui/base/color_scheme.html - -# # ------ - - -# thread_state_iowait: {r: 182, g: 125, b: 143}, -# thread_state_running: {r: 126, g: 200, b: 148}, -# thread_state_runnable: {r: 133, g: 160, b: 210}, -# .... +# Chrome trace colors class ChromeTraceColors: PINK = "thread_state_iowait" GREEN = "thread_state_running" @@ -142,4 +192,4 @@ class ChromeTraceColors: LIGHT_RED = "cq_build_failed" MUSTARD_YELLOW = "cq_build_attempt_running" NEON_GREEN = "cq_build_attempt_passed" - DARK_RED = "cq_build_attempt_failed" + DARK_RED = "cq_build_attempt_failed" \ No newline at end of file diff --git a/src/litdata/streaming/compression.py b/src/litdata/streaming/compression.py index cae2e0935..556b49daf 100644 --- a/src/litdata/streaming/compression.py +++ b/src/litdata/streaming/compression.py @@ -62,9 +62,9 @@ def compress(self, data: bytes) -> bytes: def decompress(self, data: bytes) -> bytes: import zstd - logger.debug(_get_log_msg({"name": "Decompressing data", "ph": "B", "cname": ChromeTraceColors.MUSTARD_YELLOW})) + logger.debug(_get_log_msg({"name": "decompress", "ph": "B", "cname": ChromeTraceColors.MUSTARD_YELLOW})) decompressed_data = zstd.decompress(data) - logger.debug(_get_log_msg({"name": "Decompressed data", "ph": "E", "cname": ChromeTraceColors.MUSTARD_YELLOW})) + logger.debug(_get_log_msg({"name": "decompress", "ph": "E", "cname": ChromeTraceColors.MUSTARD_YELLOW})) return decompressed_data @classmethod diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index d8555002d..0af7513ad 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -139,7 +139,7 @@ def download_chunk_from_index(self, chunk_index: int, skip_lock: bool = False) - if self._downloader is not None and not skip_lock: # We don't want to redownload the base, but we should mark # it as having been requested by something - self._downloader._increment_local_lock(local_chunkpath.replace(f".{self._compressor_name}", "")) + self._downloader._increment_local_lock(local_chunkpath.replace(f".{self._compressor_name}", ""), chunk_index) pass return @@ -147,7 +147,7 @@ def download_chunk_from_index(self, chunk_index: int, skip_lock: bool = False) - return if not skip_lock: - self._downloader._increment_local_lock(local_chunkpath.replace(f".{self._compressor_name}", "")) + self._downloader._increment_local_lock(local_chunkpath.replace(f".{self._compressor_name}", ""), chunk_index) self._downloader.download_chunk_from_index(chunk_index) @@ -208,7 +208,10 @@ def try_decompress(self, local_chunkpath: str) -> None: # delete the files only if they were downloaded if self._downloader is not None: - os.remove(local_chunkpath) + try: + os.remove(local_chunkpath) + except FileNotFoundError: + pass data = self._compressor.decompress(data) @@ -283,15 +286,6 @@ def _get_chunk_index_from_index(self, index: int) -> tuple[int, int]: def __getitem__(self, index: ChunkedIndex) -> tuple[str, int, int]: """Find the associated chunk metadata.""" - logger.debug( - _get_log_msg( - { - "name": f"get_item_for_chunk_index_{index.chunk_index}_and_index_{index.index}", - "ph": "B", - "cname": ChromeTraceColors.LIGHT_GREEN, - } - ) - ) assert self._chunks is not None chunk = self._chunks[index.chunk_index] @@ -304,16 +298,6 @@ def __getitem__(self, index: ChunkedIndex) -> tuple[str, int, int]: filesize_bytes = chunk["chunk_bytes"] - logger.debug( - _get_log_msg( - { - "name": f"get_item_for_chunk_index_{index.chunk_index}_and_index_{index.index}", - "ph": "E", - "cname": ChromeTraceColors.LIGHT_GREEN, - } - ) - ) - return local_chunkpath, begin, filesize_bytes def _get_chunk_index_from_filename(self, chunk_filename: str) -> int: diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 07c165ef3..27205ceac 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -362,7 +362,7 @@ def wrap(*args: Any, **kwargs: Any) -> Any: if os.path.exists(output_file): os.remove(output_file) - tracer = VizTracer(output_file=output_file, verbose=0) + tracer = VizTracer(output_file=output_file, verbose=0, tracer_entries=100000000) tracer.start() result = func(*args, **kwargs) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 536d82a33..7a5599a3c 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -302,7 +302,6 @@ def get_len(self, num_workers: int, batch_size: int) -> int: def __iter__(self) -> "StreamingDataset": # When the StreamingDataset is used within map or optimize, let's refetch the distributed env. - logger.debug(_get_log_msg({"name": "iterating_dataset", "ph": "B"})) if os.getenv("DATA_OPTIMIZER_GLOBAL_RANK"): self.distributed_env = _DistributedEnv.detect() @@ -438,17 +437,7 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: _my_indices = list(range(start, stop, step)) _my_cache_indices = [ChunkedIndex(*self.cache._get_chunk_index_from_index(idx)) for idx in _my_indices] return [self.cache[chnk_idx] for chnk_idx in _my_cache_indices] - logger.debug( - _get_log_msg( - {"name": f"getitem_dataset_for_chunk_index_{index.chunk_index}_and_index_{index.index}", "ph": "B"} - ) - ) item = self.cache[index] - logger.debug( - _get_log_msg( - {"name": f"getitem_dataset_for_chunk_index_{index.chunk_index}_and_index_{index.index}", "ph": "E"} - ) - ) if hasattr(self, "transform"): if isinstance(self.transform, list): for transform_fn in self.transform: @@ -466,7 +455,6 @@ def __next__(self) -> Any: # if they are equal, means, worker has processed all the chunks self.current_epoch += 1 self.reset_state_dict() - logger.debug(_get_log_msg({"name": "iterating_dataset", "ph": "E"})) self.on_demand_bytes = True # reset on_demand_bytes to True raise StopIteration @@ -502,16 +490,20 @@ def __next__(self) -> Any: # Get the first index index = self.upcoming_indexes.pop(0) + chunk_indexes = None if self.has_triggered_download else self.worker_chunks[self.worker_next_chunk_index - 1 :] + is_last_index = (self.worker_next_chunk_index) == self.num_chunks and len(self.upcoming_indexes) == 0 + chunk_index = self.worker_chunks[self.worker_next_chunk_index - 1] + chunk_size = self.worker_intervals[self.worker_next_chunk_index - 1][2] - self.worker_intervals[self.worker_next_chunk_index - 1][1] + # Call the `__getitem__` method. data = self.__getitem__( ChunkedIndex( index=index, - chunk_index=self.worker_chunks[self.worker_next_chunk_index - 1], + chunk_index=chunk_index, # We provide the chunks indexes only one the first - chunk_indexes=None - if self.has_triggered_download - else self.worker_chunks[self.worker_next_chunk_index - 1 :], - is_last_index=(self.worker_next_chunk_index) == self.num_chunks and len(self.upcoming_indexes) == 0, + chunk_indexes=chunk_indexes, + is_last_index=is_last_index, + chunk_size=chunk_size ) ) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 6efad6429..eb94a7c99 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -15,13 +15,12 @@ import logging import os import shutil -import subprocess import tempfile from abc import ABC from contextlib import suppress from typing import Any, Optional from urllib import parse - +from time import time from filelock import FileLock, Timeout from litdata.constants import ( @@ -52,8 +51,7 @@ def __init__( self._chunks = chunks self._storage_options = storage_options or {} - def _increment_local_lock(self, chunkpath: str) -> None: - logger.debug(_get_log_msg({"name": f"increment_local_lock_for_{chunkpath}", "ph": "B"})) + def _increment_local_lock(self, chunkpath: str, chunk_index: int) -> None: countpath = chunkpath + ".cnt" with suppress(Timeout, FileNotFoundError), FileLock(countpath + ".lock", timeout=1): try: @@ -63,11 +61,12 @@ def _increment_local_lock(self, chunkpath: str) -> None: curr_count = 0 curr_count += 1 with open(countpath, "w+") as count_f: + logger.debug(_get_log_msg({"name": f"increment_lock_chunk_{chunk_index}_to_{curr_count}", "ph": "B"})) count_f.write(str(curr_count)) - logger.debug(_get_log_msg({"name": f"increment_local_lock_for_{chunkpath}", "ph": "E"})) + logger.debug(_get_log_msg({"name": f"increment_lock_chunk_{chunk_index}_to_{curr_count}", "ph": "E"})) def download_chunk_from_index(self, chunk_index: int) -> None: - logger.debug(_get_log_msg({"name": f"download_chunk_from_index_{chunk_index}", "ph": "B"})) + logger.debug(_get_log_msg({"name": f"download_chunk_{chunk_index}", "ph": "B"})) chunk_filename = self._chunks[chunk_index]["filename"] local_chunkpath = os.path.join(self._cache_dir, chunk_filename) @@ -75,7 +74,7 @@ def download_chunk_from_index(self, chunk_index: int) -> None: self.download_file(remote_chunkpath, local_chunkpath) - logger.debug(_get_log_msg({"name": f"download_chunk_from_index_{chunk_index}", "ph": "E"})) + logger.debug(_get_log_msg({"name": f"download_chunk_{chunk_index}", "ph": "E"})) def download_chunk_bytes_from_index(self, chunk_index: int, offset: int, length: int) -> bytes: chunk_filename = self._chunks[chunk_index]["filename"] @@ -118,12 +117,9 @@ def __init__( **kwargs: Any, ): super().__init__(remote_dir, cache_dir, chunks, storage_options) - self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 # check if kwargs contains session_options self.session_options = kwargs.get("session_options", {}) - - if not self._s5cmd_available or _DISABLE_S5CMD: - self._client = S3Client(storage_options=self._storage_options, session_options=self.session_options) + self._client = S3Client(storage_options=self._storage_options, session_options=self.session_options) def download_file(self, remote_filepath: str, local_filepath: str) -> None: obj = parse.urlparse(remote_filepath) @@ -138,61 +134,19 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: suppress(Timeout, FileNotFoundError), FileLock(local_filepath + ".lock", timeout=1 if obj.path.endswith(_INDEX_FILENAME) else 0), ): - if self._s5cmd_available and not _DISABLE_S5CMD: - env = None - if self._storage_options: - env = os.environ.copy() - env.update(self._storage_options) - - aws_no_sign_request = self._storage_options.get("AWS_NO_SIGN_REQUEST", "no").lower() == "yes" - # prepare the s5cmd command - no_signed_option = "--no-sign-request" if aws_no_sign_request else None - cmd_parts = ["s5cmd", no_signed_option, "cp", remote_filepath, local_filepath] - cmd = " ".join(part for part in cmd_parts if part) - - proc = subprocess.Popen( - cmd, - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, + from boto3.s3.transfer import TransferConfig + + extra_args: dict[str, Any] = {} + + if not os.path.exists(local_filepath): + # Issue: https://github.com/boto/boto3/issues/3113 + self._client.client.download_file( + obj.netloc, + obj.path.lstrip("/"), + local_filepath, + ExtraArgs=extra_args, + Config=TransferConfig(use_threads=False), ) - return_code = proc.wait() - - if return_code != 0: - stderr_output = proc.stderr.read().decode().strip() if proc.stderr else "" - error_message = ( - f"Failed to execute command `{cmd}` (exit code: {return_code}). " - "This might be due to an incorrect file path, insufficient permissions, or network issues. " - "To resolve this issue, you can either:\n" - "- Pass `storage_options` with the necessary credentials and endpoint. \n" - "- Example:\n" - " storage_options = {\n" - ' "AWS_ACCESS_KEY_ID": "your-key",\n' - ' "AWS_SECRET_ACCESS_KEY": "your-secret",\n' - ' "S3_ENDPOINT_URL": "https://s3.example.com" (Optional if using AWS)\n' - " }\n" - "- or disable `s5cmd` by setting `DISABLE_S5CMD=1` if `storage_options` do not work.\n" - ) - if stderr_output: - error_message += ( - f"For further debugging, please check the command output below:\n{stderr_output}" - ) - raise RuntimeError(error_message) - else: - from boto3.s3.transfer import TransferConfig - - extra_args: dict[str, Any] = {} - - if not os.path.exists(local_filepath): - # Issue: https://github.com/boto/boto3/issues/3113 - self._client.client.download_file( - obj.netloc, - obj.path.lstrip("/"), - local_filepath, - ExtraArgs=extra_args, - Config=TransferConfig(use_threads=False), - ) def download_bytes(self, remote_filepath: str, offset: int, length: int, local_chunkpath: str) -> bytes: obj = parse.urlparse(remote_filepath) @@ -296,6 +250,7 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: if not os.path.exists(local_filepath): # Issue: https://github.com/boto/boto3/issues/3113 + t0 = time() self._client.client.download_file( obj.netloc, obj.path.lstrip("/"), @@ -303,6 +258,7 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: ExtraArgs=extra_args, Config=TransferConfig(use_threads=False), ) + print("DOWNLOAD TIME", time() - t0) def download_bytes(self, remote_filepath: str, offset: int, length: int, local_chunkpath: str) -> bytes: obj = parse.urlparse(remote_filepath) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 1383fe250..74e4687f2 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -20,6 +20,7 @@ from multiprocessing import Queue from time import sleep, time from typing import Any, Optional, Union +from datetime import datetime import numpy as np import torch @@ -31,6 +32,7 @@ _POLARS_AVAILABLE, _PYARROW_AVAILABLE, _TORCH_DTYPES_MAPPING, + _DEBUG, ) from litdata.debugger import ChromeTraceColors, _get_log_msg from litdata.streaming.serializers import Serializer @@ -197,9 +199,6 @@ def load_item_from_chunk( # => 3 * 4 = 12 # each takes 4 bytes # => offset = 12 # - logger.debug( - _get_log_msg({"name": f"load_item_from_chunk_for_chunk_index_{chunk_index}_and_index_{index}", "ph": "B"}) - ) offset = (1 + (index - begin) if index >= begin else index + 1) * 4 if chunk_filepath != self._chunk_filepath: @@ -212,12 +211,17 @@ def load_item_from_chunk( exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= filesize_bytes if not requested_force_download and (time() - start_time) > _FORCE_DOWNLOAD_TIME: + if _DEBUG: + print(f"[ItemLoader] Requested force download for {chunk_filepath} at {datetime.now().isoformat()}") self.force_download(chunk_index) requested_force_download = True if (time() - start_time) > _MAX_WAIT_TIME: raise FileNotFoundError(f"The {chunk_filepath} hasn't been found.") + if time() - start_time > 5: + print("WAIT TIME", time() - start_time) + self._chunk_filepath = chunk_filepath if self._open_handle is not None: @@ -238,10 +242,6 @@ def load_item_from_chunk( else: item_data = self.deserialize(data) - logger.debug( - _get_log_msg({"name": f"load_item_from_chunk_for_chunk_index_{chunk_index}_and_index_{index}", "ph": "E"}) - ) - return item_data def _load_encrypted_data( @@ -330,18 +330,19 @@ def delete(self, chunk_index: int, chunk_filepath: str) -> None: logger.debug( _get_log_msg( { - "name": f"delete_chunk_for_chunk_index_{chunk_index}", + "name": f"delete_chunk_{chunk_index}", "ph": "B", "cname": ChromeTraceColors.BRIGHT_RED, } ) ) if os.path.exists(chunk_filepath): + print(f"delete_chunk_{chunk_index}") os.remove(chunk_filepath) logger.debug( _get_log_msg( { - "name": f"delete_chunk_for_chunk_index_{chunk_index}", + "name": f"delete_chunk_{chunk_index}", "ph": "E", "cname": ChromeTraceColors.BRIGHT_RED, } @@ -492,10 +493,6 @@ def load_item_from_chunk( if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath): del self._chunk_filepaths[chunk_filepath] - logger.debug( - _get_log_msg({"name": f"load_item_from_chunk_for_chunk_index_{chunk_index}_and_index_{index}", "ph": "B"}) - ) - if chunk_filepath not in self._chunk_filepaths: exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > filesize_bytes @@ -536,16 +533,13 @@ def load_item_from_chunk( # count: number of tokens to read from buffer => `self._block_size` data = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) # type: ignore - logger.debug( - _get_log_msg({"name": f"load_item_from_chunk_for_chunk_index_{chunk_index}_and_index_{index}", "ph": "E"}) - ) return data def delete(self, chunk_index: int, chunk_filepath: str) -> None: logger.debug( _get_log_msg( { - "name": f"delete_chunk_for_chunk_index_{chunk_index}", + "name": f"delete_chunk_{chunk_index}", "ph": "B", "cname": ChromeTraceColors.BRIGHT_RED, } @@ -563,7 +557,7 @@ def delete(self, chunk_index: int, chunk_filepath: str) -> None: logger.debug( _get_log_msg( { - "name": f"delete_chunk_for_chunk_index_{chunk_index}", + "name": f"delete_chunk_{chunk_index}", "ph": "E", "cname": ChromeTraceColors.BRIGHT_RED, } @@ -682,9 +676,6 @@ def load_item_from_chunk( if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath): del self._chunk_filepaths[chunk_filepath] - logger.debug( - _get_log_msg({"name": f"load_item_from_chunk_for_chunk_index_{chunk_index}_and_index_{index}", "ph": "B"}) - ) if chunk_filepath not in self._chunk_filepaths: exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= filesize_bytes @@ -701,9 +692,6 @@ def load_item_from_chunk( else: item_data = self._get_item(chunk_index, chunk_filepath, relative_index) - logger.debug( - _get_log_msg({"name": f"load_item_from_chunk_for_chunk_index_{chunk_index}_and_index_{index}", "ph": "E"}) - ) return item_data def _get_item_with_low_memory(self, chunk_index: int, chunk_filepath: str, row_index: int) -> Any: @@ -792,7 +780,7 @@ def delete(self, chunk_index: int, chunk_filepath: str) -> None: logger.debug( _get_log_msg( { - "name": f"delete_chunk_for_chunk_index_{chunk_index}", + "name": f"delete_chunk_{chunk_index}", "ph": "B", "cname": ChromeTraceColors.BRIGHT_RED, } @@ -810,7 +798,7 @@ def delete(self, chunk_index: int, chunk_filepath: str) -> None: logger.debug( _get_log_msg( { - "name": f"delete_chunk_for_chunk_index_{chunk_index}", + "name": f"delete_chunk_{chunk_index}", "ph": "E", "cname": ChromeTraceColors.BRIGHT_RED, } diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index f45d5efed..5779e7924 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -19,6 +19,7 @@ from queue import Empty, Queue from threading import Event, Thread from typing import Any, Optional, Union +from datetime import datetime import numpy as np from filelock import FileLock, Timeout @@ -55,7 +56,7 @@ def __init__( item_loader: BaseItemLoader, distributed_env: _DistributedEnv, max_cache_size: Optional[int] = None, - max_pre_download: int = 2, + max_pre_download: int = 5, rank: Optional[int] = None, ) -> None: super().__init__(daemon=True) @@ -64,6 +65,7 @@ def __init__( self._max_pre_download = max_pre_download self._pre_download_counter = 0 self._distributed_env = distributed_env + self._worker_env = _WorkerEnv.detect() self._chunks_index_to_be_deleted: list[int] = [] self._max_cache_size = max_cache_size @@ -80,6 +82,9 @@ def __init__( # Check whether a dataset slice fits on the node num_bytes_per_nodes = self._config.num_bytes // self._distributed_env.num_nodes self._delete_chunks_when_processed = num_bytes_per_nodes > max_cache_size if max_cache_size else False + if self._worker_env.rank == 0: + print(f"Delete chunks when processed: {self._delete_chunks_when_processed}") + self._has_exited = False def download(self, chunk_indexes: list[int]) -> None: @@ -111,7 +116,6 @@ def _decrement_local_lock(self, chunk_index: int) -> int: if not os.path.exists(countpath): return 0 with open(countpath) as count_f: - logger.debug(_get_log_msg({"name": f"decrement_local_lock_for_ {chunk_filepath}", "ph": "B"})) try: curr_count = int(count_f.read().strip()) except Exception: @@ -125,8 +129,9 @@ def _decrement_local_lock(self, chunk_index: int) -> int: os.remove(countpath + ".lock") else: with open(countpath, "w+") as count_f: + logger.debug(_get_log_msg({"name": f"decrement_lock_{chunk_index}_to_{curr_count}", "ph": "B"})) count_f.write(str(curr_count)) - logger.debug(_get_log_msg({"name": f"decrement_local_lock_for_ {chunk_filepath}", "ph": "E"})) + logger.debug(_get_log_msg({"name": f"decrement_lock_{chunk_index}_to_{curr_count}", "ph": "E"})) return curr_count return 0 @@ -149,8 +154,8 @@ def _apply_delete(self, chunk_index: int, skip_lock: bool = False) -> None: self._item_loader.delete(chunk_index, chunk_filepath) - if _DEBUG: - print(f"Deleted {chunk_filepath} by {self._rank or 0}. Debug: {can_delete_chunk}") + # if _DEBUG: + # print(f"Deleted {chunk_filepath} by {self._rank or 0}. Debug: {can_delete_chunk}") base_name = os.path.basename(chunk_filepath) base_prefix = os.path.splitext(base_name)[0] @@ -208,7 +213,7 @@ def _force_download(self) -> None: self._apply_delete(chunk_index, skip_lock=True) if _DEBUG: chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] - print(f"Requested force download for {chunk_filepath} by {self._rank}") + print(f"[Reader] Requested force download for {chunk_filepath} by {self._rank} at {datetime.now().isoformat()}") self._config.download_chunk_from_index(chunk_index, skip_lock=True) @@ -313,6 +318,7 @@ def __init__( self._prepare_thread: Optional[PrepareChunksThread] = None self._item_loader = item_loader or PyTreeLoader() self._last_chunk_index: Optional[int] = None + self._last_chunk_size: Optional[int] = None self._chunks_queued_for_download = False self._max_cache_size = int(os.getenv("MAX_CACHE_SIZE", max_cache_size or 0)) self._storage_options = storage_options @@ -367,9 +373,6 @@ def setup_thread_and_download_chunk(self, index: ChunkedIndex) -> None: assert self._prepare_thread self._prepare_thread.download([index.chunk_index]) - if self._last_chunk_index is None: - self._last_chunk_index = index.chunk_index - @property def config(self) -> ChunksConfig: if self._config is None: @@ -392,9 +395,6 @@ def read(self, index: ChunkedIndex) -> Any: Prefetching should reduce the wait time to be the batch available. """ - logger.debug( - _get_log_msg({"name": f"reader_reading_chunk_index_{index.chunk_index}_and_index_{index.index}", "ph": "B"}) - ) if not isinstance(index, ChunkedIndex): raise ValueError("The Reader.read(...) method expects a chunked Index.") @@ -434,20 +434,41 @@ def read(self, index: ChunkedIndex) -> Any: and (self._config._remote_dir or self._config._compressor) and index.chunk_index != self._last_chunk_index and self._prepare_thread is not None + and self._last_chunk_index is not None ): - assert self._last_chunk_index is not None - # inform the chunk has been completely consumed self._prepare_thread._decrement_local_lock(self._last_chunk_index) self._prepare_thread.delete([self._last_chunk_index]) if index.chunk_index != self._last_chunk_index: + if self._last_chunk_index is not None: + # 2. Log the "End" event for the previous chunk. + print(f"read_chunk_{self._last_chunk_index}_size_{self._last_chunk_size}") + logger.debug( + _get_log_msg({ + "name": f"read_chunk_{self._last_chunk_index}_size_{self._last_chunk_size}", + "ph": "E" + }) + ) + + print(f"read_chunk_{index.chunk_index}_size_{index.chunk_size}") + + # 2. Log the "Begin" event for the NEW chunk. + logger.debug( + _get_log_msg({ + "name": f"read_chunk_{index.chunk_index}_size_{index.chunk_size}", + "ph": "B" + }) + ) + + # Close the memory-mapped file for the last chunk index if isinstance(self._item_loader, (TokensLoader, ParquetLoader)) and self._last_chunk_index is not None: self._item_loader.close(self._last_chunk_index) # track the new chunk index as the latest one self._last_chunk_index = index.chunk_index + self._last_chunk_size = index.chunk_size if index.is_last_index and self._prepare_thread: # inform the thread it is time to stop @@ -465,11 +486,9 @@ def read(self, index: ChunkedIndex) -> Any: self._prepare_thread = None self._item_loader.close(self._last_chunk_index) self._last_chunk_index = None + self._last_chunk_size = None self._chunks_queued_for_download = False - logger.debug( - _get_log_msg({"name": f"reader_reading_chunk_index_{index.chunk_index}_and_index_{index.index}", "ph": "E"}) - ) return item def read_item_bytes(self, index: ChunkedIndex, begin: int) -> bytes: diff --git a/src/litdata/streaming/sampler.py b/src/litdata/streaming/sampler.py index 3e01ac607..b8df45829 100644 --- a/src/litdata/streaming/sampler.py +++ b/src/litdata/streaming/sampler.py @@ -45,6 +45,7 @@ class ChunkedIndex: index: int chunk_index: int + chunk_size: Optional[int] = None chunk_indexes: Optional[list[int]] = None is_last_index: bool = False diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index 9927fdfd9..ae631ab90 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -21,7 +21,7 @@ from copy import deepcopy from itertools import chain from typing import Any, Optional - +import struct import numpy as np import tifffile import torch @@ -232,43 +232,53 @@ def can_serialize(self, item: bytes) -> bool: class TensorSerializer(Serializer): - """The TensorSerializer serialize and deserialize tensor to and from bytes.""" + """An optimized TensorSerializer that is compatible with deepcopy/pickle.""" def __init__(self) -> None: super().__init__() self._dtype_to_indices = {v: k for k, v in _TORCH_DTYPES_MAPPING.items()} + self._header_struct_format = ">II" + self._header_struct = struct.Struct(self._header_struct_format) def serialize(self, item: torch.Tensor) -> tuple[bytes, Optional[str]]: - dtype_indice = self._dtype_to_indices[item.dtype] - data = [np.uint32(dtype_indice).tobytes()] - data.append(np.uint32(len(item.shape)).tobytes()) - for dim in item.shape: - data.append(np.uint32(dim).tobytes()) - data.append(item.numpy().tobytes(order="C")) - return b"".join(data), None - + if item.device.type != "cpu": + item = item.cpu() + + dtype_indice = self._dtype_to_indices[item.dtype] + + numpy_item = item.numpy(force=True) + rank = len(numpy_item.shape) + shape_format = f">{rank}I" + header_bytes = self._header_struct.pack(dtype_indice, rank) + shape_bytes = struct.pack(shape_format, *numpy_item.shape) + data_bytes = numpy_item.tobytes() + return b"".join([header_bytes, shape_bytes, data_bytes]), None + + # ... (rest of the class remains the same) ... def deserialize(self, data: bytes) -> torch.Tensor: - dtype_indice = np.frombuffer(data[0:4], np.uint32).item() + buffer_view = memoryview(data) + dtype_indice, rank = self._header_struct.unpack_from(buffer_view, 0) dtype = _TORCH_DTYPES_MAPPING[dtype_indice] - shape_size = np.frombuffer(data[4:8], np.uint32).item() - shape = [] - for shape_idx in range(shape_size): - shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item()) - idx_start = 8 + 4 * shape_size - idx_end = len(data) - if idx_end > idx_start: - tensor = torch.frombuffer(data[idx_start:idx_end], dtype=dtype) - else: - assert idx_start == idx_end, "The starting index should never be greater than end ending index." - tensor = torch.empty(shape, dtype=dtype) - shape = torch.Size(shape) - if tensor.shape == shape: - return tensor - return torch.reshape(tensor, shape) - - def can_serialize(self, item: torch.Tensor) -> bool: + header_size = self._header_struct.size + shape = struct.unpack_from(f">{rank}I", buffer_view, header_size) + data_start_offset = header_size + (rank * 4) + if data_start_offset < len(buffer_view): + tensor_1d = torch.frombuffer(buffer_view[data_start_offset:], dtype=dtype) + return tensor_1d.reshape(shape) + return torch.empty(shape, dtype=dtype) + + def can_serialize(self, item: any) -> bool: return isinstance(item, torch.Tensor) and len(item.shape) != 1 + def __getstate__(self) -> dict: + state = self.__dict__.copy() + del state["_header_struct"] + return state + + def __setstate__(self, state: dict) -> None: + self.__dict__.update(state) + self._header_struct = struct.Struct(self._header_struct_format) + class NoHeaderTensorSerializer(Serializer): """The TensorSerializer serialize and deserialize tensor to and from bytes.""" diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 7fa23ee38..669c2be2e 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -12,12 +12,12 @@ # limitations under the License. import copy -from typing import Any +from typing import Any, List, Tuple import numpy as np from litdata.streaming.item_loader import Interval -from litdata.utilities.env import _DistributedEnv +from litdata.utilities.env import _DistributedEnv, _WorkerEnv def _intra_node_chunk_shuffle( @@ -64,82 +64,120 @@ def _group_chunks_by_nodes( def _associate_chunks_and_intervals_to_workers( distributed_env: _DistributedEnv, - indexes: Any, - chunk_intervals: list[Interval], + indexes: List[int], + chunk_intervals: List[Interval], drop_last: bool = False, num_workers: int = 1, batch_size: int = 1, -) -> tuple[list[list[int]], list[Any]]: + only_multiple_of_batch_size: bool = True, +) -> Tuple[List[List[int]], List[List[Interval]]]: + """ + Associates chunks and their intervals to workers in a distributed environment. + """ + + if only_multiple_of_batch_size: + filtered_indexes = [] + filtered_chunk_intervals = [] + for index, interval in zip(indexes, chunk_intervals): + num_items_in_chunk = interval[2] - interval[1] + if num_items_in_chunk > 0 and num_items_in_chunk % batch_size == 0: + filtered_indexes.append(index) + filtered_chunk_intervals.append(interval) + indexes = filtered_indexes + chunk_intervals = filtered_chunk_intervals + num_items = sum([(interval[2] - interval[1]) for interval in chunk_intervals]) + + if batch_size == 0: + raise ValueError("batch_size cannot be zero.") + max_batches = num_items // batch_size global_num_workers = distributed_env.world_size * num_workers - - num_items_per_workers: Any = [] - - for rank in range(distributed_env.world_size): - tmp_arr = [0 for _ in range(num_workers)] - - num_batches_per_rank = int(max_batches // distributed_env.world_size) - base_batches = num_batches_per_rank // num_workers - rem_batches = num_batches_per_rank % num_workers - tmp_arr = [base_batches + (1 if i < rem_batches else 0) for i in range(num_workers)] - - if rank == distributed_env.world_size - 1: - # Find how batches were associated - num_assigned_items = batch_size * (sum(num_items_per_workers) + sum(tmp_arr)) - - # Multiply with the batch_size to get the number of items - if batch_size > 1: - tmp_arr = [x * batch_size for x in tmp_arr] - num_items_per_workers = [x * batch_size for x in num_items_per_workers] - - # If there are items left to assign, let's give it the last worker - left_items = num_items - num_assigned_items - if not drop_last and left_items > 0: - tmp_arr[rem_batches % num_workers] += left_items - - num_items_per_workers.extend(tmp_arr) - else: - num_items_per_workers.extend(tmp_arr) - - chunks_per_workers: list[list[int]] = [[] for _ in range(global_num_workers)] - intervals_per_workers: list[list[list[int]]] = [[] for _ in range(global_num_workers)] - - # 4. Assign the chunk & intervals to each rank + + if global_num_workers == 0: + raise ValueError("Cannot associate chunks with zero workers.") + + # --- FIX 1: Correctly and simply calculate items per worker --- + # 1. Distribute total batches evenly among all workers globally. + num_batches_per_worker = [0] * global_num_workers + if max_batches > 0: + base_batches_per_worker = max_batches // global_num_workers + rem_batches = max_batches % global_num_workers + for i in range(global_num_workers): + num_batches_per_worker[i] = base_batches_per_worker + (1 if i < rem_batches else 0) + + # 2. Convert batch counts to item counts for each worker in one step. + num_items_per_workers = [n * batch_size for n in num_batches_per_worker] + + # 3. Add remaining items to the last worker if not dropping the last batch. + rem_items = num_items % batch_size + if not drop_last and rem_items > 0: + # Assign remainder to the last worker that has items, or the very last worker + # This prevents assigning remainders to a worker that was supposed to get 0 items. + target_worker = -1 + for i in range(global_num_workers - 1, -1, -1): + if num_items_per_workers[i] > 0: + target_worker = i + break + num_items_per_workers[target_worker] += rem_items + + chunks_per_workers: List[List[int]] = [[] for _ in range(global_num_workers)] + intervals_per_workers: List[List[Interval]] = [[] for _ in range(global_num_workers)] + + # --- FIX 2: Use a single, persistent worker index --- + worker_idx = 0 + + # 4. Assign the chunk & intervals to each worker sequentially. for chunk_index, chunk_interval in zip(indexes, chunk_intervals): - rank = 0 + current_chunk_interval = chunk_interval + # Loop until the current chunk is fully assigned while True: - if rank == len(num_items_per_workers): + # Check if all workers have been filled + if worker_idx >= global_num_workers: break + + items_needed_by_worker = num_items_per_workers[worker_idx] - items_left_to_assign = num_items_per_workers[rank] - - if items_left_to_assign == 0: - rank += 1 + if items_needed_by_worker == 0: + worker_idx += 1 continue - - items_in_chunk = chunk_interval[2] - chunk_interval[1] - + + items_in_chunk = current_chunk_interval[2] - current_chunk_interval[1] + if items_in_chunk == 0: - break - - if items_in_chunk > items_left_to_assign: - chunks_per_workers[rank].append(chunk_index) - - chunk_start, chunk_roi_start, chunk_roi_end, chunk_end = chunk_interval - - intervals_per_workers[rank].append( - [chunk_start, chunk_roi_start, chunk_roi_start + items_left_to_assign, chunk_end] + break # Move to the next chunk + + if items_in_chunk > items_needed_by_worker: + # The worker needs only a part of the current chunk + chunks_per_workers[worker_idx].append(chunk_index) + + chunk_start, chunk_roi_start, chunk_roi_end, chunk_end = current_chunk_interval + split_point = chunk_roi_start + items_needed_by_worker + + intervals_per_workers[worker_idx].append( + (chunk_start, chunk_roi_start, split_point, chunk_end) ) - chunk_interval = Interval(chunk_start, chunk_roi_start + items_left_to_assign, chunk_roi_end, chunk_end) - num_items_per_workers[rank] = 0 - rank += 1 + + # Update the chunk interval to represent the remaining part + current_chunk_interval = (chunk_start, split_point, chunk_roi_end, chunk_end) + + num_items_per_workers[worker_idx] = 0 + worker_idx += 1 else: - chunks_per_workers[rank].append(chunk_index) - intervals_per_workers[rank].append(list(chunk_interval)) - num_items_per_workers[rank] -= items_in_chunk - break + # The worker takes the whole (remaining) chunk and may need more + chunks_per_workers[worker_idx].append(chunk_index) + intervals_per_workers[worker_idx].append(current_chunk_interval) + num_items_per_workers[worker_idx] -= items_in_chunk + break # Move to the next chunk + + worker_env = _WorkerEnv.detect() + + if worker_env.rank == 0: + print("HERE num_items_per_workers", num_items_per_workers) + for idx, internal in enumerate(intervals_per_workers): + total_items = sum(x[2] - x[1] for x in internal) + print(f"Worker {idx}: Batch Size={batch_size}, Items={total_items}, Chunks={internal}") return chunks_per_workers, intervals_per_workers From 8cede3bf1c7fbf6fc8631f600cbcf8fa983c0eec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Oct 2025 11:53:22 +0000 Subject: [PATCH 04/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/debugger.py | 34 +++++++++++++++--------- src/litdata/streaming/config.py | 9 ++++--- src/litdata/streaming/dataset.py | 8 +++--- src/litdata/streaming/downloader.py | 4 +-- src/litdata/streaming/item_loader.py | 8 +++--- src/litdata/streaming/reader.py | 21 ++++++--------- src/litdata/streaming/serializers.py | 11 ++++---- src/litdata/utilities/shuffle.py | 39 ++++++++++++---------------- 8 files changed, 71 insertions(+), 63 deletions(-) diff --git a/src/litdata/debugger.py b/src/litdata/debugger.py index 023f59597..bc2b1e6c5 100644 --- a/src/litdata/debugger.py +++ b/src/litdata/debugger.py @@ -12,16 +12,19 @@ # limitations under the License. import logging +import os import re import threading import time -from litdata.utilities.env import _DistributedEnv, _WorkerEnv, _is_in_dataloader_worker from functools import lru_cache -import os + +from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv + class TimedFlushFileHandler(logging.FileHandler): """FileHandler that flushes every N seconds in a background thread.""" - def __init__(self, filename, mode='a', flush_interval=2): + + def __init__(self, filename, mode="a", flush_interval=2): super().__init__(filename, mode) self.flush_interval = flush_interval self._stop_event = threading.Event() @@ -38,8 +41,10 @@ def close(self): self.flush() super().close() + class EnvConfigFilter(logging.Filter): """A logging filter that reads its configuration from environment variables.""" + def __init__(self): super().__init__() self.name_re = re.compile(r"name:\s*([^;]+);") @@ -66,12 +71,14 @@ def filter(self, record): return True + def get_logger_level(level: str) -> int: level = level.upper() if level in logging._nameToLevel: return logging._nameToLevel[level] raise ValueError(f"Invalid log level: {level}") + class LitDataLogger: _instance = None _lock = threading.Lock() @@ -85,7 +92,7 @@ def __new__(cls, *args, **kwargs): def __init__(self, name="litdata", flush_interval=2): if hasattr(self, "logger"): - return # Already initialized + return # Already initialized self.logger = logging.getLogger(name) self.logger.propagate = False @@ -103,10 +110,7 @@ def _setup_logger(self): if self.logger.handlers: return self.logger.setLevel(self.log_level) - formatter = logging.Formatter( - "ts:%(created)s;" - "PID:%(process)d; TID:%(thread)d; %(message)s" - ) + formatter = logging.Formatter("ts:%(created)s;PID:%(process)d; TID:%(thread)d; %(message)s") handler = TimedFlushFileHandler(self.log_file, flush_interval=self.flush_interval) handler.setFormatter(formatter) handler.setLevel(self.log_level) @@ -118,9 +122,11 @@ def _setup_logger(self): def get_logger(self): return self.logger -def enable_tracer(flush_interval: int = 5, item_loader=True, iterating_dataset=True, getitem_dataset_for_chunk_index=True) -> logging.Logger: - """ - Convenience function to enable and configure litdata logging. + +def enable_tracer( + flush_interval: int = 5, item_loader=True, iterating_dataset=True, getitem_dataset_for_chunk_index=True +) -> logging.Logger: + """Convenience function to enable and configure litdata logging. This function SETS the environment variables that control the logging behavior. """ os.environ["LITDATA_LOG_FILE"] = "litdata_debug.log" @@ -131,6 +137,7 @@ def enable_tracer(flush_interval: int = 5, item_loader=True, iterating_dataset=T master_logger = LitDataLogger(flush_interval=flush_interval).get_logger() return master_logger + def _get_log_msg(data: dict) -> str: log_msg = "" if "name" not in data or "ph" not in data: @@ -141,6 +148,7 @@ def _get_log_msg(data: dict) -> str: log_msg += f"{key}: {value};" return log_msg + def env_info() -> dict: if _is_in_dataloader_worker(): return _cached_env_info() @@ -155,6 +163,7 @@ def env_info() -> dict: "worker_rank": worker_env.rank, } + @lru_cache(maxsize=1) def _cached_env_info() -> dict: dist_env = _DistributedEnv.detect() @@ -167,6 +176,7 @@ def _cached_env_info() -> dict: "worker_rank": worker_env.rank, } + # Chrome trace colors class ChromeTraceColors: PINK = "thread_state_iowait" @@ -192,4 +202,4 @@ class ChromeTraceColors: LIGHT_RED = "cq_build_failed" MUSTARD_YELLOW = "cq_build_attempt_running" NEON_GREEN = "cq_build_attempt_passed" - DARK_RED = "cq_build_attempt_failed" \ No newline at end of file + DARK_RED = "cq_build_attempt_failed" diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index 0af7513ad..1873f118a 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -18,7 +18,6 @@ from typing import Any, Optional from litdata.constants import _INDEX_FILENAME, _MAX_WAIT_TIME -from litdata.debugger import ChromeTraceColors, _get_log_msg from litdata.streaming.compression import _COMPRESSORS, Compressor from litdata.streaming.downloader import get_downloader from litdata.streaming.item_loader import BaseItemLoader, Interval, PyTreeLoader, TokensLoader @@ -139,7 +138,9 @@ def download_chunk_from_index(self, chunk_index: int, skip_lock: bool = False) - if self._downloader is not None and not skip_lock: # We don't want to redownload the base, but we should mark # it as having been requested by something - self._downloader._increment_local_lock(local_chunkpath.replace(f".{self._compressor_name}", ""), chunk_index) + self._downloader._increment_local_lock( + local_chunkpath.replace(f".{self._compressor_name}", ""), chunk_index + ) pass return @@ -147,7 +148,9 @@ def download_chunk_from_index(self, chunk_index: int, skip_lock: bool = False) - return if not skip_lock: - self._downloader._increment_local_lock(local_chunkpath.replace(f".{self._compressor_name}", ""), chunk_index) + self._downloader._increment_local_lock( + local_chunkpath.replace(f".{self._compressor_name}", ""), chunk_index + ) self._downloader.download_chunk_from_index(chunk_index) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 7a5599a3c..1dca6c5b4 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -21,7 +21,6 @@ from litdata import __version__ from litdata.constants import _INDEX_FILENAME -from litdata.debugger import _get_log_msg from litdata.helpers import _check_version_and_prompt_upgrade from litdata.streaming import Cache from litdata.streaming.item_loader import BaseItemLoader, ParquetLoader @@ -493,7 +492,10 @@ def __next__(self) -> Any: chunk_indexes = None if self.has_triggered_download else self.worker_chunks[self.worker_next_chunk_index - 1 :] is_last_index = (self.worker_next_chunk_index) == self.num_chunks and len(self.upcoming_indexes) == 0 chunk_index = self.worker_chunks[self.worker_next_chunk_index - 1] - chunk_size = self.worker_intervals[self.worker_next_chunk_index - 1][2] - self.worker_intervals[self.worker_next_chunk_index - 1][1] + chunk_size = ( + self.worker_intervals[self.worker_next_chunk_index - 1][2] + - self.worker_intervals[self.worker_next_chunk_index - 1][1] + ) # Call the `__getitem__` method. data = self.__getitem__( @@ -503,7 +505,7 @@ def __next__(self) -> Any: # We provide the chunks indexes only one the first chunk_indexes=chunk_indexes, is_last_index=is_last_index, - chunk_size=chunk_size + chunk_size=chunk_size, ) ) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index eb94a7c99..7c28b2122 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -18,14 +18,14 @@ import tempfile from abc import ABC from contextlib import suppress +from time import time from typing import Any, Optional from urllib import parse -from time import time + from filelock import FileLock, Timeout from litdata.constants import ( _AZURE_STORAGE_AVAILABLE, - _DISABLE_S5CMD, _GOOGLE_STORAGE_AVAILABLE, _HF_HUB_AVAILABLE, _INDEX_FILENAME, diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 74e4687f2..2f343e369 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -16,23 +16,23 @@ from abc import ABC, abstractmethod from collections import defaultdict, namedtuple from copy import deepcopy +from datetime import datetime from io import BytesIO, FileIO from multiprocessing import Queue from time import sleep, time from typing import Any, Optional, Union -from datetime import datetime import numpy as np import torch from litdata.constants import ( + _DEBUG, _FORCE_DOWNLOAD_TIME, _MAX_WAIT_TIME, _NUMPY_DTYPES_MAPPING, _POLARS_AVAILABLE, _PYARROW_AVAILABLE, _TORCH_DTYPES_MAPPING, - _DEBUG, ) from litdata.debugger import ChromeTraceColors, _get_log_msg from litdata.streaming.serializers import Serializer @@ -212,7 +212,9 @@ def load_item_from_chunk( if not requested_force_download and (time() - start_time) > _FORCE_DOWNLOAD_TIME: if _DEBUG: - print(f"[ItemLoader] Requested force download for {chunk_filepath} at {datetime.now().isoformat()}") + print( + f"[ItemLoader] Requested force download for {chunk_filepath} at {datetime.now().isoformat()}" + ) self.force_download(chunk_index) requested_force_download = True diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index 5779e7924..e7acc755d 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -16,10 +16,10 @@ import os import warnings from contextlib import suppress +from datetime import datetime from queue import Empty, Queue from threading import Event, Thread from typing import Any, Optional, Union -from datetime import datetime import numpy as np from filelock import FileLock, Timeout @@ -213,7 +213,9 @@ def _force_download(self) -> None: self._apply_delete(chunk_index, skip_lock=True) if _DEBUG: chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] - print(f"[Reader] Requested force download for {chunk_filepath} by {self._rank} at {datetime.now().isoformat()}") + print( + f"[Reader] Requested force download for {chunk_filepath} by {self._rank} at {datetime.now().isoformat()}" + ) self._config.download_chunk_from_index(chunk_index, skip_lock=True) @@ -445,22 +447,15 @@ def read(self, index: ChunkedIndex) -> Any: # 2. Log the "End" event for the previous chunk. print(f"read_chunk_{self._last_chunk_index}_size_{self._last_chunk_size}") logger.debug( - _get_log_msg({ - "name": f"read_chunk_{self._last_chunk_index}_size_{self._last_chunk_size}", - "ph": "E" - }) + _get_log_msg( + {"name": f"read_chunk_{self._last_chunk_index}_size_{self._last_chunk_size}", "ph": "E"} + ) ) print(f"read_chunk_{index.chunk_index}_size_{index.chunk_size}") # 2. Log the "Begin" event for the NEW chunk. - logger.debug( - _get_log_msg({ - "name": f"read_chunk_{index.chunk_index}_size_{index.chunk_size}", - "ph": "B" - }) - ) - + logger.debug(_get_log_msg({"name": f"read_chunk_{index.chunk_index}_size_{index.chunk_size}", "ph": "B"})) # Close the memory-mapped file for the last chunk index if isinstance(self._item_loader, (TokensLoader, ParquetLoader)) and self._last_chunk_index is not None: diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index ae631ab90..897c8719d 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -14,6 +14,7 @@ import io import os import pickle +import struct import tempfile from abc import ABC, abstractmethod from collections import OrderedDict @@ -21,7 +22,7 @@ from copy import deepcopy from itertools import chain from typing import Any, Optional -import struct + import numpy as np import tifffile import torch @@ -243,9 +244,9 @@ def __init__(self) -> None: def serialize(self, item: torch.Tensor) -> tuple[bytes, Optional[str]]: if item.device.type != "cpu": item = item.cpu() - - dtype_indice = self._dtype_to_indices[item.dtype] - + + dtype_indice = self._dtype_to_indices[item.dtype] + numpy_item = item.numpy(force=True) rank = len(numpy_item.shape) shape_format = f">{rank}I" @@ -253,7 +254,7 @@ def serialize(self, item: torch.Tensor) -> tuple[bytes, Optional[str]]: shape_bytes = struct.pack(shape_format, *numpy_item.shape) data_bytes = numpy_item.tobytes() return b"".join([header_bytes, shape_bytes, data_bytes]), None - + # ... (rest of the class remains the same) ... def deserialize(self, data: bytes) -> torch.Tensor: buffer_view = memoryview(data) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 669c2be2e..dd1c1848c 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -71,10 +71,7 @@ def _associate_chunks_and_intervals_to_workers( batch_size: int = 1, only_multiple_of_batch_size: bool = True, ) -> Tuple[List[List[int]], List[List[Interval]]]: - """ - Associates chunks and their intervals to workers in a distributed environment. - """ - + """Associates chunks and their intervals to workers in a distributed environment.""" if only_multiple_of_batch_size: filtered_indexes = [] filtered_chunk_intervals = [] @@ -87,13 +84,13 @@ def _associate_chunks_and_intervals_to_workers( chunk_intervals = filtered_chunk_intervals num_items = sum([(interval[2] - interval[1]) for interval in chunk_intervals]) - + if batch_size == 0: raise ValueError("batch_size cannot be zero.") - + max_batches = num_items // batch_size global_num_workers = distributed_env.world_size * num_workers - + if global_num_workers == 0: raise ValueError("Cannot associate chunks with zero workers.") @@ -123,10 +120,10 @@ def _associate_chunks_and_intervals_to_workers( chunks_per_workers: List[List[int]] = [[] for _ in range(global_num_workers)] intervals_per_workers: List[List[Interval]] = [[] for _ in range(global_num_workers)] - + # --- FIX 2: Use a single, persistent worker index --- - worker_idx = 0 - + worker_idx = 0 + # 4. Assign the chunk & intervals to each worker sequentially. for chunk_index, chunk_interval in zip(indexes, chunk_intervals): current_chunk_interval = chunk_interval @@ -136,32 +133,30 @@ def _associate_chunks_and_intervals_to_workers( # Check if all workers have been filled if worker_idx >= global_num_workers: break - + items_needed_by_worker = num_items_per_workers[worker_idx] if items_needed_by_worker == 0: worker_idx += 1 continue - + items_in_chunk = current_chunk_interval[2] - current_chunk_interval[1] - + if items_in_chunk == 0: - break # Move to the next chunk + break # Move to the next chunk if items_in_chunk > items_needed_by_worker: # The worker needs only a part of the current chunk chunks_per_workers[worker_idx].append(chunk_index) - + chunk_start, chunk_roi_start, chunk_roi_end, chunk_end = current_chunk_interval split_point = chunk_roi_start + items_needed_by_worker - - intervals_per_workers[worker_idx].append( - (chunk_start, chunk_roi_start, split_point, chunk_end) - ) - + + intervals_per_workers[worker_idx].append((chunk_start, chunk_roi_start, split_point, chunk_end)) + # Update the chunk interval to represent the remaining part current_chunk_interval = (chunk_start, split_point, chunk_roi_end, chunk_end) - + num_items_per_workers[worker_idx] = 0 worker_idx += 1 else: @@ -169,7 +164,7 @@ def _associate_chunks_and_intervals_to_workers( chunks_per_workers[worker_idx].append(chunk_index) intervals_per_workers[worker_idx].append(current_chunk_interval) num_items_per_workers[worker_idx] -= items_in_chunk - break # Move to the next chunk + break # Move to the next chunk worker_env = _WorkerEnv.detect() From 9f7e1baa5591586196b44a901a6eb891eafe6d95 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 3 Oct 2025 13:04:12 +0100 Subject: [PATCH 05/15] update --- src/litdata/utilities/shuffle.py | 52 ++----- tests/utilities/test_shuffle.py | 246 +++++++++++++------------------ 2 files changed, 113 insertions(+), 185 deletions(-) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 669c2be2e..16f1c0d98 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -69,13 +69,11 @@ def _associate_chunks_and_intervals_to_workers( drop_last: bool = False, num_workers: int = 1, batch_size: int = 1, - only_multiple_of_batch_size: bool = True, + multiple_of_batch_size_only: bool = False, ) -> Tuple[List[List[int]], List[List[Interval]]]: - """ - Associates chunks and their intervals to workers in a distributed environment. - """ - - if only_multiple_of_batch_size: + """Associates chunks and their intervals to workers in a distributed environment.""" + + if multiple_of_batch_size_only: filtered_indexes = [] filtered_chunk_intervals = [] for index, interval in zip(indexes, chunk_intervals): @@ -87,18 +85,13 @@ def _associate_chunks_and_intervals_to_workers( chunk_intervals = filtered_chunk_intervals num_items = sum([(interval[2] - interval[1]) for interval in chunk_intervals]) - if batch_size == 0: raise ValueError("batch_size cannot be zero.") - max_batches = num_items // batch_size global_num_workers = distributed_env.world_size * num_workers - if global_num_workers == 0: raise ValueError("Cannot associate chunks with zero workers.") - # --- FIX 1: Correctly and simply calculate items per worker --- - # 1. Distribute total batches evenly among all workers globally. num_batches_per_worker = [0] * global_num_workers if max_batches > 0: base_batches_per_worker = max_batches // global_num_workers @@ -106,14 +99,9 @@ def _associate_chunks_and_intervals_to_workers( for i in range(global_num_workers): num_batches_per_worker[i] = base_batches_per_worker + (1 if i < rem_batches else 0) - # 2. Convert batch counts to item counts for each worker in one step. num_items_per_workers = [n * batch_size for n in num_batches_per_worker] - - # 3. Add remaining items to the last worker if not dropping the last batch. rem_items = num_items % batch_size if not drop_last and rem_items > 0: - # Assign remainder to the last worker that has items, or the very last worker - # This prevents assigning remainders to a worker that was supposed to get 0 items. target_worker = -1 for i in range(global_num_workers - 1, -1, -1): if num_items_per_workers[i] > 0: @@ -123,56 +111,40 @@ def _associate_chunks_and_intervals_to_workers( chunks_per_workers: List[List[int]] = [[] for _ in range(global_num_workers)] intervals_per_workers: List[List[Interval]] = [[] for _ in range(global_num_workers)] - - # --- FIX 2: Use a single, persistent worker index --- - worker_idx = 0 - - # 4. Assign the chunk & intervals to each worker sequentially. + worker_idx = 0 for chunk_index, chunk_interval in zip(indexes, chunk_intervals): current_chunk_interval = chunk_interval - - # Loop until the current chunk is fully assigned while True: - # Check if all workers have been filled if worker_idx >= global_num_workers: break - items_needed_by_worker = num_items_per_workers[worker_idx] - if items_needed_by_worker == 0: worker_idx += 1 continue - items_in_chunk = current_chunk_interval[2] - current_chunk_interval[1] - if items_in_chunk == 0: - break # Move to the next chunk - + break if items_in_chunk > items_needed_by_worker: - # The worker needs only a part of the current chunk chunks_per_workers[worker_idx].append(chunk_index) - chunk_start, chunk_roi_start, chunk_roi_end, chunk_end = current_chunk_interval split_point = chunk_roi_start + items_needed_by_worker + # FIX: Ensure we always append an Interval object intervals_per_workers[worker_idx].append( - (chunk_start, chunk_roi_start, split_point, chunk_end) + Interval(chunk_start, chunk_roi_start, split_point, chunk_end) ) - - # Update the chunk interval to represent the remaining part - current_chunk_interval = (chunk_start, split_point, chunk_roi_end, chunk_end) - + # FIX: Ensure the updated interval is also an Interval object + current_chunk_interval = Interval(chunk_start, split_point, chunk_roi_end, chunk_end) num_items_per_workers[worker_idx] = 0 worker_idx += 1 else: - # The worker takes the whole (remaining) chunk and may need more chunks_per_workers[worker_idx].append(chunk_index) + # FIX: Ensure we always append an Interval object (not a list) intervals_per_workers[worker_idx].append(current_chunk_interval) num_items_per_workers[worker_idx] -= items_in_chunk - break # Move to the next chunk + break worker_env = _WorkerEnv.detect() - if worker_env.rank == 0: print("HERE num_items_per_workers", num_items_per_workers) for idx, internal in enumerate(intervals_per_workers): diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index b54cb1b44..555335d40 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -95,193 +95,149 @@ def test_group_chunks_by_nodes(): def test_associate_chunks_and_intervals_to_workers(): - indexes = [0, 1, 2, 3, 4, 5, 6, 7] - chunk_intervals = [ - Interval(0, 0, 50, 50), - Interval(0, 0, 50, 50), - Interval(0, 0, 50, 50), - Interval(0, 0, 50, 50), - Interval(0, 0, 50, 50), - Interval(0, 0, 50, 50), - Interval(0, 0, 50, 50), - Interval(0, 0, 50, 50), - ] + indexes = list(range(8)) + # Test Case 1: Even distribution + # 400 items / 4 global workers = 100 items/worker. + chunk_intervals = [Interval(0, 0, 50, 50)] * 8 workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(4, 1, 2), - indexes, - chunk_intervals, - drop_last=True, + _DistributedEnv(4, 1, 2), indexes, chunk_intervals, drop_last=True, num_workers=1 ) - assert workers_chunks == [[0, 1], [2, 3], [4, 5], [6, 7]] assert workers_intervals == [ - [[0, 0, 50, 50], [0, 0, 50, 50]], - [[0, 0, 50, 50], [0, 0, 50, 50]], - [[0, 0, 50, 50], [0, 0, 50, 50]], - [[0, 0, 50, 50], [0, 0, 50, 50]], + [Interval(0, 0, 50, 50), Interval(0, 0, 50, 50)], + [Interval(0, 0, 50, 50), Interval(0, 0, 50, 50)], + [Interval(0, 0, 50, 50), Interval(0, 0, 50, 50)], + [Interval(0, 0, 50, 50), Interval(0, 0, 50, 50)], ] + # Test Case 2: Uneven distribution + # Total items = 422. With drop_last=True, 422 batches (size 1) / 4 workers = 105 batches/worker, remainder 2. + # So, workers get item counts of [106, 106, 105, 105]. chunk_intervals = [ - Interval(0, 0, 50, 50), - Interval(0, 0, 150, 150), - Interval(0, 0, 50, 50), - Interval(0, 0, 12, 12), - Interval(0, 0, 50, 50), - Interval(0, 0, 27, 27), - Interval(0, 0, 50, 50), - Interval(0, 0, 33, 33), + Interval(0, 0, 50, 50), Interval(0, 0, 150, 150), Interval(0, 0, 50, 50), Interval(0, 0, 12, 12), + Interval(0, 0, 50, 50), Interval(0, 0, 27, 27), Interval(0, 0, 50, 50), Interval(0, 0, 33, 33), ] - workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(4, 1, 2), - indexes, - chunk_intervals, - drop_last=True, + _DistributedEnv(4, 1, 2), indexes, chunk_intervals, drop_last=True, num_workers=1 ) - assert workers_chunks == [[0, 1], [1, 2], [2, 3, 4, 5], [5, 6, 7]] - assert sum([interval[2] - interval[1] for interval in workers_intervals[0]]) == 105 - assert sum([interval[2] - interval[1] for interval in workers_intervals[1]]) == 105 - assert sum([interval[2] - interval[1] for interval in workers_intervals[2]]) == 105 - assert sum([interval[2] - interval[1] for interval in workers_intervals[3]]) == 105 - + assert sum(i.roi_end_idx - i.roi_start_idx for i in workers_intervals[0]) == 106 + assert sum(i.roi_end_idx - i.roi_start_idx for i in workers_intervals[1]) == 106 + assert sum(i.roi_end_idx - i.roi_start_idx for i in workers_intervals[2]) == 105 + assert sum(i.roi_end_idx - i.roi_start_idx for i in workers_intervals[3]) == 105 assert workers_intervals == [ - [[0, 0, 50, 50], [0, 0, 55, 150]], - [[0, 55, 150, 150], [0, 0, 10, 50]], - [[0, 10, 50, 50], [0, 0, 12, 12], [0, 0, 50, 50], [0, 0, 3, 27]], - [[0, 3, 27, 27], [0, 0, 50, 50], [0, 0, 31, 33]], + [Interval(0, 0, 50, 50), Interval(0, 0, 56, 150)], + [Interval(0, 56, 150, 150), Interval(0, 0, 12, 50)], + [Interval(0, 12, 50, 50), Interval(0, 0, 12, 12), Interval(0, 0, 50, 50), Interval(0, 0, 5, 27)], + [Interval(0, 5, 27, 27), Interval(0, 0, 50, 50), Interval(0, 0, 33, 33)], ] + # Test Case 3: Another uneven distribution + # Total items = 256. 256 / 4 workers = 64 items/worker. No remainder. chunk_intervals = [ - Interval(0, 0, 5, 5), - Interval(0, 0, 150, 150), - Interval(0, 0, 7, 7), - Interval(0, 0, 12, 12), - Interval(0, 0, 4, 4), - Interval(0, 0, 27, 27), - Interval(0, 0, 50, 50), - Interval(0, 0, 1, 1), + Interval(0, 0, 5, 5), Interval(0, 0, 150, 150), Interval(0, 0, 7, 7), Interval(0, 0, 12, 12), + Interval(0, 0, 4, 4), Interval(0, 0, 27, 27), Interval(0, 0, 50, 50), Interval(0, 0, 1, 1), ] - workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(4, 1, 2), - indexes, - chunk_intervals, - drop_last=True, + _DistributedEnv(4, 1, 2), indexes, chunk_intervals, drop_last=True, num_workers=1 ) - assert workers_chunks == [[0, 1], [1], [1, 2, 3, 4, 5], [5, 6, 7]] - assert sum([interval[2] - interval[1] for interval in workers_intervals[0]]) == 64 - assert sum([interval[2] - interval[1] for interval in workers_intervals[1]]) == 64 - assert sum([interval[2] - interval[1] for interval in workers_intervals[2]]) == 64 - assert sum([interval[2] - interval[1] for interval in workers_intervals[3]]) == 64 + assert all(sum(i.roi_end_idx - i.roi_start_idx for i in w_intervals) == 64 for w_intervals in workers_intervals) assert workers_intervals == [ - [[0, 0, 5, 5], [0, 0, 59, 150]], - [[0, 59, 123, 150]], - [[0, 123, 150, 150], [0, 0, 7, 7], [0, 0, 12, 12], [0, 0, 4, 4], [0, 0, 14, 27]], - [[0, 14, 27, 27], [0, 0, 50, 50], [0, 0, 1, 1]], - ] - - chunk_intervals = [ - Interval(0, 0, 6, 6), - Interval(0, 0, 6, 6), - Interval(0, 0, 6, 6), - Interval(0, 0, 6, 6), + [Interval(0, 0, 5, 5), Interval(0, 0, 59, 150)], + [Interval(0, 59, 123, 150)], + [Interval(0, 123, 150, 150), Interval(0, 0, 7, 7), Interval(0, 0, 12, 12), Interval(0, 0, 4, 4), Interval(0, 0, 14, 27)], + [Interval(0, 14, 27, 27), Interval(0, 0, 50, 50), Interval(0, 0, 1, 1)], ] + # Test Case 4: Many workers, small data + # 24 items, batch_size 6 -> 4 batches. 4 batches / 8 global workers. + # First 4 workers get 1 batch (6 items) each. + chunk_intervals = [Interval(0, 0, 6, 6)] * 4 workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(1, 0, 1), range(0, 4), chunk_intervals, False, 8, 6 + _DistributedEnv(1, 0, 1), range(4), chunk_intervals, drop_last=False, num_workers=8, batch_size=6 ) - - assert workers_intervals == [[[0, 0, 6, 6]], [[0, 0, 6, 6]], [[0, 0, 6, 6]], [[0, 0, 6, 6]], [], [], [], []] assert workers_chunks == [[0], [1], [2], [3], [], [], [], []] + assert workers_intervals == [ + [Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)], + [], [], [], [] + ] + # Test Case 5: Multi-node + # 24 items, batch_size 6 -> 4 batches. 4 batches / 16 global workers. + # First 4 workers get 1 batch (6 items) each. workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(2, 0, 1), range(0, 4), chunk_intervals, False, 8, 6 + _DistributedEnv(2, 0, 1), range(4), chunk_intervals, drop_last=False, num_workers=8, batch_size=6 ) - - assert workers_chunks == [[0], [1], [], [], [], [], [], [], [2], [3], [], [], [], [], [], []] + assert workers_chunks == [[0], [1], [2], [3]] + [[]] * 12 assert workers_intervals == [ - [[0, 0, 6, 6]], - [[0, 0, 6, 6]], - [], - [], - [], - [], - [], - [], - [[0, 0, 6, 6]], - [[0, 0, 6, 6]], - [], - [], - [], - [], - [], - [], - ] + [Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)] + ] + [[]] * 12 + # Test Case 6: Small workers, large batch + # 24 items, batch_size 8 -> 3 batches. 3 batches / 2 global workers. + # Worker 0 gets 2 batches (16 items), worker 1 gets 1 batch (8 items). workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(1, 0, 1), range(0, 4), chunk_intervals, False, 2, 8 + _DistributedEnv(1, 0, 1), range(4), chunk_intervals, drop_last=False, num_workers=2, batch_size=8 ) assert workers_chunks == [[0, 1, 2], [2, 3]] - assert workers_intervals == [[[0, 0, 6, 6], [0, 0, 6, 6], [0, 0, 4, 6]], [[0, 4, 6, 6], [0, 0, 6, 6]]] - - chunk_intervals = [ - Interval(0, 0, 6, 6), - Interval(0, 0, 7, 7), - Interval(0, 0, 6, 6), - Interval(0, 0, 7, 8), + assert workers_intervals == [ + [Interval(0, 0, 6, 6), Interval(0, 0, 6, 6), Interval(0, 0, 4, 6)], + [Interval(0, 4, 6, 6), Interval(0, 0, 6, 6)], ] + # Test Case 7: Uneven chunks with remainder (drop_last=False) + # Total items 26, batch_size 6 -> 4 batches, 2 remainder items. 4 batches / 16 workers. + # First 4 workers get 1 batch (6 items). Last of these (worker 3) gets the 2 remainder items. + # Item counts: [6, 6, 6, 8, 0, ...]. + chunk_intervals = [Interval(0, 0, 6, 6), Interval(0, 0, 7, 7), Interval(0, 0, 6, 6), Interval(0, 0, 7, 8)] workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(2, 0, 1), range(0, 4), chunk_intervals, False, 8, 6 + _DistributedEnv(2, 0, 1), range(4), chunk_intervals, drop_last=False, num_workers=8, batch_size=6 ) - - assert sum([y[2] - y[1] for x in workers_intervals for y in x]) == 26 - assert workers_chunks == [[0], [1], [], [], [], [], [], [], [1, 2], [2, 3], [3], [], [], [], [], []] + assert sum(i.roi_end_idx - i.roi_start_idx for w in workers_intervals for i in w) == 26 + assert workers_chunks == [[0], [1], [1, 2], [2, 3]] + [[]] * 12 assert workers_intervals == [ - [[0, 0, 6, 6]], - [[0, 0, 6, 7]], - [], - [], - [], - [], - [], - [], - [[0, 6, 7, 7], [0, 0, 5, 6]], - [[0, 5, 6, 6], [0, 0, 5, 8]], - [[0, 5, 7, 8]], - [], - [], - [], - [], - [], + [Interval(0, 0, 6, 6)], + [Interval(0, 0, 6, 7)], + [Interval(0, 6, 7, 7), Interval(0, 0, 5, 6)], + [Interval(0, 5, 6, 6), Interval(0, 0, 7, 8)], + ] + [[]] * 12 + + # Test Case 8: Uneven chunks with remainder (drop_last=True) + # Total items 26, batch_size 6 -> 4 batches. Remainder is dropped. + # First 4 workers get 1 batch (6 items) each. Item counts: [6, 6, 6, 6, 0, ...]. + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( + _DistributedEnv(2, 0, 1), range(4), chunk_intervals, drop_last=True, num_workers=8, batch_size=6 + ) + assert sum(i.roi_end_idx - i.roi_start_idx for w in workers_intervals for i in w) == 24 + assert workers_chunks == [[0], [1], [1, 2], [2, 3]] + [[]] * 12 + assert workers_intervals == [ + [Interval(0, 0, 6, 6)], + [Interval(0, 0, 6, 7)], + [Interval(0, 6, 7, 7), Interval(0, 0, 5, 6)], + [Interval(0, 5, 6, 6), Interval(0, 0, 5, 8)], + ] + [[]] * 12 + + # Test Case 9: NEW - multiple_of_batch_size_only=True + # Chunks with sizes [50, 23, 40, 10, 7, 100]. batch_size=10. + # Keep chunks with sizes [50, 40, 10, 100]. Total items = 200. + # 200 items / 2 workers = 100 items/worker. + indexes = list(range(6)) + chunk_intervals = [ + Interval(0, 0, 50, 50), Interval(0, 0, 23, 23), Interval(0, 0, 40, 40), + Interval(0, 0, 10, 10), Interval(0, 0, 7, 7), Interval(0, 0, 100, 100), ] - workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(2, 0, 1), range(0, 4), chunk_intervals, True, 8, 6 + _DistributedEnv(2, 0, 1), indexes, chunk_intervals, num_workers=1, batch_size=10, + multiple_of_batch_size_only=True ) - - assert sum([y[2] - y[1] for x in workers_intervals for y in x]) == 24 - assert workers_chunks == [[0], [1], [], [], [], [], [], [], [1, 2], [2, 3], [], [], [], [], [], []] + # Worker 0 gets 50+40+10 = 100 items. Worker 1 gets 100 items. + assert workers_chunks == [[0, 2, 3], [5]] + assert all(sum(i.roi_end_idx - i.roi_start_idx for i in w_intervals) == 100 for w_intervals in workers_intervals) assert workers_intervals == [ - [[0, 0, 6, 6]], - [[0, 0, 6, 7]], - [], - [], - [], - [], - [], - [], - [[0, 6, 7, 7], [0, 0, 5, 6]], - [[0, 5, 6, 6], [0, 0, 5, 8]], - [], - [], - [], - [], - [], - [], + [Interval(0, 0, 50, 50), Interval(0, 0, 40, 40), Interval(0, 0, 10, 10)], + [Interval(0, 0, 100, 100)], ] @@ -403,4 +359,4 @@ def test_aggregate_shared_chunks_per_rank(): def test_map_node_worker_rank_to_chunk_indexes_to_not_delete(): chunks_to_workers = {10: [2, 3, 4], 20: [1, 2, 3], 30: [3, 4], 40: [4, 5, 6]} workers_to_chunks = _map_node_worker_rank_to_chunk_indexes_to_not_delete(chunks_to_workers) - assert workers_to_chunks == {1: [20], 2: [10, 20], 3: [10, 20, 30], 4: [10, 30, 40], 5: [40], 6: [40]} + assert workers_to_chunks == {1: [20], 2: [10, 20], 3: [10, 20, 30], 4: [10, 30, 40], 5: [40], 6: [40]} \ No newline at end of file From ed0784dbd3499a74fc07c348360f9c470bb53b37 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 3 Oct 2025 13:05:35 +0100 Subject: [PATCH 06/15] update --- pyproject.toml | 28 +++++++++++++++------------- src/litdata/utilities/shuffle.py | 7 ++----- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 14d9790d3..2b589896d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,20 +65,22 @@ lint.per-file-ignores."examples/**" = [ ] lint.per-file-ignores."setup.py" = [ "D100", "SIM115" ] lint.per-file-ignores."src/**" = [ - "D100", # Missing docstring in public module - "D101", # todo: Missing docstring in public class - "D102", # todo: Missing docstring in public method - "D103", # todo: Missing docstring in public function - "D104", # Missing docstring in public package - "D105", # todo: Missing docstring in magic method - "D107", # todo: Missing docstring in __init__ - "D205", # todo: 1 blank line required between summary line and description + "D100", # Missing docstring in public module + "D101", # todo: Missing docstring in public class + "D102", # todo: Missing docstring in public method + "D103", # todo: Missing docstring in public function + "D104", # Missing docstring in public package + "D105", # todo: Missing docstring in magic method + "D107", # todo: Missing docstring in __init__ + "D205", # todo: 1 blank line required between summary line and description "D401", - "D404", # todo: First line should be in imperative mood; try rephrasing - "S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. - "S602", # todo: `subprocess` call with `shell=True` identified, security issue - "S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell` - "S607", # todo: Starting a process with a partial executable path + "D404", # todo: First line should be in imperative mood; try rephrasing + "S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. + "S602", # todo: `subprocess` call with `shell=True` identified, security issue + "S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell` + "S607", # todo: Starting a process with a partial executable path + "UP006", # UP006 Use `list` instead of `List` for type annotation + "UP035", # UP035 `typing.Tuple` is deprecated, use `tuple` instead ] lint.per-file-ignores."tests/**" = [ "D100", diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 16f1c0d98..ebd8e3f92 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -72,7 +72,6 @@ def _associate_chunks_and_intervals_to_workers( multiple_of_batch_size_only: bool = False, ) -> Tuple[List[List[int]], List[List[Interval]]]: """Associates chunks and their intervals to workers in a distributed environment.""" - if multiple_of_batch_size_only: filtered_indexes = [] filtered_chunk_intervals = [] @@ -128,11 +127,9 @@ def _associate_chunks_and_intervals_to_workers( chunks_per_workers[worker_idx].append(chunk_index) chunk_start, chunk_roi_start, chunk_roi_end, chunk_end = current_chunk_interval split_point = chunk_roi_start + items_needed_by_worker - + # FIX: Ensure we always append an Interval object - intervals_per_workers[worker_idx].append( - Interval(chunk_start, chunk_roi_start, split_point, chunk_end) - ) + intervals_per_workers[worker_idx].append(Interval(chunk_start, chunk_roi_start, split_point, chunk_end)) # FIX: Ensure the updated interval is also an Interval object current_chunk_interval = Interval(chunk_start, split_point, chunk_roi_end, chunk_end) num_items_per_workers[worker_idx] = 0 From 3c53d8ec2676f3f3b084a0abbb88f934bc537477 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Oct 2025 12:06:44 +0000 Subject: [PATCH 07/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utilities/test_shuffle.py | 95 +++++++++++++++++++++++---------- 1 file changed, 68 insertions(+), 27 deletions(-) diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index 555335d40..efe97334b 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -115,8 +115,14 @@ def test_associate_chunks_and_intervals_to_workers(): # Total items = 422. With drop_last=True, 422 batches (size 1) / 4 workers = 105 batches/worker, remainder 2. # So, workers get item counts of [106, 106, 105, 105]. chunk_intervals = [ - Interval(0, 0, 50, 50), Interval(0, 0, 150, 150), Interval(0, 0, 50, 50), Interval(0, 0, 12, 12), - Interval(0, 0, 50, 50), Interval(0, 0, 27, 27), Interval(0, 0, 50, 50), Interval(0, 0, 33, 33), + Interval(0, 0, 50, 50), + Interval(0, 0, 150, 150), + Interval(0, 0, 50, 50), + Interval(0, 0, 12, 12), + Interval(0, 0, 50, 50), + Interval(0, 0, 27, 27), + Interval(0, 0, 50, 50), + Interval(0, 0, 33, 33), ] workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( _DistributedEnv(4, 1, 2), indexes, chunk_intervals, drop_last=True, num_workers=1 @@ -136,8 +142,14 @@ def test_associate_chunks_and_intervals_to_workers(): # Test Case 3: Another uneven distribution # Total items = 256. 256 / 4 workers = 64 items/worker. No remainder. chunk_intervals = [ - Interval(0, 0, 5, 5), Interval(0, 0, 150, 150), Interval(0, 0, 7, 7), Interval(0, 0, 12, 12), - Interval(0, 0, 4, 4), Interval(0, 0, 27, 27), Interval(0, 0, 50, 50), Interval(0, 0, 1, 1), + Interval(0, 0, 5, 5), + Interval(0, 0, 150, 150), + Interval(0, 0, 7, 7), + Interval(0, 0, 12, 12), + Interval(0, 0, 4, 4), + Interval(0, 0, 27, 27), + Interval(0, 0, 50, 50), + Interval(0, 0, 1, 1), ] workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( _DistributedEnv(4, 1, 2), indexes, chunk_intervals, drop_last=True, num_workers=1 @@ -147,7 +159,13 @@ def test_associate_chunks_and_intervals_to_workers(): assert workers_intervals == [ [Interval(0, 0, 5, 5), Interval(0, 0, 59, 150)], [Interval(0, 59, 123, 150)], - [Interval(0, 123, 150, 150), Interval(0, 0, 7, 7), Interval(0, 0, 12, 12), Interval(0, 0, 4, 4), Interval(0, 0, 14, 27)], + [ + Interval(0, 123, 150, 150), + Interval(0, 0, 7, 7), + Interval(0, 0, 12, 12), + Interval(0, 0, 4, 4), + Interval(0, 0, 14, 27), + ], [Interval(0, 14, 27, 27), Interval(0, 0, 50, 50), Interval(0, 0, 1, 1)], ] @@ -160,8 +178,14 @@ def test_associate_chunks_and_intervals_to_workers(): ) assert workers_chunks == [[0], [1], [2], [3], [], [], [], []] assert workers_intervals == [ - [Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)], - [], [], [], [] + [Interval(0, 0, 6, 6)], + [Interval(0, 0, 6, 6)], + [Interval(0, 0, 6, 6)], + [Interval(0, 0, 6, 6)], + [], + [], + [], + [], ] # Test Case 5: Multi-node @@ -171,9 +195,10 @@ def test_associate_chunks_and_intervals_to_workers(): _DistributedEnv(2, 0, 1), range(4), chunk_intervals, drop_last=False, num_workers=8, batch_size=6 ) assert workers_chunks == [[0], [1], [2], [3]] + [[]] * 12 - assert workers_intervals == [ - [Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)] - ] + [[]] * 12 + assert ( + workers_intervals + == [[Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)]] + [[]] * 12 + ) # Test Case 6: Small workers, large batch # 24 items, batch_size 8 -> 3 batches. 3 batches / 2 global workers. @@ -197,12 +222,16 @@ def test_associate_chunks_and_intervals_to_workers(): ) assert sum(i.roi_end_idx - i.roi_start_idx for w in workers_intervals for i in w) == 26 assert workers_chunks == [[0], [1], [1, 2], [2, 3]] + [[]] * 12 - assert workers_intervals == [ - [Interval(0, 0, 6, 6)], - [Interval(0, 0, 6, 7)], - [Interval(0, 6, 7, 7), Interval(0, 0, 5, 6)], - [Interval(0, 5, 6, 6), Interval(0, 0, 7, 8)], - ] + [[]] * 12 + assert ( + workers_intervals + == [ + [Interval(0, 0, 6, 6)], + [Interval(0, 0, 6, 7)], + [Interval(0, 6, 7, 7), Interval(0, 0, 5, 6)], + [Interval(0, 5, 6, 6), Interval(0, 0, 7, 8)], + ] + + [[]] * 12 + ) # Test Case 8: Uneven chunks with remainder (drop_last=True) # Total items 26, batch_size 6 -> 4 batches. Remainder is dropped. @@ -212,12 +241,16 @@ def test_associate_chunks_and_intervals_to_workers(): ) assert sum(i.roi_end_idx - i.roi_start_idx for w in workers_intervals for i in w) == 24 assert workers_chunks == [[0], [1], [1, 2], [2, 3]] + [[]] * 12 - assert workers_intervals == [ - [Interval(0, 0, 6, 6)], - [Interval(0, 0, 6, 7)], - [Interval(0, 6, 7, 7), Interval(0, 0, 5, 6)], - [Interval(0, 5, 6, 6), Interval(0, 0, 5, 8)], - ] + [[]] * 12 + assert ( + workers_intervals + == [ + [Interval(0, 0, 6, 6)], + [Interval(0, 0, 6, 7)], + [Interval(0, 6, 7, 7), Interval(0, 0, 5, 6)], + [Interval(0, 5, 6, 6), Interval(0, 0, 5, 8)], + ] + + [[]] * 12 + ) # Test Case 9: NEW - multiple_of_batch_size_only=True # Chunks with sizes [50, 23, 40, 10, 7, 100]. batch_size=10. @@ -225,12 +258,20 @@ def test_associate_chunks_and_intervals_to_workers(): # 200 items / 2 workers = 100 items/worker. indexes = list(range(6)) chunk_intervals = [ - Interval(0, 0, 50, 50), Interval(0, 0, 23, 23), Interval(0, 0, 40, 40), - Interval(0, 0, 10, 10), Interval(0, 0, 7, 7), Interval(0, 0, 100, 100), + Interval(0, 0, 50, 50), + Interval(0, 0, 23, 23), + Interval(0, 0, 40, 40), + Interval(0, 0, 10, 10), + Interval(0, 0, 7, 7), + Interval(0, 0, 100, 100), ] workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(2, 0, 1), indexes, chunk_intervals, num_workers=1, batch_size=10, - multiple_of_batch_size_only=True + _DistributedEnv(2, 0, 1), + indexes, + chunk_intervals, + num_workers=1, + batch_size=10, + multiple_of_batch_size_only=True, ) # Worker 0 gets 50+40+10 = 100 items. Worker 1 gets 100 items. assert workers_chunks == [[0, 2, 3], [5]] @@ -359,4 +400,4 @@ def test_aggregate_shared_chunks_per_rank(): def test_map_node_worker_rank_to_chunk_indexes_to_not_delete(): chunks_to_workers = {10: [2, 3, 4], 20: [1, 2, 3], 30: [3, 4], 40: [4, 5, 6]} workers_to_chunks = _map_node_worker_rank_to_chunk_indexes_to_not_delete(chunks_to_workers) - assert workers_to_chunks == {1: [20], 2: [10, 20], 3: [10, 20, 30], 4: [10, 30, 40], 5: [40], 6: [40]} \ No newline at end of file + assert workers_to_chunks == {1: [20], 2: [10, 20], 3: [10, 20, 30], 4: [10, 30, 40], 5: [40], 6: [40]} From 25a7b0b2e4d3e011f3e6aea8ca808f018824c66a Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 3 Oct 2025 13:40:43 +0100 Subject: [PATCH 08/15] update --- src/litdata/streaming/reader.py | 11 +- src/litdata/utilities/shuffle.py | 133 ++++++++------- tests/streaming/test_downloader.py | 199 +---------------------- tests/utilities/test_shuffle.py | 249 +++++++++++++++-------------- 4 files changed, 195 insertions(+), 397 deletions(-) diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index e7acc755d..f5f258478 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -82,8 +82,9 @@ def __init__( # Check whether a dataset slice fits on the node num_bytes_per_nodes = self._config.num_bytes // self._distributed_env.num_nodes self._delete_chunks_when_processed = num_bytes_per_nodes > max_cache_size if max_cache_size else False - if self._worker_env.rank == 0: - print(f"Delete chunks when processed: {self._delete_chunks_when_processed}") + + if distributed_env.global_rank == 0 and self._worker_env.rank == 0: + print(f"Delete chunks when used: {self._delete_chunks_when_processed}") self._has_exited = False @@ -214,7 +215,8 @@ def _force_download(self) -> None: if _DEBUG: chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] print( - f"[Reader] Requested force download for {chunk_filepath} by {self._rank} at {datetime.now().isoformat()}" + f"[Reader] Requested force download for {chunk_filepath} " + f"by {self._rank} at {datetime.now().isoformat()}" ) self._config.download_chunk_from_index(chunk_index, skip_lock=True) @@ -445,15 +447,12 @@ def read(self, index: ChunkedIndex) -> Any: if index.chunk_index != self._last_chunk_index: if self._last_chunk_index is not None: # 2. Log the "End" event for the previous chunk. - print(f"read_chunk_{self._last_chunk_index}_size_{self._last_chunk_size}") logger.debug( _get_log_msg( {"name": f"read_chunk_{self._last_chunk_index}_size_{self._last_chunk_size}", "ph": "E"} ) ) - print(f"read_chunk_{index.chunk_index}_size_{index.chunk_size}") - # 2. Log the "Begin" event for the NEW chunk. logger.debug(_get_log_msg({"name": f"read_chunk_{index.chunk_index}_size_{index.chunk_size}", "ph": "B"})) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index ebd8e3f92..7fa23ee38 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -12,12 +12,12 @@ # limitations under the License. import copy -from typing import Any, List, Tuple +from typing import Any import numpy as np from litdata.streaming.item_loader import Interval -from litdata.utilities.env import _DistributedEnv, _WorkerEnv +from litdata.utilities.env import _DistributedEnv def _intra_node_chunk_shuffle( @@ -64,90 +64,83 @@ def _group_chunks_by_nodes( def _associate_chunks_and_intervals_to_workers( distributed_env: _DistributedEnv, - indexes: List[int], - chunk_intervals: List[Interval], + indexes: Any, + chunk_intervals: list[Interval], drop_last: bool = False, num_workers: int = 1, batch_size: int = 1, - multiple_of_batch_size_only: bool = False, -) -> Tuple[List[List[int]], List[List[Interval]]]: - """Associates chunks and their intervals to workers in a distributed environment.""" - if multiple_of_batch_size_only: - filtered_indexes = [] - filtered_chunk_intervals = [] - for index, interval in zip(indexes, chunk_intervals): - num_items_in_chunk = interval[2] - interval[1] - if num_items_in_chunk > 0 and num_items_in_chunk % batch_size == 0: - filtered_indexes.append(index) - filtered_chunk_intervals.append(interval) - indexes = filtered_indexes - chunk_intervals = filtered_chunk_intervals - +) -> tuple[list[list[int]], list[Any]]: num_items = sum([(interval[2] - interval[1]) for interval in chunk_intervals]) - if batch_size == 0: - raise ValueError("batch_size cannot be zero.") max_batches = num_items // batch_size global_num_workers = distributed_env.world_size * num_workers - if global_num_workers == 0: - raise ValueError("Cannot associate chunks with zero workers.") - - num_batches_per_worker = [0] * global_num_workers - if max_batches > 0: - base_batches_per_worker = max_batches // global_num_workers - rem_batches = max_batches % global_num_workers - for i in range(global_num_workers): - num_batches_per_worker[i] = base_batches_per_worker + (1 if i < rem_batches else 0) - - num_items_per_workers = [n * batch_size for n in num_batches_per_worker] - rem_items = num_items % batch_size - if not drop_last and rem_items > 0: - target_worker = -1 - for i in range(global_num_workers - 1, -1, -1): - if num_items_per_workers[i] > 0: - target_worker = i - break - num_items_per_workers[target_worker] += rem_items - chunks_per_workers: List[List[int]] = [[] for _ in range(global_num_workers)] - intervals_per_workers: List[List[Interval]] = [[] for _ in range(global_num_workers)] - worker_idx = 0 + num_items_per_workers: Any = [] + + for rank in range(distributed_env.world_size): + tmp_arr = [0 for _ in range(num_workers)] + + num_batches_per_rank = int(max_batches // distributed_env.world_size) + base_batches = num_batches_per_rank // num_workers + rem_batches = num_batches_per_rank % num_workers + tmp_arr = [base_batches + (1 if i < rem_batches else 0) for i in range(num_workers)] + + if rank == distributed_env.world_size - 1: + # Find how batches were associated + num_assigned_items = batch_size * (sum(num_items_per_workers) + sum(tmp_arr)) + + # Multiply with the batch_size to get the number of items + if batch_size > 1: + tmp_arr = [x * batch_size for x in tmp_arr] + num_items_per_workers = [x * batch_size for x in num_items_per_workers] + + # If there are items left to assign, let's give it the last worker + left_items = num_items - num_assigned_items + if not drop_last and left_items > 0: + tmp_arr[rem_batches % num_workers] += left_items + + num_items_per_workers.extend(tmp_arr) + else: + num_items_per_workers.extend(tmp_arr) + + chunks_per_workers: list[list[int]] = [[] for _ in range(global_num_workers)] + intervals_per_workers: list[list[list[int]]] = [[] for _ in range(global_num_workers)] + + # 4. Assign the chunk & intervals to each rank for chunk_index, chunk_interval in zip(indexes, chunk_intervals): - current_chunk_interval = chunk_interval + rank = 0 + while True: - if worker_idx >= global_num_workers: + if rank == len(num_items_per_workers): break - items_needed_by_worker = num_items_per_workers[worker_idx] - if items_needed_by_worker == 0: - worker_idx += 1 + + items_left_to_assign = num_items_per_workers[rank] + + if items_left_to_assign == 0: + rank += 1 continue - items_in_chunk = current_chunk_interval[2] - current_chunk_interval[1] + + items_in_chunk = chunk_interval[2] - chunk_interval[1] + if items_in_chunk == 0: break - if items_in_chunk > items_needed_by_worker: - chunks_per_workers[worker_idx].append(chunk_index) - chunk_start, chunk_roi_start, chunk_roi_end, chunk_end = current_chunk_interval - split_point = chunk_roi_start + items_needed_by_worker - - # FIX: Ensure we always append an Interval object - intervals_per_workers[worker_idx].append(Interval(chunk_start, chunk_roi_start, split_point, chunk_end)) - # FIX: Ensure the updated interval is also an Interval object - current_chunk_interval = Interval(chunk_start, split_point, chunk_roi_end, chunk_end) - num_items_per_workers[worker_idx] = 0 - worker_idx += 1 + + if items_in_chunk > items_left_to_assign: + chunks_per_workers[rank].append(chunk_index) + + chunk_start, chunk_roi_start, chunk_roi_end, chunk_end = chunk_interval + + intervals_per_workers[rank].append( + [chunk_start, chunk_roi_start, chunk_roi_start + items_left_to_assign, chunk_end] + ) + chunk_interval = Interval(chunk_start, chunk_roi_start + items_left_to_assign, chunk_roi_end, chunk_end) + num_items_per_workers[rank] = 0 + rank += 1 else: - chunks_per_workers[worker_idx].append(chunk_index) - # FIX: Ensure we always append an Interval object (not a list) - intervals_per_workers[worker_idx].append(current_chunk_interval) - num_items_per_workers[worker_idx] -= items_in_chunk + chunks_per_workers[rank].append(chunk_index) + intervals_per_workers[rank].append(list(chunk_interval)) + num_items_per_workers[rank] -= items_in_chunk break - worker_env = _WorkerEnv.detect() - if worker_env.rank == 0: - print("HERE num_items_per_workers", num_items_per_workers) - for idx, internal in enumerate(intervals_per_workers): - total_items = sum(x[2] - x[1] for x in internal) - print(f"Worker {idx}: Batch Size={batch_size}, Items={total_items}, Chunks={internal}") - return chunks_per_workers, intervals_per_workers diff --git a/tests/streaming/test_downloader.py b/tests/streaming/test_downloader.py index 9f5b2cea5..4461a728e 100644 --- a/tests/streaming/test_downloader.py +++ b/tests/streaming/test_downloader.py @@ -1,10 +1,8 @@ -# ruff: noqa: S604 import contextlib import io import os -import sys from unittest import mock -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest @@ -20,7 +18,6 @@ get_downloader, register_downloader, shutil, - subprocess, unregister_downloader, ) @@ -54,200 +51,6 @@ def test_get_downloader(tmpdir): unregister_downloader("dummy://") -def test_s3_downloader_fast(tmpdir, monkeypatch): - monkeypatch.setattr(os, "system", MagicMock(return_value=0)) - popen_mock = MagicMock() - popen_mock.wait.return_value = 0 # Simulate a successful download - monkeypatch.setattr(subprocess, "Popen", MagicMock(return_value=popen_mock)) - downloader = S3Downloader(tmpdir, tmpdir, []) - downloader.download_file("s3://random_bucket/a.txt", os.path.join(tmpdir, "a.txt")) - popen_mock.wait.assert_called() - - -@patch("os.system") -@patch("subprocess.Popen") -def test_s3_downloader_with_s5cmd_no_storage_options(popen_mock, system_mock, tmpdir): - system_mock.return_value = 0 # Simulates s5cmd being available - process_mock = MagicMock() - process_mock.wait.return_value = 0 # Simulate a successful download - popen_mock.return_value = process_mock - - # Initialize the S3Downloader without storage options - downloader = S3Downloader("s3://random_bucket", str(tmpdir), []) - - # Action: Call the download_file method - remote_filepath = "s3://random_bucket/sample_file.txt" - local_filepath = os.path.join(tmpdir, "sample_file.txt") - downloader.download_file(remote_filepath, local_filepath) - - # Assertion: Verify subprocess.Popen was called with correct arguments and no env variables - popen_mock.assert_called_once_with( - f"s5cmd cp {remote_filepath} {local_filepath}", - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=None, - ) - process_mock.wait.assert_called_once() - - -@patch("os.system") -@patch("subprocess.Popen") -@mock.patch("litdata.streaming.downloader._DISABLE_S5CMD", True) -@mock.patch("boto3.client") -def test_s3_downloader_s5cmd_available_but_disabled(boto3_client_mock, popen_mock, system_mock, tmpdir): - system_mock.return_value = 0 # Simulates s5cmd being available - process_mock = MagicMock() - popen_mock.return_value = process_mock - - # Mock the boto3 client - boto3_client_instance = MagicMock() - boto3_client_mock.return_value = boto3_client_instance - - # Mock the download_file method to avoid NoCredentialsError - boto3_client_instance.download_file = MagicMock() - - # Mock the S3Client class to avoid creating a real boto3 client - with mock.patch("litdata.streaming.downloader.S3Client") as S3ClientMock: - S3ClientMock.return_value.client = boto3_client_instance - - # Initialize the S3Downloader - downloader = S3Downloader("s3://random_bucket", str(tmpdir), []) - - # Action: Call the download_file method - remote_filepath = "s3://random_bucket/sample_file.txt" - local_filepath = os.path.join(tmpdir, "sample_file.txt") - downloader.download_file(remote_filepath, local_filepath) - - # Assertion: Verify subprocess.Popen was not called - popen_mock.assert_not_called() - - # Assertion: Verify boto3 download_file was called - boto3_client_instance.download_file.assert_called_once() - - -@patch("os.system") -@patch("subprocess.Popen") -def test_s3_downloader_with_s5cmd_with_storage_options(popen_mock, system_mock, tmpdir): - system_mock.return_value = 0 # Simulates s5cmd being available - process_mock = MagicMock() - process_mock.wait.return_value = 0 # Simulate a successful download - popen_mock.return_value = process_mock - - storage_options = {"AWS_ACCESS_KEY_ID": "dummy_key", "AWS_SECRET_ACCESS_KEY": "dummy_secret"} - - # Initialize the S3Downloader with storage options - downloader = S3Downloader("s3://random_bucket", str(tmpdir), [], storage_options) - - # Action: Call the download_file method - remote_filepath = "s3://random_bucket/sample_file.txt" - local_filepath = os.path.join(tmpdir, "sample_file.txt") - downloader.download_file(remote_filepath, local_filepath) - - # Create expected environment variables by merging the current env with storage_options - expected_env = os.environ.copy() - expected_env.update(storage_options) - - # Assertion: Verify subprocess.Popen was called with the correct arguments and environment variables - popen_mock.assert_called_once_with( - f"s5cmd cp {remote_filepath} {local_filepath}", - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=expected_env, - ) - process_mock.wait.assert_called_once() - - -@patch("os.system") -@patch("subprocess.Popen") -def test_s3_downloader_with_s5cmd_with_storage_options_unsigned(popen_mock, system_mock, tmpdir): - system_mock.return_value = 0 # Simulates s5cmd being available - process_mock = MagicMock() - process_mock.wait.return_value = 0 # Simulate a successful download - popen_mock.return_value = process_mock - - storage_options = {"AWS_NO_SIGN_REQUEST": "Yes"} - - # Initialize the S3Downloader with storage options - downloader = S3Downloader("s3://random_bucket", str(tmpdir), [], storage_options) - - # Action: Call the download_file method - remote_filepath = "s3://random_bucket/sample_file.txt" - local_filepath = os.path.join(tmpdir, "sample_file.txt") - downloader.download_file(remote_filepath, local_filepath) - - # Create expected environment variables by merging the current env with storage_options - expected_env = os.environ.copy() - expected_env.update(storage_options) - - # Assertion: Verify subprocess.Popen was called with the correct arguments and environment variables - popen_mock.assert_called_once_with( - f"s5cmd --no-sign-request cp {remote_filepath} {local_filepath}", - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=expected_env, - ) - process_mock.wait.assert_called_once() - - -@pytest.mark.skipif( - shutil.which("s5cmd") is None or sys.platform == "win32", reason="s5cmd is not available or running on Windows" -) -def test_s3_downloader_with_s5cmd_with_storage_options_unsigned_pl(tmpdir): - # Set up the test environment - remote_filepath = "s3://pl-flash-data/optimized_tiny_imagenet/index.json" - local_filepath = os.path.join(tmpdir, "index.json") - - storage_options = {"AWS_NO_SIGN_REQUEST": "Yes"} - # Initialize the S3Downloader - downloader = S3Downloader("s3://pl-flash-data", str(tmpdir), [], storage_options) - - # Download the file - downloader.download_file(remote_filepath, local_filepath) - - # Verify the download - assert os.path.exists(local_filepath), "The index.json file was not downloaded successfully." - - # verify the contents of the file - with open(local_filepath) as f: - content = f.read() - assert content.startswith("{"), "The downloaded file does not appear to be a valid JSON file." - - -@patch("os.system") -@patch("subprocess.Popen") -def test_s3_downloader_s5cmd_error_handling(popen_mock, system_mock, tmpdir): - system_mock.return_value = 0 # Simulates s5cmd being available - process_mock = MagicMock() - process_mock.wait.return_value = 1 # Simulate a non-zero return code - process_mock.stderr.read.return_value = b"Simulated error message" - popen_mock.return_value = process_mock - - # Initialize the S3Downloader without storage options - downloader = S3Downloader("s3://random_bucket", str(tmpdir), []) - - # Action: Call the download_file method and expect a RuntimeError - remote_filepath = "s3://random_bucket/sample_file.txt" - local_filepath = os.path.join(tmpdir, "sample_file.txt") - - with pytest.raises(RuntimeError, match="Failed to execute command"): - downloader.download_file(remote_filepath, local_filepath) - - # Assertion: Verify subprocess.Popen was called with the correct arguments - popen_mock.assert_called_once_with( - f"s5cmd cp {remote_filepath} {local_filepath}", - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=None, - ) - - # Assertion: Verify the error message includes the simulated stderr output - process_mock.stderr.read.assert_called_once() - - @mock.patch("litdata.streaming.downloader.R2Client") def test_r2_downloader_fast(r2_client_mock, tmpdir): # Mock the R2Client diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index efe97334b..b54cb1b44 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -95,25 +95,33 @@ def test_group_chunks_by_nodes(): def test_associate_chunks_and_intervals_to_workers(): - indexes = list(range(8)) + indexes = [0, 1, 2, 3, 4, 5, 6, 7] + chunk_intervals = [ + Interval(0, 0, 50, 50), + Interval(0, 0, 50, 50), + Interval(0, 0, 50, 50), + Interval(0, 0, 50, 50), + Interval(0, 0, 50, 50), + Interval(0, 0, 50, 50), + Interval(0, 0, 50, 50), + Interval(0, 0, 50, 50), + ] - # Test Case 1: Even distribution - # 400 items / 4 global workers = 100 items/worker. - chunk_intervals = [Interval(0, 0, 50, 50)] * 8 workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(4, 1, 2), indexes, chunk_intervals, drop_last=True, num_workers=1 + _DistributedEnv(4, 1, 2), + indexes, + chunk_intervals, + drop_last=True, ) + assert workers_chunks == [[0, 1], [2, 3], [4, 5], [6, 7]] assert workers_intervals == [ - [Interval(0, 0, 50, 50), Interval(0, 0, 50, 50)], - [Interval(0, 0, 50, 50), Interval(0, 0, 50, 50)], - [Interval(0, 0, 50, 50), Interval(0, 0, 50, 50)], - [Interval(0, 0, 50, 50), Interval(0, 0, 50, 50)], + [[0, 0, 50, 50], [0, 0, 50, 50]], + [[0, 0, 50, 50], [0, 0, 50, 50]], + [[0, 0, 50, 50], [0, 0, 50, 50]], + [[0, 0, 50, 50], [0, 0, 50, 50]], ] - # Test Case 2: Uneven distribution - # Total items = 422. With drop_last=True, 422 batches (size 1) / 4 workers = 105 batches/worker, remainder 2. - # So, workers get item counts of [106, 106, 105, 105]. chunk_intervals = [ Interval(0, 0, 50, 50), Interval(0, 0, 150, 150), @@ -124,23 +132,27 @@ def test_associate_chunks_and_intervals_to_workers(): Interval(0, 0, 50, 50), Interval(0, 0, 33, 33), ] + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(4, 1, 2), indexes, chunk_intervals, drop_last=True, num_workers=1 + _DistributedEnv(4, 1, 2), + indexes, + chunk_intervals, + drop_last=True, ) + assert workers_chunks == [[0, 1], [1, 2], [2, 3, 4, 5], [5, 6, 7]] - assert sum(i.roi_end_idx - i.roi_start_idx for i in workers_intervals[0]) == 106 - assert sum(i.roi_end_idx - i.roi_start_idx for i in workers_intervals[1]) == 106 - assert sum(i.roi_end_idx - i.roi_start_idx for i in workers_intervals[2]) == 105 - assert sum(i.roi_end_idx - i.roi_start_idx for i in workers_intervals[3]) == 105 + assert sum([interval[2] - interval[1] for interval in workers_intervals[0]]) == 105 + assert sum([interval[2] - interval[1] for interval in workers_intervals[1]]) == 105 + assert sum([interval[2] - interval[1] for interval in workers_intervals[2]]) == 105 + assert sum([interval[2] - interval[1] for interval in workers_intervals[3]]) == 105 + assert workers_intervals == [ - [Interval(0, 0, 50, 50), Interval(0, 0, 56, 150)], - [Interval(0, 56, 150, 150), Interval(0, 0, 12, 50)], - [Interval(0, 12, 50, 50), Interval(0, 0, 12, 12), Interval(0, 0, 50, 50), Interval(0, 0, 5, 27)], - [Interval(0, 5, 27, 27), Interval(0, 0, 50, 50), Interval(0, 0, 33, 33)], + [[0, 0, 50, 50], [0, 0, 55, 150]], + [[0, 55, 150, 150], [0, 0, 10, 50]], + [[0, 10, 50, 50], [0, 0, 12, 12], [0, 0, 50, 50], [0, 0, 3, 27]], + [[0, 3, 27, 27], [0, 0, 50, 50], [0, 0, 31, 33]], ] - # Test Case 3: Another uneven distribution - # Total items = 256. 256 / 4 workers = 64 items/worker. No remainder. chunk_intervals = [ Interval(0, 0, 5, 5), Interval(0, 0, 150, 150), @@ -151,134 +163,125 @@ def test_associate_chunks_and_intervals_to_workers(): Interval(0, 0, 50, 50), Interval(0, 0, 1, 1), ] + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(4, 1, 2), indexes, chunk_intervals, drop_last=True, num_workers=1 + _DistributedEnv(4, 1, 2), + indexes, + chunk_intervals, + drop_last=True, ) + assert workers_chunks == [[0, 1], [1], [1, 2, 3, 4, 5], [5, 6, 7]] - assert all(sum(i.roi_end_idx - i.roi_start_idx for i in w_intervals) == 64 for w_intervals in workers_intervals) + assert sum([interval[2] - interval[1] for interval in workers_intervals[0]]) == 64 + assert sum([interval[2] - interval[1] for interval in workers_intervals[1]]) == 64 + assert sum([interval[2] - interval[1] for interval in workers_intervals[2]]) == 64 + assert sum([interval[2] - interval[1] for interval in workers_intervals[3]]) == 64 assert workers_intervals == [ - [Interval(0, 0, 5, 5), Interval(0, 0, 59, 150)], - [Interval(0, 59, 123, 150)], - [ - Interval(0, 123, 150, 150), - Interval(0, 0, 7, 7), - Interval(0, 0, 12, 12), - Interval(0, 0, 4, 4), - Interval(0, 0, 14, 27), - ], - [Interval(0, 14, 27, 27), Interval(0, 0, 50, 50), Interval(0, 0, 1, 1)], + [[0, 0, 5, 5], [0, 0, 59, 150]], + [[0, 59, 123, 150]], + [[0, 123, 150, 150], [0, 0, 7, 7], [0, 0, 12, 12], [0, 0, 4, 4], [0, 0, 14, 27]], + [[0, 14, 27, 27], [0, 0, 50, 50], [0, 0, 1, 1]], + ] + + chunk_intervals = [ + Interval(0, 0, 6, 6), + Interval(0, 0, 6, 6), + Interval(0, 0, 6, 6), + Interval(0, 0, 6, 6), ] - # Test Case 4: Many workers, small data - # 24 items, batch_size 6 -> 4 batches. 4 batches / 8 global workers. - # First 4 workers get 1 batch (6 items) each. - chunk_intervals = [Interval(0, 0, 6, 6)] * 4 workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(1, 0, 1), range(4), chunk_intervals, drop_last=False, num_workers=8, batch_size=6 + _DistributedEnv(1, 0, 1), range(0, 4), chunk_intervals, False, 8, 6 ) + + assert workers_intervals == [[[0, 0, 6, 6]], [[0, 0, 6, 6]], [[0, 0, 6, 6]], [[0, 0, 6, 6]], [], [], [], []] assert workers_chunks == [[0], [1], [2], [3], [], [], [], []] + + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( + _DistributedEnv(2, 0, 1), range(0, 4), chunk_intervals, False, 8, 6 + ) + + assert workers_chunks == [[0], [1], [], [], [], [], [], [], [2], [3], [], [], [], [], [], []] assert workers_intervals == [ - [Interval(0, 0, 6, 6)], - [Interval(0, 0, 6, 6)], - [Interval(0, 0, 6, 6)], - [Interval(0, 0, 6, 6)], + [[0, 0, 6, 6]], + [[0, 0, 6, 6]], + [], + [], + [], + [], + [], + [], + [[0, 0, 6, 6]], + [[0, 0, 6, 6]], + [], + [], [], [], [], [], ] - # Test Case 5: Multi-node - # 24 items, batch_size 6 -> 4 batches. 4 batches / 16 global workers. - # First 4 workers get 1 batch (6 items) each. - workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(2, 0, 1), range(4), chunk_intervals, drop_last=False, num_workers=8, batch_size=6 - ) - assert workers_chunks == [[0], [1], [2], [3]] + [[]] * 12 - assert ( - workers_intervals - == [[Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)], [Interval(0, 0, 6, 6)]] + [[]] * 12 - ) - - # Test Case 6: Small workers, large batch - # 24 items, batch_size 8 -> 3 batches. 3 batches / 2 global workers. - # Worker 0 gets 2 batches (16 items), worker 1 gets 1 batch (8 items). workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(1, 0, 1), range(4), chunk_intervals, drop_last=False, num_workers=2, batch_size=8 + _DistributedEnv(1, 0, 1), range(0, 4), chunk_intervals, False, 2, 8 ) assert workers_chunks == [[0, 1, 2], [2, 3]] - assert workers_intervals == [ - [Interval(0, 0, 6, 6), Interval(0, 0, 6, 6), Interval(0, 0, 4, 6)], - [Interval(0, 4, 6, 6), Interval(0, 0, 6, 6)], - ] + assert workers_intervals == [[[0, 0, 6, 6], [0, 0, 6, 6], [0, 0, 4, 6]], [[0, 4, 6, 6], [0, 0, 6, 6]]] - # Test Case 7: Uneven chunks with remainder (drop_last=False) - # Total items 26, batch_size 6 -> 4 batches, 2 remainder items. 4 batches / 16 workers. - # First 4 workers get 1 batch (6 items). Last of these (worker 3) gets the 2 remainder items. - # Item counts: [6, 6, 6, 8, 0, ...]. - chunk_intervals = [Interval(0, 0, 6, 6), Interval(0, 0, 7, 7), Interval(0, 0, 6, 6), Interval(0, 0, 7, 8)] - workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(2, 0, 1), range(4), chunk_intervals, drop_last=False, num_workers=8, batch_size=6 - ) - assert sum(i.roi_end_idx - i.roi_start_idx for w in workers_intervals for i in w) == 26 - assert workers_chunks == [[0], [1], [1, 2], [2, 3]] + [[]] * 12 - assert ( - workers_intervals - == [ - [Interval(0, 0, 6, 6)], - [Interval(0, 0, 6, 7)], - [Interval(0, 6, 7, 7), Interval(0, 0, 5, 6)], - [Interval(0, 5, 6, 6), Interval(0, 0, 7, 8)], - ] - + [[]] * 12 - ) + chunk_intervals = [ + Interval(0, 0, 6, 6), + Interval(0, 0, 7, 7), + Interval(0, 0, 6, 6), + Interval(0, 0, 7, 8), + ] - # Test Case 8: Uneven chunks with remainder (drop_last=True) - # Total items 26, batch_size 6 -> 4 batches. Remainder is dropped. - # First 4 workers get 1 batch (6 items) each. Item counts: [6, 6, 6, 6, 0, ...]. workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(2, 0, 1), range(4), chunk_intervals, drop_last=True, num_workers=8, batch_size=6 - ) - assert sum(i.roi_end_idx - i.roi_start_idx for w in workers_intervals for i in w) == 24 - assert workers_chunks == [[0], [1], [1, 2], [2, 3]] + [[]] * 12 - assert ( - workers_intervals - == [ - [Interval(0, 0, 6, 6)], - [Interval(0, 0, 6, 7)], - [Interval(0, 6, 7, 7), Interval(0, 0, 5, 6)], - [Interval(0, 5, 6, 6), Interval(0, 0, 5, 8)], - ] - + [[]] * 12 + _DistributedEnv(2, 0, 1), range(0, 4), chunk_intervals, False, 8, 6 ) - # Test Case 9: NEW - multiple_of_batch_size_only=True - # Chunks with sizes [50, 23, 40, 10, 7, 100]. batch_size=10. - # Keep chunks with sizes [50, 40, 10, 100]. Total items = 200. - # 200 items / 2 workers = 100 items/worker. - indexes = list(range(6)) - chunk_intervals = [ - Interval(0, 0, 50, 50), - Interval(0, 0, 23, 23), - Interval(0, 0, 40, 40), - Interval(0, 0, 10, 10), - Interval(0, 0, 7, 7), - Interval(0, 0, 100, 100), + assert sum([y[2] - y[1] for x in workers_intervals for y in x]) == 26 + assert workers_chunks == [[0], [1], [], [], [], [], [], [], [1, 2], [2, 3], [3], [], [], [], [], []] + assert workers_intervals == [ + [[0, 0, 6, 6]], + [[0, 0, 6, 7]], + [], + [], + [], + [], + [], + [], + [[0, 6, 7, 7], [0, 0, 5, 6]], + [[0, 5, 6, 6], [0, 0, 5, 8]], + [[0, 5, 7, 8]], + [], + [], + [], + [], + [], ] + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( - _DistributedEnv(2, 0, 1), - indexes, - chunk_intervals, - num_workers=1, - batch_size=10, - multiple_of_batch_size_only=True, + _DistributedEnv(2, 0, 1), range(0, 4), chunk_intervals, True, 8, 6 ) - # Worker 0 gets 50+40+10 = 100 items. Worker 1 gets 100 items. - assert workers_chunks == [[0, 2, 3], [5]] - assert all(sum(i.roi_end_idx - i.roi_start_idx for i in w_intervals) == 100 for w_intervals in workers_intervals) + + assert sum([y[2] - y[1] for x in workers_intervals for y in x]) == 24 + assert workers_chunks == [[0], [1], [], [], [], [], [], [], [1, 2], [2, 3], [], [], [], [], [], []] assert workers_intervals == [ - [Interval(0, 0, 50, 50), Interval(0, 0, 40, 40), Interval(0, 0, 10, 10)], - [Interval(0, 0, 100, 100)], + [[0, 0, 6, 6]], + [[0, 0, 6, 7]], + [], + [], + [], + [], + [], + [], + [[0, 6, 7, 7], [0, 0, 5, 6]], + [[0, 5, 6, 6], [0, 0, 5, 8]], + [], + [], + [], + [], + [], + [], ] From dc2ca8aec749a7a883661bedc3fb07825d91f762 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 3 Oct 2025 14:28:19 +0100 Subject: [PATCH 09/15] update --- src/litdata/streaming/reader.py | 2 +- tests/streaming/test_reader.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index f5f258478..49bb3143c 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -56,7 +56,7 @@ def __init__( item_loader: BaseItemLoader, distributed_env: _DistributedEnv, max_cache_size: Optional[int] = None, - max_pre_download: int = 5, + max_pre_download: int = 2, rank: Optional[int] = None, ) -> None: super().__init__(daemon=True) diff --git a/tests/streaming/test_reader.py b/tests/streaming/test_reader.py index 3bdaaeb7e..ccbde219a 100644 --- a/tests/streaming/test_reader.py +++ b/tests/streaming/test_reader.py @@ -169,8 +169,13 @@ def test_prepare_chunks_thread_eviction(tmpdir, monkeypatch): assert len(os.listdir(cache_dir)) == 14 thread = PrepareChunksThread( - cache._reader.config, item_loader=PyTreeLoader(), distributed_env=_DistributedEnv(1, 1, 1), max_cache_size=10000 + cache._reader.config, + item_loader=PyTreeLoader(), + distributed_env=_DistributedEnv(1, 1, 1), + max_pre_download=2, + max_cache_size=10000, ) + assert not thread._delete_chunks_when_processed thread = PrepareChunksThread( From 602c644976a10cff189bbf2931a73e1bce2ef254 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 3 Oct 2025 14:32:11 +0100 Subject: [PATCH 10/15] update --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 2b589896d..16b7fb713 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,6 +168,7 @@ exclude = [ "src/litdata/imports.py", "src/litdata/imports.py", "src/litdata/processing/data_processor.py", + "src/litdata/debugger.py", ] install_types = "True" non_interactive = "True" From 826ba5f700ac5046b69a853bf8ec7e52061af8d3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 3 Oct 2025 14:34:36 +0100 Subject: [PATCH 11/15] update --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 16b7fb713..81317cf38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ line-length = 120 exclude = [ ".git", "docs", + "src/litdata/debugger.py", "src/litdata/utilities/_pytree.py", ] # Enable Pyflakes `E` and `F` codes by default. From 5cce5352a8152bb806bdbfe83a8791a933ebfe27 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 3 Oct 2025 14:35:35 +0100 Subject: [PATCH 12/15] update --- src/litdata/streaming/reader.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index 49bb3143c..de2b8c155 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -155,9 +155,6 @@ def _apply_delete(self, chunk_index: int, skip_lock: bool = False) -> None: self._item_loader.delete(chunk_index, chunk_filepath) - # if _DEBUG: - # print(f"Deleted {chunk_filepath} by {self._rank or 0}. Debug: {can_delete_chunk}") - base_name = os.path.basename(chunk_filepath) base_prefix = os.path.splitext(base_name)[0] cache_dir = os.path.dirname(chunk_filepath) From 310eed6109c51458dadf96f108e5486295e72694 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 3 Oct 2025 14:37:31 +0100 Subject: [PATCH 13/15] update --- README.md | 18 ------------------ requirements/test.txt | 1 - src/litdata/constants.py | 1 - 3 files changed, 20 deletions(-) diff --git a/README.md b/README.md index 2a82ed443..82c332ffc 100644 --- a/README.md +++ b/README.md @@ -341,21 +341,11 @@ storage_options = { dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_options) -# s5cmd compatible storage options for a custom S3-compatible endpoint -# Note: If s5cmd is installed, it will be used by default for S3 operations. If you prefer not to use s5cmd, you can disable it by setting the environment variable: `DISABLE_S5CMD=1` -storage_options = { - "AWS_ACCESS_KEY_ID": "your_access_key_id", - "AWS_SECRET_ACCESS_KEY": "your_secret_access_key", - "S3_ENDPOINT_URL": "your_endpoint_url", # Required only for custom endpoints -} dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_options) ``` -Alternative: Using `s5cmd` for S3 Operations - - Also, you can specify a custom cache directory when initializing your dataset. This is useful when you want to store the cache in a specific location. ```python from litdata import StreamingDataset @@ -543,9 +533,6 @@ aws_storage_options={ } dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options) - -# Read data from AWS S3 using s5cmd -# Note: If s5cmd is installed, it will be used by default for S3 operations. If you prefer not to use s5cmd, you can disable it by setting the environment variable: `DISABLE_S5CMD=1` aws_storage_options={ "AWS_ACCESS_KEY_ID": os.environ['AWS_ACCESS_KEY_ID'], "AWS_SECRET_ACCESS_KEY": os.environ['AWS_SECRET_ACCESS_KEY'], @@ -553,11 +540,6 @@ aws_storage_options={ } dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options) -# Read Data from AWS S3 with Unsigned Request using s5cmd -aws_storage_options={ - "AWS_NO_SIGN_REQUEST": "Yes" # Required for unsigned requests - "S3_ENDPOINT_URL": os.environ['AWS_ENDPOINT_URL'], # Required only for custom endpoints -} dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options) diff --git a/requirements/test.txt b/requirements/test.txt index 9cbd62dbe..d7d07d44c 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -16,5 +16,4 @@ polars >1.0.0 lightning transformers <4.53.0 zstd -s5cmd >=0.2.0 soundfile >=0.13.0 # required for torchaudio backend diff --git a/src/litdata/constants.py b/src/litdata/constants.py index 086f7126d..b7e7e46d0 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -55,7 +55,6 @@ _MAX_WAIT_TIME = int(os.getenv("MAX_WAIT_TIME", "120")) _FORCE_DOWNLOAD_TIME = int(os.getenv("FORCE_DOWNLOAD_TIME", "30")) -_DISABLE_S5CMD = bool(int(os.getenv("DISABLE_S5CMD", "0"))) # DON'T CHANGE ORDER _TORCH_DTYPES_MAPPING = { From 63cddb632d6ba74f40fbcf61bd07324011c9d30b Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 3 Oct 2025 14:42:59 +0100 Subject: [PATCH 14/15] update --- src/litdata/streaming/config.py | 5 ++--- src/litdata/streaming/item_loader.py | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index 1873f118a..5ef1cc88e 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import logging import os from collections import defaultdict @@ -211,10 +212,8 @@ def try_decompress(self, local_chunkpath: str) -> None: # delete the files only if they were downloaded if self._downloader is not None: - try: + with contextlib.suppress(FileNotFoundError): os.remove(local_chunkpath) - except FileNotFoundError: - pass data = self._compressor.decompress(data) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 2f343e369..8cc70aabe 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -213,7 +213,8 @@ def load_item_from_chunk( if not requested_force_download and (time() - start_time) > _FORCE_DOWNLOAD_TIME: if _DEBUG: print( - f"[ItemLoader] Requested force download for {chunk_filepath} at {datetime.now().isoformat()}" + f"[ItemLoader] Requested force download for {chunk_filepath} " + f"at {datetime.now().isoformat()}" ) self.force_download(chunk_index) requested_force_download = True From 8009c3e7beb7f966e0e15847adea87bfebe77476 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 3 Oct 2025 14:56:51 +0100 Subject: [PATCH 15/15] update --- src/litdata/streaming/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index 897c8719d..b7e543be0 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -268,7 +268,7 @@ def deserialize(self, data: bytes) -> torch.Tensor: return tensor_1d.reshape(shape) return torch.empty(shape, dtype=dtype) - def can_serialize(self, item: any) -> bool: + def can_serialize(self, item: Any) -> bool: return isinstance(item, torch.Tensor) and len(item.shape) != 1 def __getstate__(self) -> dict: