Skip to content

Commit

Permalink
Resolve compression, add support for torchaudio (#19503)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Feb 21, 2024
1 parent 2394e2f commit 39a86f8
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 16 deletions.
2 changes: 2 additions & 0 deletions src/lightning/data/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
_LIGHTNING_CLOUD_LATEST = RequirementCache("lightning-cloud>=0.5.64")
_BOTO3_AVAILABLE = RequirementCache("boto3")
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
_ZSTD_AVAILABLE = RequirementCache("zstd")

# DON'T CHANGE ORDER
_TORCH_DTYPES_MAPPING = {
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/data/streaming/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from abc import ABC, abstractmethod
from typing import Dict, TypeVar

from lightning_utilities.core.imports import RequirementCache, requires
from lightning_utilities.core.imports import requires

_ZSTD_AVAILABLE = RequirementCache("zstd")
from lightning.data.constants import _ZSTD_AVAILABLE

if _ZSTD_AVAILABLE:
import zstd
Expand Down
37 changes: 36 additions & 1 deletion src/lightning/data/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Dict, List, Optional, Tuple

from lightning.data.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.compression import _COMPRESSORS, Compressor
from lightning.data.streaming.downloader import get_downloader_cls
from lightning.data.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader
from lightning.data.streaming.sampler import ChunkedIndex
Expand Down Expand Up @@ -66,19 +67,47 @@ def __init__(
if remote_dir:
self._downloader = get_downloader_cls(remote_dir, cache_dir, self._chunks)

self._compressor_name = self._config["compression"]
self._compressor: Optional[Compressor] = None

if self._compressor_name in _COMPRESSORS:
self._compressor = _COMPRESSORS[self._compressor_name]

def download_chunk_from_index(self, chunk_index: int) -> None:
chunk_filename = self._chunks[chunk_index]["filename"]

local_chunkpath = os.path.join(self._cache_dir, chunk_filename)

if os.path.exists(local_chunkpath):
self.try_decompress(local_chunkpath)
return

if self._downloader is None:
raise RuntimeError("The downloader should be defined.")

self._downloader.download_chunk_from_index(chunk_index)

self.try_decompress(local_chunkpath)

def try_decompress(self, local_chunkpath: str) -> None:
if self._compressor is None:
return

target_local_chunkpath = local_chunkpath.replace(f".{self._compressor_name}", "")

if os.path.exists(target_local_chunkpath):
return

with open(local_chunkpath, "rb") as f:
data = f.read()

os.remove(local_chunkpath)

data = self._compressor.decompress(data)

with open(target_local_chunkpath, "wb") as f:
f.write(data)

@property
def intervals(self) -> List[Tuple[int, int]]:
if self._intervals is None:
Expand Down Expand Up @@ -132,7 +161,13 @@ def _get_chunk_index_from_index(self, index: int) -> int:
def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]:
"""Find the associated chunk metadata."""
chunk = self._chunks[index.chunk_index]
return os.path.join(self._cache_dir, chunk["filename"]), *self._intervals[index.chunk_index]

local_chunkpath = os.path.join(self._cache_dir, chunk["filename"])

if self._compressor is not None:
local_chunkpath = local_chunkpath.replace(f".{self._compressor_name}", "")

return local_chunkpath, *self._intervals[index.chunk_index]

def _get_chunk_index_from_filename(self, chunk_filename: str) -> int:
"""Retrieves the associated chunk_index for a given chunk filename."""
Expand Down
13 changes: 12 additions & 1 deletion src/lightning/data/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os
from abc import ABC, abstractmethod
from time import sleep
Expand Down Expand Up @@ -101,15 +102,25 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
begin, end = np.frombuffer(pair, np.uint32)
fp.seek(begin)
data = fp.read(end - begin)

return self.deserialize(data)

@functools.lru_cache(maxsize=128)
def _data_format_to_key(self, data_format: str) -> str:
if ":" in data_format:
serialier, serializer_sub_type = data_format.split(":")
if serializer_sub_type in self._serializers:
return serializer_sub_type
return serialier
return data_format

def deserialize(self, raw_item_data: bytes) -> "PyTree":
"""Deserialize the raw bytes into their python equivalent."""
idx = len(self._config["data_format"]) * 4
sizes = np.frombuffer(raw_item_data[:idx], np.uint32)
data = []
for size, data_format in zip(sizes, self._config["data_format"]):
serializer = self._serializers[data_format]
serializer = self._serializers[self._data_format_to_key(data_format)]
data_bytes = raw_item_data[idx : idx + size]
data.append(serializer.deserialize(data_bytes))
idx += size
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def read(self, index: ChunkedIndex) -> Any:
if self._config is None and self._try_load_config() is None:
raise Exception("The reader index isn't defined.")

if self._config and self._config._remote_dir:
if self._config and (self._config._remote_dir or self._config._compressor):
# Create and start the prepare chunks thread
if self._prepare_thread is None and self._config:
self._prepare_thread = PrepareChunksThread(
Expand Down
8 changes: 5 additions & 3 deletions src/lightning/data/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ class FileSerializer(Serializer):
def serialize(self, filepath: str) -> Tuple[bytes, Optional[str]]:
_, file_extension = os.path.splitext(filepath)
with open(filepath, "rb") as f:
return f.read(), file_extension.replace(".", "").lower()
file_extension = file_extension.replace(".", "").lower()
return f.read(), f"file:{file_extension}"

def deserialize(self, data: bytes) -> Any:
return data
Expand All @@ -292,12 +293,13 @@ def can_serialize(self, data: Any) -> bool:


class VideoSerializer(Serializer):
_EXTENSIONS = ("mp4", "ogv", "mjpeg", "avi", "mov", "h264", "mpg", "webm", "wmv", "wav")
_EXTENSIONS = ("mp4", "ogv", "mjpeg", "avi", "mov", "h264", "mpg", "webm", "wmv")

def serialize(self, filepath: str) -> Tuple[bytes, Optional[str]]:
_, file_extension = os.path.splitext(filepath)
with open(filepath, "rb") as f:
return f.read(), file_extension.replace(".", "").lower()
file_extension = file_extension.replace(".", "").lower()
return f.read(), f"video:{file_extension}"

def deserialize(self, data: bytes) -> Any:
if not _TORCH_VISION_AVAILABLE:
Expand Down
79 changes: 78 additions & 1 deletion tests/tests_data/processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
import torch
from lightning import seed_everything
from lightning.data.constants import _TORCH_AUDIO_AVAILABLE, _ZSTD_AVAILABLE
from lightning.data.processing import data_processor as data_processor_module
from lightning.data.processing import functions
from lightning.data.processing.data_processor import (
Expand All @@ -26,7 +27,7 @@
_wait_for_file_to_exist,
)
from lightning.data.processing.functions import LambdaDataTransformRecipe, map, optimize
from lightning.data.streaming import resolver
from lightning.data.streaming import StreamingDataset, resolver
from lightning.data.streaming.cache import Cache, Dir
from lightning_utilities.core.imports import RequirementCache

Expand Down Expand Up @@ -1058,3 +1059,79 @@ def test_empty_optimize(tmpdir):
)

assert os.listdir(tmpdir) == ["index.json"]


def create_synthetic_audio_bytes(index) -> dict:
from io import BytesIO

import torchaudio

# load dummy audio as bytes
data = torch.randn((1, 16000))

# convert tensor to bytes
with BytesIO() as f:
torchaudio.save(f, data, 16000, format="wav")
data = f.getvalue()

data = {"content": data}
return data


@pytest.mark.skipif(condition=not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE, reason="Requires: ['torchaudio']")
@pytest.mark.parametrize("compression", [None, "zstd"])
def test_load_torch_audio(tmpdir, compression):
seed_everything(42)

import torchaudio

optimize(
fn=create_synthetic_audio_bytes,
inputs=list(range(100)),
output_dir=str(tmpdir),
num_workers=1,
chunk_bytes="64MB",
compression=compression,
)

dataset = StreamingDataset(input_dir=str(tmpdir))
sample = dataset[0]
tensor = torchaudio.load(sample["content"])
assert tensor[0].shape == torch.Size([1, 16000])
assert tensor[1] == 16000


def create_synthetic_audio_file(filepath) -> dict:
import torchaudio

# load dummy audio as bytes
data = torch.randn((1, 16000))

# convert tensor to bytes
with open(filepath, "wb") as f:
torchaudio.save(f, data, 16000, format="wav")

return filepath


@pytest.mark.skipif(condition=not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE, reason="Requires: ['torchaudio']")
@pytest.mark.parametrize("compression", [None])
def test_load_torch_audio_from_wav_file(tmpdir, compression):
seed_everything(42)

import torchaudio

optimize(
fn=create_synthetic_audio_file,
inputs=[os.path.join(tmpdir, f"{i}.wav") for i in range(5)],
output_dir=str(tmpdir),
num_workers=1,
chunk_bytes="64MB",
compression=compression,
)

dataset = StreamingDataset(input_dir=str(tmpdir))
sample = dataset[0]
tensor = torchaudio.load(sample)
assert tensor[0].shape == torch.Size([1, 16000])
assert tensor[1] == 16000
14 changes: 7 additions & 7 deletions tests/tests_data/streaming/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,16 +213,16 @@ def test_assert_no_header_numpy_serializer():
def test_wav_deserialization(tmpdir):
from torch.hub import download_url_to_file

video_file = os.path.join(tmpdir, "video.wav")
key = "tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa E501
video_file = os.path.join(tmpdir, "video.mp4")
key = "tutorial-assets/mptestsrc.mp4" # E501
download_url_to_file(f"https://download.pytorch.org/torchaudio/{key}", video_file)

serializer = VideoSerializer()
assert serializer.can_serialize(video_file)
data, name = serializer.serialize(video_file)
assert len(data) / 1024 / 1024 == 0.10380172729492188
assert name == "wav"
assert len(data) / 1024 / 1024 == 0.2262248992919922
assert name == "video:mp4"
vframes, aframes, info = serializer.deserialize(data)
assert vframes.shape == torch.Size([0, 1, 1, 3])
assert aframes.shape == torch.Size([1, 54400])
assert info == {"audio_fps": 16000}
assert vframes.shape == torch.Size([301, 512, 512, 3])
assert aframes.shape == torch.Size([1, 0])
assert info == {"video_fps": 25.0}

0 comments on commit 39a86f8

Please sign in to comment.