# MusicControlNet Training Test

This notebook tests the training pipeline with a small subset of the NSynth dataset (100 samples total: 70 train, 15 validation, 15 test).

## Objectives:
- Load a small subset of NSynth dataset
- Prepare audio data (convert to mel-spectrograms)
- Generate text embeddings (using simple embeddings for testing)
- Train the UNetMelGenerator model
- Validate the training loop works correctly

## 1. Install Required Libraries

In [None]:
!pip install datasets librosa matplotlib seaborn pandas numpy plotly soundfile tqdm torch torchaudio transformers -q


## 2. Import Libraries

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
import librosa
from datasets import load_dataset
from huggingface_hub import HfApi
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Import our model
import sys
sys.path.append('.')
from model import build_model

# Set device
device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Libraries imported successfully!")

## 3. Load Small Subset of NSynth Dataset

We'll load exactly 100 samples from the dataset for testing.

In [None]:
print("Loading NSynth dataset from Parquet files...")
print("(Loading small subset for testing)\n")

from datasets import load_dataset
from huggingface_hub import HfApi

# Initialize HF API
api = HfApi()

# List all files in the repository
print("Discovering dataset files...")
repo_files = list(api.list_repo_files("jg583/NSynth", repo_type="dataset"))

# Find parquet files
parquet_files = [f for f in repo_files if f.endswith('.parquet')]

# Group by split
train_parquet = [f for f in parquet_files if 'train' in f.lower()]
valid_parquet = [f for f in parquet_files if 'valid' in f.lower() or 'validation' in f.lower()]
test_parquet = [f for f in parquet_files if 'test' in f.lower()]

print(f"Found {len(parquet_files)} Parquet file(s)")
print(f"  - Train: {len(train_parquet)} files")
print(f"  - Valid: {len(valid_parquet)} files")
print(f"  - Test: {len(test_parquet)} files")

# Create data files dict
data_files = {}
if train_parquet:
    data_files["train"] = [f"hf://datasets/jg583/NSynth/{f}" for f in train_parquet]
if valid_parquet:
    data_files["valid"] = [f"hf://datasets/jg583/NSynth/{f}" for f in valid_parquet]
if test_parquet:
    data_files["test"] = [f"hf://datasets/jg583/NSynth/{f}" for f in test_parquet]

# Load dataset using streaming
print("\nLoading dataset using Parquet loader...")
dataset = load_dataset(
    "parquet",
    data_files=data_files,
    streaming=True
)

print(f"\n✅ Dataset loaded successfully!")
print(f"Available splits: {list(dataset.keys())}")

## 4. Extract 100 Samples and Split (70/15/15)

In [None]:
def collect_samples(dataset_stream, n_samples):
    """Collect n samples from streaming dataset"""
    samples = []
    print(f"Collecting {n_samples} samples...")
    for i, item in enumerate(tqdm(dataset_stream, total=n_samples)):
        if i >= n_samples:
            break
        samples.append(item)
    print(f"Collected {len(samples)} samples")
    return samples

# Collect 100 samples from train split
print("Collecting 100 samples from train split...")
all_samples = collect_samples(dataset['train'], 100)

# Split into 70/15/15
train_samples = all_samples[:70]
val_samples = all_samples[70:85]
test_samples = all_samples[85:100]

print(f"\n{'='*80}")
print("DATA SPLIT COMPLETE")
print(f"{'='*80}")
print(f"Train samples: {len(train_samples)}")
print(f"Validation samples: {len(val_samples)}")
print(f"Test samples: {len(test_samples)}")
print(f"Total: {len(all_samples)}")

# Show a sample
print("\n" + "="*80)
print("Sample from training set:")
print("="*80)
sample = train_samples[0]
for key, value in sample.items():
    if key != 'audio':
        print(f"{key:25s}: {value}")
    else:
        if isinstance(value, dict) and 'array' in value:
            print(f"{key:25s}: [array of {len(value['array'])} samples at {value['sampling_rate']} Hz]")

## 5. Audio Processing Configuration

In [None]:
# Mel-spectrogram parameters
SAMPLE_RATE = 16000
N_FFT = 1024
HOP_LENGTH = 256
N_MELS = 128
DURATION = 4.0  # seconds
TARGET_LENGTH = int(SAMPLE_RATE * DURATION)  # 64,000 samples
MEL_TIME_STEPS = int(TARGET_LENGTH / HOP_LENGTH)  # ~250 time steps

print("Audio Processing Configuration:")
print(f"  Sample Rate: {SAMPLE_RATE} Hz")
print(f"  Duration: {DURATION} seconds")
print(f"  Target Length: {TARGET_LENGTH} samples")
print(f"  N_FFT: {N_FFT}")
print(f"  Hop Length: {HOP_LENGTH}")
print(f"  N_Mels: {N_MELS}")
print(f"  Mel Time Steps: {MEL_TIME_STEPS}")

## 6. Create PyTorch Dataset Class

In [None]:
class NSynthMelDataset(Dataset):
    """NSynth dataset that converts audio to mel-spectrograms"""
    
    def __init__(self, samples, sample_rate=16000, n_mels=128, n_fft=1024, 
                 hop_length=256, duration=4.0):
        self.samples = samples
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.target_length = int(sample_rate * duration)
        
        # Create simple text embeddings based on instrument family
        self.family_to_idx = {}
        families = list(set([s['instrument_family_str'] for s in samples]))
        for i, family in enumerate(sorted(families)):
            self.family_to_idx[family] = i
        
        print(f"Dataset created with {len(samples)} samples")
        print(f"Instrument families: {len(self.family_to_idx)}")
        print(f"Families: {list(self.family_to_idx.keys())}")
    
    def __len__(self):
        return len(self.samples)
    
    def _process_audio(self, audio_array, sr):
        """Convert audio to mel-spectrogram"""
        # Ensure audio_array is a numpy array
        if not isinstance(audio_array, np.ndarray):
            audio_array = np.array(audio_array, dtype=np.float32)
        
        # Ensure float type
        if audio_array.dtype != np.float32:
            audio_array = audio_array.astype(np.float32)
        
        # Resample if needed
        if sr != self.sample_rate:
            audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=self.sample_rate)
        
        # Pad or trim to target length
        if len(audio_array) < self.target_length:
            audio_array = np.pad(audio_array, (0, self.target_length - len(audio_array)))
        else:
            audio_array = audio_array[:self.target_length]
        
        # Convert to mel-spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=audio_array,
            sr=self.sample_rate,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            n_mels=self.n_mels,
            fmin=0,
            fmax=self.sample_rate // 2
        )
        
        # Convert to log scale (dB)
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        # Normalize to [-1, 1]
        mel_spec_normalized = (mel_spec_db + 80) / 80  # Assuming -80 dB is silence
        mel_spec_normalized = np.clip(mel_spec_normalized, -1, 1)
        
        return mel_spec_normalized
    
    def _create_text_embedding(self, sample):
        """Create a simple text embedding from sample metadata"""
        # Use one-hot encoding for instrument family
        family_idx = self.family_to_idx[sample['instrument_family_str']]
        n_families = len(self.family_to_idx)
        
        # Create a 512-dim embedding (to match expected input)
        # First part: one-hot family encoding (padded)
        # Second part: normalized pitch, velocity, and qualities
        embedding = np.zeros(512, dtype=np.float32)
        embedding[family_idx] = 1.0  # One-hot family
        
        # Add pitch and velocity info (normalized to [0, 1])
        embedding[n_families] = sample['pitch'] / 127.0
        embedding[n_families + 1] = sample['velocity'] / 127.0
        
        # Add source info (one-hot, 3 categories)
        source_map = {'acoustic': 0, 'electronic': 1, 'synthetic': 2}
        if sample['instrument_source_str'] in source_map:
            source_idx = source_map[sample['instrument_source_str']]
            embedding[n_families + 2 + source_idx] = 1.0
        
        return embedding
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Get audio - handle different possible formats
        audio_data = sample['audio']
        if isinstance(audio_data, dict):
            audio_array = audio_data['array']
            sr = audio_data['sampling_rate']
        else:
            # If it's not a dict, assume it's the array directly
            audio_array = audio_data
            sr = self.sample_rate
        
        # Process to mel-spectrogram
        mel_spec = self._process_audio(audio_array, sr)
        
        # Create text embedding
        text_emb = self._create_text_embedding(sample)
        
        # Convert to tensors
        # mel_spec: [n_mels, time] -> we'll use it as input and target
        mel_tensor = torch.FloatTensor(mel_spec).unsqueeze(0)  # [1, n_mels, time]
        text_tensor = torch.FloatTensor(text_emb)
        
        # Create noise for input (simple Gaussian noise)
        noise = torch.randn_like(mel_tensor)
        
        # Input: concatenate mel + noise along channel dimension
        input_tensor = torch.cat([mel_tensor, noise], dim=0)  # [2, n_mels, time]
        
        return {
            'input': input_tensor,
            'target': mel_tensor,
            'text_emb': text_tensor,
            'metadata': {
                'note': sample['note'],
                'instrument': sample['instrument_str'],
                'family': sample['instrument_family_str'],
                'pitch': sample['pitch'],
                'velocity': sample['velocity']
            }
        }

# Create datasets
train_dataset = NSynthMelDataset(train_samples, sample_rate=SAMPLE_RATE, n_mels=N_MELS, 
                                 n_fft=N_FFT, hop_length=HOP_LENGTH, duration=DURATION)
val_dataset = NSynthMelDataset(val_samples, sample_rate=SAMPLE_RATE, n_mels=N_MELS,
                               n_fft=N_FFT, hop_length=HOP_LENGTH, duration=DURATION)
test_dataset = NSynthMelDataset(test_samples, sample_rate=SAMPLE_RATE, n_mels=N_MELS,
                                n_fft=N_FFT, hop_length=HOP_LENGTH, duration=DURATION)

# Update val and test dataset family mappings
val_dataset.family_to_idx = train_dataset.family_to_idx
test_dataset.family_to_idx = train_dataset.family_to_idx

print("\n" + "="*80)
print("DATASETS CREATED")
print("="*80)
print(f"Train: {len(train_dataset)} samples")
print(f"Val: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")


## 7. Test Dataset Output

In [None]:
# Test a single sample
print("Testing dataset output...")
sample = train_dataset[0]

print(f"\nInput shape: {sample['input'].shape}")  # Expected: [2, 128, ~250]
print(f"Target shape: {sample['target'].shape}")  # Expected: [1, 128, ~250]
print(f"Text embedding shape: {sample['text_emb'].shape}")  # Expected: [512]

print(f"\nMetadata:")
for key, value in sample['metadata'].items():
    print(f"  {key}: {value}")

# Visualize mel-spectrogram
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Input (mel channel)
axes[0].imshow(sample['input'][0].numpy(), aspect='auto', origin='lower', cmap='viridis')
axes[0].set_title('Input Mel-spectrogram')
axes[0].set_xlabel('Time')
axes[0].set_ylabel('Mel Frequency')

# Input (noise channel)
axes[1].imshow(sample['input'][1].numpy(), aspect='auto', origin='lower', cmap='viridis')
axes[1].set_title('Input Noise')
axes[1].set_xlabel('Time')
axes[1].set_ylabel('Mel Frequency')

# Target
axes[2].imshow(sample['target'][0].numpy(), aspect='auto', origin='lower', cmap='viridis')
axes[2].set_title('Target Mel-spectrogram')
axes[2].set_xlabel('Time')
axes[2].set_ylabel('Mel Frequency')

plt.tight_layout()
plt.show()

## 8. Create DataLoaders

In [None]:
BATCH_SIZE = 4

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"DataLoaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

# Test a batch
batch = next(iter(train_loader))
print(f"\nSample batch shapes:")
print(f"  Input: {batch['input'].shape}")
print(f"  Target: {batch['target'].shape}")
print(f"  Text embedding: {batch['text_emb'].shape}")

## 9. Initialize Model

In [None]:
# Model configuration
N_MELS = 128
IN_CHANNELS = 2  # mel + noise
BASE_CHANNELS = 32  # Reduced for faster training on small dataset
CHANNEL_MULTS = (1, 2, 2, 4)  # Reduced complexity
RAW_TEXT_EMB_DIM = 512
TEXT_COND_DIM = 128  # Reduced

print("Building model...")
gen, text_proj = build_model(
    n_mels=N_MELS,
    in_channels=IN_CHANNELS,
    base_channels=BASE_CHANNELS,
    channel_mults=CHANNEL_MULTS,
    raw_text_emb_dim=RAW_TEXT_EMB_DIM,
    text_cond_dim=TEXT_COND_DIM
)

# Move to device
gen = gen.to(device)
text_proj = text_proj.to(device)

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n{'='*80}")
print("MODEL SUMMARY")
print(f"{'='*80}")
print(f"Generator parameters: {count_parameters(gen):,}")
print(f"Text projector parameters: {count_parameters(text_proj):,}")
print(f"Total parameters: {count_parameters(gen) + count_parameters(text_proj):,}")
print(f"Device: {device}")

## 10. Test Forward Pass

In [None]:
print("Testing forward pass...")

# Get a batch
batch = next(iter(train_loader))
x = batch['input'].to(device)
target = batch['target'].to(device)
text_emb = batch['text_emb'].to(device)

print(f"Input shape: {x.shape}")
print(f"Target shape: {target.shape}")
print(f"Text embedding shape: {text_emb.shape}")

# Forward pass
with torch.no_grad():
    text_cond = text_proj(text_emb)
    output = gen(x, text_cond)

print(f"\nText condition shape: {text_cond.shape}")
print(f"Output shape: {output.shape}")
print(f"\n✅ Forward pass successful!")

## 11. Define Training Configuration

In [None]:
# Training configuration
LEARNING_RATE = 1e-4
NUM_EPOCHS = 10
GRAD_CLIP = 1.0

# Loss function (MSE for mel-spectrogram reconstruction)
criterion = nn.MSELoss()

# Optimizer
optimizer = torch.optim.AdamW(
    list(gen.parameters()) + list(text_proj.parameters()),
    lr=LEARNING_RATE,
    betas=(0.9, 0.999),
    weight_decay=1e-4
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=NUM_EPOCHS,
    eta_min=1e-6
)

print("Training Configuration:")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Gradient Clipping: {GRAD_CLIP}")
print(f"  Optimizer: AdamW")
print(f"  Scheduler: CosineAnnealingLR")
print(f"  Loss: MSE")

## 12. Training Loop

In [None]:
from IPython.display import clear_output

# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'lr': []
}

print("Starting training...")
print("="*80)

for epoch in range(NUM_EPOCHS):
    # Training phase
    gen.train()
    text_proj.train()
    train_losses = []
    
    train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS} [Train]')
    for batch in train_pbar:
        x = batch['input'].to(device)
        target = batch['target'].to(device)
        text_emb = batch['text_emb'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        text_cond = text_proj(text_emb)
        output = gen(x, text_cond)
        
        # Compute loss
        loss = criterion(output, target)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(gen.parameters(), GRAD_CLIP)
        torch.nn.utils.clip_grad_norm_(text_proj.parameters(), GRAD_CLIP)
        
        optimizer.step()
        
        train_losses.append(loss.item())
        train_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    # Validation phase
    gen.eval()
    text_proj.eval()
    val_losses = []
    
    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS} [Val]')
        for batch in val_pbar:
            x = batch['input'].to(device)
            target = batch['target'].to(device)
            text_emb = batch['text_emb'].to(device)
            
            text_cond = text_proj(text_emb)
            output = gen(x, text_cond)
            
            loss = criterion(output, target)
            val_losses.append(loss.item())
            val_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    # Update scheduler
    scheduler.step()
    
    # Record history
    avg_train_loss = np.mean(train_losses)
    avg_val_loss = np.mean(val_losses)
    current_lr = optimizer.param_groups[0]['lr']
    
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)
    history['lr'].append(current_lr)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Val Loss: {avg_val_loss:.4f}")
    print(f"  Learning Rate: {current_lr:.6f}")
    print("-" * 80)

print("\n✅ Training complete!")

## 13. Plot Training History

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
epochs = range(1, NUM_EPOCHS + 1)
axes[0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss (MSE)')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Learning rate
axes[1].plot(epochs, history['lr'], 'g-', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Learning Rate')
axes[1].set_title('Learning Rate Schedule')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print final metrics
print("\nFinal Metrics:")
print(f"  Final Train Loss: {history['train_loss'][-1]:.4f}")
print(f"  Final Val Loss: {history['val_loss'][-1]:.4f}")
print(f"  Best Val Loss: {min(history['val_loss']):.4f} (Epoch {np.argmin(history['val_loss'])+1})")

## 14. Test Set Evaluation

In [None]:
print("Evaluating on test set...")

gen.eval()
text_proj.eval()
test_losses = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc='Testing'):
        x = batch['input'].to(device)
        target = batch['target'].to(device)
        text_emb = batch['text_emb'].to(device)
        
        text_cond = text_proj(text_emb)
        output = gen(x, text_cond)
        
        loss = criterion(output, target)
        test_losses.append(loss.item())

avg_test_loss = np.mean(test_losses)
print(f"\n{'='*80}")
print("TEST SET RESULTS")
print(f"{'='*80}")
print(f"Average Test Loss: {avg_test_loss:.4f}")
print(f"Min Test Loss: {min(test_losses):.4f}")
print(f"Max Test Loss: {max(test_losses):.4f}")
print(f"Std Test Loss: {np.std(test_losses):.4f}")

## 15. Visualize Model Predictions

In [None]:
# Get a test batch for visualization
test_batch = next(iter(test_loader))
x_test = test_batch['input'].to(device)
target_test = test_batch['target'].to(device)
text_emb_test = test_batch['text_emb'].to(device)
metadata_test = test_batch['metadata']

# Generate predictions
gen.eval()
with torch.no_grad():
    text_cond_test = text_proj(text_emb_test)
    output_test = gen(x_test, text_cond_test)

# Move to CPU for plotting
x_cpu = x_test.cpu()
target_cpu = target_test.cpu()
output_cpu = output_test.cpu()

# Plot first 4 samples in batch
n_samples = min(4, len(x_cpu))
fig, axes = plt.subplots(n_samples, 3, figsize=(15, 4*n_samples))

for i in range(n_samples):
    if n_samples == 1:
        ax_row = axes
    else:
        ax_row = axes[i]
    
    # Input mel-spectrogram
    ax_row[0].imshow(x_cpu[i, 0].numpy(), aspect='auto', origin='lower', cmap='viridis')
    ax_row[0].set_title(f'Input Mel\n{metadata_test["family"][i]}')
    ax_row[0].set_ylabel('Mel Frequency')
    
    # Target mel-spectrogram
    ax_row[1].imshow(target_cpu[i, 0].numpy(), aspect='auto', origin='lower', cmap='viridis')
    ax_row[1].set_title(f'Target Mel\nPitch: {metadata_test["pitch"][i]}')
    
    # Generated mel-spectrogram
    ax_row[2].imshow(output_cpu[i, 0].numpy(), aspect='auto', origin='lower', cmap='viridis')
    mse = F.mse_loss(output_cpu[i], target_cpu[i]).item()
    ax_row[2].set_title(f'Generated Mel\nMSE: {mse:.4f}')
    
    if i == n_samples - 1:
        for ax in ax_row:
            ax.set_xlabel('Time')

plt.tight_layout()
plt.show()

## 16. Save Model Checkpoint

In [None]:
# Save model checkpoint
checkpoint = {
    'epoch': NUM_EPOCHS,
    'gen_state_dict': gen.state_dict(),
    'text_proj_state_dict': text_proj.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'history': history,
    'config': {
        'n_mels': N_MELS,
        'in_channels': IN_CHANNELS,
        'base_channels': BASE_CHANNELS,
        'channel_mults': CHANNEL_MULTS,
        'raw_text_emb_dim': RAW_TEXT_EMB_DIM,
        'text_cond_dim': TEXT_COND_DIM,
        'sample_rate': SAMPLE_RATE,
        'n_fft': N_FFT,
        'hop_length': HOP_LENGTH,
        'duration': DURATION
    }
}

checkpoint_path = 'model_checkpoint_test.pt'
torch.save(checkpoint, checkpoint_path)
print(f"✅ Model checkpoint saved to: {checkpoint_path}")

# Also save just the model weights
torch.save({
    'gen': gen.state_dict(),
    'text_proj': text_proj.state_dict()
}, 'model_weights_test.pt')
print(f"✅ Model weights saved to: model_weights_test.pt")

## 17. Summary

### Training Test Results:
- ✅ Successfully loaded 100 samples from NSynth dataset
- ✅ Split data into 70/15/15 (train/val/test)
- ✅ Converted audio to mel-spectrograms
- ✅ Created simple text embeddings from metadata
- ✅ Built UNetMelGenerator model
- ✅ Trained for {NUM_EPOCHS} epochs
- ✅ Model can generate mel-spectrograms conditioned on text

### Next Steps:
1. Scale up to full dataset
2. Implement proper text encoder (CLAP/T5)
3. Add diffusion noise schedule
4. Implement proper evaluation metrics (FID, IS, etc.)
5. Add mel-to-audio vocoder (HiFi-GAN, etc.)
6. Experiment with different architectures and hyperparameters