# DDSP Timbre Grower - VECTORIZED COMPLETE 🚀

**THE FIX**: Vectorized HarmonicOscillator and FilteredNoiseGenerator

**Root cause**: Python for-loops launching hundreds of sequential GPU kernels

**Solution**: Vectorize all loops for parallel GPU execution

**Expected**: ~0.5-1s per epoch, 8-17 minutes for 1000 epochs per file

**Output**: Growing stages for EACH scale tone file

In [None]:
# 1. Setup
!pip install torch librosa soundfile matplotlib scipy tqdm -q

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import librosa
import soundfile as sf
import matplotlib.pyplot as plt
from tqdm import tqdm
from IPython.display import Audio, display
import glob
from pathlib import Path
import time
import os

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# 2. Upload scale tone files
from google.colab import files
uploaded = files.upload()

os.makedirs('scale_tones', exist_ok=True)
for filename in uploaded.keys():
    os.rename(filename, f'scale_tones/{filename}')

audio_files = sorted(glob.glob('scale_tones/*.wav'))
print(f"\n✅ Uploaded {len(audio_files)} scale tone files")
for f in audio_files:
    print(f"   - {Path(f).name}")

In [None]:
# 3. VECTORIZED DDSP Components (THE FIX!)

class HarmonicOscillatorVectorized(nn.Module):
    """VECTORIZED harmonic oscillator - NO for-loops!"""

    def __init__(self, sample_rate=22050, n_harmonics=64):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_harmonics = n_harmonics
        
        # Pre-compute harmonic numbers [1, 2, 3, ..., n_harmonics]
        self.register_buffer(
            'harmonic_numbers',
            torch.arange(1, n_harmonics + 1, dtype=torch.float32)
        )

    def forward(self, f0_hz, harmonic_amplitudes):
        batch_size, n_frames = f0_hz.shape
        hop_length = 512
        n_samples = n_frames * hop_length

        # Upsample f0 and amplitudes
        f0_upsampled = F.interpolate(
            f0_hz.unsqueeze(1), size=n_samples, mode='linear', align_corners=True
        ).squeeze(1)  # [batch, samples]

        harmonic_amplitudes_upsampled = F.interpolate(
            harmonic_amplitudes.transpose(1, 2), size=n_samples,
            mode='linear', align_corners=True
        ).transpose(1, 2)  # [batch, samples, n_harmonics]

        # Compute base phase
        phase = 2 * torch.pi * torch.cumsum(f0_upsampled / self.sample_rate, dim=1)  # [batch, samples]

        # VECTORIZED: Generate ALL harmonics at once!
        # Broadcast phase to [batch, samples, n_harmonics]
        phase_broadcast = phase.unsqueeze(-1)  # [batch, samples, 1]
        harmonic_phases = phase_broadcast * self.harmonic_numbers  # [batch, samples, n_harmonics]
        
        # Single sin() call for ALL harmonics
        harmonic_signals = torch.sin(harmonic_phases)  # [batch, samples, n_harmonics]
        
        # Apply amplitudes and sum across harmonics
        weighted_harmonics = harmonic_signals * harmonic_amplitudes_upsampled
        audio = weighted_harmonics.sum(dim=-1)  # [batch, samples]

        return audio


class FilteredNoiseGeneratorVectorized(nn.Module):
    """VECTORIZED noise generator - NO for-loops!"""

    def __init__(self, sample_rate=22050, n_filter_banks=64):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_filter_banks = n_filter_banks

        self.register_buffer(
            'filter_freqs',
            torch.logspace(
                torch.log10(torch.tensor(20.0)),
                torch.log10(torch.tensor(sample_rate / 2.0)),
                n_filter_banks
            )
        )

    def forward(self, filter_magnitudes):
        batch_size, n_frames, _ = filter_magnitudes.shape
        hop_length = 512
        n_samples = n_frames * hop_length

        # Generate white noise
        noise = torch.randn(batch_size, n_samples, device=filter_magnitudes.device)
        noise_fft = torch.fft.rfft(noise, dim=1)
        freqs = torch.fft.rfftfreq(n_samples, 1/self.sample_rate).to(filter_magnitudes.device)

        # Upsample filter magnitudes
        filter_magnitudes_upsampled = F.interpolate(
            filter_magnitudes.transpose(1, 2), size=n_samples,
            mode='linear', align_corners=True
        ).transpose(1, 2)  # [batch, samples, n_filter_banks]

        # VECTORIZED: Compute filter response for ALL frequencies at once
        log_freqs = torch.log(freqs + 1e-7).unsqueeze(-1)  # [n_freqs, 1]
        log_filter_freqs = torch.log(self.filter_freqs + 1e-7).unsqueeze(0)  # [1, n_filter_banks]
        
        # Compute distances between all freq pairs
        distances = torch.abs(log_freqs - log_filter_freqs)  # [n_freqs, n_filter_banks]
        
        # Compute weights for all frequencies
        weights = torch.exp(-distances**2 / 0.5)  # [n_freqs, n_filter_banks]
        weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-7)
        
        # Average over time first (matching original behavior)
        filter_mags_time_avg = filter_magnitudes_upsampled.mean(dim=1)  # [batch, n_filter_banks]
        # Then apply weights for all frequencies at once
        filter_response = torch.einsum('fk,bk->bf', weights, filter_mags_time_avg)

        # Apply filter
        filtered_fft = noise_fft * filter_response
        filtered_noise = torch.fft.irfft(filtered_fft, n=n_samples, dim=1)

        return filtered_noise


class DDSPSynthesizer(nn.Module):
    """Neural network that maps features to synthesis parameters."""

    def __init__(self, n_harmonics=64, n_filter_banks=64, hidden_size=512, n_mfcc=30):
        super().__init__()
        input_size = 1 + 1 + n_mfcc
        self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True, dropout=0.1)
        
        self.harmonic_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_harmonics),
            nn.Softplus()
        )
        
        self.noise_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_filter_banks),
            nn.Sigmoid()
        )

    def forward(self, f0, loudness, mfcc):
        x = torch.cat([f0, loudness, mfcc], dim=-1)
        x, _ = self.gru(x)
        harmonic_amplitudes = self.harmonic_head(x)
        filter_magnitudes = self.noise_head(x)
        loudness_scale = torch.exp(loudness / 20.0)
        harmonic_amplitudes = harmonic_amplitudes * loudness_scale
        return harmonic_amplitudes, filter_magnitudes


class DDSPModelVectorized(nn.Module):
    """Complete DDSP model with VECTORIZED components."""

    def __init__(self, sample_rate=22050, n_harmonics=64, n_filter_banks=64, hidden_size=512):
        super().__init__()
        self.sample_rate = sample_rate
        self.synthesizer = DDSPSynthesizer(n_harmonics, n_filter_banks, hidden_size)
        self.harmonic_osc = HarmonicOscillatorVectorized(sample_rate, n_harmonics)  # VECTORIZED!
        self.noise_gen = FilteredNoiseGeneratorVectorized(sample_rate, n_filter_banks)  # VECTORIZED!
        self.register_parameter('harmonic_noise_ratio', nn.Parameter(torch.tensor(0.8)))

    def forward(self, f0, loudness, mfcc):
        harmonic_amplitudes, filter_magnitudes = self.synthesizer(f0, loudness, mfcc)
        f0_hz = f0.squeeze(-1)
        harmonic_audio = self.harmonic_osc(f0_hz, harmonic_amplitudes)
        noise_audio = self.noise_gen(filter_magnitudes)
        ratio = torch.sigmoid(self.harmonic_noise_ratio)
        audio = ratio * harmonic_audio + (1 - ratio) * noise_audio
        return audio, harmonic_audio, noise_audio


class MultiScaleSpectralLoss(nn.Module):
    """Legacy spectral loss (kept for compatibility)."""
    def __init__(self, fft_sizes=[2048, 1024, 512, 256]):
        super().__init__()
        self.fft_sizes = fft_sizes

    def forward(self, pred_audio, target_audio):
        total_loss = 0.0
        for fft_size in self.fft_sizes:
            pred_stft = torch.stft(
                pred_audio, n_fft=fft_size, hop_length=fft_size // 4,
                window=torch.hann_window(fft_size, device=pred_audio.device),
                return_complex=True
            )
            target_stft = torch.stft(
                target_audio, n_fft=fft_size, hop_length=fft_size // 4,
                window=torch.hann_window(fft_size, device=target_audio.device),
                return_complex=True
            )
            pred_log_mag = torch.log(torch.abs(pred_stft) + 1e-5)
            target_log_mag = torch.log(torch.abs(target_stft) + 1e-5)
            total_loss += F.l1_loss(pred_log_mag, target_log_mag)
        return total_loss / len(self.fft_sizes)


class MultiResolutionSTFTLoss(nn.Module):
    """Multi-resolution STFT loss for perceptual audio quality (PRIORITY 1 IMPROVEMENT)."""
    
    def __init__(self, fft_sizes=[2048, 1024, 512, 256, 128], 
                 hop_sizes=None, win_sizes=None):
        super().__init__()
        if hop_sizes is None:
            hop_sizes = [s // 4 for s in fft_sizes]
        if win_sizes is None:
            win_sizes = fft_sizes
            
        self.fft_sizes = fft_sizes
        self.hop_sizes = hop_sizes
        self.win_sizes = win_sizes
        
    def stft(self, x, fft_size, hop_size, win_size):
        """Compute STFT with given parameters."""
        window = torch.hann_window(win_size).to(x.device)
        return torch.stft(
            x, n_fft=fft_size, hop_length=hop_size, 
            win_length=win_size, window=window,
            return_complex=True
        )
    
    def forward(self, pred_audio, target_audio):
        spectral_convergence = 0.0
        magnitude_loss = 0.0
        
        for fft_size, hop_size, win_size in zip(
            self.fft_sizes, self.hop_sizes, self.win_sizes
        ):
            pred_stft = self.stft(pred_audio, fft_size, hop_size, win_size)
            target_stft = self.stft(target_audio, fft_size, hop_size, win_size)
            
            # Spectral convergence (measures overall spectral shape similarity)
            spectral_convergence += torch.norm(
                target_stft - pred_stft, p='fro'
            ) / (torch.norm(target_stft, p='fro') + 1e-8)
            
            # Log magnitude loss (perceptually weighted)
            pred_mag = torch.abs(pred_stft)
            target_mag = torch.abs(target_stft)
            magnitude_loss += F.l1_loss(
                torch.log(pred_mag + 1e-5), 
                torch.log(target_mag + 1e-5)
            )
        
        # Average across all resolutions
        spectral_convergence = spectral_convergence / len(self.fft_sizes)
        magnitude_loss = magnitude_loss / len(self.fft_sizes)
        
        return spectral_convergence + magnitude_loss

print("✅ VECTORIZED model defined with improved perceptual loss")

In [None]:
# 4. Feature extraction

def extract_features(audio_path, sample_rate=22050):
    audio, sr = librosa.load(audio_path, sr=sample_rate, mono=True)
    hop_length = int(sample_rate / 43.066)

    f0_yin = librosa.yin(audio, fmin=librosa.note_to_hz('C2'), fmax=librosa.note_to_hz('C7'), sr=sr, hop_length=hop_length)
    f0_yin = np.nan_to_num(f0_yin, nan=0.0)
    f0_yin = np.maximum(f0_yin, 0.0)

    loudness = librosa.feature.rms(y=audio, frame_length=2048, hop_length=hop_length)[0]
    loudness_db = librosa.amplitude_to_db(loudness, ref=1.0)

    mfcc = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=30, hop_length=hop_length).T
    min_len = min(len(f0_yin), len(loudness_db), len(mfcc))

    return {
        'f0': f0_yin[:min_len],
        'loudness': loudness_db[:min_len],
        'mfcc': mfcc[:min_len],
        'audio': audio,
        'n_frames': min_len
    }

print("✅ Feature extraction ready")

In [None]:
# 5. Growing stage generation functions

class ADSREnvelope:
    def __init__(self, sr=22050):
        self.sr = sr

    def generate(self, duration, attack=0.05, decay=0.1, sustain_level=0.7, release=0.2):
        n_samples = int(duration * self.sr)
        attack_samples = int(attack * self.sr)
        decay_samples = int(decay * self.sr)
        release_samples = int(release * self.sr)
        sustain_samples = n_samples - attack_samples - decay_samples - release_samples
        sustain_samples = max(0, sustain_samples)

        envelope = []
        if attack_samples > 0:
            envelope.extend(np.linspace(0, 1, attack_samples))
        if decay_samples > 0:
            envelope.extend(np.linspace(1, sustain_level, decay_samples))
        if sustain_samples > 0:
            envelope.extend(np.ones(sustain_samples) * sustain_level)
        if release_samples > 0:
            envelope.extend(np.linspace(sustain_level, 0, release_samples))

        envelope = np.array(envelope)
        if len(envelope) < n_samples:
            envelope = np.pad(envelope, (0, n_samples - len(envelope)))
        elif len(envelope) > n_samples:
            envelope = envelope[:n_samples]
        return envelope


def synthesize_stage(model, features, active_harmonics, duration=0.5, device='cpu'):
    """Synthesize one stage with progressive harmonic activation."""
    hop_length = 512
    n_frames = int((duration * 22050) / hop_length)

    f0_mean = features['f0'][features['f0'] > 0].mean()
    if np.isnan(f0_mean):
        f0_mean = 220.0

    f0_tensor = torch.ones(1, n_frames, 1, device=device) * f0_mean
    loudness_tensor = torch.ones(1, n_frames, 1, device=device) * features['loudness'].mean()
    mfcc_mean = torch.tensor(features['mfcc'].mean(axis=0), dtype=torch.float32, device=device)
    mfcc_tensor = mfcc_mean.unsqueeze(0).unsqueeze(0).expand(1, n_frames, -1)

    with torch.no_grad():
        harmonic_amplitudes, filter_magnitudes = model.synthesizer(f0_tensor, loudness_tensor, mfcc_tensor)

        # Apply harmonic mask
        harmonic_mask = torch.tensor(active_harmonics, dtype=torch.float32, device=device)
        harmonic_mask = harmonic_mask.unsqueeze(0).unsqueeze(0)
        masked_harmonics = harmonic_amplitudes * harmonic_mask

        # Generate audio
        f0_hz = f0_tensor.squeeze(-1)
        harmonic_audio = model.harmonic_osc(f0_hz, masked_harmonics)
        noise_audio = model.noise_gen(filter_magnitudes)

        ratio = torch.sigmoid(model.harmonic_noise_ratio)
        pred_audio = ratio * harmonic_audio + (1 - ratio) * noise_audio

    audio = pred_audio.squeeze().cpu().numpy()
    target_samples = int(duration * 22050)
    if len(audio) < target_samples:
        audio = np.pad(audio, (0, target_samples - len(audio)))
    else:
        audio = audio[:target_samples]

    # Apply ADSR envelope
    envelope_gen = ADSREnvelope(sr=22050)
    envelope = envelope_gen.generate(duration)
    audio = audio * envelope

    # Normalize
    if np.abs(audio).max() > 0:
        audio = audio / np.abs(audio).max() * 0.8

    return audio


def generate_stages(strategy='linear', n_harmonics=64):
    """Generate stage masks."""
    stages = []
    if strategy == 'linear':
        for n in range(1, n_harmonics + 1):
            stage = np.zeros(n_harmonics)
            stage[:n] = 1.0
            stages.append(stage)
    return stages

print("✅ Growing stage functions ready")

In [None]:
# 6. Train and generate growing stages for EACH scale tone

N_EPOCHS = 1000
STAGE_DURATION = 0.5
SILENCE_DURATION = 0.2

os.makedirs('outputs', exist_ok=True)

print(f"\n🎯 Processing {len(audio_files)} scale tone files")
print(f"   Training: {N_EPOCHS} epochs per file")
print(f"   Expected: 10-15 minutes per file")
print(f"   🆕 USING IMPROVED PERCEPTUAL LOSS\n")

total_start = time.time()

for file_idx, audio_file in enumerate(audio_files):
    print(f"\n{'='*70}")
    print(f"File {file_idx + 1}/{len(audio_files)}: {Path(audio_file).name}")
    print(f"{'='*70}\n")
    
    # Extract features
    print("📊 Extracting features...")
    features = extract_features(audio_file)
    print(f"   F0: {features['f0'][features['f0']>0].mean():.1f} Hz")
    print(f"   Frames: {features['n_frames']}\n")
    
    f0 = torch.tensor(features['f0'], dtype=torch.float32).unsqueeze(0).unsqueeze(-1).to(device)
    loudness = torch.tensor(features['loudness'], dtype=torch.float32).unsqueeze(0).unsqueeze(-1).to(device)
    mfcc = torch.tensor(features['mfcc'], dtype=torch.float32).unsqueeze(0).to(device)
    target_audio = torch.tensor(features['audio'], dtype=torch.float32).unsqueeze(0).to(device)
    
    # Create model
    print("🏗️  Training DDSP model...")
    model = DDSPModelVectorized().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    # IMPROVED: Use MultiResolutionSTFTLoss instead of MultiScaleSpectralLoss
    loss_fn = MultiResolutionSTFTLoss().to(device)
    
    # Training
    losses = []
    best_loss = float('inf')
    file_start = time.time()
    
    for epoch in tqdm(range(N_EPOCHS), desc="Training"):
        model.train()
        optimizer.zero_grad()
        
        pred_audio, _, _ = model(f0, loudness, mfcc)
        
        min_len = min(pred_audio.shape[1], target_audio.shape[1])
        pred_audio_trim = pred_audio[:, :min_len]
        target_audio_trim = target_audio[:, :min_len]
        
        # IMPROVED: Use perceptual STFT loss + increased time-domain weight
        stft_loss = loss_fn(pred_audio_trim, target_audio_trim)
        time_loss = F.l1_loss(pred_audio_trim, target_audio_trim)
        
        # NEW WEIGHTS: Prioritize audio quality over spectrograms
        # stft_loss captures spectral convergence + log magnitude (perceptual)
        # time_loss weight increased from 0.1 to 0.5 for better waveform matching
        total_loss = 1.0 * stft_loss + 0.5 * time_loss
        
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        loss_value = total_loss.item()
        losses.append(loss_value)
        
        if loss_value < best_loss:
            best_loss = loss_value
        
        if epoch == 0:
            print(f"\n   First epoch: {time.time() - file_start:.1f}s (Loss: {loss_value:.6f})\n")
    
    train_time = time.time() - file_start
    print(f"\n✅ Training complete! Time: {train_time/60:.1f}min, Loss: {best_loss:.6f}\n")
    
    # Test reconstruction
    model.eval()
    with torch.no_grad():
        pred_audio, _, _ = model(f0, loudness, mfcc)
    pred_audio_np = pred_audio.squeeze().cpu().numpy()
    
    print("🎵 Reconstructed audio:")
    display(Audio(pred_audio_np, rate=22050))
    
    # Generate growing stages
    print(f"\n🌱 Generating growing stages...")
    stages = generate_stages('linear', 64)
    audio_segments = []
    silence = np.zeros(int(SILENCE_DURATION * 22050))
    
    for i, stage in enumerate(tqdm(stages, desc="Stages")):
        stage_audio = synthesize_stage(model, features, stage, STAGE_DURATION, device)
        audio_segments.append(stage_audio)
        if i < len(stages) - 1:
            audio_segments.append(silence)
    
    full_audio = np.concatenate(audio_segments)
    
    print(f"\n✅ Generated {len(stages)} growing stages")
    print(f"   Duration: {len(full_audio)/22050:.2f}s\n")
    
    print("🎧 Growing stages:")
    display(Audio(full_audio, rate=22050))
    
    # Save files
    file_stem = Path(audio_file).stem
    sf.write(f'outputs/{file_stem}_reconstructed.wav', pred_audio_np, 22050)
    sf.write(f'outputs/{file_stem}_grown.wav', full_audio, 22050)
    
    torch.save({
        'model_state_dict': model.state_dict(),
        'features': features,
        'losses': losses,
        'best_loss': best_loss,
    }, f'outputs/{file_stem}_model.pt')
    
    print(f"💾 Saved:")
    print(f"   - {file_stem}_reconstructed.wav")
    print(f"   - {file_stem}_grown.wav")
    print(f"   - {file_stem}_model.pt")

total_time = time.time() - total_start

print(f"\n\n{'='*70}")
print(f"🎉 ALL SCALE TONES COMPLETE!")
print(f"{'='*70}")
print(f"Total time: {total_time/60:.1f} minutes")
print(f"Average per file: {total_time/60/len(audio_files):.1f} minutes")
print(f"\nAll files saved in: outputs/")

In [None]:
# 7. Download all results

from google.colab import files
import zipfile

# Create zip file
with zipfile.ZipFile('ddsp_grown_scale_tones.zip', 'w') as zipf:
    for file in glob.glob('outputs/*'):
        zipf.write(file, Path(file).name)

print("📦 Downloading results...")
files.download('ddsp_grown_scale_tones.zip')

print("\n✅ Complete! All scale tones have been grown and downloaded.")