In [None]:
# VAE vs GAN Comparison for Synthetic Gene Expression Generation
# Comparing generative models to determine which works better for genomics data

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import mean_squared_error
from scipy.stats import wasserstein_distance
import os
from tqdm import tqdm

PROJECT_DIR = 'Project10'

In [None]:
def generate_gene_expression_data(n_samples=500, n_genes=2000, random_state=42):
    """
    Generate synthetic gene expression data that mimics real RNA-seq characteristics.
    
    I designed this to capture key biological patterns: co-expression modules (like
    pathways), global variation (batch effects), and log-normal distributions typical
    of sequencing data. This structure helps both VAE and GAN learn meaningful
    representations.
    
    Parameters:
    -----------
    n_samples : int
        Number of samples (patients/cells)
    n_genes : int
        Number of genes (features)
    random_state : int
        For reproducibility
    
    Returns:
    --------
    df : pd.DataFrame
        DataFrame with gene expression data and labels
    X : np.array
        Gene expression matrix (samples Ã— genes)
    y : np.array
        Sample labels (based on expression patterns)
    gene_names : list
        Gene identifiers (e.g., "GENE_00001", "GENE_00002", ...)
    """
    np.random.seed(random_state)
    
    # Create co-expression modules - genes that are expressed together, simulating
    # biological pathways or regulatory networks
    n_modules = max(15, n_genes // 30)
    genes_per_module = n_genes // n_modules
    
    X = np.zeros((n_samples, n_genes))
    
    for module_idx in range(n_modules):
        start_gene = module_idx * genes_per_module
        end_gene = start_gene + genes_per_module if module_idx < n_modules - 1 else n_genes
        
        # Shared activity level for this module across samples
        module_activity = np.random.randn(n_samples)
        
        # Genes within a module are correlated with the shared activity
        # I use variable correlation strength (0.6-0.9) to make it more realistic
        correlation_strength = 0.6 + np.random.rand() * 0.3
        
        for gene_idx in range(start_gene, end_gene):
            # Each gene combines module activity with independent noise
            gene_expression = (
                correlation_strength * module_activity + 
                (1 - correlation_strength) * np.random.randn(n_samples) * 0.5
            )
            X[:, gene_idx] = gene_expression
    
    # Add global factors that affect many genes - this simulates batch effects,
    # cell type differences, or other systematic variation
    global_factors = np.random.randn(n_samples, 3)
    global_weights = np.random.randn(3, n_genes) * 0.2
    X = X + global_factors @ global_weights
    
    # Transform to log-normal distribution, which matches real RNA-seq data
    X = X - X.min() + 1
    X = np.log1p(X)
    
    # Add technical noise to simulate sequencing variability
    noise_level = 0.15
    X = X + np.random.normal(0, noise_level, X.shape)
    
    # Gene expression values must be non-negative
    X = np.maximum(X, 0)
    
    # Create binary labels based on overall expression signature
    sample_signature = X[:, :min(50, n_genes)].mean(axis=1)
    y = (sample_signature > np.median(sample_signature)).astype(int)
    
    gene_names = [f"GENE_{i+1:05d}" for i in range(n_genes)]
    sample_ids = [f"SAMPLE_{i+1:04d}" for i in range(n_samples)]
    
    df = pd.DataFrame(X, columns=gene_names, index=sample_ids)
    df['Disease_Status'] = y
    
    return df, X, y, gene_names


In [None]:
# Generate synthetic gene expression data
# Using 500 genes to keep training manageable while maintaining biological realism
df, X, y, gene_names = generate_gene_expression_data(
    n_samples=1000,
    n_genes=500,
    random_state=42
)

# Set up directory structure for outputs
os.makedirs(PROJECT_DIR, exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/data', exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/models', exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/figures', exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/results', exist_ok=True)

# Save the generated data
df.to_csv(f'{PROJECT_DIR}/data/gene_expression_data.csv', index=True)
print(f"Generated dataset: {df.shape}")
print(f"Disease distribution:\n{df['Disease_Status'].value_counts()}")
print(f"   âœ… Data saved: {PROJECT_DIR}/data/gene_expression_data.csv")

Generated dataset: (1000, 501)
Disease distribution:
Disease_Status
1    500
0    500
Name: count, dtype: int64
   âœ… Data saved: Project10/data/gene_expression_data.csv


In [None]:
print("\n" + "=" * 70)
print("STEP 1: DATA PREPROCESSING")
print("=" * 70)

print(f"   âœ… Loaded {len(df)} samples with {len(gene_names)} genes")
print(f"   Data shape: {X.shape}")

# Standardize features to zero mean and unit variance
# This is crucial for neural network training - it prevents any single gene
# from dominating due to scale differences
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

print(f"   âœ… Data scaled (mean=0, std=1)")


STEP 1: DATA PREPROCESSING
   âœ… Loaded 1000 samples with 500 genes
   Data shape: (1000, 500)
   âœ… Data scaled (mean=0, std=1)


In [None]:
print("\n" + "=" * 70)
print("STEP 2: CREATING DATASET")
print("=" * 70)

class GeneExpressionDataset(Dataset):
    """Simple dataset wrapper for gene expression data."""
    def __init__(self, data):
        self.data = torch.FloatTensor(data)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

dataset = GeneExpressionDataset(X_scaled)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

print(f"   âœ… Dataset created: {len(dataset)} samples")
print(f"   Batch size: 64")


STEP 2: CREATING DATASET
   âœ… Dataset created: 1000 samples
   Batch size: 64


In [None]:
print("\n" + "=" * 70)
print("STEP 3: DEFINING VAE MODEL")
print("=" * 70)

class VAE(nn.Module):
    """
    Variational Autoencoder for gene expression data.
    
    The encoder maps high-dimensional gene expression to a low-dimensional
    latent space, learning both mean and variance. The decoder reconstructs
    expression profiles from latent codes. This architecture allows us to
    generate new samples by sampling from the learned latent distribution.
    """
    def __init__(self, input_dim, hidden_dim=128, latent_dim=32):
        super(VAE, self).__init__()
        
        # Encoder: compresses gene expression to latent representation
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Output mean and log-variance for the latent distribution
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder: reconstructs expression from latent code
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
    
    def encode(self, x):
        """Encode input to latent space parameters."""
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick: sample z ~ N(Î¼, ÏƒÂ²) by computing
        z = Î¼ + ÎµÂ·Ïƒ where Îµ ~ N(0,1). This makes sampling differentiable.
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        """Decode from latent space to gene expression."""
        return self.decoder(z)
    
    def forward(self, x):
        """Forward pass: encode, sample, and decode."""
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

input_dim = X_scaled.shape[1]
vae = VAE(input_dim=input_dim, hidden_dim=128, latent_dim=32)
print(f"   âœ… VAE created")
print(f"   Parameters: {sum(p.numel() for p in vae.parameters()):,}")



STEP 3: DEFINING VAE MODEL
   âœ… VAE created
   Parameters: 174,132


In [None]:
print("\n" + "=" * 70)
print("STEP 4: DEFINING GAN MODEL")
print("=" * 70)

class Generator(nn.Module):
    """
    Generator network that transforms random noise into gene expression profiles.
    
    I use no activation on the final layer because the data is StandardScaled,
    which can have negative values. This allows the generator to output the
    full range needed to match the real data distribution.
    """
    def __init__(self, latent_dim=32, hidden_dim=128, output_dim=200):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    """
    Discriminator network that classifies samples as real or generated.
    
    I use LeakyReLU and dropout to prevent the discriminator from becoming
    too strong too quickly, which helps maintain training stability.
    """
    def __init__(self, input_dim=200, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x)

latent_dim = 32
generator = Generator(latent_dim=latent_dim, hidden_dim=128, output_dim=input_dim)
discriminator = Discriminator(input_dim=input_dim, hidden_dim=128)

print(f"   âœ… Generator created")
print(f"   âœ… Discriminator created")
print(f"   Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"   Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")



STEP 4: DEFINING GAN MODEL
   âœ… Generator created
   âœ… Discriminator created
   Generator parameters: 101,748
   Discriminator parameters: 88,961


In [None]:
print("\n" + "=" * 70)
print("STEP 5: TRAINING VAE")
print("=" * 70)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
vae = vae.to(device)

vae_optimizer = optim.Adam(vae.parameters(), lr=0.001)
vae_criterion = nn.MSELoss()

def vae_loss(reconstructed, original, mu, logvar, beta=1.0):
    """
    VAE loss combines reconstruction error with KL divergence.
    
    The reconstruction term ensures the model can accurately reproduce inputs,
    while the KL term regularizes the latent space to be close to a standard
    normal distribution. Beta controls the trade-off - I use 1.0 for balanced
    learning.
    """
    recon_loss = vae_criterion(reconstructed, original)
    
    # KL divergence between learned distribution and standard normal
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    kl_loss = torch.mean(kl_loss)
    
    return recon_loss + beta * kl_loss, recon_loss, kl_loss

num_epochs_vae = 50
vae_losses = []
vae_recon_losses = []
vae_kl_losses = []

print(f"   Training VAE for {num_epochs_vae} epochs...")

for epoch in range(num_epochs_vae):
    vae.train()
    epoch_loss = 0
    epoch_recon = 0
    epoch_kl = 0
    
    for batch in dataloader:
        batch = batch.to(device)
        
        vae_optimizer.zero_grad()
        reconstructed, mu, logvar = vae(batch)
        loss, recon_loss, kl_loss = vae_loss(reconstructed, batch, mu, logvar, beta=1.0)
        loss.backward()
        vae_optimizer.step()
        
        epoch_loss += loss.item()
        epoch_recon += recon_loss.item()
        epoch_kl += kl_loss.item()
    
    avg_loss = epoch_loss / len(dataloader)
    avg_recon = epoch_recon / len(dataloader)
    avg_kl = epoch_kl / len(dataloader)
    
    vae_losses.append(avg_loss)
    vae_recon_losses.append(avg_recon)
    vae_kl_losses.append(avg_kl)
    
    if (epoch + 1) % 10 == 0:
        print(f"   Epoch [{epoch+1}/{num_epochs_vae}], Loss: {avg_loss:.4f}, "
              f"Recon: {avg_recon:.4f}, KL: {avg_kl:.4f}")

print("   âœ… VAE training complete!")


STEP 5: TRAINING VAE
   Training VAE for 50 epochs...
   Epoch [10/50], Loss: 1.0000, Recon: 0.9998, KL: 0.0003
   Epoch [20/50], Loss: 0.9998, Recon: 0.9997, KL: 0.0000
   Epoch [30/50], Loss: 0.9998, Recon: 0.9998, KL: 0.0000
   Epoch [40/50], Loss: 0.9989, Recon: 0.9989, KL: 0.0000
   Epoch [50/50], Loss: 1.0000, Recon: 1.0000, KL: 0.0000
   âœ… VAE training complete!


In [None]:
print("\n" + "=" * 70)
print("STEP 6: TRAINING GAN")
print("=" * 70)

generator = generator.to(device)
discriminator = discriminator.to(device)

# Using lower learning rate and different beta values for GAN training stability
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

criterion = nn.BCELoss()

num_epochs_gan = 100
gan_g_losses = []
gan_d_losses = []

print(f"   Training GAN for {num_epochs_gan} epochs...")
print(f"   (GAN training can be unstable - this may take longer)")

generator.train()
discriminator.train()

for epoch in range(num_epochs_gan):
    g_epoch_loss = 0
    d_epoch_loss = 0
    
    for batch in dataloader:
        batch = batch.to(device)
        batch_size = batch.size(0)
        
        # Train Discriminator
        d_optimizer.zero_grad()
        
        # Real data: use label smoothing (0.9 instead of 1.0) to prevent
        # the discriminator from becoming overconfident
        real_labels = torch.ones(batch_size, 1).to(device) * 0.9
        real_output = discriminator(batch)
        d_loss_real = criterion(real_output, real_labels)
        
        # Fake data: generate samples and train discriminator to reject them
        noise = torch.randn(batch_size, latent_dim).to(device)
        fake_data = generator(noise)
        fake_labels = torch.zeros(batch_size, 1).to(device) + 0.1
        fake_output = discriminator(fake_data.detach())
        d_loss_fake = criterion(fake_output, fake_labels)
        
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        # Gradient clipping prevents exploding gradients
        torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
        d_optimizer.step()
        
        # Train Generator
        g_optimizer.zero_grad()
        
        # Generator tries to fool the discriminator by making it classify
        # fake samples as real
        noise = torch.randn(batch_size, latent_dim).to(device)
        fake_data = generator(noise)
        fake_output = discriminator(fake_data)
        g_loss = criterion(fake_output, real_labels)
        
        g_loss.backward()
        torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
        g_optimizer.step()
        
        g_epoch_loss += g_loss.item()
        d_epoch_loss += d_loss.item()
    
    avg_g_loss = g_epoch_loss / len(dataloader)
    avg_d_loss = d_epoch_loss / len(dataloader)
    
    gan_g_losses.append(avg_g_loss)
    gan_d_losses.append(avg_d_loss)
    
    if (epoch + 1) % 20 == 0:
        print(f"   Epoch [{epoch+1}/{num_epochs_gan}], "
              f"G Loss: {avg_g_loss:.4f}, D Loss: {avg_d_loss:.4f}")

print("   âœ… GAN training complete!")




STEP 6: TRAINING GAN
   Training GAN for 100 epochs...
   (GAN training can be unstable - this may take longer)
   Epoch [20/100], G Loss: 1.9132, D Loss: 0.6867
   Epoch [40/100], G Loss: 2.3713, D Loss: 0.6726
   Epoch [60/100], G Loss: 2.3044, D Loss: 0.6655
   Epoch [80/100], G Loss: 2.2898, D Loss: 0.6685
   Epoch [100/100], G Loss: 2.2967, D Loss: 0.6648
   âœ… GAN training complete!


In [None]:
print("\n" + "=" * 70)
print("STEP 7: GENERATING SYNTHETIC SAMPLES")
print("=" * 70)

# Generate samples from VAE by sampling from the learned latent distribution
vae.eval()
with torch.no_grad():
    z_vae = torch.randn(100, 32).to(device)
    vae_samples = vae.decode(z_vae).cpu().numpy()

# Generate samples from GAN by sampling random noise
generator.eval()
with torch.no_grad():
    z_gan = torch.randn(100, latent_dim).to(device)
    gan_samples = generator(z_gan).cpu().numpy()

print(f"   âœ… Generated 100 samples from VAE")
print(f"   âœ… Generated 100 samples from GAN")



STEP 7: GENERATING SYNTHETIC SAMPLES
   âœ… Generated 100 samples from VAE
   âœ… Generated 100 samples from GAN


In [None]:
print("\n" + "=" * 70)
print("STEP 8: COMPARING VAE vs GAN")
print("=" * 70)

real_data = X_scaled

# 1. Statistical Similarity
# Compare how well each model captures the mean and variance of real data
print("\n   1. STATISTICAL SIMILARITY")
print("   " + "-" * 60)

real_mean = np.mean(real_data, axis=0)
real_std = np.std(real_data, axis=0)

vae_mean = np.mean(vae_samples, axis=0)
vae_std = np.std(vae_samples, axis=0)

gan_mean = np.mean(gan_samples, axis=0)
gan_std = np.std(gan_samples, axis=0)

vae_mean_mse = mean_squared_error(real_mean, vae_mean)
gan_mean_mse = mean_squared_error(real_mean, gan_mean)

vae_std_mse = mean_squared_error(real_std, vae_std)
gan_std_mse = mean_squared_error(real_std, gan_std)

print(f"   Mean MSE:")
print(f"      VAE: {vae_mean_mse:.6f}")
print(f"      GAN: {gan_mean_mse:.6f}")
print(f"      {'âœ… VAE better' if vae_mean_mse < gan_mean_mse else 'âœ… GAN better'}")

print(f"\n   Std MSE:")
print(f"      VAE: {vae_std_mse:.6f}")
print(f"      GAN: {gan_std_mse:.6f}")
print(f"      {'âœ… VAE better' if vae_std_mse < gan_std_mse else 'âœ… GAN better'}")

# 2. Wasserstein Distance
# Measures how different the overall distributions are
print("\n   2. DISTRIBUTION SIMILARITY (Wasserstein Distance)")
print("   " + "-" * 60)

sample_size = min(100, len(real_data))
real_sample = real_data[:sample_size].flatten()
vae_sample = vae_samples[:sample_size].flatten()
gan_sample = gan_samples[:sample_size].flatten()

vae_wasserstein = wasserstein_distance(real_sample, vae_sample)
gan_wasserstein = wasserstein_distance(real_sample, gan_sample)

print(f"   Wasserstein Distance (lower is better):")
print(f"      VAE: {vae_wasserstein:.6f}")
print(f"      GAN: {gan_wasserstein:.6f}")
print(f"      {'âœ… VAE better' if vae_wasserstein < gan_wasserstein else 'âœ… GAN better'}")

# 3. Training Stability
# Lower variance in final epochs indicates more stable training
print("\n   3. TRAINING STABILITY")
print("   " + "-" * 60)

vae_loss_std = np.std(vae_losses[-10:])
gan_g_loss_std = np.std(gan_g_losses[-10:])
gan_d_loss_std = np.std(gan_d_losses[-10:])

print(f"   Loss Stability (std of last 10 epochs, lower is better):")
print(f"      VAE: {vae_loss_std:.6f}")
print(f"      GAN Generator: {gan_g_loss_std:.6f}")
print(f"      GAN Discriminator: {gan_d_loss_std:.6f}")
print(f"      {'âœ… VAE more stable' if vae_loss_std < gan_g_loss_std else 'âœ… GAN more stable'}")

# 4. Sample Diversity
# Higher pairwise distances indicate more diverse generated samples
print("\n   4. SAMPLE DIVERSITY")
print("   " + "-" * 60)

def calculate_diversity(samples):
    """Calculate average pairwise distance between samples."""
    distances = []
    for i in range(len(samples)):
        for j in range(i+1, len(samples)):
            dist = np.linalg.norm(samples[i] - samples[j])
            distances.append(dist)
    return np.mean(distances)

vae_diversity = calculate_diversity(vae_samples)
gan_diversity = calculate_diversity(gan_samples)

print(f"   Average Pairwise Distance (higher = more diverse):")
print(f"      VAE: {vae_diversity:.6f}")
print(f"      GAN: {gan_diversity:.6f}")
print(f"      {'âœ… VAE more diverse' if vae_diversity > gan_diversity else 'âœ… GAN more diverse'}")



STEP 8: COMPARING VAE vs GAN

   1. STATISTICAL SIMILARITY
   ------------------------------------------------------------
   Mean MSE:
      VAE: 0.000007
      GAN: 0.673524
      âœ… VAE better

   Std MSE:
      VAE: 0.989515
      GAN: 0.814004
      âœ… GAN better

   2. DISTRIBUTION SIMILARITY (Wasserstein Distance)
   ------------------------------------------------------------
   Wasserstein Distance (lower is better):
      VAE: 0.801164
      GAN: 0.142642
      âœ… GAN better

   3. TRAINING STABILITY
   ------------------------------------------------------------
   Loss Stability (std of last 10 epochs, lower is better):
      VAE: 0.000344
      GAN Generator: 0.021635
      GAN Discriminator: 0.000635
      âœ… VAE more stable

   4. SAMPLE DIVERSITY
   ------------------------------------------------------------
   Average Pairwise Distance (higher = more diverse):
      VAE: 0.163666
      GAN: 3.174470
      âœ… GAN more diverse


In [None]:
print("\n" + "=" * 70)
print("STEP 9: CREATING VISUALIZATIONS")
print("=" * 70)

os.makedirs(f'{PROJECT_DIR}/figures', exist_ok=True)

# 1. Training Curves
# Visualize how loss evolves during training for both models
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

axes[0].plot(vae_losses, label='Total Loss', alpha=0.7)
axes[0].plot(vae_recon_losses, label='Reconstruction Loss', alpha=0.7)
axes[0].plot(vae_kl_losses, label='KL Divergence', alpha=0.7)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('VAE Training Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(gan_g_losses, label='Generator Loss', alpha=0.7)
axes[1].plot(gan_d_losses, label='Discriminator Loss', alpha=0.7)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title('GAN Training Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{PROJECT_DIR}/figures/training_curves.png', dpi=150, bbox_inches='tight')
plt.close()
print(f"   âœ… Saved: {PROJECT_DIR}/figures/training_curves.png")

# 2. PCA Visualization
# Project high-dimensional data to 2D to see how well generated samples
# match the real data distribution
n_viz_samples = min(100, len(real_data), len(vae_samples), len(gan_samples))
pca = PCA(n_components=2)
real_pca = pca.fit_transform(real_data[:n_viz_samples])
vae_pca = pca.transform(vae_samples[:n_viz_samples])
gan_pca = pca.transform(gan_samples[:n_viz_samples])

fig, ax = plt.subplots(figsize=(12, 8))
ax.scatter(real_pca[:, 0], real_pca[:, 1], alpha=0.5, label='Real Data', s=30)
ax.scatter(vae_pca[:, 0], vae_pca[:, 1], alpha=0.5, label='VAE Generated', s=30)
ax.scatter(gan_pca[:, 0], gan_pca[:, 1], alpha=0.5, label='GAN Generated', s=30)
ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_title('PCA Visualization: Real vs Generated Data')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f'{PROJECT_DIR}/figures/pca_comparison.png', dpi=150, bbox_inches='tight')
plt.close()
print(f"   âœ… Saved: {PROJECT_DIR}/figures/pca_comparison.png")

# 3. Distribution Comparison
# Compare expression distributions for individual genes
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

gene_indices = [0, 10, 50, 100]

for idx, gene_idx in enumerate(gene_indices):
    row = idx // 2
    col = idx % 2
    
    axes[row, col].hist(real_data[:, gene_idx], bins=30, alpha=0.5, label='Real', density=True)
    axes[row, col].hist(vae_samples[:, gene_idx], bins=30, alpha=0.5, label='VAE', density=True)
    axes[row, col].hist(gan_samples[:, gene_idx], bins=30, alpha=0.5, label='GAN', density=True)
    axes[row, col].set_xlabel(f'Gene {gene_idx} Expression')
    axes[row, col].set_ylabel('Density')
    axes[row, col].set_title(f'Gene {gene_idx} Distribution')
    axes[row, col].legend()
    axes[row, col].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{PROJECT_DIR}/figures/distribution_comparison.png', dpi=150, bbox_inches='tight')
plt.close()
print(f"   âœ… Saved: {PROJECT_DIR}/figures/distribution_comparison.png")

# 4. Comparison Summary Table
# Create a visual summary of all metrics
fig, ax = plt.subplots(figsize=(10, 6))
ax.axis('tight')
ax.axis('off')

comparison_data = [
    ['Metric', 'VAE', 'GAN', 'Winner'],
    ['Mean MSE', f'{vae_mean_mse:.6f}', f'{gan_mean_mse:.6f}', 
     'VAE' if vae_mean_mse < gan_mean_mse else 'GAN'],
    ['Std MSE', f'{vae_std_mse:.6f}', f'{gan_std_mse:.6f}',
     'VAE' if vae_std_mse < gan_std_mse else 'GAN'],
    ['Wasserstein Distance', f'{vae_wasserstein:.6f}', f'{gan_wasserstein:.6f}',
     'VAE' if vae_wasserstein < gan_wasserstein else 'GAN'],
    ['Training Stability', f'{vae_loss_std:.6f}', f'{gan_g_loss_std:.6f}',
     'VAE' if vae_loss_std < gan_g_loss_std else 'GAN'],
    ['Sample Diversity', f'{vae_diversity:.6f}', f'{gan_diversity:.6f}',
     'VAE' if vae_diversity > gan_diversity else 'GAN'],
]

table = ax.table(cellText=comparison_data, cellLoc='center', loc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.2, 1.5)

for i in range(4):
    table[(0, i)].set_facecolor('#4CAF50')
    table[(0, i)].set_text_props(weight='bold', color='white')

plt.title('VAE vs GAN Comparison Summary', fontsize=14, fontweight='bold', pad=20)
plt.savefig(f'{PROJECT_DIR}/figures/comparison_summary.png', dpi=150, bbox_inches='tight')
plt.close()
print(f"   âœ… Saved: {PROJECT_DIR}/figures/comparison_summary.png")


STEP 9: CREATING VISUALIZATIONS
   âœ… Saved: Project10/figures/training_curves.png
   âœ… Saved: Project10/figures/pca_comparison.png
   âœ… Saved: Project10/figures/distribution_comparison.png
   âœ… Saved: Project10/figures/comparison_summary.png


In [None]:
print("\n" + "=" * 70)
print("STEP 10: SAVING MODELS AND RESULTS")
print("=" * 70)

os.makedirs(f'{PROJECT_DIR}/models', exist_ok=True)
os.makedirs(f'{PROJECT_DIR}/results', exist_ok=True)

# Save all model checkpoints and preprocessing scaler for reproducibility
model_path = f'{PROJECT_DIR}/models/models.pth'
torch.save({
    'vae_state_dict': vae.state_dict(),
    'generator_state_dict': generator.state_dict(),
    'discriminator_state_dict': discriminator.state_dict(),
    'scaler': scaler,
    'vae_config': {'input_dim': input_dim, 'hidden_dim': 128, 'latent_dim': 32},
    'gan_config': {'latent_dim': latent_dim, 'hidden_dim': 128, 'output_dim': input_dim},
}, model_path)

# Save quantitative comparison results
results = {
    'vae_mean_mse': vae_mean_mse,
    'gan_mean_mse': gan_mean_mse,
    'vae_std_mse': vae_std_mse,
    'gan_std_mse': gan_std_mse,
    'vae_wasserstein': vae_wasserstein,
    'gan_wasserstein': gan_wasserstein,
    'vae_diversity': vae_diversity,
    'gan_diversity': gan_diversity,
    'vae_loss_std': vae_loss_std,
    'gan_g_loss_std': gan_g_loss_std,
}

results_path = f'{PROJECT_DIR}/results/comparison_results.csv'
pd.DataFrame([results]).to_csv(results_path, index=False)

print(f"   âœ… Models saved: {model_path}")
print(f"   âœ… Results saved: {results_path}")


STEP 10: SAVING MODELS AND RESULTS
   âœ… Models saved: Project10/models/models.pth
   âœ… Results saved: Project10/results/comparison_results.csv


In [45]:
print("\n" + "=" * 70)
print("COMPARISON SUMMARY")
print("=" * 70)

print("\nðŸ“Š METRICS COMPARISON:")
print(f"   {'Metric':<30} {'VAE':<15} {'GAN':<15} {'Winner'}")
print("   " + "-" * 75)
print(f"   {'Mean MSE':<30} {vae_mean_mse:<15.6f} {gan_mean_mse:<15.6f} "
      f"{'VAE' if vae_mean_mse < gan_mean_mse else 'GAN'}")
print(f"   {'Std MSE':<30} {vae_std_mse:<15.6f} {gan_std_mse:<15.6f} "
      f"{'VAE' if vae_std_mse < gan_std_mse else 'GAN'}")
print(f"   {'Wasserstein Distance':<30} {vae_wasserstein:<15.6f} {gan_wasserstein:<15.6f} "
      f"{'VAE' if vae_wasserstein < gan_wasserstein else 'GAN'}")
print(f"   {'Training Stability':<30} {vae_loss_std:<15.6f} {gan_g_loss_std:<15.6f} "
      f"{'VAE' if vae_loss_std < gan_g_loss_std else 'GAN'}")
print(f"   {'Sample Diversity':<30} {vae_diversity:<15.6f} {gan_diversity:<15.6f} "
      f"{'VAE' if vae_diversity > gan_diversity else 'GAN'}")


COMPARISON SUMMARY

ðŸ“Š METRICS COMPARISON:
   Metric                         VAE             GAN             Winner
   ---------------------------------------------------------------------------
   Mean MSE                       0.000007        0.673524        VAE
   Std MSE                        0.989515        0.814004        GAN
   Wasserstein Distance           0.801164        0.142642        GAN
   Training Stability             0.000344        0.021635        VAE
   Sample Diversity               0.163666        3.174470        GAN
