In [27]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from scipy.stats import ks_2samp, pearsonr
import warnings
warnings.filterwarnings('ignore')

In [None]:
def generate_gene_expression_data(n_samples=500, n_genes=2000, random_state=42):
    """
    Generate synthetic gene expression data optimized for VAE training.
    
    I designed this to capture key characteristics of real RNA-seq data:
    gene co-expression modules (simulating biological pathways), log-normal
    distributions, and technical noise. This structure helps the VAE learn
    meaningful representations.
    """
    np.random.seed(random_state)
    
    # Create gene modules that simulate co-expression patterns found in real data
    # Each module represents a biological pathway where genes are correlated
    n_modules = max(15, n_genes // 30)
    genes_per_module = n_genes // n_modules
    
    X = np.zeros((n_samples, n_genes))
    
    for module_idx in range(n_modules):
        start_gene = module_idx * genes_per_module
        end_gene = start_gene + genes_per_module if module_idx < n_modules - 1 else n_genes
        
        # Shared activity factor drives correlation within each module
        module_activity = np.random.randn(n_samples)
        
        # Vary correlation strength to make modules more realistic
        correlation_strength = 0.6 + np.random.rand() * 0.3
        
        for gene_idx in range(start_gene, end_gene):
            # Each gene combines module activity with independent noise
            gene_expression = (
                correlation_strength * module_activity + 
                (1 - correlation_strength) * np.random.randn(n_samples) * 0.5
            )
            X[:, gene_idx] = gene_expression
    
    # Add global factors that affect all genes (simulates batch effects or cell type differences)
    global_factors = np.random.randn(n_samples, 3)
    global_weights = np.random.randn(3, n_genes) * 0.2
    X = X + global_factors @ global_weights
    
    # Transform to log-normal distribution, which matches real RNA-seq data
    X = X - X.min() + 1
    X = np.log1p(X)
    
    # Add technical noise to simulate sequencing artifacts
    noise_level = 0.15
    X = X + np.random.normal(0, noise_level, X.shape)
    
    # Gene expression values must be non-negative
    X = np.maximum(X, 0)
    
    # Create binary labels based on expression patterns for downstream analysis
    sample_signature = X[:, :min(50, n_genes)].mean(axis=1)
    y = (sample_signature > np.median(sample_signature)).astype(int)
    
    gene_names = [f"GENE_{i+1:05d}" for i in range(n_genes)]
    sample_ids = [f"SAMPLE_{i+1:04d}" for i in range(n_samples)]
    
    df = pd.DataFrame(X, columns=gene_names, index=sample_ids)
    df['Disease_Status'] = y
    
    return df, X, y, gene_names


In [None]:
# Generate VAE-optimized gene expression data
df, X, y, gene_names = generate_gene_expression_data(
    n_samples=1000,
    n_genes=500,
    random_state=42
)



# Save to CSV
df.to_csv('gene_expression_data.csv', index=True)
print(f"Generated dataset: {df.shape}")
print(f"Disease distribution:\n{df['Disease_Status'].value_counts()}")

Generated dataset: (1000, 501)
Disease distribution:
Disease_Status
1    500
0    500
Name: count, dtype: int64


In [None]:
# Standardize features to help with training stability
# Since we're learning a generative model, using the full dataset for scaling is appropriate
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

class GeneExpressionDataset(Dataset):
    def __init__(self, data):
        self.data = torch.FloatTensor(data)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

dataset = GeneExpressionDataset(X_scaled)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


In [None]:
class VAE(nn.Module):
    """
    Variational Autoencoder for gene expression data.
    
    I used batch normalization and dropout to stabilize training, and LeakyReLU
    instead of ReLU to prevent dead neurons. The architecture gradually compresses
    from 500 genes down to a 32-dimensional latent space.
    """
    def __init__(self, input_dim, hidden_dim=128, latent_dim=32, dropout=0.2):
        super(VAE, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim * 2),
            nn.BatchNorm1d(hidden_dim * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout),
            
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout),
            
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.LeakyReLU(0.2)
        )
        
        # Map to latent distribution parameters
        self.fc_mu = nn.Linear(hidden_dim // 2, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim // 2, latent_dim)
        
        # Decoder mirrors encoder structure
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout),
            
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout),
            
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.BatchNorm1d(hidden_dim * 2),
            nn.LeakyReLU(0.2),
            
            nn.Linear(hidden_dim * 2, input_dim)
        )
    
    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick: sample from N(mu, var) by transforming N(0,1).
        This allows gradients to flow through the sampling operation.
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar

In [None]:
def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    """
    VAE loss combines reconstruction error with KL divergence regularization.
    The beta parameter controls the trade-off between reconstruction quality
    and latent space regularization.
    """
    recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum')
    
    # KL divergence encourages latent distribution to match standard normal
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    total_loss = recon_loss + beta * kl_loss
    return total_loss, recon_loss, kl_loss

In [None]:
def train_vae(model, dataloader, num_epochs=50, device='cuda', beta=1.0, 
              patience=10, min_delta=100, checkpoint_path='best_vae_model.pth'):
    """
    Train VAE with early stopping and automatic checkpointing.
    I use gradient clipping to prevent exploding gradients and a learning rate
    scheduler that reduces LR when loss plateaus.
    """
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 
                                                      factor=0.5, patience=5, verbose=True)
    
    history = {'total_loss': [], 'recon_loss': [], 'kl_loss': []}
    best_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(num_epochs):
        model.train()
        total_loss_epoch = 0
        recon_loss_epoch = 0
        kl_loss_epoch = 0
        
        for batch in dataloader:
            batch = batch.to(device)
            
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(batch)
            
            loss, recon_loss, kl_loss = vae_loss(
                recon_batch, batch, mu, logvar, beta=beta
            )
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss_epoch += loss.item()
            recon_loss_epoch += recon_loss.item()
            kl_loss_epoch += kl_loss.item()
        
        avg_total = total_loss_epoch / len(dataloader)
        avg_recon = recon_loss_epoch / len(dataloader)
        avg_kl = kl_loss_epoch / len(dataloader)
        
        history['total_loss'].append(avg_total)
        history['recon_loss'].append(avg_recon)
        history['kl_loss'].append(avg_kl)
        
        scheduler.step(avg_total)
        
        if avg_total < best_loss - min_delta:
            best_loss = avg_total
            patience_counter = 0
            torch.save(model.state_dict(), checkpoint_path)
            if (epoch + 1) % 10 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}] - NEW BEST MODEL SAVED')
        else:
            patience_counter += 1
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}]')
            print(f'Total Loss: {avg_total:.4f}, Recon: {avg_recon:.4f}, KL: {avg_kl:.4f}')
            print(f'Best Loss: {best_loss:.4f}, Patience: {patience_counter}/{patience}')
        
        if patience_counter >= patience:
            print(f'\nEarly stopping at epoch {epoch+1}')
            print(f'Loading best model with loss: {best_loss:.4f}')
            model.load_state_dict(torch.load(checkpoint_path))
            break
    
    return model, history

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vae_model = VAE(input_dim=X_scaled.shape[1], hidden_dim=128, latent_dim=32)

trained_vae, history = train_vae(
    vae_model, dataloader,
    num_epochs=50,
    device=device,
    beta=1.0
)


Epoch [10/50] - NEW BEST MODEL SAVED
Epoch [10/50]
Total Loss: 26736.9033, Recon: 25581.8240, KL: 1155.0791
Best Loss: 26736.9033, Patience: 0/10
Epoch [20/50]
Total Loss: 25288.9240, Recon: 23877.1299, KL: 1411.7939
Best Loss: 25354.4303, Patience: 1/10
Epoch [30/50] - NEW BEST MODEL SAVED
Epoch [30/50]
Total Loss: 24628.3772, Recon: 23190.4320, KL: 1437.9452
Best Loss: 24628.3772, Patience: 0/10
Epoch [40/50]
Total Loss: 24269.7368, Recon: 22845.3348, KL: 1424.4021
Best Loss: 24344.4175, Patience: 2/10
Epoch [50/50]
Total Loss: 23793.7281, Recon: 22387.3011, KL: 1406.4269
Best Loss: 23829.1682, Patience: 2/10


In [None]:
def visualize_latent_space(model, X_scaled, y, device='cuda', n_samples=500):
    """
    Project the learned latent space to 2D for visualization.
    I use both PCA (fast) and t-SNE (better separation) to see how samples
    cluster in the compressed representation.
    """
    model.eval()
    with torch.no_grad():
        X_tensor = torch.FloatTensor(X_scaled[:n_samples]).to(device)
        mu, logvar = model.encode(X_tensor)
        z = model.reparameterize(mu, logvar)
        z_np = z.cpu().numpy()
        y_subset = y[:n_samples]
    
    pca = PCA(n_components=2)
    z_pca = pca.fit_transform(z_np)
    
    print("Computing t-SNE embedding (this may take a minute)...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    z_tsne = tsne.fit_transform(z_np)
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    scatter1 = axes[0].scatter(z_pca[:, 0], z_pca[:, 1], c=y_subset, 
                              cmap='viridis', alpha=0.6, s=20)
    axes[0].set_title('Latent Space - PCA Projection')
    axes[0].set_xlabel('PC1')
    axes[0].set_ylabel('PC2')
    axes[0].grid(True, alpha=0.3)
    plt.colorbar(scatter1, ax=axes[0], label='Disease Status')
    
    scatter2 = axes[1].scatter(z_tsne[:, 0], z_tsne[:, 1], c=y_subset, 
                              cmap='viridis', alpha=0.6, s=20)
    axes[1].set_title('Latent Space - t-SNE Projection')
    axes[1].set_xlabel('t-SNE 1')
    axes[1].set_ylabel('t-SNE 2')
    axes[1].grid(True, alpha=0.3)
    plt.colorbar(scatter2, ax=axes[1], label='Disease Status')
    
    plt.tight_layout()
    plt.savefig('latent_space_visualization.png', dpi=150, bbox_inches='tight')
    plt.close()
    print("Latent space visualization saved!")
    
    return z_np

latent_vectors = visualize_latent_space(trained_vae, X_scaled, y, device=device, n_samples=500)

def interpolate_latent(model, z1, z2, n_steps=10, device='cuda'):
    """
    Linearly interpolate between two latent vectors and decode to see
    how gene expression changes smoothly in the latent space.
    """
    model.eval()
    alphas = np.linspace(0, 1, n_steps)
    interpolated = []
    
    with torch.no_grad():
        for alpha in alphas:
            z_interp = (1 - alpha) * z1 + alpha * z2
            recon = model.decode(z_interp.unsqueeze(0).to(device))
            interpolated.append(recon.cpu().numpy())
    
    return np.vstack(interpolated)

print("\nGenerating latent space interpolation...")
sample_idx1, sample_idx2 = 0, 100
z1 = torch.FloatTensor(latent_vectors[sample_idx1]).to(device)
z2 = torch.FloatTensor(latent_vectors[sample_idx2]).to(device)

interpolated_samples = interpolate_latent(trained_vae, z1, z2, n_steps=10, device=device)
interpolated_original = scaler.inverse_transform(interpolated_samples)

fig, axes = plt.subplots(2, 5, figsize=(20, 8))
for i in range(10):
    row = i // 5
    col = i % 5
    axes[row, col].plot(interpolated_original[i, :20], 'o-', markersize=3)
    axes[row, col].set_title(f'Step {i+1}')
    axes[row, col].set_ylabel('Expression')
    axes[row, col].grid(True, alpha=0.3)
    if row == 1:
        axes[row, col].set_xlabel('Gene Index')

plt.suptitle('Latent Space Interpolation (First 20 Genes)', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('latent_interpolation.png', dpi=150, bbox_inches='tight')
plt.close()
print("Interpolation visualization saved!")



Computing t-SNE embedding (this may take a minute)...
Latent space visualization saved!

Generating latent space interpolation...
Interpolation visualization saved!


In [None]:
def generate_samples(model, n_samples=100, device='cuda'):
    """
    Generate new synthetic samples by sampling from the learned latent distribution
    and decoding back to gene expression space.
    """
    model.eval()
    with torch.no_grad():
        z = torch.randn(n_samples, 32).to(device)
        generated = model.decode(z).cpu().numpy()
        generated_original = scaler.inverse_transform(generated)
    
    return generated_original

synthetic_data = generate_samples(trained_vae, n_samples=100, device=device)

print(f"\nReal data shape: {X.shape}")
print(f"Synthetic data shape: {synthetic_data.shape}")
print(f"\nReal data stats:")
print(pd.DataFrame(X).describe().iloc[1:3])
print(f"\nSynthetic data stats:")
print(pd.DataFrame(synthetic_data).describe().iloc[1:3])


Real data shape: (1000, 500)
Synthetic data shape: (100, 500)

Real data stats:
           0         1         2         3         4         5         6    \
mean  1.779777  1.774891  1.787856  1.789800  1.784244  1.778544  1.785110   
std   0.188598  0.198636  0.191158  0.185457  0.185571  0.196865  0.197522   

           7         8         9    ...       490       491       492  \
mean  1.779795  1.776958  1.774142  ...  1.771237  1.779141  1.775565   
std   0.185122  0.208729  0.193769  ...  0.221526  0.216452  0.226163   

           493       494       495       496       497       498       499  
mean  1.772184  1.771721  1.769287  1.778823  1.775983  1.774925  1.777608  
std   0.225290  0.221796  0.233326  0.223448  0.216436  0.239569  0.215450  

[2 rows x 500 columns]

Synthetic data stats:
           0         1         2         3         4         5         6    \
mean  1.792412  1.791412  1.796760  1.788679  1.782088  1.790907  1.796872   
std   0.058046  0.053509  0.04

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(history['total_loss'], label='Total Loss', linewidth=2)
axes[0].plot(history['recon_loss'], label='Reconstruction Loss', linewidth=2)
axes[0].plot(history['kl_loss'], label='KL Divergence', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss Curves')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history['recon_loss'], label='Reconstruction', linewidth=2)
axes[1].plot(history['kl_loss'], label='KL Divergence', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title('Loss Components')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Track the balance between reconstruction and regularization
kl_recon_ratio = [kl/r if r > 0 else 0 for kl, r in zip(history['kl_loss'], history['recon_loss'])]
axes[2].plot(kl_recon_ratio, label='KL/Recon Ratio', linewidth=2, color='purple')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Ratio')
axes[2].set_title('KL/Reconstruction Ratio')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
plt.close()
print("Training curves saved!")

# Compare distributions of individual genes between real and synthetic data
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

gene_indices = [0, 10, 50, 100]
for idx, gene_idx in enumerate(gene_indices):
    row = idx // 2
    col = idx % 2
    axes[row, col].hist(X[:, gene_idx], bins=30, alpha=0.6, label='Real', color='blue', density=True)
    axes[row, col].hist(synthetic_data[:, gene_idx], bins=30, alpha=0.6, label='Synthetic', 
                       color='orange', density=True)
    axes[row, col].set_title(f'Gene {gene_idx} Distribution')
    axes[row, col].set_xlabel('Expression Level')
    axes[row, col].set_ylabel('Density')
    axes[row, col].legend()
    axes[row, col].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('gene_distribution_comparison.png', dpi=150, bbox_inches='tight')
plt.close()
print("Distribution comparison saved!")

# Statistical evaluation of synthetic data quality
real_mean = X.mean(axis=0)
real_std = X.std(axis=0)
synth_mean = synthetic_data.mean(axis=0)
synth_std = synthetic_data.std(axis=0)

mean_corr, mean_p = pearsonr(real_mean, synth_mean)
std_corr, std_p = pearsonr(real_std, synth_std)

print("\n" + "="*60)
print("EVALUATION METRICS")
print("="*60)
print(f"Mean correlation (real vs synthetic): {mean_corr:.4f} (p={mean_p:.2e})")
print(f"Std correlation (real vs synthetic): {std_corr:.4f} (p={std_p:.2e})")

# Kolmogorov-Smirnov test to compare distributions
ks_results = []
for i in range(0, min(50, X.shape[1]), 5):
    ks_stat, ks_p = ks_2samp(X[:, i], synthetic_data[:, i])
    ks_results.append(ks_stat)

avg_ks_stat = np.mean(ks_results)
print(f"Average KS statistic (lower is better): {avg_ks_stat:.4f}")
print(f"  (KS < 0.1: excellent, < 0.2: good, < 0.3: acceptable)")

variance_ratio = synthetic_data.std(axis=0).mean() / X.std(axis=0).mean()
print(f"\nVariance ratio (synthetic/real): {variance_ratio:.4f}")
print(f"  (1.0 = perfect match, < 1.0 = synthetic has less variance)")

print("="*60)

# Compare gene-gene correlation structures
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

n_genes_sample = 50
sample_indices = np.random.choice(X.shape[1], n_genes_sample, replace=False)

real_corr = np.corrcoef(X[:, sample_indices].T)
synth_corr = np.corrcoef(synthetic_data[:, sample_indices].T)

im1 = axes[0].imshow(real_corr, cmap='coolwarm', vmin=-1, vmax=1, aspect='auto')
axes[0].set_title('Real Data Correlation Matrix (50 genes)')
axes[0].set_xlabel('Gene Index')
axes[0].set_ylabel('Gene Index')
plt.colorbar(im1, ax=axes[0])

im2 = axes[1].imshow(synth_corr, cmap='coolwarm', vmin=-1, vmax=1, aspect='auto')
axes[1].set_title('Synthetic Data Correlation Matrix (50 genes)')
axes[1].set_xlabel('Gene Index')
axes[1].set_ylabel('Gene Index')
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.savefig('correlation_comparison.png', dpi=150, bbox_inches='tight')
plt.close()
print("Correlation comparison saved!")

corr_corr, corr_p = pearsonr(real_corr.flatten(), synth_corr.flatten())
print(f"\nCorrelation matrix similarity: {corr_corr:.4f} (p={corr_p:.2e})")

torch.save(trained_vae.state_dict(), 'gene_expression_vae.pth')
print("\nVAE model saved!")

print("\n" + "="*70)
print("PROJECT SUMMARY - VAE for Gene Expression Data Generation")
print("="*70)
print(f"\nðŸ“Š Dataset:")
print(f"   - Samples: {X.shape[0]}")
print(f"   - Genes: {X.shape[1]}")
print(f"   - Generated synthetic samples: {synthetic_data.shape[0]}")

print(f"\nðŸ¤– Model Architecture:")
print(f"   - Input dimension: {X_scaled.shape[1]}")
print(f"   - Latent dimension: 32")
print(f"   - Hidden dimension: 128")
print(f"   - Total parameters: {sum(p.numel() for p in trained_vae.parameters()):,}")

print(f"\nðŸ“ˆ Training Performance:")
print(f"   - Final total loss: {history['total_loss'][-1]:.2f}")
print(f"   - Final reconstruction loss: {history['recon_loss'][-1]:.2f}")
print(f"   - Final KL divergence: {history['kl_loss'][-1]:.2f}")
print(f"   - Training epochs: {len(history['total_loss'])}")

print(f"\nâœ… Generated Files:")
print(f"   - gene_expression_data.csv (original data)")
print(f"   - gene_expression_vae.pth (trained model)")
print(f"   - best_vae_model.pth (best checkpoint)")
print(f"   - training_curves.png")
print(f"   - gene_distribution_comparison.png")
print(f"   - correlation_comparison.png")
print(f"   - latent_space_visualization.png")
print(f"   - latent_interpolation.png")

print(f"\nðŸ’¡ Key Improvements Made:")
print(f"   âœ“ Enhanced VAE architecture (batch norm, dropout, deeper layers)")
print(f"   âœ“ Early stopping and model checkpointing")
print(f"   âœ“ Learning rate scheduling")
print(f"   âœ“ Comprehensive evaluation metrics")
print(f"   âœ“ Latent space visualization and interpolation")
print(f"   âœ“ Statistical validation (KS test, correlations)")

print("\n" + "="*70)

Training curves saved!
Distribution comparison saved!

EVALUATION METRICS
Mean correlation (real vs synthetic): 0.3384 (p=7.33e-15)
Std correlation (real vs synthetic): 0.8330 (p=3.97e-130)
Average KS statistic (lower is better): 0.3023
  (KS < 0.1: excellent, < 0.2: good, < 0.3: acceptable)

Variance ratio (synthetic/real): 0.3860
  (1.0 = perfect match, < 1.0 = synthetic has less variance)
Correlation comparison saved!

Correlation matrix similarity: 0.7687 (p=0.00e+00)

VAE model saved!

PROJECT SUMMARY - VAE for Gene Expression Data Generation

ðŸ“Š Dataset:
   - Samples: 1000
   - Genes: 500
   - Generated synthetic samples: 100

ðŸ¤– Model Architecture:
   - Input dimension: 500
   - Latent dimension: 32
   - Hidden dimension: 128
   - Total parameters: 347,316

ðŸ“ˆ Training Performance:
   - Final total loss: 23793.73
   - Final reconstruction loss: 22387.30
   - Final KL divergence: 1406.43
   - Training epochs: 50

âœ… Generated Files:
   - gene_expression_data.csv (original 