# DDSP Timbre Grower - Quick Fix Version

**Fixes Applied**:
- ‚úÖ Removed torch.compile (causes hanging)
- ‚úÖ Fixed autocast() device_type parameter
- ‚úÖ Simplified to core optimizations only
- ‚úÖ Added early progress indicators

**Expected Performance**: ~60-90 minutes for 1000 epochs with multiple files

## 1. Setup & Dependencies

In [None]:
# Install dependencies
!pip install torch librosa soundfile matplotlib scipy tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import librosa
import soundfile as sf
import matplotlib.pyplot as plt
from scipy import signal as scipy_signal
from tqdm import tqdm
from IPython.display import Audio, display
import glob
from pathlib import Path
import time
import os

# Check PyTorch version and import accordingly
print(f"PyTorch version: {torch.__version__}")

if torch.__version__.startswith('2.'):
    from torch.amp import autocast, GradScaler
    DEVICE_TYPE = 'cuda'
    print("Using PyTorch 2.x AMP API")
else:
    from torch.cuda.amp import autocast, GradScaler
    DEVICE_TYPE = None  # Old API doesn't need device_type
    print("Using PyTorch 1.x AMP API")

# Check for GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"\nUsing device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Upload Target Audio Files

In [None]:
# For Google Colab - upload multiple files
from google.colab import files

print("üìÅ Upload your scale tone audio files:")
uploaded = files.upload()

# Create directory for uploaded files
os.makedirs('scale_tones', exist_ok=True)

# Move uploaded files to directory
for filename in uploaded.keys():
    os.rename(filename, f'scale_tones/{filename}')

audio_files = sorted(glob.glob('scale_tones/*.wav'))
print(f"\n‚úÖ Uploaded {len(audio_files)} files:")
for f in audio_files:
    print(f"   - {Path(f).name}")

## 3. DDSP Core Components

In [None]:
class HarmonicOscillator(nn.Module):
    """Differentiable harmonic oscillator for additive synthesis."""

    def __init__(self, sample_rate=22050, n_harmonics=64):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_harmonics = n_harmonics

    def forward(self, f0_hz, harmonic_amplitudes):
        batch_size, n_frames = f0_hz.shape
        hop_length = 512
        n_samples = n_frames * hop_length

        f0_upsampled = F.interpolate(
            f0_hz.unsqueeze(1), size=n_samples, mode='linear', align_corners=True
        ).squeeze(1)

        harmonic_amplitudes_upsampled = F.interpolate(
            harmonic_amplitudes.transpose(1, 2), size=n_samples,
            mode='linear', align_corners=True
        ).transpose(1, 2)

        phase = 2 * torch.pi * torch.cumsum(f0_upsampled / self.sample_rate, dim=1)

        audio = torch.zeros(batch_size, n_samples, device=f0_hz.device)
        for h in range(self.n_harmonics):
            harmonic_phase = phase * (h + 1)
            harmonic_signal = torch.sin(harmonic_phase)
            harmonic_signal = harmonic_signal * harmonic_amplitudes_upsampled[:, :, h]
            audio += harmonic_signal

        return audio


class FilteredNoiseGenerator(nn.Module):
    """Differentiable filtered noise generator."""

    def __init__(self, sample_rate=22050, n_filter_banks=64):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_filter_banks = n_filter_banks

        self.register_buffer(
            'filter_freqs',
            torch.logspace(torch.log10(torch.tensor(20.0)), torch.log10(torch.tensor(sample_rate / 2.0)), n_filter_banks)
        )

    def forward(self, filter_magnitudes):
        batch_size, n_frames, _ = filter_magnitudes.shape
        hop_length = 512
        n_samples = n_frames * hop_length

        noise = torch.randn(batch_size, n_samples, device=filter_magnitudes.device)
        noise_fft = torch.fft.rfft(noise, dim=1)
        freqs = torch.fft.rfftfreq(n_samples, 1/self.sample_rate).to(filter_magnitudes.device)

        filter_magnitudes_upsampled = F.interpolate(
            filter_magnitudes.transpose(1, 2), size=n_samples,
            mode='linear', align_corners=True
        ).transpose(1, 2)

        filter_response = torch.zeros(batch_size, len(freqs), device=filter_magnitudes.device)
        for i, freq in enumerate(freqs):
            distances = torch.abs(torch.log(self.filter_freqs + 1e-7) - torch.log(freq + 1e-7))
            weights = torch.exp(-distances**2 / 0.5)
            weights = weights / (weights.sum() + 1e-7)
            filter_value = (filter_magnitudes_upsampled.mean(dim=1) * weights).sum(dim=1)
            filter_response[:, i] = filter_value

        filtered_fft = noise_fft * filter_response
        filtered_noise = torch.fft.irfft(filtered_fft, n=n_samples, dim=1)

        return filtered_noise


class DDSPSynthesizer(nn.Module):
    """Neural network that maps features to synthesis parameters."""

    def __init__(self, n_harmonics=64, n_filter_banks=64, hidden_size=512, n_mfcc=30):
        super().__init__()
        self.n_harmonics = n_harmonics
        self.n_filter_banks = n_filter_banks

        input_size = 1 + 1 + n_mfcc

        self.gru = nn.GRU(
            input_size=input_size, hidden_size=hidden_size,
            num_layers=2, batch_first=True, dropout=0.1
        )

        self.harmonic_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_harmonics),
            nn.Softplus()
        )

        self.noise_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_filter_banks),
            nn.Sigmoid()
        )

    def forward(self, f0, loudness, mfcc):
        x = torch.cat([f0, loudness, mfcc], dim=-1)
        x, _ = self.gru(x)

        harmonic_amplitudes = self.harmonic_head(x)
        filter_magnitudes = self.noise_head(x)

        loudness_scale = torch.exp(loudness / 20.0)
        harmonic_amplitudes = harmonic_amplitudes * loudness_scale

        return harmonic_amplitudes, filter_magnitudes


class DDSPModel(nn.Module):
    """Complete DDSP model."""

    def __init__(self, sample_rate=22050, n_harmonics=64, n_filter_banks=64, hidden_size=512):
        super().__init__()
        self.sample_rate = sample_rate

        self.synthesizer = DDSPSynthesizer(n_harmonics, n_filter_banks, hidden_size)
        self.harmonic_osc = HarmonicOscillator(sample_rate, n_harmonics)
        self.noise_gen = FilteredNoiseGenerator(sample_rate, n_filter_banks)

        self.register_parameter(
            'harmonic_noise_ratio', nn.Parameter(torch.tensor(0.8))
        )

    def forward(self, f0, loudness, mfcc):
        harmonic_amplitudes, filter_magnitudes = self.synthesizer(f0, loudness, mfcc)

        f0_hz = f0.squeeze(-1)
        harmonic_audio = self.harmonic_osc(f0_hz, harmonic_amplitudes)
        noise_audio = self.noise_gen(filter_magnitudes)

        ratio = torch.sigmoid(self.harmonic_noise_ratio)
        audio = ratio * harmonic_audio + (1 - ratio) * noise_audio

        return audio, harmonic_audio, noise_audio


class OptimizedMultiScaleSpectralLoss(nn.Module):
    """Multi-scale spectral loss with cached target computation."""

    def __init__(self, fft_sizes=[2048, 1024, 512, 256]):
        super().__init__()
        self.fft_sizes = fft_sizes
        self.target_stfts = {}

    def precompute_target(self, target_audio, device):
        """Precompute target STFTs once before training."""
        print("üìä Precomputing target STFTs...")
        with torch.no_grad():
            for fft_size in self.fft_sizes:
                target_stft = torch.stft(
                    target_audio, n_fft=fft_size, hop_length=fft_size // 4,
                    window=torch.hann_window(fft_size, device=device),
                    return_complex=True
                )
                target_log_mag = torch.log(torch.abs(target_stft) + 1e-5)
                self.target_stfts[fft_size] = target_log_mag
        print("‚úÖ Target STFTs cached")

    def forward(self, pred_audio):
        """Compute loss using cached target STFTs."""
        total_loss = 0.0
        device = pred_audio.device

        for fft_size in self.fft_sizes:
            pred_stft = torch.stft(
                pred_audio, n_fft=fft_size, hop_length=fft_size // 4,
                window=torch.hann_window(fft_size, device=device),
                return_complex=True
            )

            pred_log_mag = torch.log(torch.abs(pred_stft) + 1e-5)
            total_loss += F.l1_loss(pred_log_mag, self.target_stfts[fft_size])

        return total_loss / len(self.fft_sizes)


print("‚úÖ DDSP components defined")

## 4. Multi-File Feature Extraction

In [None]:
def extract_features(audio_path, sample_rate=22050):
    """Extract f0, loudness, and MFCCs from audio."""
    audio, sr = librosa.load(audio_path, sr=sample_rate, mono=True)
    hop_length = int(sample_rate / 43.066)

    f0_yin = librosa.yin(
        audio, fmin=librosa.note_to_hz('C2'),
        fmax=librosa.note_to_hz('C7'), sr=sr, hop_length=hop_length
    )
    f0_yin = np.nan_to_num(f0_yin, nan=0.0)
    f0_yin = np.maximum(f0_yin, 0.0)

    loudness = librosa.feature.rms(
        y=audio, frame_length=2048, hop_length=hop_length
    )[0]
    loudness_db = librosa.amplitude_to_db(loudness, ref=1.0)

    mfcc = librosa.feature.mfcc(
        y=audio, sr=sr, n_mfcc=30, hop_length=hop_length
    ).T

    min_len = min(len(f0_yin), len(loudness_db), len(mfcc))

    return {
        'f0': f0_yin[:min_len],
        'loudness': loudness_db[:min_len],
        'mfcc': mfcc[:min_len],
        'audio': audio,
        'n_frames': min_len
    }


def load_multiple_files(file_paths, sample_rate=22050, target_frames=None):
    """Load and batch multiple audio files for parallel training."""
    print(f"üìä Loading {len(file_paths)} files...")
    all_features = []
    
    for path in tqdm(file_paths, desc="Extracting features"):
        features = extract_features(path, sample_rate)
        all_features.append(features)
        print(f"   {Path(path).name}: F0 {features['f0'].min():.1f}-{features['f0'].max():.1f} Hz, {features['n_frames']} frames")
    
    if target_frames is None:
        target_frames = max(f['n_frames'] for f in all_features)
    
    print(f"\nüìê Padding all files to {target_frames} frames...")
    
    batch_f0 = []
    batch_loudness = []
    batch_mfcc = []
    batch_audio = []
    
    for features in all_features:
        n = features['n_frames']
        
        if n < target_frames:
            pad_f0 = np.pad(features['f0'], (0, target_frames - n), mode='edge')
            pad_loud = np.pad(features['loudness'], (0, target_frames - n), mode='edge')
            pad_mfcc = np.pad(features['mfcc'], ((0, target_frames - n), (0, 0)), mode='edge')
            audio_samples = target_frames * 512
            pad_audio = np.pad(features['audio'], (0, max(0, audio_samples - len(features['audio']))))
        else:
            pad_f0 = features['f0'][:target_frames]
            pad_loud = features['loudness'][:target_frames]
            pad_mfcc = features['mfcc'][:target_frames]
            pad_audio = features['audio'][:target_frames * 512]
        
        batch_f0.append(pad_f0)
        batch_loudness.append(pad_loud)
        batch_mfcc.append(pad_mfcc)
        batch_audio.append(pad_audio)
    
    print(f"‚úÖ Batch prepared: {len(file_paths)} files √ó {target_frames} frames")
    
    return {
        'f0': torch.tensor(np.stack(batch_f0), dtype=torch.float32).unsqueeze(-1),
        'loudness': torch.tensor(np.stack(batch_loudness), dtype=torch.float32).unsqueeze(-1),
        'mfcc': torch.tensor(np.stack(batch_mfcc), dtype=torch.float32),
        'audio': torch.tensor(np.stack(batch_audio), dtype=torch.float32),
        'n_frames': target_frames,
        'file_paths': file_paths
    }


print("üìÅ Loading audio files...\n")
batch_features = load_multiple_files(audio_files, sample_rate=22050)
print(f"\nüéµ Ready to train on {len(audio_files)} scale tones simultaneously!")

## 5. Streamlined Training Loop (No torch.compile)

In [None]:
# Prepare batched tensors
f0 = batch_features['f0'].to(device)
loudness = batch_features['loudness'].to(device)
mfcc = batch_features['mfcc'].to(device)
target_audio = batch_features['audio'].to(device)

print(f"‚úÖ Batch loaded to {device}:")
print(f"   f0: {f0.shape}")
print(f"   loudness: {loudness.shape}")
print(f"   mfcc: {mfcc.shape}")
print(f"   audio: {target_audio.shape}")

# Create model (NO torch.compile)
print("\nüèóÔ∏è  Building model...")
model = DDSPModel(sample_rate=22050, n_harmonics=64, n_filter_banks=64, hidden_size=512).to(device)

n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"   Parameters: {n_params:,}")
print(f"   torch.compile: DISABLED (avoids hanging)")

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Mixed Precision with correct API
if DEVICE_TYPE:
    scaler = GradScaler(DEVICE_TYPE)
else:
    scaler = GradScaler()
print("‚ö° Mixed precision training enabled")

# Cached target STFT loss
spectral_loss_fn = OptimizedMultiScaleSpectralLoss().to(device)
spectral_loss_fn.precompute_target(target_audio, device)

# Training configuration
N_EPOCHS = 1000
losses = []
best_loss = float('inf')

print(f"\nüéØ Training {len(audio_files)} files for {N_EPOCHS} epochs...")
print(f"   Expected time: ~60-90 minutes\n")

start_time = time.time()

# Test forward pass first
print("üß™ Testing forward pass...")
model.eval()
with torch.no_grad():
    if DEVICE_TYPE:
        with autocast(device_type=DEVICE_TYPE):
            test_out, _, _ = model(f0, loudness, mfcc)
    else:
        with autocast():
            test_out, _, _ = model(f0, loudness, mfcc)
print(f"‚úÖ Forward pass successful! Output shape: {test_out.shape}")
print("\nüöÄ Starting training...\n")

model.train()

for epoch in tqdm(range(N_EPOCHS), desc="Training"):
    optimizer.zero_grad()

    # Mixed precision forward pass with correct API
    if DEVICE_TYPE:
        with autocast(device_type=DEVICE_TYPE):
            pred_audio, harmonic_audio, noise_audio = model(f0, loudness, mfcc)
            
            min_len = min(pred_audio.shape[1], target_audio.shape[1])
            pred_audio_trim = pred_audio[:, :min_len]
            target_audio_trim = target_audio[:, :min_len]
            
            spec_loss = spectral_loss_fn(pred_audio_trim)
            time_loss = F.l1_loss(pred_audio_trim, target_audio_trim)
            total_loss = 1.0 * spec_loss + 0.1 * time_loss
    else:
        with autocast():
            pred_audio, harmonic_audio, noise_audio = model(f0, loudness, mfcc)
            
            min_len = min(pred_audio.shape[1], target_audio.shape[1])
            pred_audio_trim = pred_audio[:, :min_len]
            target_audio_trim = target_audio[:, :min_len]
            
            spec_loss = spectral_loss_fn(pred_audio_trim)
            time_loss = F.l1_loss(pred_audio_trim, target_audio_trim)
            total_loss = 1.0 * spec_loss + 0.1 * time_loss

    # Backward pass
    scaler.scale(total_loss).backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    scaler.step(optimizer)
    scaler.update()

    loss_value = total_loss.item()
    losses.append(loss_value)

    if loss_value < best_loss:
        best_loss = loss_value

    # Show first epoch completion
    if epoch == 0:
        elapsed = time.time() - start_time
        print(f"\n‚úÖ First epoch complete! ({elapsed:.1f}s)")
        print(f"   Loss: {loss_value:.6f}")
        print(f"   Training is working!\n")

    if (epoch + 1) % 100 == 0:
        elapsed = time.time() - start_time
        eta = (elapsed / (epoch + 1)) * (N_EPOCHS - epoch - 1)
        print(f"\nüìä Epoch {epoch+1}/{N_EPOCHS}:")
        print(f"   Loss: {loss_value:.6f} | Best: {best_loss:.6f}")
        print(f"   Time: {elapsed/60:.1f}min | ETA: {eta/60:.1f}min")

total_time = time.time() - start_time

print(f"\n‚úÖ Training complete!")
print(f"   Total time: {total_time/60:.1f} minutes")
print(f"   Per-file time: {total_time/60/len(audio_files):.1f} minutes")
print(f"   Final loss: {best_loss:.6f}")

# Plot training curve
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.yscale('log')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(losses[-500:])
plt.xlabel('Epoch (last 500)')
plt.ylabel('Loss')
plt.title('Training Loss (Detail)')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()