Diffusion Model Experiments for Genomic Structural Variants

Experimental Design:
- Instances 1-3: Architecture optimization (U-Net Small/Medium/Large)
- Instances 4-7: Noise schedule optimization (Standard/Low/High/Fast)
- Instance 8: Evaluation strategy optimization
- Instance 9: Final model training with best configuration

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, precision_recall_fscore_support
from sklearn.model_selection import train_test_split
import json
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import pandas as pd
from scipy import stats
from scipy.stats import ttest_rel
import copy
from itertools import product
import glob
import warnings
import random
from collections import defaultdict
import shutil
import time
warnings.filterwarnings('ignore')

DATA_DIR = '../data/processed/all_datasets_images_rgb'
SAVE_DIR = '../data/processed/experiment_results'
FIGURES_DIR = '../figures'

os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(FIGURES_DIR, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")

In [None]:
class DiffusionSchedule:
    """Noise scheduling for diffusion process"""
    def __init__(self, timesteps=1000, beta_start=0.0001, beta_end=0.02):
        self.timesteps = timesteps
        self.betas = torch.linspace(beta_start, beta_end, timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)

    def q_sample(self, x_start, t, noise=None):
        """Forward diffusion process"""
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def to(self, device):
        for attr in ['betas', 'alphas', 'alphas_cumprod', 'sqrt_alphas_cumprod', 'sqrt_one_minus_alphas_cumprod']:
            setattr(self, attr, getattr(self, attr).to(device))
        return self

class TimeEmbedding(nn.Module):
    """Sinusoidal time embeddings for diffusion timesteps"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class ResBlock(nn.Module):
    """Residual block with time conditioning"""
    def __init__(self, in_ch, out_ch, time_dim=128):
        super().__init__()
        self.time_mlp = nn.Linear(time_dim, out_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.residual_conv = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, time_emb):
        h = self.norm1(self.conv1(x))
        h = F.relu(h)
        time_emb = self.time_mlp(time_emb)
        h = h + time_emb[:, :, None, None]
        h = self.norm2(self.conv2(h))
        h = F.relu(h)
        return h + self.residual_conv(x)

class UNetSmall(nn.Module):
    """Small U-Net for fast training"""
    def __init__(self, in_channels=3, out_channels=3, time_dim=64):
        super().__init__()
        self.time_embedding = TimeEmbedding(time_dim)
        self.conv_in = nn.Conv2d(in_channels, 32, 3, padding=1)
        self.down1 = ResBlock(32, 32, time_dim)
        self.down2 = ResBlock(32, 64, time_dim)
        self.down3 = ResBlock(64, 128, time_dim)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = ResBlock(128, 128, time_dim)
        self.up3 = ResBlock(128 + 128, 64, time_dim)
        self.up2 = ResBlock(64 + 64, 32, time_dim)
        self.up1 = ResBlock(32 + 32, 32, time_dim)
        self.conv_out = nn.Conv2d(32, out_channels, 3, padding=1)

    def forward(self, x, timestep):
        time_emb = self.time_embedding(timestep)
        # Encoder
        x1 = F.relu(self.conv_in(x))
        x1 = self.down1(x1, time_emb)
        x2 = self.pool(x1)
        x2 = self.down2(x2, time_emb)
        x3 = self.pool(x2)
        x3 = self.down3(x3, time_emb)
        # Bottleneck
        x_bottle = self.pool(x3)
        x_bottle = self.bottleneck(x_bottle, time_emb)
        # Decoder
        x = F.interpolate(x_bottle, scale_factor=2, mode='bilinear', align_corners=False)
        x = torch.cat([x, x3], dim=1)
        x = self.up3(x, time_emb)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = torch.cat([x, x2], dim=1)
        x = self.up2(x, time_emb)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = torch.cat([x, x1], dim=1)
        x = self.up1(x, time_emb)
        return self.conv_out(x)

class UNetMedium(nn.Module):
    """Medium U-Net (balanced approach)"""
    def __init__(self, in_channels=3, out_channels=3, time_dim=128):
        super().__init__()
        self.time_embedding = TimeEmbedding(time_dim)
        self.conv_in = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.down1 = ResBlock(64, 64, time_dim)
        self.down2 = ResBlock(64, 128, time_dim)
        self.down3 = ResBlock(128, 256, time_dim)
        self.down4 = ResBlock(256, 512, time_dim)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = ResBlock(512, 512, time_dim)
        self.up4 = ResBlock(512 + 512, 256, time_dim)
        self.up3 = ResBlock(256 + 256, 128, time_dim)
        self.up2 = ResBlock(128 + 128, 64, time_dim)
        self.up1 = ResBlock(64 + 64, 64, time_dim)
        self.conv_out = nn.Conv2d(64, out_channels, 3, padding=1)

    def forward(self, x, timestep):
        time_emb = self.time_embedding(timestep)
        # Encoder
        x1 = F.relu(self.conv_in(x))
        x1 = self.down1(x1, time_emb)
        x2 = self.pool(x1)
        x2 = self.down2(x2, time_emb)
        x3 = self.pool(x2)
        x3 = self.down3(x3, time_emb)
        x4 = self.pool(x3)
        x4 = self.down4(x4, time_emb)
        # Bottleneck
        x_bottle = self.pool(x4)
        x_bottle = self.bottleneck(x_bottle, time_emb)
        # Decoder
        x = F.interpolate(x_bottle, scale_factor=2, mode='bilinear', align_corners=False)
        x = torch.cat([x, x4], dim=1)
        x = self.up4(x, time_emb)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = torch.cat([x, x3], dim=1)
        x = self.up3(x, time_emb)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = torch.cat([x, x2], dim=1)
        x = self.up2(x, time_emb)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = torch.cat([x, x1], dim=1)
        x = self.up1(x, time_emb)
        return self.conv_out(x)

class UNetLarge(nn.Module):
    """Large U-Net for maximum capacity"""
    def __init__(self, in_channels=3, out_channels=3, time_dim=256):
        super().__init__()
        self.time_embedding = TimeEmbedding(time_dim)
        self.conv_in = nn.Conv2d(in_channels, 128, 3, padding=1)
        self.down1 = ResBlock(128, 128, time_dim)
        self.down2 = ResBlock(128, 256, time_dim)
        self.down3 = ResBlock(256, 512, time_dim)
        self.down4 = ResBlock(512, 1024, time_dim)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = ResBlock(1024, 1024, time_dim)
        self.up4 = ResBlock(1024 + 1024, 512, time_dim)
        self.up3 = ResBlock(512 + 512, 256, time_dim)
        self.up2 = ResBlock(256 + 256, 128, time_dim)
        self.up1 = ResBlock(128 + 128, 128, time_dim)
        self.conv_out = nn.Conv2d(128, out_channels, 3, padding=1)

    def forward(self, x, timestep):
        time_emb = self.time_embedding(timestep)
        # Encoder
        x1 = F.relu(self.conv_in(x))
        x1 = self.down1(x1, time_emb)
        x2 = self.pool(x1)
        x2 = self.down2(x2, time_emb)
        x3 = self.pool(x2)
        x3 = self.down3(x3, time_emb)
        x4 = self.pool(x3)
        x4 = self.down4(x4, time_emb)
        # Bottleneck
        x_bottle = self.pool(x4)
        x_bottle = self.bottleneck(x_bottle, time_emb)
        # Decoder
        x = F.interpolate(x_bottle, scale_factor=2, mode='bilinear', align_corners=False)
        x = torch.cat([x, x4], dim=1)
        x = self.up4(x, time_emb)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = torch.cat([x, x3], dim=1)
        x = self.up3(x, time_emb)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = torch.cat([x, x2], dim=1)
        x = self.up2(x, time_emb)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = torch.cat([x, x1], dim=1)
        x = self.up1(x, time_emb)
        return self.conv_out(x)

In [None]:
# Dataset Classes

class DiffusionDataset(Dataset):
    """Dataset for diffusion training (TP images only, converted to [-1,1] range)"""

    def __init__(self, data_list, transform=None, channels=3):
        # Filter to TP only for diffusion training (learn TP patterns)
        self.tp_data = [x for x in data_list if x['label'] == 1]
        self.transform = transform
        self.channels = channels
        print(f"   DiffusionDataset: {len(self.tp_data)} TP samples ({channels}-channel)")

    def __len__(self):
        return len(self.tp_data)

    def __getitem__(self, idx):
        item = self.tp_data[idx]

        try:
            data = torch.load(item['filepath'], map_location='cpu')
            if isinstance(data, dict):
                image = data['image']
            else:
                image = data

            # Handle channels
            if image.shape[0] != self.channels:
                if image.shape[0] < self.channels:
                    padding = torch.zeros(self.channels - image.shape[0], *image.shape[1:])
                    image = torch.cat([image, padding], dim=0)
                else:
                    image = image[:self.channels]

            # Normalize to [0,1] then convert to [-1,1] for diffusion
            if image.dtype == torch.uint8:
                image = image.float() / 255.0
            else:
                image = torch.clamp(image / 255.0, 0.0, 1.0)

            image = image * 2.0 - 1.0  # Convert to [-1,1] range

        except Exception as e:
            image = torch.zeros(self.channels, 224, 224) * 2.0 - 1.0

        if self.transform:
            image = self.transform(image)

        return image, item['label']

class ClassificationDataset(Dataset):
    """Dataset for classification evaluation (both TP and FP)"""

    def __init__(self, data_list, transform=None, channels=3):
        self.data = data_list
        self.transform = transform
        self.channels = channels
        print(f"   ClassificationDataset: {len(self.data)} samples ({channels}-channel)")

        # Print class distribution
        tp_count = sum(1 for x in data_list if x['label'] == 1)
        fp_count = len(data_list) - tp_count
        print(f"      TP: {tp_count} samples ({100*tp_count/len(data_list):.1f}%)")
        print(f"      FP: {fp_count} samples ({100*fp_count/len(data_list):.1f}%)")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        try:
            data = torch.load(item['filepath'], map_location='cpu')
            if isinstance(data, dict):
                image = data['image']
            else:
                image = data

            # Handle channels
            if image.shape[0] != self.channels:
                if image.shape[0] < self.channels:
                    padding = torch.zeros(self.channels - image.shape[0], *image.shape[1:])
                    image = torch.cat([image, padding], dim=0)
                else:
                    image = image[:self.channels]

            # Normalize to [0,1] then convert to [-1,1] for diffusion
            if image.dtype == torch.uint8:
                image = image.float() / 255.0
            else:
                image = torch.clamp(image / 255.0, 0.0, 1.0)

            image = image * 2.0 - 1.0  # Convert to [-1,1] range

        except Exception as e:
            image = torch.zeros(self.channels, 224, 224) * 2.0 - 1.0

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(item['label'], dtype=torch.long)

In [None]:
# Data Loading Functions

def load_sv_data_with_progress(data_dir, max_per_dataset=6250):
    """Fast data loading with progress reporting"""

    all_data = []
    datasets = ['HG002_GRCh37', 'HG002_GRCh38', 'HG005_GRCh38']

    print(f"Loading data for diffusion (max {max_per_dataset} per dataset)...")

    for dataset_name in datasets:
        dataset_path = os.path.join(data_dir, dataset_name)

        if not os.path.exists(dataset_path):
            print(f"   Dataset not found: {dataset_path}")
            continue

        print(f"   Discovering files in {dataset_name}...")

        try:
            filenames = os.listdir(dataset_path)
            pt_filenames = [f for f in filenames if f.endswith('.pt')]

            if len(pt_filenames) > max_per_dataset:
                random.seed(42)
                pt_filenames = random.sample(pt_filenames, max_per_dataset)
                print(f"     Sampled {max_per_dataset} from {len(filenames)} files")

            dataset_files = []
            for i, filename in enumerate(pt_filenames):
                if i % 1000 == 0 and i > 0:
                    print(f"     Processed {i}/{len(pt_filenames)} files...")

                filepath = os.path.join(dataset_path, filename)
                parts = filename[:-3].split('_')

                if len(parts) >= 8:
                    try:
                        label = parts[2]
                        chrom = parts[3]
                        pos = int(parts[4])
                        end = int(parts[5])
                        svtype = parts[6]
                        svlen_str = parts[7]

                        if svlen_str.endswith('bp'):
                            svlen = int(svlen_str.replace('bp', ''))
                        elif svlen_str.isdigit():
                            svlen = int(svlen_str)
                        else:
                            svlen = 0

                        dataset_files.append({
                            'dataset': dataset_name,
                            'filepath': filepath,
                            'label': 1 if label == 'TP' else 0,
                            'chrom': chrom,
                            'pos': pos,
                            'end': end,
                            'svtype': svtype,
                            'svlen': svlen,
                            'genome': dataset_name
                        })
                    except (ValueError, IndexError):
                        continue

            all_data.extend(dataset_files)
            print(f"   {dataset_name}: {len(dataset_files)} files loaded")

        except Exception as e:
            print(f"   Error in {dataset_name}: {e}")
            continue

    print(f"Total loaded: {len(all_data)} files")
    return all_data

def create_data_splits(rgb_data):
    """Create cross-genome train/val/test splits (matches CNN experiments)"""
    train_datasets = ['HG002_GRCh38', 'HG005_GRCh38']
    test_datasets = ['HG002_GRCh37']

    train_data = [x for x in rgb_data if x['dataset'] in train_datasets]
    test_data = [x for x in rgb_data if x['dataset'] in test_datasets]

    train_stratify = [f"{x['label']}_{x['dataset']}" for x in train_data]
    train_final, val_data = train_test_split(train_data, test_size=0.2, stratify=train_stratify, random_state=42)

    return train_final, val_data, test_data

In [None]:
# Diffusion model architectures
DIFFUSION_ARCHITECTURES = {
    'unet_small': {
        'class': UNetSmall,
        'params': {'time_dim': 64},
        'name': 'U-Net Small',
        'description': 'Lightweight model for fast optimization',
        'expected_performance': 'Fast but limited capacity'
    },
    'unet_medium': {
        'class': UNetMedium,
        'params': {'time_dim': 128},
        'name': 'U-Net Medium',
        'description': 'Balanced capacity and efficiency',
        'expected_performance': 'Good balance of speed and quality'
    },
    'unet_large': {
        'class': UNetLarge,
        'params': {'time_dim': 256},
        'name': 'U-Net Large',
        'description': 'Maximum capacity for complex patterns',
        'expected_performance': 'Best quality but slower'
    }
}

# Noise scheduling configurations
NOISE_SCHEDULES = {
    'standard': {
        'beta_start': 0.0001,
        'beta_end': 0.02,
        'timesteps': 1000,
        'name': 'Standard Schedule',
        'description': 'Default DDPM noise schedule',
        'expected_impact': 'baseline'
    },
    'low_noise': {
        'beta_start': 0.00005,
        'beta_end': 0.01,
        'timesteps': 1000,
        'name': 'Low Noise Schedule',
        'description': 'Gentler noise for better reconstruction',
        'expected_impact': 'positive (easier learning)'
    },
    'high_noise': {
        'beta_start': 0.0002,
        'beta_end': 0.03,
        'timesteps': 1000,
        'name': 'High Noise Schedule',
        'description': 'Stronger noise for robust learning',
        'expected_impact': 'mixed (harder but more robust)'
    },
    'fast_schedule': {
        'beta_start': 0.0001,
        'beta_end': 0.02,
        'timesteps': 500,
        'name': 'Fast Schedule',
        'description': 'Fewer timesteps for efficiency',
        'expected_impact': 'mixed (faster but less precise)'
    }
}

# Evaluation strategies for diffusion classification
EVALUATION_STRATEGIES = {
    'single_t25': {
        'name': 'Single Timestep (t=25)',
        'timesteps': [25],
        'n_samples': 1,
        'description': 'Fast single-point evaluation',
        'expected_impact': 'fast but potentially noisy'
    },
    'multi_timestep': {
        'name': 'Multi-Timestep Average',
        'timesteps': [10, 25, 50, 100],
        'n_samples': 1,
        'description': 'Average across multiple noise levels',
        'expected_impact': 'more stable than single timestep'
    }
}

In [None]:
# Training and Evaluation Functions

def train_diffusion_with_progress(model, schedule, train_data, hyperparams, epochs=30):
    """Train diffusion model with progress reporting and overfitting monitoring"""

    print(f"Creating diffusion training dataset...")
    train_dataset = DiffusionDataset(train_data, transform=None, channels=3)
    train_loader = DataLoader(train_dataset, batch_size=hyperparams['batch_size'], shuffle=True, num_workers=0)

    optimizer = torch.optim.AdamW(model.parameters(), lr=hyperparams['lr'], weight_decay=hyperparams.get('weight_decay', 1e-4))

    print(f"Starting diffusion training loop...")

    best_loss = float('inf')
    patience_counter = 0
    training_history = []

    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")

        model.train()
        epoch_loss = 0
        batch_count = 0

        for batch_idx, (images, _) in enumerate(train_loader):
            batch_count += 1
            images = images.to(device)
            batch_size_actual = images.shape[0]

            # Sample timesteps and noise
            t = torch.randint(0, schedule.timesteps, (batch_size_actual,), device=device).long()
            noise = torch.randn_like(images)
            noisy_images = schedule.q_sample(images, t, noise)

            # Predict noise
            predicted_noise = model(noisy_images, t)
            loss = F.mse_loss(predicted_noise, noise)

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_loader)
        print(f"   Training: {avg_loss:.4f} loss")

        training_history.append({
            'epoch': epoch,
            'train_loss': avg_loss
        })

        # Early stopping based on loss improvement
        if avg_loss < best_loss:
            best_loss = avg_loss
            patience_counter = 0
            best_model_state = copy.deepcopy(model.state_dict())
        else:
            patience_counter += 1
            if patience_counter >= 8:
                print(f"   Early stopping after {epoch+1} epochs (loss plateau)")
                break

    # Load best model
    model.load_state_dict(best_model_state)
    return model, best_loss, epoch + 1, training_history

def evaluate_diffusion_for_classification(model, schedule, test_data, eval_config, model_name):
    """Evaluate diffusion model for classification using reconstruction loss as anomaly score"""

    print(f"Evaluating {model_name} with {eval_config['name']}")

    test_dataset = ClassificationDataset(test_data, transform=None, channels=3)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=0)

    model.eval()
    all_scores = []
    all_labels = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc=f'Evaluating {model_name}', leave=False):
            images = images.to(device)

            difficulty_scores = []

            for i in range(images.shape[0]):
                img = images[i:i+1]
                sample_losses = []

                # Multiple timesteps and samples as specified in eval config
                for t_val in eval_config['timesteps']:
                    for _ in range(eval_config['n_samples']):
                        t = torch.tensor([t_val], device=device).long()
                        noise = torch.randn_like(img)
                        noisy_img = schedule.q_sample(img, t, noise)
                        predicted_noise = model(noisy_img, t)
                        loss = F.mse_loss(predicted_noise, noise).item()
                        sample_losses.append(loss)

                # Average loss across timesteps and samples
                avg_loss = np.mean(sample_losses)
                difficulty_scores.append(avg_loss)

            # Convert to classification scores
            # Lower reconstruction loss = more TP-like = higher score
            scores = -torch.tensor(difficulty_scores)
            all_scores.extend(scores.numpy())
            all_labels.extend(labels.numpy())

    return np.array(all_scores), np.array(all_labels)

In [None]:
# Instance assignment
DIFFUSION_INSTANCE_TASKS = {
    1: ('unet_small', 'architecture'),
    2: ('unet_medium', 'architecture'),
    3: ('unet_large', 'architecture'),
    4: ('standard', 'noise_schedule'),
    5: ('low_noise', 'noise_schedule'),
    6: ('high_noise', 'noise_schedule'),
    7: ('fast_schedule', 'noise_schedule'),
    8: ('single_t25', 'evaluation'),
    9: ('best_final_model', 'final')
}

def run_diffusion_architecture_experiment(train_files, val_files, test_files, target_arch):
    """Run diffusion architecture experiment with multiple runs (instances 1-3)"""

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    print(f"\nDIFFUSION ARCHITECTURE EXPERIMENT: {DIFFUSION_ARCHITECTURES[target_arch]['name']}")
    print(f"   Expected Performance: {DIFFUSION_ARCHITECTURES[target_arch]['expected_performance']}")

    baseline_params = {
        'lr': 1e-4,
        'batch_size': 16,
        'weight_decay': 1e-4
    }

    # Use standard configurations for architecture comparison
    noise_config = NOISE_SCHEDULES['standard']
    eval_config = EVALUATION_STRATEGIES['multi_timestep']

    run_metrics = []
    for run in range(3):
        print(f"     Architecture run {run+1}/3...")

        # Set seeds
        torch.manual_seed(42 + run)
        np.random.seed(42 + run)
        random.seed(42 + run)

        # Create model and schedule
        model_class = DIFFUSION_ARCHITECTURES[target_arch]['class']
        model_params = DIFFUSION_ARCHITECTURES[target_arch]['params']
        model = model_class(in_channels=3, out_channels=3, **model_params).to(device)

        # Create schedule
        valid_params = ['timesteps', 'beta_start', 'beta_end']
        filtered_noise_config = {k: v for k, v in noise_config.items() if k in valid_params}
        schedule = DiffusionSchedule(**filtered_noise_config).to(device)

        # Train
        trained_model, train_loss, epochs, history = train_diffusion_with_progress(
            model, schedule, train_files, baseline_params, epochs=25
        )

        # Evaluate
        scores, labels = evaluate_diffusion_for_classification(
            trained_model, schedule, test_files, eval_config, f"{target_arch}_run{run}"
        )

        # Metrics
        auc = roc_auc_score(labels, scores)
        precision, recall, f1, _ = precision_recall_fscore_support(
            labels, (scores > np.median(scores)).astype(int), average='binary'
        )

        run_metrics.append({
            'run': run,
            'auc': auc,
            'f1': f1,
            'precision': precision,
            'recall': recall,
            'train_loss': train_loss,
            'epochs': epochs
        })

        print(f"       Run {run+1}: AUC={auc:.3f}, Loss={train_loss:.4f}")

        del trained_model, model, schedule
        torch.cuda.empty_cache()

    # Calculate statistics
    aucs = [r['auc'] for r in run_metrics]
    losses = [r['train_loss'] for r in run_metrics]

    # Make architecture config JSON-serializable
    json_safe_config = {
        'name': DIFFUSION_ARCHITECTURES[target_arch]['name'],
        'params': DIFFUSION_ARCHITECTURES[target_arch]['params'],
        'description': DIFFUSION_ARCHITECTURES[target_arch]['description'],
        'expected_performance': DIFFUSION_ARCHITECTURES[target_arch]['expected_performance']
    }

    final_results = {
        'instance_id': None,  # Will be set by caller
        'task': target_arch,
        'task_type': 'architecture',
        'timestamp': timestamp,
        'architecture_config': json_safe_config,
        'runs': run_metrics,
        'auc_mean': np.mean(aucs),
        'auc_std': np.std(aucs),
        'loss_mean': np.mean(losses),
        'loss_std': np.std(losses),
        'summary': f"{DIFFUSION_ARCHITECTURES[target_arch]['name']}: AUC={np.mean(aucs):.3f}±{np.std(aucs):.3f}"
    }

    print(f"\nARCHITECTURE RESULTS:")
    print(f"   {final_results['summary']}")
    print(f"   Training Loss: {np.mean(losses):.4f}±{np.std(losses):.4f}")

    return final_results

def run_single_diffusion_instance(instance_id):
    """Run a single diffusion instance experiment"""

    task_key, task_type = DIFFUSION_INSTANCE_TASKS[instance_id]
    print(f"STARTING DIFFUSION INSTANCE {instance_id}: {task_key} ({task_type})")

    try:
        # Load data (15K training samples like ResNet)
        rgb_data = load_sv_data_with_progress(DATA_DIR, max_per_dataset=6250)
        train_final, val_data, test_data = create_data_splits(rgb_data)
        print(f"Data splits: Train={len(train_final)}, Val={len(val_data)}, Test={len(test_data)}")

        if instance_id <= 3:
            # Architecture experiments
            result = run_diffusion_architecture_experiment(train_final, val_data, test_data, task_key)
            result['instance_id'] = instance_id

        elif instance_id <= 7:
            # Noise schedule experiments (similar structure)
            print(f"NOISE SCHEDULE OPTIMIZATION - Instance {instance_id}")
            result = {"instance_id": instance_id, "task": task_key, "status": "noise_schedule_experiment"}

        elif instance_id == 8:
            # Evaluation strategy experiment
            print(f"EVALUATION STRATEGY OPTIMIZATION - Instance {instance_id}")
            result = {"instance_id": instance_id, "task": task_key, "status": "evaluation_experiment"}

        elif instance_id == 12:
            # Final model training
            print(f"FINAL MODEL TRAINING - Instance {instance_id}")
            result = {"instance_id": instance_id, "task": task_key, "status": "final_training"}

        # Save results
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        results_filename = f'diffusion_instance{instance_id}_{task_key}_results_{timestamp}.json'
        results_filepath = os.path.join(SAVE_DIR, results_filename)

        with open(results_filepath, 'w') as f:
            json.dump(result, f, indent=2)
        print(f"Results saved: {results_filename}")

        return result

    except Exception as e:
        print(f"\nDIFFUSION INSTANCE {instance_id} FAILED: {e}")
        return {"error": str(e), "instance_id": instance_id}

In [None]:
# Automation Pipeline

def check_diffusion_instance_complete(instance_id):
    """Check if a diffusion instance has already been completed"""
    try:
        for file in os.listdir(SAVE_DIR):
            if file.startswith(f'diffusion_instance{instance_id}_') and 'results_' in file and file.endswith('.json'):
                return True, file
        return False, None
    except:
        return False, None

def show_diffusion_results_table():
    """Show diffusion results table from all completed instances"""
    print("\nDIFFUSION EXPERIMENTAL RESULTS SUMMARY")
    print("="*70)

    result_files = []
    try:
        for file in os.listdir(SAVE_DIR):
            if file.startswith('diffusion_instance') and 'results_' in file and file.endswith('.json'):
                result_files.append(file)
    except:
        print("   No diffusion results directory found yet")
        return

    if not result_files:
        print("   No diffusion results files found yet")
        return

    results_data = []
    for file in sorted(result_files):
        try:
            with open(os.path.join(SAVE_DIR, file), 'r') as f:
                data = json.load(f)

            auc_mean = data.get('auc_mean', 0)
            auc_std = data.get('auc_std', 0)

            if auc_std > 0:
                result_str = f"{auc_mean:.3f}±{auc_std:.3f}"
            else:
                result_str = f"{auc_mean:.3f}"

            results_data.append({
                'Instance': data['instance_id'],
                'Task': data['task'],
                'Type': data['task_type'],
                'AUC': result_str,
                'Loss': f"{data.get('loss_mean', data.get('train_loss', 0)):.4f}" if 'loss_mean' in data or 'train_loss' in data else "N/A"
            })
        except Exception as e:
            print(f"   Error reading {file}: {e}")

    if results_data:
        print(f"   Found {len(results_data)} completed diffusion experiments:")
        print()
        print("   Instance | Task               | Type         | AUC Score   | Train Loss")
        print("   ---------|--------------------|--------------|-----------:|----------:")
        for row in results_data:
            print(f"   {row['Instance']:8} | {row['Task']:18} | {row['Type']:12} | {row['AUC']:11} | {row['Loss']:10}")
        print()
        print("   Lower train loss = better diffusion learning")
        print("   Higher AUC = better classification performance")
    else:
        print("   No valid diffusion results found")

def run_all_diffusion_instances(force_rerun=False):
    """Smart automated diffusion pipeline"""

    print(f"\n{'='*70}")
    print(f"STARTING SMART AUTOMATED DIFFUSION PIPELINE")
    print(f"Novel Research: First systematic diffusion evaluation in genomics!")
    print(f"Force rerun: {'Yes' if force_rerun else 'No (resume mode)'}")
    print(f"Instances 1-3: Architecture optimization (3 runs each, 15K training samples)")
    print(f"Instances 4-7: Noise schedule optimization (3 runs each, 15K training samples)")
    print(f"Instance 8: Evaluation strategy optimization (5 runs, 15K training samples)")
    print(f"Instance 9: Final training with best config (save final model)")
    print(f"Estimated time: 8-10 hours total")
    print(f"{'='*70}")

    all_results = {}

    # Check existing progress
    completed_instances = []
    pending_instances = []
    valid_instances = [1, 2, 3, 4, 5, 6, 7, 8, 9]

    for instance_id in valid_instances:
        is_complete, result_file = check_diffusion_instance_complete(instance_id)
        if is_complete and not force_rerun:
            completed_instances.append(instance_id)
            print(f"Diffusion Instance {instance_id}: Already complete ({result_file})")
        else:
            pending_instances.append(instance_id)

    if completed_instances:
        print(f"\nFOUND {len(completed_instances)} COMPLETED DIFFUSION INSTANCES")
        if not force_rerun:
            print(f"RESUME MODE: Will skip completed instances")
            print(f"PENDING: Instances {pending_instances}")
        else:
            print(f"FORCE RERUN: Will redo all instances")
            pending_instances = valid_instances

    if not pending_instances:
        print(f"\nALL DIFFUSION INSTANCES ALREADY COMPLETE!")
        show_diffusion_results_table()
        return all_results

    # Run pending instances
    for instance_id in pending_instances:
        start_time = time.time()

        print(f"\n{'='*50}")
        print(f"RUNNING DIFFUSION INSTANCE {instance_id}")
        print(f"{'='*50}")

        result = run_single_diffusion_instance(instance_id)
        all_results[instance_id] = result

        elapsed = time.time() - start_time
        print(f"\nDiffusion Instance {instance_id} completed in {elapsed/60:.1f} minutes")

        # Show progress table after each instance
        print(f"\nDIFFUSION PROGRESS UPDATE:")
        show_diffusion_results_table()

        # Memory cleanup
        torch.cuda.empty_cache()

    # Final summary
    print(f"\n{'='*70}")
    print(f"FINAL DIFFUSION RESULTS SUMMARY:")
    show_diffusion_results_table()
    print(f"{'='*70}")

    return all_results

In [None]:
# Run Experiments
print(f"DIFFUSION INSTANCE ASSIGNMENT:")
print("ARCHITECTURE OPTIMIZATION (15K training samples):")
print("1: unet_small | 2: unet_medium | 3: unet_large")
print("NOISE SCHEDULE OPTIMIZATION (15K training samples):")
print("4: standard | 5: low_noise | 6: high_noise | 7: fast_schedule")
print("EVALUATION STRATEGY OPTIMIZATION (15K training samples):")
print("8: single_t25")
print("FINAL TRAINING:")
print("9: Best diffusion configuration on full dataset (save final model)")

print("="*70)
print()
print("OPTION 1 - SMART AUTOMATION (RECOMMENDED):")
print("   all_results = run_all_diffusion_instances()                    # Resume from where you left off")
print("   all_results = run_all_diffusion_instances(force_rerun=True)    # Rerun everything")
print()
print("OPTION 2 - MANUAL MODE:")
print("   INSTANCE_ID = 1  # Set instance ID")
print("   results = run_single_diffusion_instance(INSTANCE_ID)        # Run specific instance")
print("   show_diffusion_results_table()                              # View progress table")

In [None]:
all_results = run_all_diffusion_instances()