In [1]:
%pip install pyloudnorm

Collecting pyloudnorm
  Downloading pyloudnorm-0.1.1-py3-none-any.whl.metadata (5.6 kB)
Collecting future>=0.16.0 (from pyloudnorm)
  Downloading future-1.0.0-py3-none-any.whl.metadata (4.0 kB)
Downloading pyloudnorm-0.1.1-py3-none-any.whl (9.6 kB)
Downloading future-1.0.0-py3-none-any.whl (491 kB)
Installing collected packages: future, pyloudnorm

   ---------------------------------------- 0/2 [future]
   ---------------------------------------- 0/2 [future]
   ---------------------------------------- 0/2 [future]
   ---------------------------------------- 0/2 [future]
   ---------------------------------------- 0/2 [future]
   ---------------------------------------- 0/2 [future]
   ---------------------------------------- 0/2 [future]
   ---------------------------------------- 0/2 [future]
   ---------------------------------------- 0/2 [future]
   ---------------------------------------- 0/2 [future]
   ---------------------------------------- 0/2 [future]
   -------------------



In [2]:
# requirements: torch, torchaudio, librosa, pyloudnorm, soundfile, ffmpeg in PATH
import os, tempfile, subprocess, math, numpy as np, torch, torchaudio, librosa, pyloudnorm as pyln
from pathlib import Path

SR = 22050
WIN_SECS = 10.0
N_SAMPLES = int(SR * WIN_SECS)
TARGET_LUFS = -14.0

def to_mp3_bytes(waveform, sr=SR, bitrate="192k"):
    # round-trip encode/decode to MP3 to enforce codec parity
    with tempfile.NamedTemporaryFile(suffix=".wav") as w, tempfile.NamedTemporaryFile(suffix=".mp3") as m:
        torchaudio.save(w.name, waveform.unsqueeze(0), sr)  # mono [1,T]
        subprocess.run(["ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
                        "-i", w.name, "-b:a", bitrate, m.name], check=True)
        wav, new_sr = torchaudio.load(m.name)
        if new_sr != sr:
            wav = torchaudio.functional.resample(wav, new_sr, sr)
        return wav.squeeze(0)  # [T]

def load_wave(path):
    wav, sr = torchaudio.load(str(path))        # [C,T]
    wav = torchaudio.functional.resample(wav, sr, SR)
    wav = torch.mean(wav, dim=0)                # mono [T]
    return wav

def lufs_normalize(wav_t):
    y = wav_t.numpy().astype(np.float32)
    meter = pyln.Meter(SR)
    loud = meter.integrated_loudness(y)
    gain_db = TARGET_LUFS - loud
    gain = 10 ** (gain_db / 20.0)
    y = np.clip(y * gain, -1.0, 1.0)
    return torch.from_numpy(y)

def fix_window(wav_t):
    if wav_t.numel() < N_SAMPLES:
        pad = N_SAMPLES - wav_t.numel()
        wav_t = torch.nn.functional.pad(wav_t, (0, pad))
    else:
        wav_t = wav_t[:N_SAMPLES]
    return wav_t

def to_logmel(wav_t, n_fft=1024, hop=256, n_mels=128, fmin=20, fmax=8000):
    y = wav_t.numpy()
    S = librosa.feature.melspectrogram(
        y=y, sr=SR, n_fft=n_fft, hop_length=hop,
        n_mels=n_mels, fmin=fmin, fmax=fmax, power=2.0)
    S_db = librosa.power_to_db(S, ref=np.max)
    # scale each window to [-1,1]
    S_db = (S_db - S_db.min()) / (S_db.max() - S_db.min() + 1e-9)
    S_db = 2.0 * S_db - 1.0
    return torch.tensor(S_db, dtype=torch.float32)  # [mels, time]

def process_file(path, codec_roundtrip=False):
    wav = load_wave(path)
    if codec_roundtrip:
        wav = to_mp3_bytes(wav)                 # force MP3 parity
    wav = lufs_normalize(wav)
    wav = fix_window(wav)
    mel = to_logmel(wav).unsqueeze(0)           # [1, mels, frames]
    return mel


In [3]:
from torch.utils.data import Dataset, DataLoader
from glob import glob

class MelWindowDataset(Dataset):
    def __init__(self, files, codec_roundtrip=False):
        self.files = files
        self.codec_roundtrip = codec_roundtrip
    def __len__(self): return len(self.files)
    def __getitem__(self, i):
        mel = process_file(self.files[i], codec_roundtrip=self.codec_roundtrip)
        return mel  # [1, 128, T]

# Example:
real_files = sorted(glob("data/REAL_audio/genres_original/**/*.wav", recursive=True))
train_ds = MelWindowDataset(real_files, codec_roundtrip=True)  # GAN train set
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
