# DCGAN Training with CSV Spectrogram Data

This notebook demonstrates how to train a DCGAN on 128×128 CSV spectrogram windows prepared for solar radio burst generation.

**Key Differences from Image-based Training:**
- Uses custom `CSVSpectrogramDataset` instead of `ImageFolder`
- Loads numerical CSV data instead of PNG/JPG images
- Directly processes spectral intensity values
- Optimized for 128×128 input (instead of 64×64)


## 1. Imports and Setup


In [None]:
import os
import sys
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Import custom CSV dataset loader
sys.path.append('/Users/remiliascarlet/Desktop/MDP/transfer_learning/dcgan')
from csv_spectrogram_dataset import CSVSpectrogramDataset


## 2. Set Random Seed for Reproducibility


In [None]:
seed = 999  # Set manually for reproducible results
print("Using Seed: ", seed)
random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)  # Needed for reproducible results


## 3. Hyperparameters and Configuration


In [None]:
# Data root - only use type_3 directory (218 samples)
dataroot = "/Users/remiliascarlet/Desktop/MDP/transfer_learning/burst_data/csv/gan_training_windows_128/type_3/"

# Model checkpoint directory
checkpoint_dir = "./checkpoints_gan_type3"

# Training hyperparameters
workers = 2
batch_size = 16  # Adjusted for 128×128 images
image_size = 128  # CSV files are 128×128
nc = 3  # 3 channels (CSV data duplicated to RGB for compatibility)
nz = 100  # Latent vector size
ngf = 64  # Generator feature map size
ndf = 64  # Discriminator feature map size
num_epochs = 500  # Number of training epochs
lr = 0.0002  # Learning rate for Generator (from DCGAN paper)
lr_d = 0.00005  # Lower learning rate for Discriminator (lr/4) to prevent overpowering
beta1 = 0.5  # Beta1 for Adam optimizer
ngpu = 1  # Number of GPUs (0 for CPU)

# Label smoothing parameters (to make discriminator training harder)
real_label_smooth = 0.9  # Real labels: 0.9 instead of 1.0
fake_label_smooth = 0.1  # Fake labels: 0.1 instead of 0.0

# Model saving parameters
save_interval = 5  # Save every N epochs (in addition to best model)

print(f"📋 Configuration:")
print(f"   Data root: {dataroot}")
print(f"   Checkpoint dir: {checkpoint_dir}")
print(f"   Image size: {image_size}×{image_size}")
print(f"   Batch size: {batch_size}")
print(f"   Epochs: {num_epochs}")
print(f"   Generator LR: {lr}")
print(f"   Discriminator LR: {lr_d} (reduced to prevent overpowering)")
print(f"   Label smoothing: Real={real_label_smooth}, Fake={fake_label_smooth}")
print(f"   Save interval: every {save_interval} epochs")
print(f"   Training Type 3 bursts only")


## 4. Load CSV Spectrogram Data


In [None]:
# Create the dataset using our custom CSV loader
dataset = CSVSpectrogramDataset(
    root_dir=dataroot,
    normalize_method='minmax',  # Normalize to [-1, 1] for tanh activation
    grayscale=False,  # Output 3 channels for RGB compatibility
    subsample_ratio=1.0  # Use all data
)

# Create the dataloader
dataloader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=batch_size,
    shuffle=True, 
    num_workers=workers
)

# Device configuration
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
print(f"\n🖥️  Using device: {device}")

# Plot some training samples
real_batch = next(iter(dataloader))
print(f"\nBatch shape: {real_batch.shape}")
print(f"Batch value range: [{real_batch.min():.3f}, {real_batch.max():.3f}]")

plt.figure(figsize=(10, 10))
plt.axis("off")
plt.title("Example Training Spectrograms (Type 3 Bursts)")
plt.imshow(np.transpose(vutils.make_grid(real_batch[:16].to(device), padding=2, normalize=True).cpu(), (1, 2, 0)))
plt.show()


## 5. Weight Initialization


In [None]:
# Custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


## 6. Generator Network (128×128 Version)


In [None]:
# Generator for 128×128 images
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # Input: Z (nz=100), going into a convolution
            # Output: (ngf*16) x 4 x 4
            nn.ConvTranspose2d(nz, ngf * 16, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 16),
            nn.ReLU(True),
            # State: (ngf*16) x 4 x 4
            nn.ConvTranspose2d(ngf * 16, ngf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # State: (ngf*8) x 8 x 8
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # State: (ngf*4) x 16 x 16
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # State: (ngf*2) x 32 x 32
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # State: (ngf) x 64 x 64
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # Final state: (nc=3) x 128 x 128
        )

    def forward(self, input):
        return self.main(input)

# Create the generator
netG = Generator(ngpu).to(device)
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))
netG.apply(weights_init)
print(netG)


## 7. Discriminator Network (128×128 Version)


In [None]:
# Discriminator for 128×128 images
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # Input: (nc=3) x 128 x 128
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf) x 64 x 64
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf*2) x 32 x 32
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf*4) x 16 x 16
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf*8) x 8 x 8
            nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 16),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf*16) x 4 x 4
            nn.Conv2d(ndf * 16, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # Output: 1 (real/fake probability)
        )

    def forward(self, input):
        return self.main(input)

# Create the Discriminator
netD = Discriminator(ngpu).to(device)
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))
netD.apply(weights_init)
print(netD)


## 8. Loss Function and Optimizers


In [None]:
# Initialize the BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors for visualization
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
# Using label smoothing to make discriminator training harder
real_label = real_label_smooth  # 0.9 instead of 1.0
fake_label = fake_label_smooth  # 0.1 instead of 0.0

# Setup Adam optimizers with different learning rates
# Discriminator gets lower learning rate to prevent overpowering Generator
optimizerD = optim.Adam(netD.parameters(), lr=lr_d, betas=(beta1, 0.999))  # Lower LR for D
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))    # Normal LR for G

print("✅ Loss function and optimizers initialized")
print(f"   Discriminator LR: {lr_d}")
print(f"   Generator LR: {lr}")
print(f"   Label smoothing: Real={real_label}, Fake={fake_label}")


## 8.5 Model Checkpoint Saving Function


In [None]:
def save_gan_checkpoint(netG, netD, optimizerG, optimizerD, epoch, quality_metric, checkpoint_dir):
    """
    Save GAN checkpoint with both Generator and Discriminator.
    
    Args:
        netG: Generator network
        netD: Discriminator network
        optimizerG: Generator optimizer
        optimizerD: Discriminator optimizer
        epoch: Current epoch number
        quality_metric: Quality metric value (D(G(z)) second value)
        checkpoint_dir: Directory to save checkpoints
    """
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Filename with epoch and metric
    checkpoint_path = os.path.join(
        checkpoint_dir, 
        f"checkpoint_epoch_{epoch}_quality_{quality_metric:.4f}.pth"
    )
    
    checkpoint = {
        'epoch': epoch,
        'generator_state_dict': netG.state_dict(),
        'discriminator_state_dict': netD.state_dict(),
        'optimizerG_state_dict': optimizerG.state_dict(),
        'optimizerD_state_dict': optimizerD.state_dict(),
        'quality_metric': quality_metric,
        'image_size': image_size,
        'nz': nz,
        'ngf': ngf,
        'ndf': ndf,
        'nc': nc
    }
    
    torch.save(checkpoint, checkpoint_path)
    print(f"💾 Checkpoint saved: epoch {epoch}, quality metric {quality_metric:.4f}")
    print(f"   Path: {checkpoint_path}")
    
    return checkpoint_path

print("✅ Checkpoint saving function defined")


## 9. Training Loop


In [None]:
# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
D_x_history = []  # Track D(x) over time
D_G_z_history = []  # Track D(G(z)) second value over time
iters = 0

# Best model tracking
best_quality_metric = 0.0  # D(G(z)) second value, higher is better
best_epoch = 0

print("🚀 Starting Training Loop with Stabilization Techniques...")
print(f"Total batches per epoch: {len(dataloader)}")
print(f"Total iterations: {num_epochs * len(dataloader)}")
print(f"🔧 Techniques applied:")
print(f"   - Discriminator LR reduced to {lr_d} (vs {lr} for Generator)")
print(f"   - Label smoothing: Real labels ~{real_label}, Fake labels ~{fake_label}")
print(f"   - Random label noise added for robustness")
print(f"   - Best model tracking based on D(G(z)) quality metric")
print(f"   - Checkpoints saved to: {checkpoint_dir}")
print("-" * 70)

# For each epoch
for epoch in range(num_epochs):
    # Track metrics for this epoch
    epoch_D_x = []
    epoch_D_G_z2 = []  # D(G(z)) second value - quality metric
    
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data.to(device)
        b_size = real_cpu.size(0)
        
        # Use label smoothing with random noise for more robust training
        # Real labels: uniformly sample from [0.8, 1.0] instead of fixed 0.9
        label = torch.FloatTensor(b_size).uniform_(0.8, 1.0).to(device)
        
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        
        # Fake labels: uniformly sample from [0.0, 0.2] instead of fixed 0.1
        label = torch.FloatTensor(b_size).uniform_(0.0, 0.2).to(device)
        
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        # For generator training, use strong real labels (no smoothing)
        label = torch.full((b_size,), 1.0, dtype=torch.float, device=device)
        # Since we just updated D, perform another forward pass through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Track metrics for this batch
        epoch_D_x.append(D_x)
        epoch_D_G_z2.append(D_G_z2)
        
        # Output training stats
        if i % 5 == 0:  # Print every 5 batches
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 100 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1
    
    # ===== End of epoch - evaluate and save models =====
    avg_D_x = np.mean(epoch_D_x)
    avg_D_G_z2 = np.mean(epoch_D_G_z2)  # Quality metric: higher is better (G fools D more)
    
    D_x_history.append(avg_D_x)
    D_G_z_history.append(avg_D_G_z2)
    
    print(f"\n📊 Epoch {epoch} Summary:")
    print(f"   Avg D(x): {avg_D_x:.4f}")
    print(f"   Avg D(G(z)) [Quality Metric]: {avg_D_G_z2:.4f}")
    print(f"   Best Quality so far: {best_quality_metric:.4f} (Epoch {best_epoch})")
    
    # Save model if quality improved
    if avg_D_G_z2 > best_quality_metric:
        improvement = avg_D_G_z2 - best_quality_metric
        best_quality_metric = avg_D_G_z2
        best_epoch = epoch
        
        print(f"🎉 Quality improved by {improvement:.4f}! Saving best model...")
        save_gan_checkpoint(netG, netD, optimizerG, optimizerD, epoch, 
                          best_quality_metric, checkpoint_dir)
    else:
        print(f"⏭️  No improvement (current: {avg_D_G_z2:.4f} vs best: {best_quality_metric:.4f})")
    
    # Also save periodically regardless of improvement
    if (epoch + 1) % save_interval == 0:
        print(f"💾 Periodic save at epoch {epoch}...")
        periodic_path = os.path.join(checkpoint_dir, f"periodic_epoch_{epoch}.pth")
        checkpoint = {
            'epoch': epoch,
            'generator_state_dict': netG.state_dict(),
            'discriminator_state_dict': netD.state_dict(),
            'optimizerG_state_dict': optimizerG.state_dict(),
            'optimizerD_state_dict': optimizerD.state_dict(),
            'quality_metric': avg_D_G_z2,
        }
        torch.save(checkpoint, periodic_path)
        print(f"   Saved to: {periodic_path}")
    
    print("-" * 70)

print("\n🎉 Training completed!")
print(f"\n📊 Final Statistics:")
print(f"   Best Quality Metric: {best_quality_metric:.4f} (Epoch {best_epoch})")
print(f"   Total epochs trained: {num_epochs}")
print(f"   Models saved in: {checkpoint_dir}")
print(f"\n💡 Tip: Check generated images in Cell 24 to verify quality!")


## 10. Plot Training Losses


In [None]:
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G Loss")
plt.plot(D_losses, label="D Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()


## 11. Visualize Results - Animation


In [None]:
# Animation of generator progress
fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())


## 12. Final Generated Spectrograms


In [None]:
# Visualize final generated spectrograms
plt.figure(figsize=(12, 12))
plt.axis("off")
plt.title("Generated Fake Type 3 Radio Burst Spectrograms (Final)")
plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
plt.show()


## 13. Real vs Fake Comparison


In [None]:
# Grab a batch of real images
real_batch = next(iter(dataloader))

# Plot real images
plt.figure(figsize=(15, 7))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("Real Type 3 Burst Spectrograms")
plt.imshow(np.transpose(vutils.make_grid(real_batch[:32].to(device), padding=2, normalize=True).cpu(), (1, 2, 0)))

# Plot fake images
plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Generated Fake Type 3 Burst Spectrograms")
plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))

plt.tight_layout()
plt.show()


## 14. Save Trained Models


In [None]:
# Save trained models
torch.save(netG.state_dict(), 'generator_type3_128x128.pth')
torch.save(netD.state_dict(), 'discriminator_type3_128x128.pth')
print("✅ Models saved:")
print("   - generator_type3_128x128.pth")
print("   - discriminator_type3_128x128.pth")


## 15. Generate New Samples from Trained Model


In [None]:
# Generate new samples with the trained generator
num_samples = 16
with torch.no_grad():
    noise = torch.randn(num_samples, nz, 1, 1, device=device)
    generated = netG(noise).cpu()

# Visualize newly generated spectrograms
plt.figure(figsize=(10, 10))
plt.axis("off")
plt.title("Newly Generated Type 3 Solar Radio Burst Spectrograms")
plt.imshow(np.transpose(vutils.make_grid(generated, padding=2, normalize=True), (1, 2, 0)))
plt.show()

print(f"✅ Generated {num_samples} new synthetic Type 3 radio burst spectrograms!")
