diff --git a/litdata/streaming/config.py b/litdata/streaming/config.py index 684f0f843..ef32cd843 100644 --- a/litdata/streaming/config.py +++ b/litdata/streaming/config.py @@ -167,7 +167,9 @@ def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]: if self._compressor is not None: local_chunkpath = local_chunkpath.replace(f".{self._compressor_name}", "") - return local_chunkpath, *self._intervals[index.chunk_index] + begin = self._intervals[index.chunk_index][0] + + return local_chunkpath, begin, chunk["chunk_bytes"] def _get_chunk_index_from_filename(self, chunk_filename: str) -> int: """Retrieves the associated chunk_index for a given chunk filename.""" diff --git a/litdata/streaming/item_loader.py b/litdata/streaming/item_loader.py index 5cd8fdd25..6a2bbde5e 100644 --- a/litdata/streaming/item_loader.py +++ b/litdata/streaming/item_loader.py @@ -52,7 +52,9 @@ def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None: pass @abstractmethod - def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> Any: + def load_item_from_chunk( + self, index: int, chunk_index: int, chunk_filepath: str, begin: int, chunk_bytes: int + ) -> Any: """Returns an item loaded from a chunk.""" pass @@ -81,18 +83,20 @@ def generate_intervals(self) -> List[Tuple[int, int]]: def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None: pass - def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> bytes: + def load_item_from_chunk( + self, index: int, chunk_index: int, chunk_filepath: str, begin: int, chunk_bytes: int + ) -> bytes: offset = (1 + (index - begin) if index >= begin else index + 1) * 4 if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath): del self._chunk_filepaths[chunk_filepath] if chunk_filepath not in self._chunk_filepaths: - exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 + exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= chunk_bytes while not exists: sleep(0.1) - exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 + exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= chunk_bytes self._chunk_filepaths[chunk_filepath] = True @@ -191,7 +195,9 @@ def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None: if os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0: self._load_chunk(chunk_index, chunk_filepath) - def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> torch.Tensor: + def load_item_from_chunk( + self, index: int, chunk_index: int, chunk_filepath: str, begin: int, chunk_bytes: int + ) -> torch.Tensor: if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath): del self._chunk_filepaths[chunk_filepath] diff --git a/litdata/streaming/reader.py b/litdata/streaming/reader.py index b0aff6a98..2116d9c36 100644 --- a/litdata/streaming/reader.py +++ b/litdata/streaming/reader.py @@ -248,8 +248,8 @@ def read(self, index: ChunkedIndex) -> Any: self._last_chunk_index = index.chunk_index # Fetch the element - chunk_filepath, begin, _ = self.config[index] - item = self._item_loader.load_item_from_chunk(index.index, index.chunk_index, chunk_filepath, begin) + chunk_filepath, begin, chunk_bytes = self.config[index] + item = self._item_loader.load_item_from_chunk(index.index, index.chunk_index, chunk_filepath, begin, chunk_bytes) # We need to request deletion after the latest element has been loaded. # Otherwise, this could trigger segmentation fault error depending on the item loader used.