# DDSP Timbre Grower - Complete Implementation

This notebook trains a lightweight DDSP model and generates discrete growing stages.

**Runtime**: Enable GPU in Colab for 5-10x speedup!

**Workflow**:
1. Upload target audio (violin.wav)
2. Install dependencies
3. Define DDSP components
4. Train model (~5-10 min on GPU)
5. Generate discrete growing stages
6. Download results

## 1. Setup & Dependencies

In [None]:
# Install dependencies
!pip install torch librosa soundfile matplotlib scipy tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import librosa
import soundfile as sf
import matplotlib.pyplot as plt
from scipy import signal as scipy_signal
from tqdm import tqdm
from IPython.display import Audio, display

# Check for GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Upload Target Audio

Upload your `violin.wav` file (or any mono audio file).

In [None]:
# For Google Colab
from google.colab import files
uploaded = files.upload()

# Get the uploaded filename
audio_path = list(uploaded.keys())[0]
print(f"Uploaded: {audio_path}")

## 3. DDSP Core Components

In [None]:
class HarmonicOscillator(nn.Module):
    """Differentiable harmonic oscillator for additive synthesis."""
    
    def __init__(self, sample_rate=22050, n_harmonics=64):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_harmonics = n_harmonics
    
    def forward(self, f0_hz, harmonic_amplitudes):
        batch_size, n_frames = f0_hz.shape
        hop_length = 512
        n_samples = n_frames * hop_length
        
        # Upsample f0 and amplitudes
        f0_upsampled = F.interpolate(
            f0_hz.unsqueeze(1), size=n_samples, mode='linear', align_corners=True
        ).squeeze(1)
        
        harmonic_amplitudes_upsampled = F.interpolate(
            harmonic_amplitudes.transpose(1, 2), size=n_samples, 
            mode='linear', align_corners=True
        ).transpose(1, 2)
        
        # Compute phase
        phase = 2 * np.pi * torch.cumsum(f0_upsampled / self.sample_rate, dim=1)
        
        # Generate harmonics
        audio = torch.zeros(batch_size, n_samples, device=f0_hz.device)
        for h in range(self.n_harmonics):
            harmonic_phase = phase * (h + 1)
            harmonic_signal = torch.sin(harmonic_phase)
            harmonic_signal = harmonic_signal * harmonic_amplitudes_upsampled[:, :, h]
            audio += harmonic_signal
        
        return audio


class FilteredNoiseGenerator(nn.Module):
    """Differentiable filtered noise generator."""
    
    def __init__(self, sample_rate=22050, n_filter_banks=64):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_filter_banks = n_filter_banks
        
        self.register_buffer(
            'filter_freqs',
            torch.logspace(np.log10(20), np.log10(sample_rate / 2), n_filter_banks)
        )
    
    def forward(self, filter_magnitudes):
        batch_size, n_frames, _ = filter_magnitudes.shape
        hop_length = 512
        n_samples = n_frames * hop_length
        
        # Generate white noise
        noise = torch.randn(batch_size, n_samples, device=filter_magnitudes.device)
        noise_fft = torch.fft.rfft(noise, dim=1)
        freqs = torch.fft.rfftfreq(n_samples, 1/self.sample_rate).to(filter_magnitudes.device)
        
        # Upsample filter magnitudes
        filter_magnitudes_upsampled = F.interpolate(
            filter_magnitudes.transpose(1, 2), size=n_samples,
            mode='linear', align_corners=True
        ).transpose(1, 2)
        
        # Create frequency-domain filter
        filter_response = torch.zeros(batch_size, len(freqs), device=filter_magnitudes.device)
        for i, freq in enumerate(freqs):
            distances = torch.abs(torch.log(self.filter_freqs + 1e-7) - np.log(freq + 1e-7))
            weights = torch.exp(-distances**2 / 0.5)
            weights = weights / (weights.sum() + 1e-7)
            filter_value = (filter_magnitudes_upsampled.mean(dim=1) * weights).sum(dim=1)
            filter_response[:, i] = filter_value
        
        # Apply filter
        filtered_fft = noise_fft * filter_response
        filtered_noise = torch.fft.irfft(filtered_fft, n=n_samples, dim=1)
        
        return filtered_noise


class DDSPSynthesizer(nn.Module):
    """Neural network that maps features to synthesis parameters."""
    
    def __init__(self, n_harmonics=64, n_filter_banks=64, hidden_size=512, n_mfcc=30):
        super().__init__()
        self.n_harmonics = n_harmonics
        self.n_filter_banks = n_filter_banks
        
        input_size = 1 + 1 + n_mfcc  # f0 + loudness + MFCCs
        
        self.gru = nn.GRU(
            input_size=input_size, hidden_size=hidden_size,
            num_layers=2, batch_first=True, dropout=0.1
        )
        
        self.harmonic_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_harmonics),
            nn.Softplus()
        )
        
        self.noise_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_filter_banks),
            nn.Sigmoid()
        )
    
    def forward(self, f0, loudness, mfcc):
        x = torch.cat([f0, loudness, mfcc], dim=-1)
        x, _ = self.gru(x)
        
        harmonic_amplitudes = self.harmonic_head(x)
        filter_magnitudes = self.noise_head(x)
        
        loudness_scale = torch.exp(loudness / 20.0)
        harmonic_amplitudes = harmonic_amplitudes * loudness_scale
        
        return harmonic_amplitudes, filter_magnitudes


class DDSPModel(nn.Module):
    """Complete DDSP model."""
    
    def __init__(self, sample_rate=22050, n_harmonics=64, n_filter_banks=64, hidden_size=512):
        super().__init__()
        self.sample_rate = sample_rate
        
        self.synthesizer = DDSPSynthesizer(n_harmonics, n_filter_banks, hidden_size)
        self.harmonic_osc = HarmonicOscillator(sample_rate, n_harmonics)
        self.noise_gen = FilteredNoiseGenerator(sample_rate, n_filter_banks)
        
        self.register_parameter(
            'harmonic_noise_ratio', nn.Parameter(torch.tensor(0.8))
        )
    
    def forward(self, f0, loudness, mfcc):
        harmonic_amplitudes, filter_magnitudes = self.synthesizer(f0, loudness, mfcc)
        
        f0_hz = f0.squeeze(-1)
        harmonic_audio = self.harmonic_osc(f0_hz, harmonic_amplitudes)
        noise_audio = self.noise_gen(filter_magnitudes)
        
        ratio = torch.sigmoid(self.harmonic_noise_ratio)
        audio = ratio * harmonic_audio + (1 - ratio) * noise_audio
        
        return audio, harmonic_audio, noise_audio


class MultiScaleSpectralLoss(nn.Module):
    """Multi-scale spectral loss."""
    
    def __init__(self, fft_sizes=[2048, 1024, 512, 256]):
        super().__init__()
        self.fft_sizes = fft_sizes
    
    def forward(self, pred_audio, target_audio):
        total_loss = 0.0
        
        for fft_size in self.fft_sizes:
            pred_stft = torch.stft(
                pred_audio, n_fft=fft_size, hop_length=fft_size // 4,
                window=torch.hann_window(fft_size, device=pred_audio.device),
                return_complex=True
            )
            target_stft = torch.stft(
                target_audio, n_fft=fft_size, hop_length=fft_size // 4,
                window=torch.hann_window(fft_size, device=target_audio.device),
                return_complex=True
            )
            
            pred_log_mag = torch.log(torch.abs(pred_stft) + 1e-5)
            target_log_mag = torch.log(torch.abs(target_stft) + 1e-5)
            
            total_loss += F.l1_loss(pred_log_mag, target_log_mag)
        
        return total_loss / len(self.fft_sizes)


print("✅ DDSP components defined")

## 4. Feature Extraction

In [None]:
def extract_features(audio_path, sample_rate=22050):
    """Extract f0, loudness, and MFCCs from audio."""
    audio, sr = librosa.load(audio_path, sr=sample_rate, mono=True)
    hop_length = int(sample_rate / 43.066)
    
    # F0
    f0_yin = librosa.yin(
        audio, fmin=librosa.note_to_hz('C2'),
        fmax=librosa.note_to_hz('C7'), sr=sr, hop_length=hop_length
    )
    f0_yin = np.nan_to_num(f0_yin, nan=0.0)
    f0_yin = np.maximum(f0_yin, 0.0)
    
    # Loudness
    loudness = librosa.feature.rms(
        y=audio, frame_length=2048, hop_length=hop_length
    )[0]
    loudness_db = librosa.amplitude_to_db(loudness, ref=1.0)
    
    # MFCCs
    mfcc = librosa.feature.mfcc(
        y=audio, sr=sr, n_mfcc=30, hop_length=hop_length
    ).T
    
    min_len = min(len(f0_yin), len(loudness_db), len(mfcc))
    
    return {
        'f0': f0_yin[:min_len],
        'loudness': loudness_db[:min_len],
        'mfcc': mfcc[:min_len],
        'audio': audio,
        'n_frames': min_len
    }


# Extract features
print("📊 Extracting features...")
features = extract_features(audio_path)
print(f"   F0 range: {features['f0'].min():.1f} - {features['f0'].max():.1f} Hz")
print(f"   Frames: {features['n_frames']}")
print(f"   Duration: {len(features['audio']) / 22050:.2f}s")

# Listen to original
print("\n🎵 Original audio:")
display(Audio(features['audio'], rate=22050))

## 5. Train DDSP Model

This will take ~5-10 minutes on GPU, ~30 minutes on CPU.

In [None]:
# Prepare tensors
f0 = torch.tensor(features['f0'], dtype=torch.float32).unsqueeze(0).unsqueeze(-1).to(device)
loudness = torch.tensor(features['loudness'], dtype=torch.float32).unsqueeze(0).unsqueeze(-1).to(device)
mfcc = torch.tensor(features['mfcc'], dtype=torch.float32).unsqueeze(0).to(device)
target_audio = torch.tensor(features['audio'], dtype=torch.float32).unsqueeze(0).to(device)

# Create model
print("🏗️  Building model...")
model = DDSPModel(sample_rate=22050, n_harmonics=64, n_filter_banks=64, hidden_size=512).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"   Parameters: {n_params:,}")

# Optimizer and loss
optimizer = optim.Adam(model.parameters(), lr=1e-3)
spectral_loss_fn = MultiScaleSpectralLoss().to(device)

# Training loop
N_EPOCHS = 1000
losses = []
best_loss = float('inf')

print(f"\n🎯 Training for {N_EPOCHS} epochs...")
for epoch in tqdm(range(N_EPOCHS)):
    model.train()
    optimizer.zero_grad()
    
    pred_audio, harmonic_audio, noise_audio = model(f0, loudness, mfcc)
    
    min_len = min(pred_audio.shape[1], target_audio.shape[1])
    pred_audio_trim = pred_audio[:, :min_len]
    target_audio_trim = target_audio[:, :min_len]
    
    spec_loss = spectral_loss_fn(pred_audio_trim, target_audio_trim)
    time_loss = F.l1_loss(pred_audio_trim, target_audio_trim)
    total_loss = 1.0 * spec_loss + 0.1 * time_loss
    
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    
    loss_value = total_loss.item()
    losses.append(loss_value)
    
    if loss_value < best_loss:
        best_loss = loss_value
    
    if (epoch + 1) % 100 == 0:
        print(f"\nEpoch {epoch+1}: Loss={loss_value:.6f}, Best={best_loss:.6f}")

print(f"\n✅ Training complete! Best loss: {best_loss:.6f}")

# Plot training curve
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('DDSP Training Loss')
plt.yscale('log')
plt.grid(True, alpha=0.3)
plt.show()

## 6. Test Reconstruction

In [None]:
model.eval()
with torch.no_grad():
    pred_audio, harmonic_audio, noise_audio = model(f0, loudness, mfcc)

pred_audio_np = pred_audio.squeeze().cpu().numpy()
harmonic_audio_np = harmonic_audio.squeeze().cpu().numpy()
noise_audio_np = noise_audio.squeeze().cpu().numpy()

print("🎵 Reconstructed audio:")
display(Audio(pred_audio_np, rate=22050))

print("\n🎵 Harmonic component only:")
display(Audio(harmonic_audio_np, rate=22050))

print("\n🎵 Noise component only:")
display(Audio(noise_audio_np, rate=22050))

## 7. Generate Discrete Growing Stages

In [None]:
class ADSREnvelope:
    def __init__(self, sr=22050):
        self.sr = sr
    
    def generate(self, duration, attack=0.05, decay=0.1, sustain_level=0.7, release=0.2):
        n_samples = int(duration * self.sr)
        attack_samples = int(attack * self.sr)
        decay_samples = int(decay * self.sr)
        release_samples = int(release * self.sr)
        sustain_samples = n_samples - attack_samples - decay_samples - release_samples
        sustain_samples = max(0, sustain_samples)
        
        envelope = []
        if attack_samples > 0:
            envelope.extend(np.linspace(0, 1, attack_samples))
        if decay_samples > 0:
            envelope.extend(np.linspace(1, sustain_level, decay_samples))
        if sustain_samples > 0:
            envelope.extend(np.ones(sustain_samples) * sustain_level)
        if release_samples > 0:
            envelope.extend(np.linspace(sustain_level, 0, release_samples))
        
        envelope = np.array(envelope)
        if len(envelope) < n_samples:
            envelope = np.pad(envelope, (0, n_samples - len(envelope)))
        elif len(envelope) > n_samples:
            envelope = envelope[:n_samples]
        return envelope


def synthesize_stage(model, features, active_harmonics, duration=0.5, device='cpu'):
    """Synthesize one stage with progressive harmonic activation."""
    hop_length = 512
    n_frames = int((duration * 22050) / hop_length)
    
    f0_mean = features['f0'][features['f0'] > 0].mean()
    if np.isnan(f0_mean):
        f0_mean = 220.0
    
    f0_tensor = torch.ones(1, n_frames, 1, device=device) * f0_mean
    loudness_tensor = torch.ones(1, n_frames, 1, device=device) * features['loudness'].mean()
    mfcc_mean = torch.tensor(features['mfcc'].mean(axis=0), dtype=torch.float32, device=device)
    mfcc_tensor = mfcc_mean.unsqueeze(0).unsqueeze(0).expand(1, n_frames, -1)
    
    with torch.no_grad():
        # Get synthesis parameters from model
        harmonic_amplitudes, filter_magnitudes = model.synthesizer(f0_tensor, loudness_tensor, mfcc_tensor)
        
        # Apply harmonic mask to progressively add harmonics
        harmonic_mask = torch.tensor(active_harmonics, dtype=torch.float32, device=device)
        harmonic_mask = harmonic_mask.unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, n_harmonics]
        masked_harmonics = harmonic_amplitudes * harmonic_mask
        
        # Generate audio with masked harmonics
        f0_hz = f0_tensor.squeeze(-1)
        harmonic_audio = model.harmonic_osc(f0_hz, masked_harmonics)
        noise_audio = model.noise_gen(filter_magnitudes)
        
        ratio = torch.sigmoid(model.harmonic_noise_ratio)
        pred_audio = ratio * harmonic_audio + (1 - ratio) * noise_audio
    
    audio = pred_audio.squeeze().cpu().numpy()
    target_samples = int(duration * 22050)
    if len(audio) < target_samples:
        audio = np.pad(audio, (0, target_samples - len(audio)))
    else:
        audio = audio[:target_samples]
    
    # Apply ADSR envelope
    envelope_gen = ADSREnvelope(sr=22050)
    envelope = envelope_gen.generate(duration)
    audio = audio * envelope
    
    # Normalize
    if np.abs(audio).max() > 0:
        audio = audio / np.abs(audio).max() * 0.8
    
    return audio


def generate_stages(strategy='linear', n_harmonics=64):
    """Generate stage masks."""
    stages = []
    if strategy == 'linear':
        for n in range(1, n_harmonics + 1):
            stage = np.zeros(n_harmonics)
            stage[:n] = 1.0
            stages.append(stage)
    return stages


# Generate discrete growing stages
print("🎵 Generating discrete growing stages...")
STAGE_DURATION = 0.5
SILENCE_DURATION = 0.2

stages = generate_stages('linear', 64)
audio_segments = []
silence = np.zeros(int(SILENCE_DURATION * 22050))

for i, stage in enumerate(tqdm(stages)):
    stage_audio = synthesize_stage(model, features, stage, STAGE_DURATION, device)
    audio_segments.append(stage_audio)
    if i < len(stages) - 1:
        audio_segments.append(silence)

full_audio = np.concatenate(audio_segments)

print(f"\n✅ Generated {len(stages)} stages")
print(f"   Total duration: {len(full_audio)/22050:.2f}s")

# Listen
print("\n🎧 Discrete growing stages with DDSP:")
display(Audio(full_audio, rate=22050))

## 8. Download Results

In [None]:
# Save files
sf.write('ddsp_reconstructed.wav', pred_audio_np, 22050)
sf.write('ddsp_grown_timbre.wav', full_audio, 22050)

# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'features': features,
}, 'ddsp_model.pt')

# Download in Colab
from google.colab import files
files.download('ddsp_reconstructed.wav')
files.download('ddsp_grown_timbre.wav')
files.download('ddsp_model.pt')

print("✅ Files ready for download!")