From 20cea9c6450bb46d3021c632927542f1f51b941d Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Tue, 21 Oct 2025 14:24:07 +0545 Subject: [PATCH 1/2] refactor: remove torchaudio dependency and update audio processing to use soundfile --- requirements/test.txt | 3 +- src/litdata/constants.py | 1 - tests/processing/test_data_processor.py | 51 ++++++++++++------------- 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index f8edd079..31bb2966 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -2,7 +2,6 @@ av >=14.0.0 coverage ==7.10.* cryptography==45.0.7 mosaicml-streaming==0.11.0 -torchaudio>=2.7.0,<2.9 pytest ==8.4.* pytest-asyncio>=1.0.0 pytest-cov ==7.0.0 @@ -16,4 +15,4 @@ polars >1.0.0 lightning transformers <4.57.0 zstd -soundfile >=0.13.0 # required for torchaudio backend +soundfile >=0.13.0 diff --git a/src/litdata/constants.py b/src/litdata/constants.py index b7e7e46d..a167441f 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -35,7 +35,6 @@ _VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer") _BOTO3_AVAILABLE = RequirementCache("boto3") _FSSPEC_AVAILABLE = RequirementCache("fsspec") -_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio") _ZSTD_AVAILABLE = RequirementCache("zstd") _CRYPTOGRAPHY_AVAILABLE = RequirementCache("cryptography") _GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage") diff --git a/tests/processing/test_data_processor.py b/tests/processing/test_data_processor.py index b201bcae..3ef56b5b 100644 --- a/tests/processing/test_data_processor.py +++ b/tests/processing/test_data_processor.py @@ -3,6 +3,7 @@ import os import random import sys +import tempfile from contextlib import suppress from functools import partial from io import BytesIO @@ -16,7 +17,7 @@ import torch from lightning_utilities.core.imports import RequirementCache -from litdata.constants import _TORCH_AUDIO_AVAILABLE, _ZSTD_AVAILABLE +from litdata.constants import _ZSTD_AVAILABLE from litdata.processing import data_processor as data_processor_module from litdata.processing import functions from litdata.processing.data_processor import ( @@ -1127,29 +1128,26 @@ def test_empty_optimize(tmpdir, inputs): def create_synthetic_audio_bytes(index) -> dict: - from io import BytesIO - - import torchaudio + import soundfile as sf # load dummy audio as bytes - data = torch.randn((1, 16000)) + data = torch.randn((1, 16000)).numpy().squeeze() # shape (16000,) - # convert tensor to bytes - with BytesIO() as f: - torchaudio.save(f, data, 16000, format="wav") - data = f.getvalue() + # convert array to bytes + with tempfile.NamedTemporaryFile(suffix=".wav") as tmp: + sf.write(tmp.name, data, 16000, format="WAV") + with open(tmp.name, "rb") as f: + data = f.read() return {"content": data} -@pytest.mark.skipif(condition=not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE, reason="Requires: ['torchaudio']") +@pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']") @pytest.mark.parametrize("compression", [None, "zstd"]) -def test_load_torch_audio(tmpdir, compression): +def test_load_audio_bytes_optimize_and_stream(tmpdir, compression): seed_everything(42) - import torchaudio - - torchaudio.set_audio_backend("soundfile") + import soundfile as sf optimize( fn=create_synthetic_audio_bytes, @@ -1164,30 +1162,30 @@ def test_load_torch_audio(tmpdir, compression): sample = dataset[0] buffer = BytesIO(sample["content"]) buffer.seek(0) - tensor, sample_rate = torchaudio.load(buffer, format="wav") + data, sample_rate = sf.read(buffer) + tensor = torch.from_numpy(data).unsqueeze(0) assert tensor.shape == torch.Size([1, 16000]) assert sample_rate == 16000 def create_synthetic_audio_file(filepath) -> dict: - import torchaudio + import soundfile as sf # load dummy audio as bytes - data = torch.randn((1, 16000)) + data = torch.randn((1, 16000)).numpy().squeeze() - # convert tensor to bytes - with open(filepath, "wb") as f: - torchaudio.save(f, data, 16000, format="wav") + # convert array to bytes + sf.write(filepath, data, 16000, format="WAV") return filepath -@pytest.mark.skipif(condition=not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE, reason="Requires: ['torchaudio']") +@pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']") @pytest.mark.parametrize("compression", [None]) -def test_load_torch_audio_from_wav_file(tmpdir, compression): +def test_load_audio_file_optimize_and_stream(tmpdir, compression): seed_everything(42) - import torchaudio + import soundfile as sf optimize( fn=create_synthetic_audio_file, @@ -1200,9 +1198,10 @@ def test_load_torch_audio_from_wav_file(tmpdir, compression): dataset = StreamingDataset(input_dir=str(tmpdir)) sample = dataset[0] - tensor = torchaudio.load(sample) - assert tensor[0].shape == torch.Size([1, 16000]) - assert tensor[1] == 16000 + data, sample_rate = sf.read(sample) + tensor = torch.from_numpy(data).unsqueeze(0) + assert tensor.shape == torch.Size([1, 16000]) + assert sample_rate == 16000 def test_is_path_valid_in_studio(monkeypatch, tmpdir): From 56f02161589664f222c4dccb83bdd84531db931d Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Tue, 21 Oct 2025 15:26:18 +0545 Subject: [PATCH 2/2] test: update skip condition to include Windows platform --- tests/processing/test_data_processor.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/processing/test_data_processor.py b/tests/processing/test_data_processor.py index 3ef56b5b..80ec1b99 100644 --- a/tests/processing/test_data_processor.py +++ b/tests/processing/test_data_processor.py @@ -1142,7 +1142,9 @@ def create_synthetic_audio_bytes(index) -> dict: return {"content": data} -@pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']") +@pytest.mark.skipif( + condition=not _ZSTD_AVAILABLE or sys.platform == "win32", reason="Requires: ['zstd'] or Windows not supported" +) @pytest.mark.parametrize("compression", [None, "zstd"]) def test_load_audio_bytes_optimize_and_stream(tmpdir, compression): seed_everything(42) @@ -1180,7 +1182,9 @@ def create_synthetic_audio_file(filepath) -> dict: return filepath -@pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']") +@pytest.mark.skipif( + condition=not _ZSTD_AVAILABLE or sys.platform == "win32", reason="Requires: ['zstd'] or Windows not supported" +) @pytest.mark.parametrize("compression", [None]) def test_load_audio_file_optimize_and_stream(tmpdir, compression): seed_everything(42)