From 752488812d17bfca6e28a3b896b04f5b54e82838 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 8 Mar 2024 10:49:25 +0000 Subject: [PATCH 1/2] update --- litdata/streaming/config.py | 4 +++- litdata/streaming/item_loader.py | 15 +++++++++------ litdata/streaming/reader.py | 4 ++-- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/litdata/streaming/config.py b/litdata/streaming/config.py index 684f0f843..ef32cd843 100644 --- a/litdata/streaming/config.py +++ b/litdata/streaming/config.py @@ -167,7 +167,9 @@ def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]: if self._compressor is not None: local_chunkpath = local_chunkpath.replace(f".{self._compressor_name}", "") - return local_chunkpath, *self._intervals[index.chunk_index] + begin = self._intervals[index.chunk_index][0] + + return local_chunkpath, begin, chunk["chunk_bytes"] def _get_chunk_index_from_filename(self, chunk_filename: str) -> int: """Retrieves the associated chunk_index for a given chunk filename.""" diff --git a/litdata/streaming/item_loader.py b/litdata/streaming/item_loader.py index 5cd8fdd25..4940192ff 100644 --- a/litdata/streaming/item_loader.py +++ b/litdata/streaming/item_loader.py @@ -52,7 +52,7 @@ def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None: pass @abstractmethod - def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> Any: + def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int, chunk_bytes: int) -> Any: """Returns an item loaded from a chunk.""" pass @@ -81,18 +81,18 @@ def generate_intervals(self) -> List[Tuple[int, int]]: def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None: pass - def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> bytes: + def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int, chunk_bytes: int) -> bytes: offset = (1 + (index - begin) if index >= begin else index + 1) * 4 if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath): del self._chunk_filepaths[chunk_filepath] if chunk_filepath not in self._chunk_filepaths: - exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 + exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= chunk_bytes while not exists: sleep(0.1) - exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 + exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= chunk_bytes self._chunk_filepaths[chunk_filepath] = True @@ -124,7 +124,10 @@ def deserialize(self, raw_item_data: bytes) -> "PyTree": data_bytes = raw_item_data[idx : idx + size] data.append(serializer.deserialize(data_bytes)) idx += size - return tree_unflatten(data, self._config["data_spec"]) + try: + return tree_unflatten(data, self._config["data_spec"]) + except Exception: + breakpoint() def delete(self, chunk_index: int, chunk_filepath: str) -> None: if os.path.exists(chunk_filepath): @@ -191,7 +194,7 @@ def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None: if os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0: self._load_chunk(chunk_index, chunk_filepath) - def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> torch.Tensor: + def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int, chunk_bytes: int) -> torch.Tensor: if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath): del self._chunk_filepaths[chunk_filepath] diff --git a/litdata/streaming/reader.py b/litdata/streaming/reader.py index b0aff6a98..2116d9c36 100644 --- a/litdata/streaming/reader.py +++ b/litdata/streaming/reader.py @@ -248,8 +248,8 @@ def read(self, index: ChunkedIndex) -> Any: self._last_chunk_index = index.chunk_index # Fetch the element - chunk_filepath, begin, _ = self.config[index] - item = self._item_loader.load_item_from_chunk(index.index, index.chunk_index, chunk_filepath, begin) + chunk_filepath, begin, chunk_bytes = self.config[index] + item = self._item_loader.load_item_from_chunk(index.index, index.chunk_index, chunk_filepath, begin, chunk_bytes) # We need to request deletion after the latest element has been loaded. # Otherwise, this could trigger segmentation fault error depending on the item loader used. From 7b8c4462a384d194398398ecf00317be013f4b71 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 8 Mar 2024 10:53:14 +0000 Subject: [PATCH 2/2] update --- litdata/streaming/item_loader.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/litdata/streaming/item_loader.py b/litdata/streaming/item_loader.py index 4940192ff..6a2bbde5e 100644 --- a/litdata/streaming/item_loader.py +++ b/litdata/streaming/item_loader.py @@ -52,7 +52,9 @@ def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None: pass @abstractmethod - def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int, chunk_bytes: int) -> Any: + def load_item_from_chunk( + self, index: int, chunk_index: int, chunk_filepath: str, begin: int, chunk_bytes: int + ) -> Any: """Returns an item loaded from a chunk.""" pass @@ -81,7 +83,9 @@ def generate_intervals(self) -> List[Tuple[int, int]]: def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None: pass - def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int, chunk_bytes: int) -> bytes: + def load_item_from_chunk( + self, index: int, chunk_index: int, chunk_filepath: str, begin: int, chunk_bytes: int + ) -> bytes: offset = (1 + (index - begin) if index >= begin else index + 1) * 4 if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath): @@ -124,10 +128,7 @@ def deserialize(self, raw_item_data: bytes) -> "PyTree": data_bytes = raw_item_data[idx : idx + size] data.append(serializer.deserialize(data_bytes)) idx += size - try: - return tree_unflatten(data, self._config["data_spec"]) - except Exception: - breakpoint() + return tree_unflatten(data, self._config["data_spec"]) def delete(self, chunk_index: int, chunk_filepath: str) -> None: if os.path.exists(chunk_filepath): @@ -194,7 +195,9 @@ def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None: if os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0: self._load_chunk(chunk_index, chunk_filepath) - def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int, chunk_bytes: int) -> torch.Tensor: + def load_item_from_chunk( + self, index: int, chunk_index: int, chunk_filepath: str, begin: int, chunk_bytes: int + ) -> torch.Tensor: if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath): del self._chunk_filepaths[chunk_filepath]