Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@ def setup(self, config: Dict, chunks: List, serializers: Dict[str, Serializer])
self._config = config
self._chunks = chunks
self._serializers = serializers
self._data_format = self._config["data_format"]
self._shift_idx = len(self._data_format) * 4

# setup the serializers on restart
for data_format in self._data_format:
serializer = self._serializers[self._data_format_to_key(data_format)]
serializer.setup(data_format)

@functools.lru_cache(maxsize=128)
def _data_format_to_key(self, data_format: str) -> str:
if ":" in data_format:
serialier, serializer_sub_type = data_format.split(":")
if serializer_sub_type in self._serializers:
return serializer_sub_type
return serialier
return data_format

def state_dict(self) -> Dict:
return {}
Expand Down Expand Up @@ -109,21 +125,12 @@ def load_item_from_chunk(

return self.deserialize(data)

@functools.lru_cache(maxsize=128)
def _data_format_to_key(self, data_format: str) -> str:
if ":" in data_format:
serialier, serializer_sub_type = data_format.split(":")
if serializer_sub_type in self._serializers:
return serializer_sub_type
return serialier
return data_format

def deserialize(self, raw_item_data: bytes) -> "PyTree":
"""Deserialize the raw bytes into their python equivalent."""
idx = len(self._config["data_format"]) * 4
idx = self._shift_idx
sizes = np.frombuffer(raw_item_data[:idx], np.uint32)
data = []
for size, data_format in zip(sizes, self._config["data_format"]):
for size, data_format in zip(sizes, self._data_format):
serializer = self._serializers[self._data_format_to_key(data_format)]
data_bytes = raw_item_data[idx : idx + size]
data.append(serializer.deserialize(data_bytes))
Expand Down
12 changes: 12 additions & 0 deletions tests/streaming/test_item_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from unittest.mock import MagicMock

from litdata.streaming.item_loader import PyTreeLoader


def test_serializer_setup():
config_mock = MagicMock()
config_mock.__getitem__.return_value = ["fake:12"]
serializer_mock = MagicMock()
item_loader = PyTreeLoader()
item_loader.setup(config_mock, [], {"fake": serializer_mock})
serializer_mock.setup._mock_mock_calls[0].args[0] == "fake:12"