# SpecGAN Training for Solar Radio Burst Generation

**Based on:** Chris Donahue's SpecGAN (https://github.com/chrisdonahue/wavegan)  
**Ported from:** TensorFlow to PyTorch  
**Data:** 128×128 CSV spectrogram windows of solar radio bursts

**Key SpecGAN Features:**
- Per-frequency bin normalization (preserves frequency-specific characteristics)
- WGAN-GP loss (more stable training than DCGAN)
- Single-channel architecture (matches grayscale spectrograms)
- 5×5 kernels (larger receptive field than DCGAN's 4×4)
- D:G update ratio of 5:1 (discriminator trains 5 times per generator update)

**Prerequisites:**
Run `compute_moments.py` BEFORE training to generate moments.npz file


## 1. Imports


In [None]:
import os
import sys
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.utils as vutils

# Add paths for local modules
sys.path.insert(0, '/Users/remiliascarlet/Desktop/MDP/transfer_learning/dcgan')

# Import SpecGAN components
from specgan.specgan_models import SpecGANGenerator, SpecGANDiscriminator, weights_init
from specgan.specgan_utils import (
    PerFrequencyNormalizer, GANLoss, compute_gradient_penalty,
    save_gan_checkpoint, load_gan_checkpoint, get_specgan_optimizer,
    SPECGAN_DEFAULTS
)
from csv_spectrogram_dataset import CSVSpectrogramDataset


## 2. Set Random Seed for Reproducibility


In [None]:
seed = 999
print(f"Using Seed: {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)


## 3. Configuration (SpecGAN Defaults)

**From train_specgan.py, Lines 687-712**


In [None]:
# Data paths
dataroot = "/Users/remiliascarlet/Desktop/MDP/transfer_learning/burst_data/csv/gan_training_windows_128/type_3/"
moments_path = "./checkpoints_specgan/type3_moments.npz"
checkpoint_dir = "./checkpoints_specgan"

# SpecGAN architecture defaults (from Lines 697-700)
nz = 100              # Latent dimension (specgan_latent_dim)
nc = 1                # Number of channels (single-channel spectrogram)
kernel_len = 5        # Kernel size (specgan_kernel_len) - 5×5, not 4×4!
dim = 64              # Dimension multiplier (specgan_dim)
use_batchnorm = False # BatchNorm (specgan_batchnorm) - SpecGAN default is False

# Training parameters (from Lines 701-705)
disc_nupdates = 5     # D updates per G update (specgan_disc_nupdates) - Important!
loss_type = 'wgan-gp' # Loss function (specgan_loss) - WGAN-GP is SpecGAN default
batch_size = 16       # Batch size (adjusted from 64 for small dataset)
num_epochs = 500      # Number of training epochs
workers = 2           # DataLoader workers

# Optimizer parameters for WGAN-GP (from Lines 261-269)
lr = 1e-4             # Learning rate (both G and D use same LR for WGAN-GP)
beta1 = 0.5           # Adam beta1
beta2 = 0.9           # Adam beta2

# GPU settings
ngpu = 1              # Number of GPUs

# Model saving
save_interval = 5     # Save every N epochs

print("="*70)
print("SpecGAN Configuration (following original SpecGAN defaults)")
print("="*70)
print(f"Data root: {dataroot}")
print(f"Moments file: {moments_path}")
print(f"Checkpoint dir: {checkpoint_dir}")
print(f"\nArchitecture (from specgan.py):")
print(f"  Latent dim (nz): {nz}")
print(f"  Channels (nc): {nc} (single-channel)")
print(f"  Kernel size: {kernel_len}×{kernel_len}")
print(f"  Dimension multiplier: {dim}")
print(f"  BatchNorm: {use_batchnorm}")
print(f"\nTraining (from train_specgan.py):")
print(f"  Loss type: {loss_type}")
print(f"  D:G update ratio: {disc_nupdates}:1")
print(f"  Batch size: {batch_size}")
print(f"  Epochs: {num_epochs}")
print(f"  Learning rate: {lr}")
print(f"  Adam betas: ({beta1}, {beta2})")
print("="*70)


## 4. Load Dataset with Per-Frequency Normalization

**Uses pre-computed moments from compute_moments.py**


In [None]:
# Create dataset with SpecGAN features
# This replaces SpecGAN's loader.py + t_to_f() normalization
dataset = CSVSpectrogramDataset(
    root_dir=dataroot,
    normalize_method='per_frequency',  # SpecGAN's key feature!
    grayscale=True,                    # Single channel (nc=1)
    moments_path=moments_path,         # Load pre-computed moments
    augment=True,                      # Enable temporal shift augmentation
    subsample_ratio=1.0
)

# Create dataloader
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers,
    drop_last=True  # Ensure consistent batch size
)

# Device configuration
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
print(f"\n🖥️  Using device: {device}")

# Visualize sample batch
print("\n📊 Loading sample batch...")
real_batch = next(iter(dataloader))
print(f"Batch shape: {real_batch.shape}")  # Should be [batch_size, 1, 128, 128]
print(f"Batch value range: [{real_batch.min():.3f}, {real_batch.max():.3f}]")
print(f"Expected: [-1, 1] (per-frequency normalized)")

# Visualize
plt.figure(figsize=(10, 10))
plt.axis("off")
plt.title("Sample Training Spectrograms (Type 3, Per-Frequency Normalized)")
# For single channel, we need to convert [N,1,H,W] to displayable format
grid = vutils.make_grid(real_batch[:16].to(device), padding=2, normalize=True, nrow=4)
plt.imshow(grid.cpu().permute(1, 2, 0)[:,:,0], cmap='hot')  # Show single channel
plt.colorbar()
plt.show()


## 5. Build Models (SpecGAN Architecture)

**From specgan.py: SpecGANGenerator and SpecGANDiscriminator**


In [None]:
# Create Generator
# From specgan.py, Lines 47-111
netG = SpecGANGenerator(
    nz=nz,
    kernel_len=kernel_len,
    dim=dim,
    nc=nc,
    use_batchnorm=use_batchnorm,
    ngpu=ngpu
).to(device)

# Create Discriminator  
# From specgan.py, Lines 122-178
netD = SpecGANDiscriminator(
    kernel_len=kernel_len,
    dim=dim,
    nc=nc,
    use_batchnorm=use_batchnorm,
    ngpu=ngpu
).to(device)

# Handle multi-GPU if available
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply weight initialization (DCGAN standard)
netG.apply(weights_init)
netD.apply(weights_init)

# Print model summaries (from train_specgan.py, Lines 135-175)
print("\n" + "-"*70)
print("Generator Architecture")
print("-"*70)
print(netG)
total_params_g = sum(p.numel() for p in netG.parameters())
print(f"Total params: {total_params_g:,} ({total_params_g * 4 / (1024*1024):.2f} MB)")

print("\n" + "-"*70)
print("Discriminator Architecture")
print("-"*70)
print(netD)
total_params_d = sum(p.numel() for p in netD.parameters())
print(f"Total params: {total_params_d:,} ({total_params_d * 4 / (1024*1024):.2f} MB)")
print("-"*70)


## 6. Setup Optimizers (SpecGAN Recommended Settings)

**From train_specgan.py, Lines 261-269 (WGAN-GP settings)**


In [None]:
# Get optimizers based on loss type
# Using helper function that implements Lines 243-271
optimizerG, optimizerD = get_specgan_optimizer(netG, netD, loss_type=loss_type)

# For WGAN-GP (default), both use:
# - learning_rate = 1e-4
# - beta1 = 0.5, beta2 = 0.9

print("✅ Optimizers initialized")
print(f"   Loss type: {loss_type}")
print(f"   Learning rate: {lr}")
print(f"   Adam betas: ({beta1}, {beta2})")

# Fixed noise for visualization (similar to SpecGAN's preview)
fixed_noise = torch.randn(64, nz, device=device)


## 7. Training Loop (SpecGAN Training Strategy)

**From train_specgan.py, Lines 278-295:**
- Train Discriminator `disc_nupdates` times (default: 5)
- Then train Generator once
- Uses WGAN-GP loss with gradient penalty


In [None]:
# Training tracking lists
img_list = []
G_losses = []
D_losses = []
D_real_history = []
D_fake_history = []
iters = 0

# Best model tracking
best_quality_metric = 0.0  # Track D(G(z)) - higher is better (up to ~0.5)
best_epoch = 0

print("="*70)
print("Starting SpecGAN Training Loop")
print("="*70)
print(f"Total batches per epoch: {len(dataloader)}")
print(f"Total iterations: {num_epochs * len(dataloader)}")
print(f"\nSpecGAN Training Strategy:")
print(f"  - Loss: {loss_type}")
print(f"  - D updates: {disc_nupdates} times per G update")
print(f"  - Per-frequency normalization: Enabled")
print(f"  - Temporal augmentation: Enabled")
print(f"  - Single channel (nc={nc})")
print("="*70)
print()

# Main training loop
for epoch in range(num_epochs):
    epoch_D_loss = []
    epoch_G_loss = []
    epoch_D_real = []
    epoch_D_fake = []
    
    for i, data in enumerate(dataloader, 0):
        real_data = data.to(device)
        b_size = real_data.size(0)
        
        ############################
        # (1) Update D network: Train discriminator multiple times
        #     From train_specgan.py, Lines 287-292
        ############################
        for d_iter in range(disc_nupdates):
            netD.zero_grad()
            
            # Forward pass real batch through D
            D_real = netD(real_data)
            
            # Generate fake batch
            noise = torch.randn(b_size, nz, device=device)
            fake_data = netG(noise)
            
            # Forward pass fake batch through D
            D_fake = netD(fake_data.detach())
            
            # Compute loss (WGAN-GP from Lines 222-236)
            if loss_type == 'wgan-gp':
                # Note: We only compute D_loss here; G_loss computed separately
                _, D_loss = GANLoss.wgan_gp_loss(
                    D_real, D_fake, netD, real_data, fake_data.detach(), device, lambda_gp=10
                )
            elif loss_type == 'dcgan':
                _, D_loss = GANLoss.dcgan_loss(D_real, D_fake)
            elif loss_type == 'lsgan':
                _, D_loss = GANLoss.lsgan_loss(D_real, D_fake)
            elif loss_type == 'wgan':
                _, D_loss = GANLoss.wgan_loss(D_real, D_fake)
            
            # Backprop and optimize D
            D_loss.backward()
            optimizerD.step()
            
            # Weight clipping for vanilla WGAN (not WGAN-GP)
            # From Lines 291-292
            if loss_type == 'wgan':
                from specgan.specgan_utils import clip_discriminator_weights
                clip_discriminator_weights(netD, clip_value=0.01)
        
        ############################
        # (2) Update G network: Train generator once
        #     From train_specgan.py, Lines 294-295
        ############################
        netG.zero_grad()
        
        # Generate new fake batch (don't reuse from D training)
        noise = torch.randn(b_size, nz, device=device)
        fake_data = netG(noise)
        
        # Forward through D
        D_fake_for_G = netD(fake_data)
        
        # Compute G loss
        if loss_type == 'wgan-gp' or loss_type == 'wgan':
            G_loss = -torch.mean(D_fake_for_G)  # Wasserstein
        elif loss_type == 'dcgan':
            G_loss, _ = GANLoss.dcgan_loss(D_real, D_fake_for_G)
        elif loss_type == 'lsgan':
            G_loss, _ = GANLoss.lsgan_loss(D_real, D_fake_for_G)
        
        # Backprop and optimize G
        G_loss.backward()
        optimizerG.step()
        
        # Record statistics
        D_real_mean = D_real.mean().item()
        D_fake_mean = D_fake_for_G.mean().item()
        
        epoch_D_loss.append(D_loss.item())
        epoch_G_loss.append(G_loss.item())
        epoch_D_real.append(D_real_mean)
        epoch_D_fake.append(D_fake_mean)
        
        # Save for plotting
        G_losses.append(G_loss.item())
        D_losses.append(D_loss.item())
        
        # Output training stats every 5 batches
        if i % 5 == 0:
            print(f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] '
                  f'Loss_D: {D_loss.item():.4f} Loss_G: {G_loss.item():.4f} '
                  f'D(real): {D_real_mean:.4f} D(fake): {D_fake_mean:.4f}')
        
        # Save generated images periodically
        if (iters % 100 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True, nrow=8))
        
        iters += 1
    
    # ===== End of epoch - evaluate and save models =====
    avg_D_loss = np.mean(epoch_D_loss)
    avg_G_loss = np.mean(epoch_G_loss)
    avg_D_real = np.mean(epoch_D_real)
    avg_D_fake = np.mean(epoch_D_fake)
    
    D_real_history.append(avg_D_real)
    D_fake_history.append(avg_D_fake)
    
    print(f"\n{'='*70}")
    print(f"Epoch {epoch} Summary:")
    print(f"  Avg D_loss: {avg_D_loss:.4f}")
    print(f"  Avg G_loss: {avg_G_loss:.4f}")
    print(f"  Avg D(real): {avg_D_real:.4f}")
    print(f"  Avg D(fake): {avg_D_fake:.4f}")
    print(f"  Best quality so far: {best_quality_metric:.4f} (Epoch {best_epoch})")
    
    # Save model if quality improved
    # For WGAN-GP, D(fake) closer to 0 is ideal (around 0 means balanced)
    # We use abs(D(fake)) as quality metric - lower is better for WGAN
    quality_metric = avg_D_fake if loss_type in ['dcgan', 'lsgan'] else -abs(avg_D_fake)
    
    if quality_metric > best_quality_metric:
        improvement = quality_metric - best_quality_metric
        best_quality_metric = quality_metric
        best_epoch = epoch
        
        print(f"  🎉 Quality improved by {improvement:.4f}! Saving checkpoint...")
        save_gan_checkpoint(
            netG, netD, optimizerG, optimizerD, epoch, quality_metric,
            checkpoint_dir,
            hyperparams={'nz': nz, 'nc': nc, 'kernel_len': kernel_len, 'dim': dim}
        )
    else:
        print(f"  ⏭️  No improvement this epoch")
    
    # Periodic save
    if (epoch + 1) % save_interval == 0:
        periodic_path = os.path.join(checkpoint_dir, f"periodic_epoch_{epoch}.pth")
        torch.save({
            'epoch': epoch,
            'generator_state_dict': netG.state_dict(),
            'discriminator_state_dict': netD.state_dict(),
            'optimizerG_state_dict': optimizerG.state_dict(),
            'optimizerD_state_dict': optimizerD.state_dict(),
        }, periodic_path)
        print(f"  💾 Periodic save: {periodic_path}")
    
    print("="*70 + "\n")

print("\n✅ Training completed!")
print(f"Best quality metric: {best_quality_metric:.4f} (Epoch {best_epoch})")


In [None]:
# Plot loss curves
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
axes[0].plot(G_losses, label='G Loss', alpha=0.7)
axes[0].plot(D_losses, label='D Loss', alpha=0.7)
axes[0].set_xlabel('Iterations')
axes[0].set_ylabel('Loss')
axes[0].set_title('Generator and Discriminator Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# D(real) and D(fake) curves
axes[1].plot(D_real_history, label='D(real)', alpha=0.7)
axes[1].plot(D_fake_history, label='D(fake)', alpha=0.7)
axes[1].axhline(y=0, color='r', linestyle='--', alpha=0.3, label='Target (WGAN)')
axes[1].set_xlabel('Epochs')
axes[1].set_ylabel('Discriminator Output')
axes[1].set_title('D(real) and D(fake) Over Training')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


## 9. Visualize Training Progress (Animation)


In [None]:
# Animation showing generator progress over time
fig = plt.figure(figsize=(8, 8))
plt.axis("off")

# For single-channel images, we need to handle them specially
ims = []
for img_grid in img_list:
    # img_grid is [C, H, W], take first channel for grayscale
    im = plt.imshow(img_grid[0].numpy(), animated=True, cmap='hot')
    ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())


## 10. Final Generated Spectrograms


In [None]:
# Visualize final generated spectrograms
plt.figure(figsize=(12, 12))
plt.axis("off")
plt.title("Generated Type 3 Radio Burst Spectrograms (SpecGAN, Final Epoch)")
# Show single-channel image with hot colormap
plt.imshow(img_list[-1][0].numpy(), cmap='hot')
plt.colorbar(label='Normalized Intensity [-1, 1]')
plt.show()


## 11. Real vs Fake Comparison


In [None]:
# Load a batch of real images
real_batch = next(iter(dataloader))

# Generate fake images
with torch.no_grad():
    noise = torch.randn(batch_size, nz, device=device)
    fake_batch = netG(noise).cpu()

# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(15, 7))

# Real spectrograms
axes[0].imshow(vutils.make_grid(real_batch[:16], nrow=4, padding=2, normalize=True)[0].numpy(), cmap='hot')
axes[0].set_title("Real Type 3 Burst Spectrograms")
axes[0].axis('off')

# Fake spectrograms  
axes[1].imshow(vutils.make_grid(fake_batch[:16], nrow=4, padding=2, normalize=True)[0].numpy(), cmap='hot')
axes[1].set_title("Generated (SpecGAN) Type 3 Spectrograms")
axes[1].axis('off')

plt.tight_layout()
plt.show()


## 12. Generate New Samples from Best Model


In [None]:
# Generate new samples with trained generator
num_samples = 16

with torch.no_grad():
    noise = torch.randn(num_samples, nz, device=device)
    generated = netG(noise).cpu()

# Visualize
plt.figure(figsize=(10, 10))
plt.axis("off")
plt.title("Newly Generated Type 3 Solar Radio Burst Spectrograms (SpecGAN)")
grid = vutils.make_grid(generated, padding=2, normalize=True, nrow=4)
plt.imshow(grid[0].numpy(), cmap='hot')  # Single channel
plt.colorbar(label='Normalized Intensity')
plt.show()

print(f"✅ Generated {num_samples} synthetic Type 3 radio burst spectrograms using SpecGAN!")


## 13. SpecGAN vs DCGAN Comparison Summary

**Key Improvements from SpecGAN:**
1. **Per-frequency normalization** - Each frequency bin normalized independently
2. **Single-channel design** - Matches spectrogram data (no artificial RGB)
3. **5×5 kernels** - Larger receptive field than DCGAN's 4×4
4. **WGAN-GP loss** - More stable training than BCE
5. **5:1 D:G ratio** - Discriminator trains more to stay balanced
6. **Temporal augmentation** - Random shifts break spatial bias

**Expected improvements over original DCGAN:**
- Reduced horizontal striping (per-frequency norm handles frequency-dependent noise)
- Eliminated left-side concentration (temporal augmentation)
- Better stability (WGAN-GP loss)
- Cleaner frequency structures (domain-aware normalization)
