# Timbre Generation: Conservative Fix 🎻

**Lessons Learned**:
- ❌ Circular padding broke temporal structure
- ❌ Aggressive frequency weighting (5.0x) was too much
- ✅ Original approach had structure, just needed gentle improvements

**Conservative Changes**:
1. ✅ Fix dB-to-power conversion: `/20 → /10`
2. ✅ Moderate frequency weighting: `2.0 → 3.0` (not 5.0!)
3. ✅ Moderate L1 weight: `0.1 → 0.3` (not 0.5!)
4. ✅ NO circular padding (keep original zero padding)

**Expected**: Slightly better high-freq content, correct energy, minimal risk

In [None]:
!pip install librosa soundfile bigvgan -q

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import soundfile as sf
import os
import bigvgan

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
class Config:
    SAMPLE_RATE = 22050
    N_FFT = 1024
    HOP_LENGTH = 256
    N_MELS = 80
    CELL_CHANNELS = 16
    UPDATE_STEPS_TRAIN = 96
    UPDATE_STEPS_INFERENCE = 96
    LEARNING_RATE = 2e-4
    NUM_EPOCHS = 8000
    LOG_INTERVAL = 100
    OUTPUT_IMAGE_DIR = 'training_progress'

In [None]:
def create_auditory_grid(audio_path, config):
    waveform, sr = librosa.load(audio_path, sr=config.SAMPLE_RATE, mono=True)
    n_frames = len(waveform) / config.HOP_LENGTH
    if n_frames % 8 != 0:
        target_frames = int(np.ceil(n_frames / 8.0)) * 8
        target_samples = target_frames * config.HOP_LENGTH
        padding_needed = target_samples - len(waveform)
        waveform = np.pad(waveform, (0, padding_needed), 'constant')
    mel_spectrogram = librosa.feature.melspectrogram(
        y=waveform, sr=config.SAMPLE_RATE, n_fft=config.N_FFT,
        hop_length=config.HOP_LENGTH, n_mels=config.N_MELS
    )
    return librosa.power_to_db(mel_spectrogram, ref=np.max)

def visualize_spectrogram(spectrogram, config, title='Mel-Spectrogram', output_path=None):
    plt.figure(figsize=(12, 5))
    if isinstance(spectrogram, torch.Tensor):
        spectrogram = spectrogram.detach().cpu().numpy()
    librosa.display.specshow(spectrogram, sr=config.SAMPLE_RATE, hop_length=config.HOP_LENGTH, 
                             x_axis='time', y_axis='mel')
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.tight_layout()
    if output_path:
        plt.savefig(output_path, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

In [None]:
print("Loading BigVGAN...")
bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_base_22khz_80band', use_cuda_kernel=False)
bigvgan_model = bigvgan_model.to(device).eval()
print("BigVGAN loaded!")

In [None]:
class TNCAModel(nn.Module):
    """Original TNCA (NO circular padding)"""
    def __init__(self, num_channels=16):
        super().__init__()
        self.num_channels = num_channels
        perception_vector_size = self.num_channels * 4
        self.update_mlp = nn.Sequential(
            nn.Conv2d(perception_vector_size, 128, 1), nn.ReLU(),
            nn.Conv2d(128, 64, 1), nn.ReLU(),
            nn.Conv2d(64, self.num_channels, 1, bias=True)
        )
        self.update_mlp[-1].weight.data.zero_()
        self.update_mlp[-1].bias.data.zero_()
        self.update_mlp[-1].bias.data[1] = 1.0

    def perceive(self, grid):
        device = grid.device
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=device).view(1, 1, 3, 3)
        sobel_y = sobel_x.transpose(-2, -1).contiguous()
        laplacian = torch.tensor([[1, 2, 1], [2, -12, 2], [1, 2, 1]], dtype=torch.float32, device=device).view(1, 1, 3, 3)
        
        sobel_x_kernel = sobel_x.repeat(self.num_channels, 1, 1, 1)
        sobel_y_kernel = sobel_y.repeat(self.num_channels, 1, 1, 1)
        laplacian_kernel = laplacian.repeat(self.num_channels, 1, 1, 1)
        
        # ORIGINAL zero padding (no circular!)
        grad_x = F.conv2d(grid, sobel_x_kernel, padding=1, groups=self.num_channels)
        grad_y = F.conv2d(grid, sobel_y_kernel, padding=1, groups=self.num_channels)
        lap = F.conv2d(grid, laplacian_kernel, padding=1, groups=self.num_channels)
        
        return torch.cat([grid, grad_x, grad_y, lap], dim=1)

    def forward(self, grid):
        perception_vector = self.perceive(grid)
        ds = self.update_mlp(perception_vector)
        grid = grid + ds
        clamped_alpha = torch.clamp(grid[:, 1:2, :, :], 0.0, 1.0)
        grid = torch.cat([grid[:, :1, :, :], clamped_alpha, grid[:, 2:, :, :]], dim=1)
        alpha_channel = grid[:, 1:2, :, :]
        living_mask = F.max_pool2d(alpha_channel, kernel_size=3, stride=1, padding=1)
        grid = grid * living_mask
        return grid

class PerceptualLoss(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 16, 7, 2, 3), nn.ReLU(),
            nn.Conv2d(16, 32, 5, 2, 2), nn.ReLU(),
            nn.Conv2d(32, 64, 3, 2, 1), nn.ReLU()
        ).to(device)
    
    def gram_matrix(self, f):
        b, c, h, w = f.size()
        f = f.view(b, c, h * w)
        return torch.bmm(f, f.transpose(1, 2)) / (c * h * w)
    
    def forward(self, gen_mel, target_mel):
        return F.mse_loss(
            self.gram_matrix(self.feature_extractor(gen_mel)),
            self.gram_matrix(self.feature_extractor(target_mel))
        )

In [None]:
def train(model, vocoder, loss_fn, optimizer, target_mel, config):
    print("\n--- Conservative Training ---")
    perceptual_loss_fn = loss_fn
    l1_loss_fn = nn.L1Loss()
    
    target_mel_tensor = torch.tensor(target_mel, dtype=torch.float32).to(device).unsqueeze(0).unsqueeze(1)
    target_mask = (target_mel_tensor > -70.0).float()
    
    # MODERATE frequency weighting: 1.0 → 3.0 (not 5.0!)
    n_mels = target_mel_tensor.shape[2]
    frequency_loss_weights = torch.linspace(1.0, 3.0, n_mels, device=device).view(1, 1, n_mels, 1)
    print(f"Frequency weighting: 1.0 → 3.0 (conservative)")
    
    seed_grid = torch.zeros(1, config.CELL_CHANNELS, config.N_MELS, target_mel.shape[1], device=device)
    h, w = seed_grid.shape[2], seed_grid.shape[3]
    seed_grid[:, 1, h//2, w//2] = 1.0
    
    pbar = tqdm(range(config.NUM_EPOCHS), desc="Training...")
    
    for epoch in pbar:
        grid = seed_grid.clone()
        for _ in range(config.UPDATE_STEPS_TRAIN):
            grid = model(grid)
        
        generated_mel_db_unscaled = grid[:, 0:1, :, :]
        generated_mel_db = torch.tanh(generated_mel_db_unscaled) * 40.0 - 40.0
        
        p_loss = perceptual_loss_fn(generated_mel_db * target_mask, target_mel_tensor * target_mask)
        l1_loss = l1_loss_fn(generated_mel_db * target_mask, target_mel_tensor * target_mask)
        
        # MODERATE L1 weight: 0.3 (not 0.5!)
        loss = (2.0 * p_loss + 0.3 * l1_loss) * frequency_loss_weights
        loss = loss.mean()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (epoch + 1) % config.LOG_INTERVAL == 0:
            pbar.set_description(f"Epoch {epoch+1}, Loss: {loss.item():.6f} (P: {p_loss.item():.6f}, L1: {l1_loss.item():.6f})")
            filepath = os.path.join(config.OUTPUT_IMAGE_DIR, f'epoch_{(epoch+1):04d}.png')
            visualize_spectrogram(generated_mel_db.squeeze(), config, 
                                title=f'Conservative - Epoch {epoch+1}', output_path=filepath)
    
    print(f"\nTraining complete. Final loss: {loss.item():.6f}")
    return model

In [None]:
def inference(model, vocoder, target_mel, config, output_path='generated_conservative.wav'):
    print("\n--- Conservative Inference ---")
    model.eval()
    
    if isinstance(target_mel, np.ndarray):
        target_mel = torch.tensor(target_mel, dtype=torch.float32, device=device)
    else:
        target_mel = target_mel.to(device)
    
    seed_grid = torch.zeros(1, config.CELL_CHANNELS, config.N_MELS, target_mel.shape[1], device=device)
    h, w = seed_grid.shape[2], seed_grid.shape[3]
    seed_grid[:, 1, h//2, w//2] = 1.0
    
    with torch.no_grad():
        final_grid = seed_grid.clone()
        for _ in tqdm(range(config.UPDATE_STEPS_INFERENCE), desc="Generating..."):
            final_grid = model(final_grid)
        
        final_mel_db_unscaled = final_grid[:, 0, :, :]
        final_mel_db = torch.tanh(final_mel_db_unscaled) * 40.0 - 40.0
        
        # FIXED: dB to power (10^(dB/10))
        print("Converting dB to power: 10^(dB/10)")
        final_mel_power = torch.pow(10.0, final_mel_db / 10.0)
        
        if final_mel_power.dim() == 2:
            final_mel_power = final_mel_power.unsqueeze(0)
        
        print("Generating audio with BigVGAN...")
        final_waveform = vocoder(final_mel_power).squeeze()
    
    waveform_np = final_waveform.cpu().numpy()
    sf.write(output_path, waveform_np, config.SAMPLE_RATE)
    
    print(f"\n✅ Audio saved: {output_path}")
    print(f"Duration: {len(waveform_np) / config.SAMPLE_RATE:.2f}s")
    print(f"RMS energy: {np.sqrt(np.mean(waveform_np**2)):.4f}")
    
    visualize_spectrogram(final_mel_db.squeeze().cpu(), config, title='Final (Conservative)')
    return waveform_np

In [None]:
# Main execution
config = Config()
os.makedirs(config.OUTPUT_IMAGE_DIR, exist_ok=True)

target_audio_path = 'violin.wav'

if not os.path.exists(target_audio_path):
    print(f"ERROR: {target_audio_path} not found")
else:
    target_spectrogram = create_auditory_grid(target_audio_path, config)
    visualize_spectrogram(target_spectrogram, config, title='Target')
    
    tnca_model = TNCAModel(config.CELL_CHANNELS).to(device)
    perceptual_loss_fn = PerceptualLoss(device).to(device)
    optimizer = optim.Adam(tnca_model.parameters(), lr=config.LEARNING_RATE)
    
    print("\n" + "="*60)
    print("CONSERVATIVE APPROACH:")
    print("  ✅ NO circular padding (original zero padding)")
    print("  ✅ Frequency weighting: 1.0 → 3.0 (moderate)")
    print("  ✅ L1 weight: 0.3 (moderate increase)")
    print("  ✅ dB conversion: 10^(dB/10) for power")
    print("="*60)
    
    trained_model = train(tnca_model, bigvgan_model, perceptual_loss_fn, optimizer, target_spectrogram, config)
    generated_audio = inference(trained_model, bigvgan_model, target_spectrogram, config)
    
    print("\n🎉 Conservative approach complete!")
    print("Expected: Stable training, slight high-freq improvement, correct energy")