diff --git a/litdata/constants.py b/litdata/constants.py index 60e43dda2..90bdd0c24 100644 --- a/litdata/constants.py +++ b/litdata/constants.py @@ -28,6 +28,8 @@ _VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer") _LIGHTNING_CLOUD_LATEST = RequirementCache("lightning-cloud>=0.5.64") _BOTO3_AVAILABLE = RequirementCache("boto3") +_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio") +_ZSTD_AVAILABLE = RequirementCache("zstd") # DON'T CHANGE ORDER _TORCH_DTYPES_MAPPING = { diff --git a/litdata/streaming/compression.py b/litdata/streaming/compression.py index bd6f5efb6..75a479451 100644 --- a/litdata/streaming/compression.py +++ b/litdata/streaming/compression.py @@ -14,9 +14,9 @@ from abc import ABC, abstractmethod from typing import Dict, TypeVar -from lightning_utilities.core.imports import RequirementCache, requires +from lightning_utilities.core.imports import requires -_ZSTD_AVAILABLE = RequirementCache("zstd") +from litdata.constants import _ZSTD_AVAILABLE if _ZSTD_AVAILABLE: import zstd diff --git a/litdata/streaming/config.py b/litdata/streaming/config.py index befb30196..684f0f843 100644 --- a/litdata/streaming/config.py +++ b/litdata/streaming/config.py @@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Tuple from litdata.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0 +from litdata.streaming.compression import _COMPRESSORS, Compressor from litdata.streaming.downloader import get_downloader_cls from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader from litdata.streaming.sampler import ChunkedIndex @@ -66,12 +67,19 @@ def __init__( if remote_dir: self._downloader = get_downloader_cls(remote_dir, cache_dir, self._chunks) + self._compressor_name = self._config["compression"] + self._compressor: Optional[Compressor] = None + + if self._compressor_name in _COMPRESSORS: + self._compressor = _COMPRESSORS[self._compressor_name] + def download_chunk_from_index(self, chunk_index: int) -> None: chunk_filename = self._chunks[chunk_index]["filename"] local_chunkpath = os.path.join(self._cache_dir, chunk_filename) if os.path.exists(local_chunkpath): + self.try_decompress(local_chunkpath) return if self._downloader is None: @@ -79,6 +87,27 @@ def download_chunk_from_index(self, chunk_index: int) -> None: self._downloader.download_chunk_from_index(chunk_index) + self.try_decompress(local_chunkpath) + + def try_decompress(self, local_chunkpath: str) -> None: + if self._compressor is None: + return + + target_local_chunkpath = local_chunkpath.replace(f".{self._compressor_name}", "") + + if os.path.exists(target_local_chunkpath): + return + + with open(local_chunkpath, "rb") as f: + data = f.read() + + os.remove(local_chunkpath) + + data = self._compressor.decompress(data) + + with open(target_local_chunkpath, "wb") as f: + f.write(data) + @property def intervals(self) -> List[Tuple[int, int]]: if self._intervals is None: @@ -132,7 +161,13 @@ def _get_chunk_index_from_index(self, index: int) -> int: def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]: """Find the associated chunk metadata.""" chunk = self._chunks[index.chunk_index] - return os.path.join(self._cache_dir, chunk["filename"]), *self._intervals[index.chunk_index] + + local_chunkpath = os.path.join(self._cache_dir, chunk["filename"]) + + if self._compressor is not None: + local_chunkpath = local_chunkpath.replace(f".{self._compressor_name}", "") + + return local_chunkpath, *self._intervals[index.chunk_index] def _get_chunk_index_from_filename(self, chunk_filename: str) -> int: """Retrieves the associated chunk_index for a given chunk filename.""" diff --git a/litdata/streaming/item_loader.py b/litdata/streaming/item_loader.py index b578e3bf0..5cd8fdd25 100644 --- a/litdata/streaming/item_loader.py +++ b/litdata/streaming/item_loader.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import os from abc import ABC, abstractmethod from time import sleep @@ -101,15 +102,25 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str begin, end = np.frombuffer(pair, np.uint32) fp.seek(begin) data = fp.read(end - begin) + return self.deserialize(data) + @functools.lru_cache(maxsize=128) + def _data_format_to_key(self, data_format: str) -> str: + if ":" in data_format: + serialier, serializer_sub_type = data_format.split(":") + if serializer_sub_type in self._serializers: + return serializer_sub_type + return serialier + return data_format + def deserialize(self, raw_item_data: bytes) -> "PyTree": """Deserialize the raw bytes into their python equivalent.""" idx = len(self._config["data_format"]) * 4 sizes = np.frombuffer(raw_item_data[:idx], np.uint32) data = [] for size, data_format in zip(sizes, self._config["data_format"]): - serializer = self._serializers[data_format] + serializer = self._serializers[self._data_format_to_key(data_format)] data_bytes = raw_item_data[idx : idx + size] data.append(serializer.deserialize(data_bytes)) idx += size diff --git a/litdata/streaming/reader.py b/litdata/streaming/reader.py index ad63175cc..b0aff6a98 100644 --- a/litdata/streaming/reader.py +++ b/litdata/streaming/reader.py @@ -229,7 +229,7 @@ def read(self, index: ChunkedIndex) -> Any: if self._config is None and self._try_load_config() is None: raise Exception("The reader index isn't defined.") - if self._config and self._config._remote_dir: + if self._config and (self._config._remote_dir or self._config._compressor): # Create and start the prepare chunks thread if self._prepare_thread is None and self._config: self._prepare_thread = PrepareChunksThread( diff --git a/litdata/streaming/serializers.py b/litdata/streaming/serializers.py index 700251d81..ccc37869c 100644 --- a/litdata/streaming/serializers.py +++ b/litdata/streaming/serializers.py @@ -282,7 +282,8 @@ class FileSerializer(Serializer): def serialize(self, filepath: str) -> Tuple[bytes, Optional[str]]: _, file_extension = os.path.splitext(filepath) with open(filepath, "rb") as f: - return f.read(), file_extension.replace(".", "").lower() + file_extension = file_extension.replace(".", "").lower() + return f.read(), f"file:{file_extension}" def deserialize(self, data: bytes) -> Any: return data @@ -292,12 +293,13 @@ def can_serialize(self, data: Any) -> bool: class VideoSerializer(Serializer): - _EXTENSIONS = ("mp4", "ogv", "mjpeg", "avi", "mov", "h264", "mpg", "webm", "wmv", "wav") + _EXTENSIONS = ("mp4", "ogv", "mjpeg", "avi", "mov", "h264", "mpg", "webm", "wmv") def serialize(self, filepath: str) -> Tuple[bytes, Optional[str]]: _, file_extension = os.path.splitext(filepath) with open(filepath, "rb") as f: - return f.read(), file_extension.replace(".", "").lower() + file_extension = file_extension.replace(".", "").lower() + return f.read(), f"video:{file_extension}" def deserialize(self, data: bytes) -> Any: if not _TORCH_VISION_AVAILABLE: diff --git a/tests/processing/test_data_processor.py b/tests/processing/test_data_processor.py index 7af1271a1..911997cd7 100644 --- a/tests/processing/test_data_processor.py +++ b/tests/processing/test_data_processor.py @@ -12,6 +12,7 @@ from lightning import seed_everything from lightning_utilities.core.imports import RequirementCache +from litdata.constants import _TORCH_AUDIO_AVAILABLE, _ZSTD_AVAILABLE from litdata.processing import data_processor as data_processor_module from litdata.processing import functions from litdata.processing.data_processor import ( @@ -28,7 +29,7 @@ _wait_for_file_to_exist, ) from litdata.processing.functions import LambdaDataTransformRecipe, map, optimize -from litdata.streaming import resolver +from litdata.streaming import StreamingDataset, resolver from litdata.streaming.cache import Cache, Dir _PIL_AVAILABLE = RequirementCache("PIL") @@ -1059,3 +1060,79 @@ def test_empty_optimize(tmpdir): ) assert os.listdir(tmpdir) == ["index.json"] + + +def create_synthetic_audio_bytes(index) -> dict: + from io import BytesIO + + import torchaudio + + # load dummy audio as bytes + data = torch.randn((1, 16000)) + + # convert tensor to bytes + with BytesIO() as f: + torchaudio.save(f, data, 16000, format="wav") + data = f.getvalue() + + data = {"content": data} + return data + + +@pytest.mark.skipif(condition=not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE, reason="Requires: ['torchaudio']") +@pytest.mark.parametrize("compression", [None, "zstd"]) +def test_load_torch_audio(tmpdir, compression): + seed_everything(42) + + import torchaudio + + optimize( + fn=create_synthetic_audio_bytes, + inputs=list(range(100)), + output_dir=str(tmpdir), + num_workers=1, + chunk_bytes="64MB", + compression=compression, + ) + + dataset = StreamingDataset(input_dir=str(tmpdir)) + sample = dataset[0] + tensor = torchaudio.load(sample["content"]) + assert tensor[0].shape == torch.Size([1, 16000]) + assert tensor[1] == 16000 + + +def create_synthetic_audio_file(filepath) -> dict: + import torchaudio + + # load dummy audio as bytes + data = torch.randn((1, 16000)) + + # convert tensor to bytes + with open(filepath, "wb") as f: + torchaudio.save(f, data, 16000, format="wav") + + return filepath + + +@pytest.mark.skipif(condition=not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE, reason="Requires: ['torchaudio']") +@pytest.mark.parametrize("compression", [None]) +def test_load_torch_audio_from_wav_file(tmpdir, compression): + seed_everything(42) + + import torchaudio + + optimize( + fn=create_synthetic_audio_file, + inputs=[os.path.join(tmpdir, f"{i}.wav") for i in range(5)], + output_dir=str(tmpdir), + num_workers=1, + chunk_bytes="64MB", + compression=compression, + ) + + dataset = StreamingDataset(input_dir=str(tmpdir)) + sample = dataset[0] + tensor = torchaudio.load(sample) + assert tensor[0].shape == torch.Size([1, 16000]) + assert tensor[1] == 16000 diff --git a/tests/streaming/test_serializer.py b/tests/streaming/test_serializer.py index bd5fb002f..3771f95e2 100644 --- a/tests/streaming/test_serializer.py +++ b/tests/streaming/test_serializer.py @@ -214,16 +214,16 @@ def test_assert_no_header_numpy_serializer(): def test_wav_deserialization(tmpdir): from torch.hub import download_url_to_file - video_file = os.path.join(tmpdir, "video.wav") - key = "tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa E501 + video_file = os.path.join(tmpdir, "video.mp4") + key = "tutorial-assets/mptestsrc.mp4" # E501 download_url_to_file(f"https://download.pytorch.org/torchaudio/{key}", video_file) serializer = VideoSerializer() assert serializer.can_serialize(video_file) data, name = serializer.serialize(video_file) - assert len(data) / 1024 / 1024 == 0.10380172729492188 - assert name == "wav" + assert len(data) / 1024 / 1024 == 0.2262248992919922 + assert name == "video:mp4" vframes, aframes, info = serializer.deserialize(data) - assert vframes.shape == torch.Size([0, 1, 1, 3]) - assert aframes.shape == torch.Size([1, 54400]) - assert info == {"audio_fps": 16000} + assert vframes.shape == torch.Size([301, 512, 512, 3]) + assert aframes.shape == torch.Size([1, 0]) + assert info == {"video_fps": 25.0}