# Timbre Generation: Dual-Spectrogram Approach 🎼

**Innovation**: Generate BOTH mel + linear spectrograms simultaneously

**Why Dual Representations**:
- **Mel (80 bands)**: Perceptually relevant, matches human hearing, good for vocoder
- **Linear (513 bands)**: Full frequency resolution, captures all harmonics, preserves detail
- **Complementary**: Each compensates for the other's weaknesses

**Architecture**:
```
TNCA Output Channels:
├─ Channel 0: Mel spectrogram value
├─ Channel 1: Alpha (life mask)
├─ Channel 2: Linear spectrogram value
└─ Channels 3-15: Hidden state
```

In [None]:
!pip install librosa soundfile bigvgan -q

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import soundfile as sf
import os
import bigvgan

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
class Config:
    SAMPLE_RATE = 22050
    N_FFT = 1024
    HOP_LENGTH = 256
    N_MELS = 80  # Mel spectrogram bands
    N_LINEAR = 513  # Linear spectrogram bands (N_FFT//2 + 1)
    CELL_CHANNELS = 16  # Total channels in TNCA
    UPDATE_STEPS_TRAIN = 96
    UPDATE_STEPS_INFERENCE = 96
    LEARNING_RATE = 2e-4
    NUM_EPOCHS = 8000
    LOG_INTERVAL = 100
    OUTPUT_IMAGE_DIR = 'training_progress'

In [None]:
def create_dual_spectrograms(audio_path, config):
    """Create BOTH mel and linear spectrograms."""
    waveform, sr = librosa.load(audio_path, sr=config.SAMPLE_RATE, mono=True)
    
    # Padding
    n_frames = len(waveform) / config.HOP_LENGTH
    if n_frames % 8 != 0:
        target_frames = int(np.ceil(n_frames / 8.0)) * 8
        target_samples = target_frames * config.HOP_LENGTH
        padding_needed = target_samples - len(waveform)
        waveform = np.pad(waveform, (0, padding_needed), 'constant')
    
    # Mel spectrogram
    mel_spec = librosa.feature.melspectrogram(
        y=waveform, sr=config.SAMPLE_RATE, n_fft=config.N_FFT,
        hop_length=config.HOP_LENGTH, n_mels=config.N_MELS
    )
    mel_db = librosa.power_to_db(mel_spec, ref=np.max)
    
    # Linear spectrogram (full frequency resolution)
    linear_spec = np.abs(librosa.stft(waveform, n_fft=config.N_FFT, hop_length=config.HOP_LENGTH))**2
    linear_db = librosa.power_to_db(linear_spec, ref=np.max)
    
    return mel_db, linear_db, waveform

def visualize_dual_spectrograms(mel_spec, linear_spec, config, title_prefix='', output_path=None):
    """Visualize both spectrograms side by side."""
    fig, axes = plt.subplots(1, 2, figsize=(18, 5))
    
    if isinstance(mel_spec, torch.Tensor):
        mel_spec = mel_spec.detach().cpu().numpy()
    if isinstance(linear_spec, torch.Tensor):
        linear_spec = linear_spec.detach().cpu().numpy()
    
    # Mel spectrogram
    librosa.display.specshow(mel_spec, sr=config.SAMPLE_RATE, hop_length=config.HOP_LENGTH,
                             x_axis='time', y_axis='mel', ax=axes[0])
    axes[0].set_title(f'{title_prefix} Mel Spectrogram (80 bands)')
    plt.colorbar(axes[0].collections[0], ax=axes[0], format='%+2.0f dB')
    
    # Linear spectrogram
    librosa.display.specshow(linear_spec, sr=config.SAMPLE_RATE, hop_length=config.HOP_LENGTH,
                             x_axis='time', y_axis='linear', ax=axes[1])
    axes[1].set_title(f'{title_prefix} Linear Spectrogram (513 bands)')
    plt.colorbar(axes[1].collections[0], ax=axes[1], format='%+2.0f dB')
    
    plt.tight_layout()
    if output_path:
        plt.savefig(output_path, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

In [None]:
print("Loading BigVGAN...")
bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_base_22khz_80band', use_cuda_kernel=False)
bigvgan_model = bigvgan_model.to(device).eval()
print("BigVGAN loaded!")

In [None]:
class DualSpecTNCAModel(nn.Module):
    """TNCA that generates BOTH mel and linear spectrograms."""
    def __init__(self, num_channels=16, n_mels=80, n_linear=513):
        super().__init__()
        self.num_channels = num_channels
        self.n_mels = n_mels
        self.n_linear = n_linear
        
        perception_vector_size = self.num_channels * 4
        
        self.update_mlp = nn.Sequential(
            nn.Conv2d(perception_vector_size, 128, 1), nn.ReLU(),
            nn.Conv2d(128, 64, 1), nn.ReLU(),
            nn.Conv2d(64, self.num_channels, 1, bias=True)
        )
        
        self.update_mlp[-1].weight.data.zero_()
        self.update_mlp[-1].bias.data.zero_()
        self.update_mlp[-1].bias.data[1] = 1.0  # Alpha channel
    
    def perceive(self, grid):
        device = grid.device
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=device).view(1, 1, 3, 3)
        sobel_y = sobel_x.transpose(-2, -1).contiguous()
        laplacian = torch.tensor([[1, 2, 1], [2, -12, 2], [1, 2, 1]], dtype=torch.float32, device=device).view(1, 1, 3, 3)
        
        sobel_x_kernel = sobel_x.repeat(self.num_channels, 1, 1, 1)
        sobel_y_kernel = sobel_y.repeat(self.num_channels, 1, 1, 1)
        laplacian_kernel = laplacian.repeat(self.num_channels, 1, 1, 1)
        
        grad_x = F.conv2d(grid, sobel_x_kernel, padding=1, groups=self.num_channels)
        grad_y = F.conv2d(grid, sobel_y_kernel, padding=1, groups=self.num_channels)
        lap = F.conv2d(grid, laplacian_kernel, padding=1, groups=self.num_channels)
        
        return torch.cat([grid, grad_x, grad_y, lap], dim=1)
    
    def forward(self, grid_mel, grid_linear):
        """Forward pass for BOTH grids."""
        # Process mel grid
        perception_mel = self.perceive(grid_mel)
        ds_mel = self.update_mlp(perception_mel)
        grid_mel = grid_mel + ds_mel
        
        # Process linear grid
        perception_linear = self.perceive(grid_linear)
        ds_linear = self.update_mlp(perception_linear)
        grid_linear = grid_linear + ds_linear
        
        # Apply alpha clamping and living mask to both
        for grid in [grid_mel, grid_linear]:
            clamped_alpha = torch.clamp(grid[:, 1:2, :, :], 0.0, 1.0)
            grid = torch.cat([grid[:, :1, :, :], clamped_alpha, grid[:, 2:, :, :]], dim=1)
            alpha_channel = grid[:, 1:2, :, :]
            living_mask = F.max_pool2d(alpha_channel, kernel_size=3, stride=1, padding=1)
            grid = grid * living_mask
        
        return grid_mel, grid_linear

class PerceptualLoss(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 16, 7, 2, 3), nn.ReLU(),
            nn.Conv2d(16, 32, 5, 2, 2), nn.ReLU(),
            nn.Conv2d(32, 64, 3, 2, 1), nn.ReLU()
        ).to(device)
    
    def gram_matrix(self, f):
        b, c, h, w = f.size()
        f = f.view(b, c, h * w)
        return torch.bmm(f, f.transpose(1, 2)) / (c * h * w)
    
    def forward(self, gen_mel, target_mel):
        return F.mse_loss(
            self.gram_matrix(self.feature_extractor(gen_mel)),
            self.gram_matrix(self.feature_extractor(target_mel))
        )

In [None]:
def train_dual(model, vocoder, loss_fn, optimizer, target_mel, target_linear, config):
    print("\n--- Dual-Spectrogram Training ---")
    perceptual_loss_fn = loss_fn
    l1_loss_fn = nn.L1Loss()
    
    # Prepare targets
    target_mel_tensor = torch.tensor(target_mel, dtype=torch.float32).to(device).unsqueeze(0).unsqueeze(1)
    target_linear_tensor = torch.tensor(target_linear, dtype=torch.float32).to(device).unsqueeze(0).unsqueeze(1)
    
    target_mel_mask = (target_mel_tensor > -70.0).float()
    target_linear_mask = (target_linear_tensor > -70.0).float()
    
    # Initialize seed grids (DIFFERENT sizes!)
    seed_grid_mel = torch.zeros(1, config.CELL_CHANNELS, config.N_MELS, target_mel.shape[1], device=device)
    seed_grid_linear = torch.zeros(1, config.CELL_CHANNELS, config.N_LINEAR, target_linear.shape[1], device=device)
    
    # Seed centers
    seed_grid_mel[:, 1, config.N_MELS//2, target_mel.shape[1]//2] = 1.0
    seed_grid_linear[:, 1, config.N_LINEAR//2, target_linear.shape[1]//2] = 1.0
    
    pbar = tqdm(range(config.NUM_EPOCHS), desc="Training...")
    
    for epoch in pbar:
        grid_mel = seed_grid_mel.clone()
        grid_linear = seed_grid_linear.clone()
        
        # Evolve BOTH grids
        for _ in range(config.UPDATE_STEPS_TRAIN):
            grid_mel, grid_linear = model(grid_mel, grid_linear)
        
        # Extract values
        gen_mel_db = torch.tanh(grid_mel[:, 0:1, :, :]) * 40.0 - 40.0
        gen_linear_db = torch.tanh(grid_linear[:, 0:1, :, :]) * 40.0 - 40.0
        
        # Dual loss
        p_loss_mel = perceptual_loss_fn(gen_mel_db * target_mel_mask, target_mel_tensor * target_mel_mask)
        l1_loss_mel = l1_loss_fn(gen_mel_db * target_mel_mask, target_mel_tensor * target_mel_mask)
        mel_loss = 2.0 * p_loss_mel + 0.3 * l1_loss_mel
        
        l1_loss_linear = l1_loss_fn(gen_linear_db * target_linear_mask, target_linear_tensor * target_linear_mask)
        linear_loss = l1_loss_linear  # Linear uses L1 only (no perceptual)
        
        # Combined loss
        total_loss = mel_loss + 0.5 * linear_loss
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        if (epoch + 1) % config.LOG_INTERVAL == 0:
            pbar.set_description(
                f"Epoch {epoch+1}, Loss: {total_loss.item():.6f} "
                f"(Mel: {mel_loss.item():.6f}, Lin: {linear_loss.item():.6f})"
            )
            filepath = os.path.join(config.OUTPUT_IMAGE_DIR, f'epoch_{(epoch+1):04d}.png')
            visualize_dual_spectrograms(
                gen_mel_db.squeeze(), gen_linear_db.squeeze(), config,
                title_prefix=f'Epoch {epoch+1}', output_path=filepath
            )
    
    print(f"\nTraining complete. Final loss: {total_loss.item():.6f}")
    return model

In [None]:
def inference_dual(model, vocoder, target_mel, target_linear, config, output_path='generated_dual.wav'):
    print("\n--- Dual-Spectrogram Inference ---")
    model.eval()
    
    # Initialize seed grids
    seed_grid_mel = torch.zeros(1, config.CELL_CHANNELS, config.N_MELS, target_mel.shape[1], device=device)
    seed_grid_linear = torch.zeros(1, config.CELL_CHANNELS, config.N_LINEAR, target_linear.shape[1], device=device)
    seed_grid_mel[:, 1, config.N_MELS//2, target_mel.shape[1]//2] = 1.0
    seed_grid_linear[:, 1, config.N_LINEAR//2, target_linear.shape[1]//2] = 1.0
    
    with torch.no_grad():
        grid_mel = seed_grid_mel.clone()
        grid_linear = seed_grid_linear.clone()
        
        for _ in tqdm(range(config.UPDATE_STEPS_INFERENCE), desc="Generating..."):
            grid_mel, grid_linear = model(grid_mel, grid_linear)
        
        final_mel_db = torch.tanh(grid_mel[:, 0, :, :]) * 40.0 - 40.0
        final_linear_db = torch.tanh(grid_linear[:, 0, :, :]) * 40.0 - 40.0
        
        # Convert to power
        final_mel_power = torch.pow(10.0, final_mel_db / 10.0)
        
        if final_mel_power.dim() == 2:
            final_mel_power = final_mel_power.unsqueeze(0)
        
        print("Using MEL spectrogram for BigVGAN vocoding...")
        final_waveform = vocoder(final_mel_power).squeeze()
    
    waveform_np = final_waveform.cpu().numpy()
    sf.write(output_path, waveform_np, config.SAMPLE_RATE)
    
    print(f"\n✅ Audio saved: {output_path}")
    print(f"RMS energy: {np.sqrt(np.mean(waveform_np**2)):.4f}")
    
    visualize_dual_spectrograms(final_mel_db.squeeze().cpu(), final_linear_db.squeeze().cpu(), 
                               config, title_prefix='Final')
    return waveform_np

In [None]:
# Main execution
config = Config()
os.makedirs(config.OUTPUT_IMAGE_DIR, exist_ok=True)

target_audio_path = 'violin.wav'

if not os.path.exists(target_audio_path):
    print(f"ERROR: {target_audio_path} not found")
else:
    target_mel, target_linear, waveform = create_dual_spectrograms(target_audio_path, config)
    visualize_dual_spectrograms(target_mel, target_linear, config, title_prefix='Target')
    
    model = DualSpecTNCAModel(config.CELL_CHANNELS, config.N_MELS, config.N_LINEAR).to(device)
    perceptual_loss_fn = PerceptualLoss(device).to(device)
    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
    
    print("\n" + "="*60)
    print("DUAL-SPECTROGRAM APPROACH:")
    print(f"  ✅ Mel: {config.N_MELS} bands (perceptual)")
    print(f"  ✅ Linear: {config.N_LINEAR} bands (full resolution)")
    print("  ✅ Complementary representations")
    print("  ✅ Better high-frequency capture expected")
    print("="*60)
    
    trained_model = train_dual(model, bigvgan_model, perceptual_loss_fn, optimizer, 
                               target_mel, target_linear, config)
    generated_audio = inference_dual(trained_model, bigvgan_model, target_mel, target_linear, config)
    
    print("\n🎉 Dual-spectrogram training complete!")
    print("Linear spectrogram should capture high-frequency detail missed by mel-only")