In [1]:
import glob, random

from datasets import *

TRAIN_CACHE = "data/dataset_cache/dmresnet/train"
VAL_CACHE = "data/dataset_cache/dmresnet/val"
TEST_CACHE = "data/dataset_cache/dmresnet/test"
SEGMENT_TIME = 10

In [2]:
audio_files = glob.glob(
    "data/neural-audio-fp-dataset/music/train-10k-30s/**/*.*", recursive=True
)
audio_files = [f for f in audio_files if f.endswith(".wav")]
print(f"Total audio number: {len(audio_files)}")
random.seed(42)
random.shuffle(audio_files)
split_idx = int(0.8 * len(audio_files))
train_files = audio_files[:split_idx]
val_files = audio_files[split_idx:]
print(f"Total train files: {len(train_files)}")
print(f"Total validation files: {len(val_files)}")

Total audio number: 10000
Total train files: 8000
Total validation files: 2000


In [3]:
audio_files_musan = glob.glob(
    "data/musan/music/**/*.*",
    recursive=True,
)
audio_files_musan = [f for f in audio_files_musan if f.endswith(".wav")]
print(f"Total musan music files: {len(audio_files_musan)}")

Total musan music files: 660


In [4]:
test_doc = glob.glob(
    "data/neural-audio-fp-dataset/music/test-query-db-500-30s/db/**/*.*",
    recursive=True,
)
test_doc = [f for f in test_doc if f.endswith(".wav")]
print(f"Total test db: {len(test_doc)}")

Total test db: 500


In [5]:
bg_noise_musan = glob.glob(
    "data/musan/noise/**/*.*",
    recursive=True,
)
bg_noise_neural = glob.glob(
    "data/neural-audio-fp-dataset/aug/bg/**/*.*",
    recursive=True,
)
bg_noise_musan = [f for f in bg_noise_musan if f.endswith(".wav")]
bg_noise_neural = [f for f in bg_noise_neural if f.endswith(".wav")]
bg_noise = bg_noise_musan + bg_noise_neural
print(f"Total background noise files: {len(bg_noise)}")

Total background noise files: 3604


In [6]:
rir_noise = glob.glob(
    "data/neural-audio-fp-dataset/aug/ir/**/*.*",
    recursive=True,
)
rir_noise = [f for f in rir_noise if f.endswith(".wav")]
print(f"Total rir noise files: {len(rir_noise)}")

Total rir noise files: 440


In [7]:
device = torch.device("cuda")
augment = WaveformAugment(bg_noise_list=bg_noise, rir_noise_list=rir_noise)
train_dataset_raw = MelSpecDataset(
    train_files, split="train", segment_sec=SEGMENT_TIME, augment=augment, device=device
)
val_dataset_raw = MelSpecDataset(val_files, split="val", segment_sec=SEGMENT_TIME, device=device)
test_dataset_raw = MelSpecDataset(test_doc, split="test", segment_sec=SEGMENT_TIME, device=device)

In [8]:
train_data = preprocess_and_cache_lazy(train_dataset_raw, TRAIN_CACHE)
val_data = preprocess_and_cache_lazy(val_dataset_raw, VAL_CACHE)
test_data = preprocess_and_cache_lazy(test_dataset_raw, TEST_CACHE)

Cache meta found. Loading sample paths from 'data/dataset_cache/dmresnet/train/meta.pt'...
Resuming processing at sample 0 of 8000.


Checkpoint reached: saved 100/8000 samples.
Checkpoint reached: saved 200/8000 samples.
Checkpoint reached: saved 300/8000 samples.
Checkpoint reached: saved 400/8000 samples.
Checkpoint reached: saved 500/8000 samples.
Checkpoint reached: saved 600/8000 samples.
Checkpoint reached: saved 700/8000 samples.
Checkpoint reached: saved 800/8000 samples.
Checkpoint reached: saved 900/8000 samples.
Checkpoint reached: saved 1000/8000 samples.
Checkpoint reached: saved 1100/8000 samples.
Checkpoint reached: saved 1200/8000 samples.
Checkpoint reached: saved 1300/8000 samples.
Checkpoint reached: saved 1400/8000 samples.
Checkpoint reached: saved 1500/8000 samples.
Checkpoint reached: saved 1600/8000 samples.
Checkpoint reached: saved 1700/8000 samples.
Checkpoint reached: saved 1800/8000 samples.
Checkpoint reached: saved 1900/8000 samples.
Checkpoint reached: saved 2000/8000 samples.
Checkpoint reached: saved 2100/8000 samples.
Checkpoint reached: saved 2200/8000 samples.
Checkpoint reached: