Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
_LIGHTNING_CLOUD_LATEST = RequirementCache("lightning-cloud>=0.5.64")
_BOTO3_AVAILABLE = RequirementCache("boto3")
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
_ZSTD_AVAILABLE = RequirementCache("zstd")

# DON'T CHANGE ORDER
_TORCH_DTYPES_MAPPING = {
Expand Down
4 changes: 2 additions & 2 deletions litdata/streaming/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from abc import ABC, abstractmethod
from typing import Dict, TypeVar

from lightning_utilities.core.imports import RequirementCache, requires
from lightning_utilities.core.imports import requires

_ZSTD_AVAILABLE = RequirementCache("zstd")
from litdata.constants import _ZSTD_AVAILABLE

if _ZSTD_AVAILABLE:
import zstd
Expand Down
37 changes: 36 additions & 1 deletion litdata/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Dict, List, Optional, Tuple

from litdata.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
from litdata.streaming.compression import _COMPRESSORS, Compressor
from litdata.streaming.downloader import get_downloader_cls
from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader
from litdata.streaming.sampler import ChunkedIndex
Expand Down Expand Up @@ -66,19 +67,47 @@ def __init__(
if remote_dir:
self._downloader = get_downloader_cls(remote_dir, cache_dir, self._chunks)

self._compressor_name = self._config["compression"]
self._compressor: Optional[Compressor] = None

if self._compressor_name in _COMPRESSORS:
self._compressor = _COMPRESSORS[self._compressor_name]

def download_chunk_from_index(self, chunk_index: int) -> None:
chunk_filename = self._chunks[chunk_index]["filename"]

local_chunkpath = os.path.join(self._cache_dir, chunk_filename)

if os.path.exists(local_chunkpath):
self.try_decompress(local_chunkpath)
return

if self._downloader is None:
raise RuntimeError("The downloader should be defined.")

self._downloader.download_chunk_from_index(chunk_index)

self.try_decompress(local_chunkpath)

def try_decompress(self, local_chunkpath: str) -> None:
if self._compressor is None:
return

target_local_chunkpath = local_chunkpath.replace(f".{self._compressor_name}", "")

if os.path.exists(target_local_chunkpath):
return

with open(local_chunkpath, "rb") as f:
data = f.read()

os.remove(local_chunkpath)

data = self._compressor.decompress(data)

with open(target_local_chunkpath, "wb") as f:
f.write(data)

@property
def intervals(self) -> List[Tuple[int, int]]:
if self._intervals is None:
Expand Down Expand Up @@ -132,7 +161,13 @@ def _get_chunk_index_from_index(self, index: int) -> int:
def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]:
"""Find the associated chunk metadata."""
chunk = self._chunks[index.chunk_index]
return os.path.join(self._cache_dir, chunk["filename"]), *self._intervals[index.chunk_index]

local_chunkpath = os.path.join(self._cache_dir, chunk["filename"])

if self._compressor is not None:
local_chunkpath = local_chunkpath.replace(f".{self._compressor_name}", "")

return local_chunkpath, *self._intervals[index.chunk_index]

def _get_chunk_index_from_filename(self, chunk_filename: str) -> int:
"""Retrieves the associated chunk_index for a given chunk filename."""
Expand Down
13 changes: 12 additions & 1 deletion litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os
from abc import ABC, abstractmethod
from time import sleep
Expand Down Expand Up @@ -101,15 +102,25 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
begin, end = np.frombuffer(pair, np.uint32)
fp.seek(begin)
data = fp.read(end - begin)

return self.deserialize(data)

@functools.lru_cache(maxsize=128)
def _data_format_to_key(self, data_format: str) -> str:
if ":" in data_format:
serialier, serializer_sub_type = data_format.split(":")
if serializer_sub_type in self._serializers:
return serializer_sub_type
return serialier
return data_format

def deserialize(self, raw_item_data: bytes) -> "PyTree":
"""Deserialize the raw bytes into their python equivalent."""
idx = len(self._config["data_format"]) * 4
sizes = np.frombuffer(raw_item_data[:idx], np.uint32)
data = []
for size, data_format in zip(sizes, self._config["data_format"]):
serializer = self._serializers[data_format]
serializer = self._serializers[self._data_format_to_key(data_format)]
data_bytes = raw_item_data[idx : idx + size]
data.append(serializer.deserialize(data_bytes))
idx += size
Expand Down
2 changes: 1 addition & 1 deletion litdata/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def read(self, index: ChunkedIndex) -> Any:
if self._config is None and self._try_load_config() is None:
raise Exception("The reader index isn't defined.")

if self._config and self._config._remote_dir:
if self._config and (self._config._remote_dir or self._config._compressor):
# Create and start the prepare chunks thread
if self._prepare_thread is None and self._config:
self._prepare_thread = PrepareChunksThread(
Expand Down
8 changes: 5 additions & 3 deletions litdata/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ class FileSerializer(Serializer):
def serialize(self, filepath: str) -> Tuple[bytes, Optional[str]]:
_, file_extension = os.path.splitext(filepath)
with open(filepath, "rb") as f:
return f.read(), file_extension.replace(".", "").lower()
file_extension = file_extension.replace(".", "").lower()
return f.read(), f"file:{file_extension}"

def deserialize(self, data: bytes) -> Any:
return data
Expand All @@ -292,12 +293,13 @@ def can_serialize(self, data: Any) -> bool:


class VideoSerializer(Serializer):
_EXTENSIONS = ("mp4", "ogv", "mjpeg", "avi", "mov", "h264", "mpg", "webm", "wmv", "wav")
_EXTENSIONS = ("mp4", "ogv", "mjpeg", "avi", "mov", "h264", "mpg", "webm", "wmv")

def serialize(self, filepath: str) -> Tuple[bytes, Optional[str]]:
_, file_extension = os.path.splitext(filepath)
with open(filepath, "rb") as f:
return f.read(), file_extension.replace(".", "").lower()
file_extension = file_extension.replace(".", "").lower()
return f.read(), f"video:{file_extension}"

def deserialize(self, data: bytes) -> Any:
if not _TORCH_VISION_AVAILABLE:
Expand Down
79 changes: 78 additions & 1 deletion tests/processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lightning import seed_everything
from lightning_utilities.core.imports import RequirementCache

from litdata.constants import _TORCH_AUDIO_AVAILABLE, _ZSTD_AVAILABLE
from litdata.processing import data_processor as data_processor_module
from litdata.processing import functions
from litdata.processing.data_processor import (
Expand All @@ -28,7 +29,7 @@
_wait_for_file_to_exist,
)
from litdata.processing.functions import LambdaDataTransformRecipe, map, optimize
from litdata.streaming import resolver
from litdata.streaming import StreamingDataset, resolver
from litdata.streaming.cache import Cache, Dir

_PIL_AVAILABLE = RequirementCache("PIL")
Expand Down Expand Up @@ -1059,3 +1060,79 @@ def test_empty_optimize(tmpdir):
)

assert os.listdir(tmpdir) == ["index.json"]


def create_synthetic_audio_bytes(index) -> dict:
from io import BytesIO

import torchaudio

# load dummy audio as bytes
data = torch.randn((1, 16000))

# convert tensor to bytes
with BytesIO() as f:
torchaudio.save(f, data, 16000, format="wav")
data = f.getvalue()

data = {"content": data}
return data


@pytest.mark.skipif(condition=not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE, reason="Requires: ['torchaudio']")
@pytest.mark.parametrize("compression", [None, "zstd"])
def test_load_torch_audio(tmpdir, compression):
seed_everything(42)

import torchaudio

optimize(
fn=create_synthetic_audio_bytes,
inputs=list(range(100)),
output_dir=str(tmpdir),
num_workers=1,
chunk_bytes="64MB",
compression=compression,
)

dataset = StreamingDataset(input_dir=str(tmpdir))
sample = dataset[0]
tensor = torchaudio.load(sample["content"])
assert tensor[0].shape == torch.Size([1, 16000])
assert tensor[1] == 16000


def create_synthetic_audio_file(filepath) -> dict:
import torchaudio

# load dummy audio as bytes
data = torch.randn((1, 16000))

# convert tensor to bytes
with open(filepath, "wb") as f:
torchaudio.save(f, data, 16000, format="wav")

return filepath


@pytest.mark.skipif(condition=not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE, reason="Requires: ['torchaudio']")
@pytest.mark.parametrize("compression", [None])
def test_load_torch_audio_from_wav_file(tmpdir, compression):
seed_everything(42)

import torchaudio

optimize(
fn=create_synthetic_audio_file,
inputs=[os.path.join(tmpdir, f"{i}.wav") for i in range(5)],
output_dir=str(tmpdir),
num_workers=1,
chunk_bytes="64MB",
compression=compression,
)

dataset = StreamingDataset(input_dir=str(tmpdir))
sample = dataset[0]
tensor = torchaudio.load(sample)
assert tensor[0].shape == torch.Size([1, 16000])
assert tensor[1] == 16000
14 changes: 7 additions & 7 deletions tests/streaming/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,16 +214,16 @@ def test_assert_no_header_numpy_serializer():
def test_wav_deserialization(tmpdir):
from torch.hub import download_url_to_file

video_file = os.path.join(tmpdir, "video.wav")
key = "tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa E501
video_file = os.path.join(tmpdir, "video.mp4")
key = "tutorial-assets/mptestsrc.mp4" # E501
download_url_to_file(f"https://download.pytorch.org/torchaudio/{key}", video_file)

serializer = VideoSerializer()
assert serializer.can_serialize(video_file)
data, name = serializer.serialize(video_file)
assert len(data) / 1024 / 1024 == 0.10380172729492188
assert name == "wav"
assert len(data) / 1024 / 1024 == 0.2262248992919922
assert name == "video:mp4"
vframes, aframes, info = serializer.deserialize(data)
assert vframes.shape == torch.Size([0, 1, 1, 3])
assert aframes.shape == torch.Size([1, 54400])
assert info == {"audio_fps": 16000}
assert vframes.shape == torch.Size([301, 512, 512, 3])
assert aframes.shape == torch.Size([1, 0])
assert info == {"video_fps": 25.0}