Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ av >=14.0.0
coverage ==7.10.*
cryptography==45.0.7
mosaicml-streaming==0.11.0
torchaudio>=2.7.0,<2.9
pytest ==8.4.*
pytest-asyncio>=1.0.0
pytest-cov ==7.0.0
Expand All @@ -16,4 +15,4 @@ polars >1.0.0
lightning
transformers <4.57.0
zstd
soundfile >=0.13.0 # required for torchaudio backend
soundfile >=0.13.0
1 change: 0 additions & 1 deletion src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
_BOTO3_AVAILABLE = RequirementCache("boto3")
_FSSPEC_AVAILABLE = RequirementCache("fsspec")
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
_ZSTD_AVAILABLE = RequirementCache("zstd")
_CRYPTOGRAPHY_AVAILABLE = RequirementCache("cryptography")
_GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage")
Expand Down
55 changes: 29 additions & 26 deletions tests/processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import random
import sys
import tempfile
from contextlib import suppress
from functools import partial
from io import BytesIO
Expand All @@ -16,7 +17,7 @@
import torch
from lightning_utilities.core.imports import RequirementCache

from litdata.constants import _TORCH_AUDIO_AVAILABLE, _ZSTD_AVAILABLE
from litdata.constants import _ZSTD_AVAILABLE
from litdata.processing import data_processor as data_processor_module
from litdata.processing import functions
from litdata.processing.data_processor import (
Expand Down Expand Up @@ -1127,29 +1128,28 @@ def test_empty_optimize(tmpdir, inputs):


def create_synthetic_audio_bytes(index) -> dict:
from io import BytesIO

import torchaudio
import soundfile as sf

# load dummy audio as bytes
data = torch.randn((1, 16000))
data = torch.randn((1, 16000)).numpy().squeeze() # shape (16000,)

# convert tensor to bytes
with BytesIO() as f:
torchaudio.save(f, data, 16000, format="wav")
data = f.getvalue()
# convert array to bytes
with tempfile.NamedTemporaryFile(suffix=".wav") as tmp:
sf.write(tmp.name, data, 16000, format="WAV")
with open(tmp.name, "rb") as f:
data = f.read()

return {"content": data}


@pytest.mark.skipif(condition=not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE, reason="Requires: ['torchaudio']")
@pytest.mark.skipif(
condition=not _ZSTD_AVAILABLE or sys.platform == "win32", reason="Requires: ['zstd'] or Windows not supported"
)
@pytest.mark.parametrize("compression", [None, "zstd"])
def test_load_torch_audio(tmpdir, compression):
def test_load_audio_bytes_optimize_and_stream(tmpdir, compression):
seed_everything(42)

import torchaudio

torchaudio.set_audio_backend("soundfile")
import soundfile as sf

optimize(
fn=create_synthetic_audio_bytes,
Expand All @@ -1164,30 +1164,32 @@ def test_load_torch_audio(tmpdir, compression):
sample = dataset[0]
buffer = BytesIO(sample["content"])
buffer.seek(0)
tensor, sample_rate = torchaudio.load(buffer, format="wav")
data, sample_rate = sf.read(buffer)
tensor = torch.from_numpy(data).unsqueeze(0)
assert tensor.shape == torch.Size([1, 16000])
assert sample_rate == 16000


def create_synthetic_audio_file(filepath) -> dict:
import torchaudio
import soundfile as sf

# load dummy audio as bytes
data = torch.randn((1, 16000))
data = torch.randn((1, 16000)).numpy().squeeze()

# convert tensor to bytes
with open(filepath, "wb") as f:
torchaudio.save(f, data, 16000, format="wav")
# convert array to bytes
sf.write(filepath, data, 16000, format="WAV")

return filepath


@pytest.mark.skipif(condition=not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE, reason="Requires: ['torchaudio']")
@pytest.mark.skipif(
condition=not _ZSTD_AVAILABLE or sys.platform == "win32", reason="Requires: ['zstd'] or Windows not supported"
)
@pytest.mark.parametrize("compression", [None])
def test_load_torch_audio_from_wav_file(tmpdir, compression):
def test_load_audio_file_optimize_and_stream(tmpdir, compression):
seed_everything(42)

import torchaudio
import soundfile as sf

optimize(
fn=create_synthetic_audio_file,
Expand All @@ -1200,9 +1202,10 @@ def test_load_torch_audio_from_wav_file(tmpdir, compression):

dataset = StreamingDataset(input_dir=str(tmpdir))
sample = dataset[0]
tensor = torchaudio.load(sample)
assert tensor[0].shape == torch.Size([1, 16000])
assert tensor[1] == 16000
data, sample_rate = sf.read(sample)
tensor = torch.from_numpy(data).unsqueeze(0)
assert tensor.shape == torch.Size([1, 16000])
assert sample_rate == 16000


def test_is_path_valid_in_studio(monkeypatch, tmpdir):
Expand Down
Loading