# DDSP Emergency Fix - Train ONE File at a Time

**Problem**: GPU is enabled but mysteriously slow (11.9s per forward pass)

**Solution**: Train files sequentially instead of in batch

**Expected time**: ~90-120 minutes for 8 files (vs 61 hours)

In [None]:
# 1. Setup
!pip install torch librosa soundfile matplotlib scipy tqdm -q

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import librosa
import soundfile as sf
from tqdm import tqdm
from IPython.display import Audio, display
import glob
from pathlib import Path
import time
import os

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

# DISABLE mixed precision - might be causing slowdown
USE_AMP = False
print(f"Mixed precision: {'DISABLED' if not USE_AMP else 'ENABLED'}")

In [None]:
# 2. Upload files
from google.colab import files
uploaded = files.upload()

os.makedirs('scale_tones', exist_ok=True)
for filename in uploaded.keys():
    os.rename(filename, f'scale_tones/{filename}')

audio_files = sorted(glob.glob('scale_tones/*.wav'))
print(f"Uploaded {len(audio_files)} files")

In [None]:
# 3. Define DDSP model (SIMPLIFIED - no mixed precision complications)

class HarmonicOscillator(nn.Module):
    def __init__(self, sample_rate=22050, n_harmonics=64):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_harmonics = n_harmonics

    def forward(self, f0_hz, harmonic_amplitudes):
        batch_size, n_frames = f0_hz.shape
        hop_length = 512
        n_samples = n_frames * hop_length

        f0_upsampled = F.interpolate(
            f0_hz.unsqueeze(1), size=n_samples, mode='linear', align_corners=True
        ).squeeze(1)

        harmonic_amplitudes_upsampled = F.interpolate(
            harmonic_amplitudes.transpose(1, 2), size=n_samples,
            mode='linear', align_corners=True
        ).transpose(1, 2)

        phase = 2 * torch.pi * torch.cumsum(f0_upsampled / self.sample_rate, dim=1)
        audio = torch.zeros(batch_size, n_samples, device=f0_hz.device)
        
        for h in range(self.n_harmonics):
            harmonic_phase = phase * (h + 1)
            harmonic_signal = torch.sin(harmonic_phase) * harmonic_amplitudes_upsampled[:, :, h]
            audio += harmonic_signal

        return audio


class FilteredNoiseGenerator(nn.Module):
    def __init__(self, sample_rate=22050, n_filter_banks=64):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_filter_banks = n_filter_banks
        self.register_buffer(
            'filter_freqs',
            torch.logspace(torch.log10(torch.tensor(20.0)), torch.log10(torch.tensor(sample_rate / 2.0)), n_filter_banks)
        )

    def forward(self, filter_magnitudes):
        batch_size, n_frames, _ = filter_magnitudes.shape
        hop_length = 512
        n_samples = n_frames * hop_length

        noise = torch.randn(batch_size, n_samples, device=filter_magnitudes.device)
        noise_fft = torch.fft.rfft(noise, dim=1)
        freqs = torch.fft.rfftfreq(n_samples, 1/self.sample_rate).to(filter_magnitudes.device)

        filter_magnitudes_upsampled = F.interpolate(
            filter_magnitudes.transpose(1, 2), size=n_samples,
            mode='linear', align_corners=True
        ).transpose(1, 2)

        filter_response = torch.zeros(batch_size, len(freqs), device=filter_magnitudes.device)
        for i, freq in enumerate(freqs):
            distances = torch.abs(torch.log(self.filter_freqs + 1e-7) - torch.log(freq + 1e-7))
            weights = torch.exp(-distances**2 / 0.5)
            weights = weights / (weights.sum() + 1e-7)
            filter_value = (filter_magnitudes_upsampled.mean(dim=1) * weights).sum(dim=1)
            filter_response[:, i] = filter_value

        filtered_fft = noise_fft * filter_response
        filtered_noise = torch.fft.irfft(filtered_fft, n=n_samples, dim=1)
        return filtered_noise


class DDSPSynthesizer(nn.Module):
    def __init__(self, n_harmonics=64, n_filter_banks=64, hidden_size=512, n_mfcc=30):
        super().__init__()
        input_size = 1 + 1 + n_mfcc
        self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True, dropout=0.1)
        
        self.harmonic_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_harmonics),
            nn.Softplus()
        )
        
        self.noise_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_filter_banks),
            nn.Sigmoid()
        )

    def forward(self, f0, loudness, mfcc):
        x = torch.cat([f0, loudness, mfcc], dim=-1)
        x, _ = self.gru(x)
        harmonic_amplitudes = self.harmonic_head(x)
        filter_magnitudes = self.noise_head(x)
        loudness_scale = torch.exp(loudness / 20.0)
        harmonic_amplitudes = harmonic_amplitudes * loudness_scale
        return harmonic_amplitudes, filter_magnitudes


class DDSPModel(nn.Module):
    def __init__(self, sample_rate=22050, n_harmonics=64, n_filter_banks=64, hidden_size=512):
        super().__init__()
        self.sample_rate = sample_rate
        self.synthesizer = DDSPSynthesizer(n_harmonics, n_filter_banks, hidden_size)
        self.harmonic_osc = HarmonicOscillator(sample_rate, n_harmonics)
        self.noise_gen = FilteredNoiseGenerator(sample_rate, n_filter_banks)
        self.register_parameter('harmonic_noise_ratio', nn.Parameter(torch.tensor(0.8)))

    def forward(self, f0, loudness, mfcc):
        harmonic_amplitudes, filter_magnitudes = self.synthesizer(f0, loudness, mfcc)
        f0_hz = f0.squeeze(-1)
        harmonic_audio = self.harmonic_osc(f0_hz, harmonic_amplitudes)
        noise_audio = self.noise_gen(filter_magnitudes)
        ratio = torch.sigmoid(self.harmonic_noise_ratio)
        audio = ratio * harmonic_audio + (1 - ratio) * noise_audio
        return audio, harmonic_audio, noise_audio


class MultiScaleSpectralLoss(nn.Module):
    def __init__(self, fft_sizes=[2048, 1024, 512, 256]):
        super().__init__()
        self.fft_sizes = fft_sizes

    def forward(self, pred_audio, target_audio):
        total_loss = 0.0
        for fft_size in self.fft_sizes:
            pred_stft = torch.stft(
                pred_audio, n_fft=fft_size, hop_length=fft_size // 4,
                window=torch.hann_window(fft_size, device=pred_audio.device),
                return_complex=True
            )
            target_stft = torch.stft(
                target_audio, n_fft=fft_size, hop_length=fft_size // 4,
                window=torch.hann_window(fft_size, device=target_audio.device),
                return_complex=True
            )
            pred_log_mag = torch.log(torch.abs(pred_stft) + 1e-5)
            target_log_mag = torch.log(torch.abs(target_stft) + 1e-5)
            total_loss += F.l1_loss(pred_log_mag, target_log_mag)
        return total_loss / len(self.fft_sizes)

print("✅ Model defined")

In [None]:
# 4. Feature extraction

def extract_features(audio_path, sample_rate=22050):
    audio, sr = librosa.load(audio_path, sr=sample_rate, mono=True)
    hop_length = int(sample_rate / 43.066)

    f0_yin = librosa.yin(audio, fmin=librosa.note_to_hz('C2'), fmax=librosa.note_to_hz('C7'), sr=sr, hop_length=hop_length)
    f0_yin = np.nan_to_num(f0_yin, nan=0.0)
    f0_yin = np.maximum(f0_yin, 0.0)

    loudness = librosa.feature.rms(y=audio, frame_length=2048, hop_length=hop_length)[0]
    loudness_db = librosa.amplitude_to_db(loudness, ref=1.0)

    mfcc = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=30, hop_length=hop_length).T
    min_len = min(len(f0_yin), len(loudness_db), len(mfcc))

    return {
        'f0': f0_yin[:min_len],
        'loudness': loudness_db[:min_len],
        'mfcc': mfcc[:min_len],
        'audio': audio,
        'n_frames': min_len
    }

print("✅ Feature extraction ready")

In [None]:
# 5. Train ONE FILE AT A TIME (sequential, not batch)

N_EPOCHS = 1000
os.makedirs('outputs', exist_ok=True)

print(f"🎯 Training {len(audio_files)} files SEQUENTIALLY")
print(f"   Epochs per file: {N_EPOCHS}")
print(f"   Estimated time per file: 10-15 minutes")
print(f"   Total estimated time: {len(audio_files) * 12} minutes\n")

total_start_time = time.time()

for file_idx, audio_file in enumerate(audio_files):
    print(f"\n{'='*70}")
    print(f"File {file_idx + 1}/{len(audio_files)}: {Path(audio_file).name}")
    print(f"{'='*70}\n")
    
    # Extract features for this file
    print("📊 Extracting features...")
    features = extract_features(audio_file)
    
    f0 = torch.tensor(features['f0'], dtype=torch.float32).unsqueeze(0).unsqueeze(-1).to(device)
    loudness = torch.tensor(features['loudness'], dtype=torch.float32).unsqueeze(0).unsqueeze(-1).to(device)
    mfcc = torch.tensor(features['mfcc'], dtype=torch.float32).unsqueeze(0).to(device)
    target_audio = torch.tensor(features['audio'], dtype=torch.float32).unsqueeze(0).to(device)
    
    print(f"   F0: {features['f0'].min():.1f} - {features['f0'].max():.1f} Hz")
    print(f"   Frames: {features['n_frames']}")
    
    # Create fresh model for this file
    model = DDSPModel().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = MultiScaleSpectralLoss().to(device)
    
    # Training loop
    losses = []
    best_loss = float('inf')
    file_start_time = time.time()
    
    print(f"\n🚀 Training...")
    
    for epoch in tqdm(range(N_EPOCHS), desc=f"File {file_idx+1}"):
        model.train()
        optimizer.zero_grad()
        
        # Simple forward pass (NO mixed precision)
        pred_audio, _, _ = model(f0, loudness, mfcc)
        
        min_len = min(pred_audio.shape[1], target_audio.shape[1])
        pred_audio_trim = pred_audio[:, :min_len]
        target_audio_trim = target_audio[:, :min_len]
        
        spec_loss = loss_fn(pred_audio_trim, target_audio_trim)
        time_loss = F.l1_loss(pred_audio_trim, target_audio_trim)
        total_loss = 1.0 * spec_loss + 0.1 * time_loss
        
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        loss_value = total_loss.item()
        losses.append(loss_value)
        
        if loss_value < best_loss:
            best_loss = loss_value
        
        if epoch == 0:
            epoch_time = time.time() - file_start_time
            print(f"\n   First epoch: {epoch_time:.1f}s")
            print(f"   Loss: {loss_value:.6f}\n")
        
        if (epoch + 1) % 200 == 0:
            elapsed = time.time() - file_start_time
            print(f"\n   Epoch {epoch+1}: Loss={loss_value:.6f}, Best={best_loss:.6f}, Time={elapsed/60:.1f}min")
    
    file_time = time.time() - file_start_time
    
    # Save model
    torch.save({
        'model_state_dict': model.state_dict(),
        'features': features,
        'losses': losses,
        'best_loss': best_loss,
    }, f"outputs/model_{Path(audio_file).stem}.pt")
    
    print(f"\n✅ File {file_idx + 1} complete!")
    print(f"   Time: {file_time/60:.1f} minutes")
    print(f"   Final loss: {best_loss:.6f}")
    print(f"   Model saved: model_{Path(audio_file).stem}.pt")

total_time = time.time() - total_start_time

print(f"\n\n{'='*70}")
print(f"🎉 ALL FILES COMPLETE!")
print(f"{'='*70}")
print(f"Total time: {total_time/60:.1f} minutes")
print(f"Average per file: {total_time/60/len(audio_files):.1f} minutes")
print(f"\nModels saved in: outputs/")