# Generating Audio Timbres with Neural Cellular Automata 🎶
## Phase 1: BigVGAN Integration (FIXED VERSION)

**Fixes Applied**:
1. ✅ Corrected dB-to-power conversion (was `/20`, now `/10`)
2. ✅ Increased frequency weighting (from `2.0` to `5.0`)
3. ✅ Increased L1 loss weight (from `0.1` to `0.5`)
4. ✅ Added circular padding for CA boundary conditions

**Expected Improvements**:
- Correct energy distribution (no more 4x too loud)
- Better high-frequency content (above 4kHz)
- Reduced edge artifacts
- More natural violin timbre

In [None]:
# Install necessary libraries
!pip install librosa soundfile
!pip install bigvgan  # Universal vocoder for music/speech

# Import all required modules
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import soundfile as sf
import os
import bigvgan

# Device configuration
print(f"PyTorch version: {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
class Config:
    SAMPLE_RATE = 22050
    N_FFT = 1024
    HOP_LENGTH = 256
    N_MELS = 80  # Matches BigVGAN base model
    CELL_CHANNELS = 16
    UPDATE_STEPS_TRAIN = 96
    UPDATE_STEPS_INFERENCE = 96
    LEARNING_RATE = 2e-4
    NUM_EPOCHS = 8000
    LOG_INTERVAL = 100
    OUTPUT_IMAGE_DIR = 'training_progress'

In [None]:
def create_auditory_grid(audio_path, config):
    """Loads audio and converts to mel spectrogram."""
    waveform, sr = librosa.load(audio_path, sr=config.SAMPLE_RATE, mono=True)

    # Ensure time dimension is multiple of 8
    n_frames = len(waveform) / config.HOP_LENGTH
    if n_frames % 8 != 0:
        target_frames = int(np.ceil(n_frames / 8.0)) * 8
        target_samples = target_frames * config.HOP_LENGTH
        padding_needed = target_samples - len(waveform)
        waveform = np.pad(waveform, (0, padding_needed), 'constant')

    mel_spectrogram = librosa.feature.melspectrogram(
        y=waveform, 
        sr=config.SAMPLE_RATE, 
        n_fft=config.N_FFT,
        hop_length=config.HOP_LENGTH, 
        n_mels=config.N_MELS
    )
    return librosa.power_to_db(mel_spectrogram, ref=np.max)

def visualize_spectrogram(spectrogram, config, title='Mel-Spectrogram', output_path=None):
    """Visualizes a mel spectrogram."""
    plt.figure(figsize=(12, 5))

    if isinstance(spectrogram, torch.Tensor):
        spectrogram = spectrogram.detach().cpu().numpy()

    librosa.display.specshow(
        spectrogram, 
        sr=config.SAMPLE_RATE, 
        hop_length=config.HOP_LENGTH, 
        x_axis='time', 
        y_axis='mel'
    )
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

In [None]:
# Load BigVGAN universal vocoder
print("Loading BigVGAN universal vocoder...")
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
    'nvidia/bigvgan_base_22khz_80band',
    use_cuda_kernel=False
)
bigvgan_model = bigvgan_model.to(device)
bigvgan_model.eval()
print("BigVGAN vocoder loaded successfully!")
print(f"Sample rate: 22kHz, Mel bands: 80")

In [None]:
class TNCAModel(nn.Module):
    """Textured Neural Cellular Automata with circular padding (FIXED)."""
    def __init__(self, num_channels=16):
        super().__init__()
        self.num_channels = num_channels
        perception_vector_size = self.num_channels * 4

        self.update_mlp = nn.Sequential(
            nn.Conv2d(perception_vector_size, 128, 1),
            nn.ReLU(),
            nn.Conv2d(128, 64, 1),
            nn.ReLU(),
            nn.Conv2d(64, self.num_channels, 1, bias=True)
        )
        
        # Initialize for stability
        self.update_mlp[-1].weight.data.zero_()
        self.update_mlp[-1].bias.data.zero_()
        self.update_mlp[-1].bias.data[1] = 1.0  # Alpha channel bias

    def perceive(self, grid):
        """Apply perception filters with CIRCULAR padding (FIX #4)."""
        device = grid.device
        
        sobel_x = torch.tensor(
            [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], 
            dtype=torch.float32, device=device
        ).view(1, 1, 3, 3)
        sobel_y = sobel_x.transpose(-2, -1).contiguous()
        laplacian = torch.tensor(
            [[1, 2, 1], [2, -12, 2], [1, 2, 1]], 
            dtype=torch.float32, device=device
        ).view(1, 1, 3, 3)

        # Create kernels for each channel
        sobel_x_kernel = sobel_x.repeat(self.num_channels, 1, 1, 1)
        sobel_y_kernel = sobel_y.repeat(self.num_channels, 1, 1, 1)
        laplacian_kernel = laplacian.repeat(self.num_channels, 1, 1, 1)

        # Apply CIRCULAR padding to avoid edge artifacts
        # Pad: (left, right, top, bottom)
        grid_padded = F.pad(grid, (1, 1, 1, 1), mode='circular')

        # Apply filters with padding=0 since we pre-padded
        grad_x = F.conv2d(grid_padded, sobel_x_kernel, padding=0, groups=self.num_channels)
        grad_y = F.conv2d(grid_padded, sobel_y_kernel, padding=0, groups=self.num_channels)
        lap = F.conv2d(grid_padded, laplacian_kernel, padding=0, groups=self.num_channels)

        return torch.cat([grid, grad_x, grad_y, lap], dim=1)

    def forward(self, grid):
        perception_vector = self.perceive(grid)
        ds = self.update_mlp(perception_vector)
        grid = grid + ds

        # Clamp alpha channel
        clamped_alpha = torch.clamp(grid[:, 1:2, :, :], 0.0, 1.0)
        grid = torch.cat([grid[:, :1, :, :], clamped_alpha, grid[:, 2:, :, :]], dim=1)

        # Apply living mask (also with circular padding)
        alpha_channel = grid[:, 1:2, :, :]
        alpha_padded = F.pad(alpha_channel, (1, 1, 1, 1), mode='circular')
        living_mask = F.max_pool2d(alpha_padded, kernel_size=3, stride=1, padding=0)
        grid = grid * living_mask

        return grid

class PerceptualLoss(nn.Module):
    """Perceptual loss based on Gram matrix."""
    def __init__(self, device):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 16, 7, 2, 3), nn.ReLU(),
            nn.Conv2d(16, 32, 5, 2, 2), nn.ReLU(),
            nn.Conv2d(32, 64, 3, 2, 1), nn.ReLU()
        ).to(device)

    def gram_matrix(self, f):
        b, c, h, w = f.size()
        f = f.view(b, c, h * w)
        return torch.bmm(f, f.transpose(1, 2)) / (c * h * w)

    def forward(self, gen_mel, target_mel):
        return F.mse_loss(
            self.gram_matrix(self.feature_extractor(gen_mel)),
            self.gram_matrix(self.feature_extractor(target_mel))
        )

In [None]:
def train(model, vocoder, loss_fn, optimizer, target_mel, config):
    """Train the TNCA model (FIXED loss weights)."""
    print("\n--- Starting Training (FIXED VERSION) ---")
    perceptual_loss_fn = loss_fn
    l1_loss_fn = nn.L1Loss()

    target_mel_tensor = torch.tensor(
        target_mel, dtype=torch.float32
    ).to(device).unsqueeze(0).unsqueeze(1)

    # Mask for non-silent regions
    target_mask = (target_mel_tensor > -70.0).float()

    # FIX #2: INCREASED frequency weighting (1.0 -> 5.0 instead of 2.0)
    n_mels = target_mel_tensor.shape[2]
    frequency_loss_weights = torch.linspace(
        1.0, 5.0, n_mels, device=device  # Changed from 2.0 to 5.0
    ).view(1, 1, n_mels, 1)
    print(f"Frequency weighting: 1.0 (low freq) -> 5.0 (high freq)")

    # Initialize seed grid
    seed_grid = torch.zeros(
        1, config.CELL_CHANNELS, config.N_MELS, target_mel.shape[1], device=device
    )
    h, w = seed_grid.shape[2], seed_grid.shape[3]
    seed_grid[:, 1, h//2, w//2] = 1.0

    pbar = tqdm(range(config.NUM_EPOCHS), desc="Training...")

    for epoch in pbar:
        grid = seed_grid.clone()
        
        # Evolve the grid
        for _ in range(config.UPDATE_STEPS_TRAIN):
            grid = model(grid)

        # Extract and scale to dB range
        generated_mel_db_unscaled = grid[:, 0:1, :, :]
        generated_mel_db = torch.tanh(generated_mel_db_unscaled) * 40.0 - 40.0

        # Calculate losses
        p_loss = perceptual_loss_fn(
            generated_mel_db * target_mask, 
            target_mel_tensor * target_mask
        )
        l1_loss = l1_loss_fn(
            generated_mel_db * target_mask, 
            target_mel_tensor * target_mask
        )
        
        # FIX #3: INCREASED L1 loss weight (0.1 -> 0.5)
        loss = (2.0 * p_loss + 0.5 * l1_loss) * frequency_loss_weights  # Changed from 0.1
        loss = loss.mean()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Logging
        if (epoch + 1) % config.LOG_INTERVAL == 0:
            pbar.set_description(
                f"Epoch {epoch+1}, Loss: {loss.item():.6f} "
                f"(P: {p_loss.item():.6f}, L1: {l1_loss.item():.6f})"
            )
            filepath = os.path.join(
                config.OUTPUT_IMAGE_DIR, 
                f'epoch_{(epoch+1):04d}.png'
            )
            visualize_spectrogram(
                generated_mel_db.squeeze(), 
                config,
                title=f'Generated Spectrogram - Epoch {epoch+1} (FIXED)',
                output_path=filepath
            )

    print(f"\nTraining complete. Final loss: {loss.item():.6f}")
    return model

In [None]:
def inference(model, vocoder, target_mel, config, output_path='generated_bigvgan_FIXED.wav'):
    """Generate audio (FIXED dB conversion)."""
    print("\n--- Running Inference with BigVGAN (FIXED) ---")

    model.eval()

    # Ensure target_mel is tensor on device
    if isinstance(target_mel, np.ndarray):
        target_mel = torch.tensor(target_mel, dtype=torch.float32, device=device)
    else:
        target_mel = target_mel.to(device)

    # Initialize seed grid
    seed_grid = torch.zeros(
        1, config.CELL_CHANNELS, config.N_MELS, target_mel.shape[1], device=device
    )
    h, w = seed_grid.shape[2], seed_grid.shape[3]
    seed_grid[:, 1, h//2, w//2] = 1.0

    with torch.no_grad():
        final_grid = seed_grid.clone()
        
        for _ in tqdm(range(config.UPDATE_STEPS_INFERENCE), desc="Generating..."):
            final_grid = model(final_grid)

        # Extract and scale
        final_mel_db_unscaled = final_grid[:, 0, :, :]
        final_mel_db = torch.tanh(final_mel_db_unscaled) * 40.0 - 40.0

        # FIX #1: CORRECT dB to power conversion
        # librosa.power_to_db uses: dB = 10 * log10(power)
        # Therefore: power = 10^(dB/10)
        print("Converting dB to power spectrum (10^(dB/10))...")
        final_mel_power = torch.pow(10.0, final_mel_db / 10.0)  # FIXED: was /20, now /10
        
        print(f"Mel spectrogram shape: {final_mel_power.shape}")
        print(f"Mel power range: [{final_mel_power.min():.6f}, {final_mel_power.max():.6f}]")

        # BigVGAN expects [batch, mels, time]
        if final_mel_power.dim() == 2:
            final_mel_power = final_mel_power.unsqueeze(0)

        print("Generating audio with BigVGAN vocoder...")
        
        # BigVGAN inference
        final_waveform = vocoder(final_mel_power)
        
        # BigVGAN returns [batch, 1, samples], squeeze to [samples]
        final_waveform = final_waveform.squeeze()

    # Save audio
    waveform_np = final_waveform.cpu().numpy()
    sf.write(output_path, waveform_np, config.SAMPLE_RATE)
    
    print(f"\n✅ Inference complete!")
    print(f"Audio saved to: {output_path}")
    print(f"Audio duration: {len(waveform_np) / config.SAMPLE_RATE:.2f} seconds")
    print(f"RMS energy: {np.sqrt(np.mean(waveform_np**2)):.4f}")
    
    visualize_spectrogram(
        final_mel_db.squeeze().cpu(), 
        config, 
        title='Final Generated Spectrogram (BigVGAN FIXED)'
    )
    
    return waveform_np

In [None]:
# --- Main Execution ---

config = Config()
os.makedirs(config.OUTPUT_IMAGE_DIR, exist_ok=True)

# Load target audio
target_audio_path = 'violin.wav'  # <-- Update with your file path

if not os.path.exists(target_audio_path):
    print(f"ERROR: Audio file not found at '{target_audio_path}'")
    print("Please upload the file and update the path.")
else:
    print("Loading target audio...")
    target_spectrogram = create_auditory_grid(target_audio_path, config)
    
    print("\nTarget Spectrogram:")
    visualize_spectrogram(target_spectrogram, config, title='Target Spectrogram')

    # Initialize models
    print("\nInitializing TNCA model (FIXED with circular padding)...")
    tnca_model = TNCAModel(config.CELL_CHANNELS).to(device)
    perceptual_loss_fn = PerceptualLoss(device).to(device)
    optimizer = optim.Adam(tnca_model.parameters(), lr=config.LEARNING_RATE)

    print("\n" + "="*70)
    print("FIXES APPLIED:")
    print("  1. ✅ dB-to-power: Changed /20 to /10 (correct for power spectrograms)")
    print("  2. ✅ Frequency weights: Increased from 2.0 to 5.0 (emphasize high freq)")
    print("  3. ✅ L1 loss weight: Increased from 0.1 to 0.5 (capture more detail)")
    print("  4. ✅ Circular padding: Eliminate edge artifacts in CA")
    print("="*70)

    # Train
    trained_model = train(
        tnca_model, 
        bigvgan_model, 
        perceptual_loss_fn, 
        optimizer, 
        target_spectrogram, 
        config
    )

    # Generate audio with BigVGAN
    generated_audio = inference(
        trained_model, 
        bigvgan_model, 
        target_spectrogram, 
        config, 
        output_path='generated_bigvgan_FIXED.wav'
    )
    
    print("\n" + "="*70)
    print("🎉 TRAINING COMPLETE WITH FIXES!")
    print("="*70)
    print("\nCompare 'generated_bigvgan_FIXED.wav' with 'generated_bigvgan-2.wav'")
    print("\nExpected improvements:")
    print("  - Correct energy level (RMS ~0.10, not 0.42)")
    print("  - Better high-frequency content (above 4kHz)")
    print("  - No edge artifacts (vertical bands)")
    print("  - More natural violin timbre")
    print("\nIf quality is good, proceed to Phase 2: Audio-domain loss")