In [None]:
import librosa
from librosa._typing import _STFTPad
import numpy as np
import os
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm

def aug_gaussian_noise(audio, snr_db):
    """
    Add Gaussian noise to a waveform at a given SNR (dB).
    
    audio: np.ndarray, waveform
    snr_db: target signal-to-noise ratio in dB
    """
    # Calculate RMS of the original signal
    rms_signal = np.sqrt(np.mean(audio**2))
    
    # Desired RMS of noise
    snr_linear = 10**(snr_db / 20)
    rms_noise = rms_signal / snr_linear
    
    # Generate Gaussian noise
    noise = np.random.normal(0, 1, len(audio))
    rms_noise_current = np.sqrt(np.mean(noise**2))
    noise = noise * (rms_noise / rms_noise_current)
    
    return audio + noise

def aug_loudness_normalize(y: np.ndarray, target_db: float = -20.0) -> np.ndarray:
    """
    Normalize waveform to a target RMS level in dBFS.
    
    y : waveform (numpy array, float32 or float64, range -1..1)
    target_db : desired RMS level in dBFS (e.g. -20 dB)
    
    Returns normalized waveform (clipped to [-1, 1]).
    """
    rms = np.sqrt(np.mean(y**2) + 1e-12)
    current_db = 20 * np.log10(rms + 1e-12)
    gain_db = target_db - current_db
    gain = 10**(gain_db / 20)
    y_norm = y * gain
    return np.clip(y_norm, -1.0, 1.0)



def perform_augmentations(audio, gaussian_snr=None, loudness_target_db=None):
    # To implement
    # - SpecAugment
    # - Mixup
    processed = audio
    if gaussian_snr != None:
        processed = aug_gaussian_noise(audio, gaussian_snr)
    
    if loudness_target_db != None:
        processed = aug_loudness_normalize(audio, loudness_target_db)
       
    return processed
    
    
    
def show_spec(spec, sr, fmax, name):
    import matplotlib.pyplot as plt
    fig, ax = plt.subplots()
    img = librosa.display.specshow(spec, x_axis='time',
                            y_axis='mel', sr=sr,
                            fmax=fmax, ax=ax)
    fig.colorbar(img, ax=ax, format='%+2.0f dB')
    print(name)
    ax.set(title='Mel-frequency spectrogram')
    plt.show()
    
def generate_mel_spectrogram(audio, sr = 22050, n_fft=1024, n_mels=128, hop_length=512, win_length=None, 
                       window:str='hann', center=True, pad_mode:_STFTPad='constant', power=2.0, fmin=200, 
                       fmax=8000, norm='slaney', 
                       show=False, name=None):
    # Generate mel spectrogram
    spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=n_fft, n_mels=n_mels, hop_length=hop_length, win_length=win_length, window=window, center=center, pad_mode=pad_mode, power=power, fmin=fmin, fmax=fmax, norm=norm)
    # Convert to dB
    spec = librosa.power_to_db(spec, ref=np.max)
    
    if show:
        show_spec(spec, sr, fmax, name)
    return spec
    
def process_audio_file(path, sample_rate = 22050, n_fft=1024, n_mels=128, hop_length=512, win_length=None, 
                       window:str='hann', center=True, pad_mode:_STFTPad='constant', power=2.0, fmin=200, 
                       fmax=8000, norm='slaney', 
                       show=False):
    # Load the audio file
    y, sr = librosa.load(path, sr=sample_rate)
    
    # Perform augmentations on the raw waveform
    y = perform_augmentations(y, gaussian_snr=5, loudness_target_db=-20)
    
    spec = generate_mel_spectrogram(y, sr=int(sr), n_fft=n_fft, n_mels=n_mels, hop_length=hop_length, win_length=win_length, 
                                    window=window, center=center, pad_mode=pad_mode, power=power, fmin=fmin, 
                                    fmax=fmax, norm=norm, 
                                    show=show, name="Base")
    return spec
    
def safe_handle_file(path, output_dir, sample_rate = 22050, n_fft=1024, n_mels=128, hop_length=512, win_length=None, 
                       window:str='hann', center=True, pad_mode:_STFTPad='constant', power=2.0, fmin=200, 
                       fmax=8000, norm='slaney'):
    filename = os.path.basename(path)
    try:
        spectrogram = process_audio_file(path, sample_rate=sample_rate, n_fft=n_fft, n_mels=n_mels, hop_length=hop_length, win_length=win_length, 
                                    window=window, center=center, pad_mode=pad_mode, power=power, fmin=fmin, 
                                    fmax=fmax, norm=norm, 
                                    show=False)
        np.save(os.path.join(output_dir, path.split("/")[-2], filename.removesuffix(".ogg") + ".npy"), spectrogram)
        return path, True, None
    except Exception as e:
        return path, False, repr(e)
    
def process_audio_dir(root, output_dir, sample_rate = 22050, n_fft=1024, n_mels=128, hop_length=512, win_length=None, 
                       window:str='hann', center=True, pad_mode:_STFTPad='constant', power=2.0, fmin=200, 
                       fmax=8000, norm='slaney'):
        
    work_pool = []
    
    for path, _, files in os.walk(root):
        if (path == root):
            continue
        os.makedirs(os.path.join(output_dir, os.path.basename(path)), exist_ok=True)
        for name in files:
            work_pool.append(os.path.join(path, name))
            
    successes = 0
    failures = 0
    with ProcessPoolExecutor(max_workers=os.cpu_count()) as ex:
        futures = [ex.submit(safe_handle_file, p, output_dir) for p in work_pool]
        for fut in tqdm(as_completed(futures), total=len(work_pool), desc="Processing", unit="file"):
            path, ok, err = fut.result()
            if ok:
                successes += 1
            else:
                failures += 1
                print(f"[ERROR] {path}, {err}")
    
    print(f"Done. Successes: {successes}, Failed: {failures}")

def process_files(files, output_dir, sample_rate = 22050, n_fft=1024, n_mels=128, hop_length=512, win_length=None, 
                       window:str='hann', center=True, pad_mode:_STFTPad='constant', power=2.0, fmin=200, 
                       fmax=8000, norm='slaney'):
    work_pool = files
    
    # Create the output dirs
    for f in files:
        os.makedirs(os.path.join(output_dir, f.split("/")[-2]), exist_ok=True)
            
    successes = 0
    failures = 0
    with ProcessPoolExecutor(max_workers=os.cpu_count()) as ex:
        futures = [ex.submit(safe_handle_file, p, output_dir, sample_rate=sample_rate, n_fft=n_fft, n_mels=n_mels, hop_length=hop_length, win_length=win_length, 
                                    window=window, center=center, pad_mode=pad_mode, power=power, fmin=fmin, 
                                    fmax=fmax, norm=norm) for p in work_pool]
        for fut in tqdm(as_completed(futures), total=len(work_pool), desc="Processing", unit="file"):
            path, ok, err = fut.result()
            if ok:
                successes += 1
            else:
                failures += 1
                print(f"[ERROR] {path}, {err}")

def process_df(files, output_dir, sample_rate = 22050, n_fft=1024, n_mels=128, hop_length=512, win_length=None, 
                       window:str='hann', center=True, pad_mode:_STFTPad='constant', power=2.0, fmin=200, 
                       fmax=8000, norm='slaney'):
    work_pool = files["filename"]
    
    # Create the output dirs
    for f in files:
        os.makedirs(os.path.join(output_dir, f.split("/")[-2]), exist_ok=True)
            
    successes = 0
    failures = 0
    with ProcessPoolExecutor(max_workers=os.cpu_count()) as ex:
        futures = [ex.submit(safe_handle_file, p, output_dir, sample_rate=sample_rate, n_fft=n_fft, n_mels=n_mels, hop_length=hop_length, win_length=win_length, 
                                    window=window, center=center, pad_mode=pad_mode, power=power, fmin=fmin, 
                                    fmax=fmax, norm=norm) for p in work_pool]
        for fut in tqdm(as_completed(futures), total=len(work_pool), desc="Processing", unit="file"):
            path, ok, err = fut.result()
            if ok:
                successes += 1
            else:
                failures += 1
                print(f"[ERROR] {path}, {err}")
    