# üß™ Spectral Affinity: Audio Restoration (The Boutique Lab)

This notebook implements a professional-grade restoration pipeline specifically tuned for AI-generated audio (e.g., Suno, Udio). 
It combines **High-Performance Parallel Matching** (CPU) with **Boutique Mastering** (GPU).

---

In [None]:
!pip install -q pedalboard matchering numpy scipy torchaudio tqdm joblib

In [None]:
import os
import glob
import time
import numpy as np
import scipy.signal as signal
import torch
import torchaudio
import matchering as mg
from tqdm.auto import tqdm
from joblib import Parallel, delayed
from pedalboard import Pedalboard, Compressor, Distortion, Gain, HighpassFilter, LowpassFilter, HighShelfFilter, Limiter
from pedalboard.io import AudioFile
from IPython.display import FileLink

# --- SETTINGS ---
INPUT_DIR = "/kaggle/input/datasets/danieldobles/ost-songs-a"
REF_FILE = "/kaggle/input/datasets/danieldobles/ost-songs-a/REF.flac"
OUTPUT_DIR = "/kaggle/working/mastered_tracks"
TEMP_MATCH_DIR = "/kaggle/working/temp_matched"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"üöÄ Optimization Mode: {'GPU (CUDA)' if DEVICE == 'cuda' else 'CPU'} for DSP.")
# ----------------

def spectral_deharsh_gpu(audio_tensor, sample_rate, threshold_ratio=1.4):
    """GPU-Accelerated De-Harshing using PyTorch STFT"""
    window = torch.hann_window(2048).to(audio_tensor.device)
    Zxx = torch.stft(audio_tensor, n_fft=2048, return_complex=True, window=window, center=True)
    mag = torch.abs(Zxx)
    
    # Approximate envelope with AvgPool2d (much faster on GPU than median filter)
    # Kernel (31, 1) to smooth frequency axis
    envelope = torch.nn.functional.avg_pool2d(mag.unsqueeze(0), kernel_size=(31, 1), stride=1, padding=(15, 0)).squeeze(0)
    
    mask = mag > (envelope * threshold_ratio)
    
    # Reduction factor
    reduction = torch.clamp(envelope / (mag + 1e-6), 0.5, 1.0)
    
    # Apply reduction only where mask is True
    gain_map = torch.ones_like(mag)
    gain_map[mask] = reduction[mask]
    
    Zxx_clean = Zxx * gain_map
    audio_clean = torch.istft(Zxx_clean, n_fft=2048, window=window, length=audio_tensor.shape[-1], center=True)
    return audio_clean

# --- DSP Helpers (NumPy/Pedalboard for M/S and Saturation) --- 
def mono_maker(audio_side, sample_rate, cutoff_hz=120):
    sos = signal.butter(4, cutoff_hz, 'hp', fs=sample_rate, output='sos')
    return signal.sosfilt(sos, audio_side)

def ms_encode(audio_lr):
    mid = (audio_lr[0] + audio_lr[1]) * 0.5
    side = (audio_lr[0] - audio_lr[1]) * 0.5
    return mid, side

def ms_decode(mid, side):
    left = mid + side
    right = mid - side
    return np.stack([left, right])

def transient_shaper_mid(mid_signal, sample_rate, punch=1.4):
    abs_sig = np.abs(mid_signal)
    sos_fast = signal.butter(1, 40, 'low', fs=sample_rate, output='sos')
    sos_slow = signal.butter(1, 5, 'low', fs=sample_rate, output='sos')
    env_fast = signal.sosfiltfilt(sos_fast, abs_sig)
    env_slow = signal.sosfiltfilt(sos_slow, abs_sig)
    transient_ratio = env_fast / (env_slow + 1e-8)
    gain_curve = np.where(transient_ratio > 1.05, transient_ratio ** (punch - 1.0), 1.0)
    return mid_signal * np.clip(gain_curve, 1.0, 2.0)

def saturate_side(side_signal, sample_rate, drive=4.0):
    side_expanded = side_signal[None, :]
    board = Pedalboard([HighpassFilter(300), Distortion(drive_db=drive), Gain(-1)])
    return board(side_expanded, sample_rate).squeeze()

def boutique_master_hybrid(audio_path, output_path, device_str):
    # 1. Load Audio (CPU -> GPU if needed)
    waveform, sr = torchaudio.load(audio_path)
    
    # 2. De-Harshing on GPU (Heavy Lifting)
    if device_str == 'cuda':
        waveform = waveform.to(device_str)
        # Processes channels independently on GPU
        clean_l = spectral_deharsh_gpu(waveform[0], sr)
        clean_r = spectral_deharsh_gpu(waveform[1], sr)
        waveform_clean = torch.stack([clean_l, clean_r]).cpu().numpy()
    else:
        # Fallback CPU implementation (slower)
        waveform_clean = waveform.numpy()
    
    # 3. M/S and Saturation on CPU (Pedalboard/Scipy are fast enough)
    mid, side = ms_encode(waveform_clean)
    side = mono_maker(side, sr, 120)
    mid = transient_shaper_mid(mid, sr, 1.4)
    side = saturate_side(side, sr, 4.0)
    stereo = ms_decode(mid, side)
    
    # 4. Final Limiter
    final = Pedalboard([Limiter(threshold_db=-1.0)])(stereo, sr)
    
    # 5. Save
    with AudioFile(output_path, 'w', sr, final.shape[0]) as f:
        f.write(final)

def process_matchering_task(path, ref_file, temp_dir):
    fname = os.path.basename(path)
    temp_out = os.path.join(temp_dir, f"matched_{fname}.wav")
    start_time = time.time()
    
    try:
        print(f"  ‚ñ∂Ô∏è Processing: {fname}", flush=True)
        # The heavy CPU task
        mg.process(target=path, reference=ref_file, results=[mg.pcm24(temp_out)])
        elapsed = time.time() - start_time
        print(f"  ‚úÖ Done: {fname} ({elapsed:.2f}s)", flush=True)
        return temp_out
    except Exception as e:
        elapsed = time.time() - start_time
        print(f"  ‚ùå Failed: {fname} ({elapsed:.2f}s) - {e}", flush=True)
        return None

print("‚úÖ Hybrid Engine Ready: GPU De-Harshing + Parallel Matchering.")

### üöÄ Phase 1: High-Speed Parallel Reference Matching (CPU Bound)

In [None]:
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(TEMP_MATCH_DIR, exist_ok=True)

file_paths = glob.glob(os.path.join(INPUT_DIR, "*.mp3")) + glob.glob(os.path.join(INPUT_DIR, "*.wav"))
print(f"üî• Starting PHASE 1: Parallel Matching for {len(file_paths)} tracks...")

# Run Matchering in parallel to maximize CPU Cores
matched_files = Parallel(n_jobs=-1, backend="threading")(
    delayed(process_matchering_task)(p, REF_FILE, TEMP_MATCH_DIR) 
    for p in file_paths
)

matched_files = [f for f in matched_files if f is not None]
print(f"‚úÖ Phase 1 Complete. {len(matched_files)} tracks ready for DSP.")

### ‚ö° Phase 2: Boutique DSP & De-Harshing (GPU Accelerated)

In [None]:
print(f"‚ö° Starting PHASE 2: GPU Restoration on {DEVICE.upper()}...")

for i, temp_path in enumerate(matched_files):
    fname = os.path.basename(temp_path).replace("matched_", "")
    final_path = os.path.join(OUTPUT_DIR, f"Mastered_{fname}")
    
    start_time = time.time()
    try:
        print(f"[{i+1}/{len(matched_files)}] üéõÔ∏è Mastering: {fname}...", end=" ", flush=True)
        boutique_master_hybrid(temp_path, final_path, DEVICE)
        elapsed = time.time() - start_time
        print(f"Done ({elapsed:.2f}s)")
    except Exception as e:
        print(f"‚ùå DSP Failed: {e}")

print("‚ú® All tracks restored successfully.")

### üì¶ Final Download

In [None]:
# Improved Download Cell with path checks
os.chdir('/kaggle/working')
print("üì¶ Zipping files...")
!zip -r -q restoration_results.zip mastered_tracks

if os.path.exists('restoration_results.zip'):
    size_mb = os.path.getsize('restoration_results.zip') / 1e6
    print(f"‚úÖ ZIP created! Size: {size_mb:.2f} MB")
    display(FileLink('restoration_results.zip'))
else:
    print("‚ùå Error: Zip file was not created. Checking output directory...")
    if os.path.exists('mastered_tracks'):
        print(f"Files in 'mastered_tracks': {len(os.listdir('mastered_tracks'))}")
    else:
        print("'mastered_tracks' directory does not exist.")