In [24]:
import os
import glob
import random
import numpy as np
import soundfile as sf
import torchaudio
from tqdm import tqdm

DATA_DIR = "50_speakers_audio_data"
OUTPUT_DIR = "mix_dataset"
SAMPLE_RATE = 8000
MIX_DURATION = 4.0
SNR_RANGE = (-5, 5)  # dB


NUM_MIX = {
    "train": 5000,
    "valid": 500,
    "test": 1000,
}

speakers = sorted([d for d in os.listdir(DATA_DIR) if d.startswith("Speaker")])
random.shuffle(speakers)
SPEAKER_SPLIT = {
    "train": [speakers[i] for i in range(0, 40)],
    "valid": [speakers[i] for i in range(40, 45)],
    "test":  [speakers[i] for i in range(45, 50)],
}

for split in ["train", "valid", "test"]:
    for sub in ["mix", "s1", "s2"]:
        os.makedirs(os.path.join(OUTPUT_DIR, split, sub), exist_ok=True)


def normalize_audio(x):
    return x / np.sqrt(np.mean(x**2) + 1e-9)


def resample_audio(path, sample_rate):
    waveform, orig_sr = torchaudio.load(path)
    if orig_sr != sample_rate:
        resampler = torchaudio.transforms.Resample(orig_sr, sample_rate)
        waveform = resampler(waveform)
    return waveform.squeeze(0).numpy()


def create_mix(s1, s2, snr):
    s1 = normalize_audio(s1)
    s2 = normalize_audio(s2)
    scale = 10 ** (-snr / 20)
    s2 = s2 * scale
    mix = s1 + s2
    return mix, s1, s2


def get_all_utts(speakers):
    all_utts = []
    for spk in speakers:
        wavs = glob.glob(os.path.join(DATA_DIR, spk, "*.wav"))
        all_utts.extend([(spk, utt) for utt in wavs])
    return all_utts


def generate_mixes(split):
    speakers = SPEAKER_SPLIT[split]
    utts_by_spk = {spk: glob.glob(os.path.join(DATA_DIR, spk, "*.wav")) for spk in speakers}
    utt_pairs = []
    for _ in range(NUM_MIX[split]):
        spk1, spk2 = random.sample(speakers, 2)
        utt1 = random.choice(utts_by_spk[spk1])
        utt2 = random.choice(utts_by_spk[spk2])
        utt_pairs.append((utt1, utt2))

    for idx, (utt1_path, utt2_path) in enumerate(tqdm(utt_pairs, desc=f"Generating {split} mixtures")):
        try:
            s1 = resample_audio(utt1_path, SAMPLE_RATE)
            s2 = resample_audio(utt2_path, SAMPLE_RATE)
        except Exception as e:
            print(f"Error reading {utt1_path} or {utt2_path}: {e}")
            continue

        num_samples = int(MIX_DURATION * SAMPLE_RATE)
        if len(s1) < num_samples or len(s2) < num_samples:
            continue

        s1_start = random.randint(0, len(s1) - num_samples)
        s2_start = random.randint(0, len(s2) - num_samples)
        s1_crop = s1[s1_start:s1_start + num_samples]
        s2_crop = s2[s2_start:s2_start + num_samples]

        snr = random.uniform(*SNR_RANGE)
        mix, src1, src2 = create_mix(s1_crop, s2_crop, snr)

        base_path = os.path.join(OUTPUT_DIR, split)
        sf.write(f"{base_path}/mix/mix_{idx:04d}.wav", mix, SAMPLE_RATE)
        sf.write(f"{base_path}/s1/s1_{idx:04d}.wav", src1, SAMPLE_RATE)
        sf.write(f"{base_path}/s2/s2_{idx:04d}.wav", src2, SAMPLE_RATE)

In [None]:
if __name__ == "__main__":
    for split in ["train", "valid", "test"]:
        generate_mixes(split)

Generating train mixtures: 100%|██████████| 5000/5000 [05:11<00:00, 16.05it/s]
Generating valid mixtures: 100%|██████████| 500/500 [00:31<00:00, 15.96it/s]
Generating test mixtures: 100%|██████████| 1000/1000 [01:02<00:00, 15.90it/s]
