From 88ff5274bc47e2477ebff9991fd24c231b837181 Mon Sep 17 00:00:00 2001 From: Viatcheslav Gurev Date: Wed, 27 Mar 2024 22:02:52 -0400 Subject: [PATCH 1/5] Fixed bug of missing call to setup of serializer --- src/litdata/streaming/item_loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 6a2bbde5e..f0e6ce614 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -126,6 +126,7 @@ def deserialize(self, raw_item_data: bytes) -> "PyTree": for size, data_format in zip(sizes, self._config["data_format"]): serializer = self._serializers[self._data_format_to_key(data_format)] data_bytes = raw_item_data[idx : idx + size] + serializer.setup(data_format) data.append(serializer.deserialize(data_bytes)) idx += size return tree_unflatten(data, self._config["data_spec"]) From a72a8d6890bb23cbf0e7994b1315844ec0b2441c Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 11 Apr 2024 09:14:59 +0100 Subject: [PATCH 2/5] update --- src/litdata/streaming/item_loader.py | 23 ++++++++++++++--------- tests/streaming/test_item_loader.py | 12 ++++++++++++ 2 files changed, 26 insertions(+), 9 deletions(-) create mode 100644 tests/streaming/test_item_loader.py diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index f0e6ce614..8c7f26db8 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -38,6 +38,20 @@ def setup(self, config: Dict, chunks: List, serializers: Dict[str, Serializer]) self._chunks = chunks self._serializers = serializers + # setup the serializers on restart + for data_format in self._config["data_format"]: + serializer = self._serializers[self._data_format_to_key(data_format)] + serializer.setup(data_format) + + @functools.lru_cache(maxsize=128) + def _data_format_to_key(self, data_format: str) -> str: + if ":" in data_format: + serialier, serializer_sub_type = data_format.split(":") + if serializer_sub_type in self._serializers: + return serializer_sub_type + return serialier + return data_format + def state_dict(self) -> Dict: return {} @@ -109,15 +123,6 @@ def load_item_from_chunk( return self.deserialize(data) - @functools.lru_cache(maxsize=128) - def _data_format_to_key(self, data_format: str) -> str: - if ":" in data_format: - serialier, serializer_sub_type = data_format.split(":") - if serializer_sub_type in self._serializers: - return serializer_sub_type - return serialier - return data_format - def deserialize(self, raw_item_data: bytes) -> "PyTree": """Deserialize the raw bytes into their python equivalent.""" idx = len(self._config["data_format"]) * 4 diff --git a/tests/streaming/test_item_loader.py b/tests/streaming/test_item_loader.py new file mode 100644 index 000000000..e17fe705f --- /dev/null +++ b/tests/streaming/test_item_loader.py @@ -0,0 +1,12 @@ +from unittest.mock import MagicMock + +from litdata.streaming.item_loader import PyTreeLoader + + +def test_serializer_setup(): + config_mock = MagicMock() + config_mock.__getitem__.return_value = ["fake:12"] + serializer_mock = MagicMock() + item_loader = PyTreeLoader() + item_loader.setup(config_mock, [], {"fake": serializer_mock}) + serializer_mock.setup._mock_mock_calls[0].args[0] == "fake:12" From 13703fbdca0840ae226417152a3c9daad6b009ee Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 11 Apr 2024 09:16:10 +0100 Subject: [PATCH 3/5] update --- src/litdata/streaming/item_loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 8c7f26db8..faeebe521 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -131,7 +131,6 @@ def deserialize(self, raw_item_data: bytes) -> "PyTree": for size, data_format in zip(sizes, self._config["data_format"]): serializer = self._serializers[self._data_format_to_key(data_format)] data_bytes = raw_item_data[idx : idx + size] - serializer.setup(data_format) data.append(serializer.deserialize(data_bytes)) idx += size return tree_unflatten(data, self._config["data_spec"]) From 443c74ca6447d15b92d5751f64cd054a604c8b80 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 11 Apr 2024 09:16:54 +0100 Subject: [PATCH 4/5] update --- src/litdata/streaming/item_loader.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index faeebe521..87cc5c4b8 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -37,9 +37,10 @@ def setup(self, config: Dict, chunks: List, serializers: Dict[str, Serializer]) self._config = config self._chunks = chunks self._serializers = serializers + self._data_format = self._config["data_format"] # setup the serializers on restart - for data_format in self._config["data_format"]: + for data_format in self._data_format: serializer = self._serializers[self._data_format_to_key(data_format)] serializer.setup(data_format) @@ -125,10 +126,10 @@ def load_item_from_chunk( def deserialize(self, raw_item_data: bytes) -> "PyTree": """Deserialize the raw bytes into their python equivalent.""" - idx = len(self._config["data_format"]) * 4 + idx = len(self._data_format) * 4 sizes = np.frombuffer(raw_item_data[:idx], np.uint32) data = [] - for size, data_format in zip(sizes, self._config["data_format"]): + for size, data_format in zip(sizes, self._data_format): serializer = self._serializers[self._data_format_to_key(data_format)] data_bytes = raw_item_data[idx : idx + size] data.append(serializer.deserialize(data_bytes)) From 7e254817f4b08f9ba7aae7e9328eb2038183874e Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 11 Apr 2024 09:18:40 +0100 Subject: [PATCH 5/5] update --- src/litdata/streaming/item_loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 87cc5c4b8..04b9b2b8c 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -38,6 +38,7 @@ def setup(self, config: Dict, chunks: List, serializers: Dict[str, Serializer]) self._chunks = chunks self._serializers = serializers self._data_format = self._config["data_format"] + self._shift_idx = len(self._data_format) * 4 # setup the serializers on restart for data_format in self._data_format: @@ -126,7 +127,7 @@ def load_item_from_chunk( def deserialize(self, raw_item_data: bytes) -> "PyTree": """Deserialize the raw bytes into their python equivalent.""" - idx = len(self._data_format) * 4 + idx = self._shift_idx sizes = np.frombuffer(raw_item_data[:idx], np.uint32) data = [] for size, data_format in zip(sizes, self._data_format):