In [2]:
import os
import random
import torch
import torchaudio
import numpy as np
import librosa
import soundfile as sf

# Target sample rate
TARGET_SAMPLE_RATE = 16000

# Function to load all MUSAN noise files
def load_musan_noises(musan_dir):
    noise_files = []
    for subdir in ['noise', 'music', 'speech']:
        subpath = os.path.join(musan_dir, subdir)
        if os.path.exists(subpath):
            for file in os.listdir(subpath):
                if file.endswith('.wav'):
                    noise_files.append(os.path.join(subpath, file))
    return noise_files

# Function to add random MUSAN background noise
def add_background_noise(signal, sample_rate, noise_files, snr_db_range=(0, 15)):
    if not noise_files:
        return signal
    
    # Randomly select a noise file
    noise_path = random.choice(noise_files)
    noise, noise_sr = torchaudio.load(noise_path)
    
    # Resample noise to 16000 Hz if needed
    if noise_sr != sample_rate:
        noise = torchaudio.transforms.Resample(noise_sr, sample_rate)(noise)
    
    # Trim or pad noise to match signal length
    sig_len = signal.shape[1]
    noise_len = noise.shape[1]
    if noise_len > sig_len:
        start = random.randint(0, noise_len - sig_len)
        noise = noise[:, start:start + sig_len]
    elif noise_len < sig_len:
        pad_left = random.randint(0, sig_len - noise_len)
        pad_right = sig_len - noise_len - pad_left
        noise = torch.nn.functional.pad(noise, (pad_left, pad_right))
    
    # Compute SNR and mix
    snr_db = random.uniform(*snr_db_range)
    signal_power = torch.mean(signal**2)
    noise_power = torch.mean(noise**2)
    factor = torch.sqrt(signal_power / (noise_power * 10**(snr_db / 10)))
    noise = noise * factor
    augmented = signal + noise
    return augmented / torch.max(torch.abs(augmented))  # Normalize to avoid clipping

# Function to apply time stretch using librosa
def time_stretch(signal, sample_rate, rate):
    # Convert torch tensor to numpy for librosa
    signal_np = signal.numpy().flatten()
    # Apply time stretch
    stretched = librosa.effects.time_stretch(signal_np, rate=rate)
    # Convert back to torch tensor
    stretched = torch.tensor(stretched, dtype=torch.float32).unsqueeze(0)
    return stretched

# Function to apply gain manually (replacing apply_gain)
def apply_manual_gain(signal, gain_db):
    # Convert dB to linear gain
    gain_linear = 10 ** (gain_db / 20)
    return signal * gain_linear

# Function to apply effective augmentations
def augment_audio(audio_path, musan_noises, output_dir):
    # Load audio (assumed to be 16000 Hz)
    signal, orig_sr = torchaudio.load(audio_path)
    if orig_sr != TARGET_SAMPLE_RATE:
        print(f"Warning: {audio_path} has sample rate {orig_sr}. Expected 16000 Hz. Resampling.")
        signal = torchaudio.transforms.Resample(orig_sr, TARGET_SAMPLE_RATE)(signal)
    
    # Add MUSAN background noise (50% chance)
    if random.random() < 0.5:
        signal = add_background_noise(signal, TARGET_SAMPLE_RATE, musan_noises)
    
    # Pitch shift (-4 to +4 semitones, 50% chance)
    if random.random() < 0.5:
        pitch_shift = random.uniform(-4, 4)
        signal = torchaudio.functional.pitch_shift(signal, TARGET_SAMPLE_RATE, n_steps=pitch_shift)
    
    # Time stretch (0.8–1.2 speed, 50% chance)
    if random.random() < 0.5:
        rate = random.uniform(0.8, 1.2)
        signal = time_stretch(signal, TARGET_SAMPLE_RATE, rate)
    
    # Gain (-6 to +6 dB, 50% chance)
    if random.random() < 0.5:
        gain_db = random.uniform(-6, 6)
        signal = apply_manual_gain(signal, gain_db)
    
    # Gaussian noise (low amplitude, 30% chance)
    if random.random() < 0.3:
        noise = torch.normal(0, random.uniform(0.001, 0.015), signal.shape)
        signal = signal + noise
        signal = signal / torch.max(torch.abs(signal))  # Normalize
    
    # Save augmented file
    base_name = os.path.basename(audio_path)
    output_path = os.path.join(output_dir, f"aug_{base_name}")
    # Convert to numpy for soundfile
    signal_np = signal.numpy().flatten()
    sf.write(output_path, signal_np, TARGET_SAMPLE_RATE)
    return output_path

# Main script
def main(input_dir, musan_dir, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    musan_noises = load_musan_noises(musan_dir)
    if not musan_noises:
        print("Warning: No MUSAN noise files found. Skipping background noise augmentation.")
    
    wav_files = [f for f in os.listdir(input_dir) if f.endswith('.wav')]
    for wav_file in wav_files:
        audio_path = os.path.join(input_dir, wav_file)
        augmented_path = augment_audio(audio_path, musan_noises, output_dir)
        print(f"Augmented and saved: {augmented_path}")


# Example usage (replace with your paths)
if __name__ == "__main__":
    MUSAN_DIR = "/Users/sethwright/Downloads/musan"    
    input_dir = "/Users/sethwright/Documents/audio-model/data/negative_1"  # Folder with your WAV files

    output_dir = "/Users/sethwright/Documents/audio-model/data/augmentedneg_data"   # Where to save augmented files
    main(input_dir, MUSAN_DIR, output_dir)

Augmented and saved: /Users/sethwright/Documents/audio-model/data/augmentedneg_data/aug_test_recording1383.wav
Augmented and saved: /Users/sethwright/Documents/audio-model/data/augmentedneg_data/aug_test_recording1397.wav
Augmented and saved: /Users/sethwright/Documents/audio-model/data/augmentedneg_data/aug_test_recording1340.wav
Augmented and saved: /Users/sethwright/Documents/audio-model/data/augmentedneg_data/aug_test_recording1426.wav
Augmented and saved: /Users/sethwright/Documents/audio-model/data/augmentedneg_data/aug_test_recording1432.wav
Augmented and saved: /Users/sethwright/Documents/audio-model/data/augmentedneg_data/aug_test_recording1354.wav
Augmented and saved: /Users/sethwright/Documents/audio-model/data/augmentedneg_data/aug_test_recording1368.wav
Augmented and saved: /Users/sethwright/Documents/audio-model/data/augmentedneg_data/aug_test_recording1181.wav
Augmented and saved: /Users/sethwright/Documents/audio-model/data/augmentedneg_data/aug_test_recording1195.wav
A