Diffusion Latent Feature Extraction & Classification

Extract internal representations from trained diffusion model and use as features for downstream classification.

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import h5py
import json
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score, roc_auc_score, precision_recall_fscore_support, accuracy_score
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')


# Configuration - GitHub repository structure
DATA_DIR = '../data/processed/all_datasets_images_rgb'
SAVE_DIR = '../data/processed/experiment_results'
LATENTS_DIR = '../data/processed/diffusion_latents'
FIGURES_DIR = '../figures'

# Create directories
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(LATENTS_DIR, exist_ok=True)
os.makedirs(FIGURES_DIR, exist_ok=True)

# Model file paths (from model training notebook)
MODEL_FILENAME = 'diffusion_final_model_20250730_171711.pth'
MODEL_PATH = os.path.join(SAVE_DIR, MODEL_FILENAME)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Load Model Architecture (same as previous notebooks)

class DiffusionSchedule:
    """Noise scheduling for diffusion process"""
    def __init__(self, timesteps=1000, beta_start=0.0001, beta_end=0.02):
        self.timesteps = timesteps
        self.betas = torch.linspace(beta_start, beta_end, timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def to(self, device):
        for attr in ['betas', 'alphas', 'alphas_cumprod', 'sqrt_alphas_cumprod', 'sqrt_one_minus_alphas_cumprod']:
            setattr(self, attr, getattr(self, attr).to(device))
        return self

class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim=128):
        super().__init__()
        self.time_mlp = nn.Linear(time_dim, out_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.residual_conv = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, time_emb):
        h = self.norm1(self.conv1(x))
        h = F.relu(h)
        time_emb = self.time_mlp(time_emb)
        h = h + time_emb[:, :, None, None]
        h = self.norm2(self.conv2(h))
        h = F.relu(h)
        return h + self.residual_conv(x)

class UNetLarge(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, time_dim=256):
        super().__init__()
        self.time_embedding = TimeEmbedding(time_dim)
        self.conv_in = nn.Conv2d(in_channels, 128, 3, padding=1)
        self.down1 = ResBlock(128, 128, time_dim)
        self.down2 = ResBlock(128, 256, time_dim)
        self.down3 = ResBlock(256, 512, time_dim)
        self.down4 = ResBlock(512, 1024, time_dim)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = ResBlock(1024, 1024, time_dim)
        self.up4 = ResBlock(1024 + 1024, 512, time_dim)
        self.up3 = ResBlock(512 + 512, 256, time_dim)
        self.up2 = ResBlock(256 + 256, 128, time_dim)
        self.up1 = ResBlock(128 + 128, 128, time_dim)
        self.conv_out = nn.Conv2d(128, out_channels, 3, padding=1)

    def forward(self, x, timestep):
        time_emb = self.time_embedding(timestep)
        # Encoder
        x1 = F.relu(self.conv_in(x))
        x1 = self.down1(x1, time_emb)
        x2 = self.pool(x1)
        x2 = self.down2(x2, time_emb)
        x3 = self.pool(x2)
        x3 = self.down3(x3, time_emb)
        x4 = self.pool(x3)
        x4 = self.down4(x4, time_emb)
        # Bottleneck
        x_bottle = self.pool(x4)
        x_bottle = self.bottleneck(x_bottle, time_emb)
        # Decoder
        x = F.interpolate(x_bottle, scale_factor=2, mode='bilinear', align_corners=False)
        x = torch.cat([x, x4], dim=1)
        x = self.up4(x, time_emb)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = torch.cat([x, x3], dim=1)
        x = self.up3(x, time_emb)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = torch.cat([x, x2], dim=1)
        x = self.up2(x, time_emb)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = torch.cat([x, x1], dim=1)
        x = self.up1(x, time_emb)
        return self.conv_out(x)

In [None]:
# Dataset for feature extraction
class FeatureExtractionDataset(Dataset):
    def __init__(self, file_info_list):
        self.file_info = file_info_list

    def __len__(self):
        return len(self.file_info)

    def __getitem__(self, idx):
        info = self.file_info[idx]
        try:
            data = torch.load(info['filepath'], map_location='cpu')
            if isinstance(data, dict):
                image = data['image']
            else:
                image = data

            # Handle channels and normalization
            if image.shape[0] != 3:
                if image.shape[0] < 3:
                    padding = torch.zeros(3 - image.shape[0], *image.shape[1:])
                    image = torch.cat([image, padding], dim=0)
                else:
                    image = image[:3]

            if image.dtype == torch.uint8:
                image = image.float() / 255.0
            else:
                image = torch.clamp(image / 255.0, 0.0, 1.0)
            image = image * 2.0 - 1.0

            success = True

        except Exception as e:
            image = torch.zeros(3, 224, 224) * 2.0 - 1.0
            success = False

        return image, info, success

# Latent Feature Extractor

class RawLatentExtractor:
    """Extract raw latent vectors (higher dimensional)"""

    def __init__(self, model, schedule, device):
        self.model = model
        self.schedule = schedule
        self.device = device
        self.activations = {}

    def register_hooks_for_raw_vectors(self):
        """Register hooks for raw vector extraction"""

        def get_activation(name):
            def hook(model, input, output):
                if isinstance(output, torch.Tensor):
                    self.activations[name] = output.detach()
            return hook

        # Focus on bottleneck and key transition points
        self.model.bottleneck.register_forward_hook(get_activation('bottleneck'))
        self.model.down4.register_forward_hook(get_activation('down4'))
        self.model.up4.register_forward_hook(get_activation('up4'))

    def clear_hooks(self):
        """Clear all registered hooks"""
        for module in self.model.modules():
            module._forward_hooks.clear()

    def extract_raw_latents_single(self, image, timestep=25):
        """Extract raw latent vectors for a single image"""

        self.activations.clear()

        # Forward pass
        t = torch.tensor([timestep], device=self.device).long()
        noise = torch.randn_like(image)
        noisy_image = self.schedule.q_sample(image, t, noise)

        with torch.no_grad():
            _ = self.model(noisy_image, t)

        # Extract flattened vectors
        features = {}
        for name, activation in self.activations.items():
            # Global average pooling to get fixed-size representation
            if len(activation.shape) == 4:  # [B, C, H, W]
                pooled = F.adaptive_avg_pool2d(activation, (1, 1))
                flattened = pooled.view(pooled.shape[0], -1)
                features[f'{name}_vector'] = flattened.cpu().numpy().flatten()

        return features

In [None]:
# Main Feature Extraction Functions

def load_trained_diffusion_model():
    """Load the trained diffusion model from Notebook 1"""

    print("Loading trained diffusion model...")

    if not os.path.exists(MODEL_PATH):
        print(f"Model file not found: {MODEL_PATH}")
        print("Please run Notebook 1 (diffusion experiments) first to train the model")
        return None, None

    # Load checkpoint
    checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False)

    # Create model
    model = UNetLarge(in_channels=3, out_channels=3, time_dim=256).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Create schedule
    schedule_config = checkpoint['schedule_config']
    schedule = DiffusionSchedule(
        timesteps=schedule_config['timesteps'],
        beta_start=schedule_config['beta_start'],
        beta_end=schedule_config['beta_end']
    ).to(device)

    print(f"   Model loaded successfully")
    print(f"   Architecture: {checkpoint.get('architecture', 'Unknown')}")
    print(f"   Parameters: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")

    return model, schedule

def load_all_genomic_data():
    """Load all genomic data files for feature extraction"""

    print("Loading all genomic data files...")

    datasets = ['HG002_GRCh37', 'HG002_GRCh38', 'HG005_GRCh38']
    all_file_info = []

    for dataset_name in datasets:
        dataset_path = os.path.join(DATA_DIR, dataset_name)
        if not os.path.exists(dataset_path):
            print(f"   Dataset not found: {dataset_path}")
            continue

        print(f"   Scanning {dataset_name}...")
        filenames = [f for f in os.listdir(dataset_path) if f.endswith('.pt')]

        for filename in filenames:
            filepath = os.path.join(dataset_path, filename)
            parts = filename[:-3].split('_')

            if len(parts) >= 8:
                try:
                    label_str = parts[2]  # TP or FP
                    svtype = parts[6]     # INS, DEL, etc.

                    if label_str in ['TP', 'FP']:
                        all_file_info.append({
                            'filename': filename,
                            'filepath': filepath,
                            'dataset': dataset_name,
                            'label_str': label_str,
                            'svtype': svtype,
                            'binary_label': 1 if label_str == 'TP' else 0,
                        })
                except (ValueError, IndexError):
                    continue

    print(f"   Total files found: {len(all_file_info)}")

    # Show distribution
    tp_count = sum(1 for x in all_file_info if x['label_str'] == 'TP')
    fp_count = len(all_file_info) - tp_count
    print(f"   TP: {tp_count:,} samples ({100*tp_count/len(all_file_info):.1f}%)")
    print(f"   FP: {fp_count:,} samples ({100*fp_count/len(all_file_info):.1f}%)")

    return all_file_info

def extract_and_save_diffusion_latents(max_samples=None):
    """Extract and save diffusion latent vectors for all data"""

    print(f"EXTRACTING AND SAVING DIFFUSION LATENT VECTORS")
    print("="*60)

    # Load model
    model, schedule = load_trained_diffusion_model()
    if model is None:
        return None, None, None

    # Load data
    all_file_info = load_all_genomic_data()

    # Limit samples if specified
    if max_samples and len(all_file_info) > max_samples:
        import random
        random.seed(42)
        all_file_info = random.sample(all_file_info, max_samples)
        print(f"Limited to {max_samples} samples for testing")

    # Create dataset and dataloader
    def collate_fn(batch):
        images = torch.stack([item[0] for item in batch])
        infos = [item[1] for item in batch]
        successes = [item[2] for item in batch]
        return images, infos, successes

    dataset = FeatureExtractionDataset(all_file_info)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn)

    # Create extractor
    extractor = RawLatentExtractor(model, schedule, device)
    extractor.register_hooks_for_raw_vectors()
    # Test extraction to determine feature dimension
    test_img = torch.zeros(1, 3, 224, 224).to(device) * 2.0 - 1.0
    test_features = extractor.extract_raw_latents_single(test_img)
    n_features = sum(len(v) for v in test_features.values())
    print(f"   Raw latent vectors: {n_features} features per sample")

    # Prepare storage
    total_samples = len(all_file_info)
    features_array = np.zeros((total_samples, n_features), dtype=np.float32)
    metadata_list = []

    sample_idx = 0
    failed_count = 0

    print(f"Processing {total_samples:,} samples...")

    with torch.no_grad():
        for batch_images, batch_infos, batch_successes in tqdm(dataloader, desc="Extracting features"):
            batch_images = batch_images.to(device)

            # Process each image individually
            for i, (image, info, success) in enumerate(zip(batch_images, batch_infos, batch_successes)):
                if sample_idx >= total_samples:
                    break

                if success:
                    try:
                        img = image.unsqueeze(0)  # Add batch dimension

                        features_dict = extractor.extract_raw_latents_single(img, timestep=25)
                        # Flatten all vectors
                        features_vector = np.concatenate([features_dict[name] for name in sorted(features_dict.keys())])

                        features_array[sample_idx] = features_vector
                        extraction_success = True

                    except Exception as e:
                        print(f"Feature extraction failed for {info['filename']}: {e}")
                        features_array[sample_idx] = np.zeros(n_features)
                        failed_count += 1
                        extraction_success = False
                else:
                    features_array[sample_idx] = np.zeros(n_features)
                    failed_count += 1
                    extraction_success = False

                # Save metadata
                metadata_list.append({
                    'sample_idx': sample_idx,
                    'filename': info['filename'],
                    'dataset': info['dataset'],
                    'label_str': info['label_str'],
                    'svtype': info['svtype'],
                    'binary_label': info['binary_label'],
                    'extraction_success': extraction_success,
                    'filepath': info['filepath']
                })

                sample_idx += 1

    extractor.clear_hooks()

    print(f"Feature extraction complete!")
    print(f"   Successfully processed: {total_samples - failed_count:,}/{total_samples:,} samples")
    print(f"   Failed extractions: {failed_count:,}")

    # Save features to HDF5
    features_file = os.path.join(LATENTS_DIR, f'diffusion_latent_features.h5')
    print(f"Saving features to: {features_file}")

    with h5py.File(features_file, 'w') as f:
        f.create_dataset('features', data=features_array, compression='gzip')
        f.attrs['model_path'] = MODEL_PATH
        f.attrs['feature_dim'] = n_features
        f.attrs['total_samples'] = total_samples
        f.attrs['extraction_date'] = datetime.now().isoformat()

    # Save metadata to CSV
    metadata_file = os.path.join(LATENTS_DIR, f'diffusion_latent_metadata.csv')
    print(f"Saving metadata to: {metadata_file}")

    metadata_df = pd.DataFrame(metadata_list)
    metadata_df.to_csv(metadata_file, index=False)

    # Create summary
    summary = {
        'extraction_date': datetime.now().isoformat(),
        'model_path': MODEL_PATH,
        'total_samples': int(total_samples),
        'successful_extractions': int(total_samples - failed_count),
        'failed_extractions': int(failed_count),
        'feature_dimension': int(n_features),
        'features_file': features_file,
        'metadata_file': metadata_file,
        'datasets': {k: int(v) for k, v in metadata_df['dataset'].value_counts().items()},
        'labels': {k: int(v) for k, v in metadata_df['label_str'].value_counts().items()},
        'sv_types': {k: int(v) for k, v in metadata_df['svtype'].value_counts().items()}
    }

    summary_file = os.path.join(LATENTS_DIR, f'diffusion_latent_summary.json')
    print(f"Saving summary to: {summary_file}")

    with open(summary_file, 'w') as f:
        json.dump(summary, f, indent=2)

    print(f"\nDIFFUSION FEATURE EXTRACTION COMPLETE!")
    print(f"Files saved in: {LATENTS_DIR}")
    print(f"   Features: diffusion_latent_features.h5 ({features_array.nbytes / 1024**2:.1f} MB)")
    print(f"   Metadata: diffusion_latent_metadata.csv")
    print(f"   Summary: diffusion_latent_summary.json")

    return features_array, metadata_df, summary

def load_saved_diffusion_features():
    """Load previously saved diffusion latent features"""

    features_file = os.path.join(LATENTS_DIR, f'diffusion_latent_features.h5')
    metadata_file = os.path.join(LATENTS_DIR, f'diffusion_latent_metadata.csv')

    if not os.path.exists(features_file):
        print(f"Features file not found: {features_file}")
        print("Please run feature extraction first")
        return None, None, None

    print(f"Loading saved diffusion latent features...")

    # Load features
    with h5py.File(features_file, 'r') as f:
        features = f['features'][:]
        attrs = dict(f.attrs)

    # Load metadata
    metadata_df = pd.read_csv(metadata_file)

    print(f"   Loaded {len(features):,} diffusion feature vectors")
    print(f"   Feature dimension: {features.shape[1]}")

    return features, metadata_df, attrs

In [None]:
# Classification Functions

def train_classifiers_on_latents(test_size=0.2):
    """Train multiple classifiers on diffusion latent features"""

    print(f"TRAINING CLASSIFIERS ON DIFFUSION LATENTS")
    print("="*60)

    # Load features
    features, metadata_df, attrs = load_saved_diffusion_features()
    if features is None:
        return None

    # Prepare data
    X = features
    y = metadata_df['binary_label'].values

    print(f"Dataset:")
    print(f"   Total samples: {len(X):,}")
    print(f"   Features: {X.shape[1]}")
    print(f"   TP samples: {sum(y == 1):,}")
    print(f"   FP samples: {sum(y == 0):,}")

    # Train/test split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=42, stratify=y
    )

    print(f"\nTrain/Test split:")
    print(f"   Training: {len(X_train):,} samples")
    print(f"   Testing: {len(X_test):,} samples")

    # Scale features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # Define classifiers to test
    classifiers = {
        'Logistic Regression': LogisticRegression(
            random_state=42, max_iter=1000, class_weight='balanced'
        ),
        'Random Forest': RandomForestClassifier(
            n_estimators=100, random_state=42, n_jobs=-1, class_weight='balanced'
        ),
        'Random Forest (Large)': RandomForestClassifier(
            n_estimators=300, max_depth=15, random_state=42, n_jobs=-1, class_weight='balanced'
        )
    }

    results = {}

    print(f"\nTraining classifiers...")

    for name, classifier in classifiers.items():
        print(f"   Training {name}...")

        # Train
        classifier.fit(X_train_scaled, y_train)

        # Evaluate
        y_pred = classifier.predict(X_test_scaled)
        y_pred_proba = classifier.predict_proba(X_test_scaled)[:, 1]

        # Calculate metrics
        f1 = f1_score(y_test, y_pred)
        auc = roc_auc_score(y_test, y_pred_proba)
        precision, recall, _, _ = precision_recall_fscore_support(y_test, y_pred, average='binary')
        accuracy = accuracy_score(y_test, y_pred)

        results[name] = {
            'f1_score': f1,
            'auc': auc,
            'precision': precision,
            'recall': recall,
            'accuracy': accuracy,
            'classifier': classifier,
            'y_pred': y_pred,
            'y_pred_proba': y_pred_proba
        }

        print(f"      F1: {f1:.3f}, AUC: {auc:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}")

    # Save results
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    results_summary = {
        'timestamp': timestamp,
        'feature_dimension': X.shape[1],
        'training_samples': len(X_train),
        'test_samples': len(X_test),
        'classifier_results': {}
    }

    for name, result in results.items():
        results_summary['classifier_results'][name] = {
            'f1_score': float(result['f1_score']),
            'auc': float(result['auc']),
            'precision': float(result['precision']),
            'recall': float(result['recall']),
            'accuracy': float(result['accuracy'])
        }

    results_file = os.path.join(SAVE_DIR, f'diffusion_latents_classification_{timestamp}.json')
    with open(results_file, 'w') as f:
        json.dump(results_summary, f, indent=2)

    print(f"\nResults saved: {results_file}")

    return results, X_test, y_test, scaler

def compare_with_threshold_optimization(latent_results):
    """Compare latent feature results with threshold optimization"""

    print(f"\nCOMPARISON WITH THRESHOLD OPTIMIZATION")
    print("="*50)

    # Load threshold optimization results
    import glob
    threshold_files = glob.glob(os.path.join(SAVE_DIR, 'threshold_optimization_results_*.json'))

    if not threshold_files:
        print("No threshold optimization results found")
        print("Please run Notebook 2 (threshold optimization) first")
        return

    latest_threshold_file = max(threshold_files, key=os.path.getctime)
    with open(latest_threshold_file, 'r') as f:
        threshold_results = json.load(f)

    threshold_f1 = threshold_results['optimal_threshold_results']['f1']['f1']

    print(f"Threshold Optimization F1: {threshold_f1:.3f}")
    print(f"\nDiffusion Latent Features:")

    best_latent_f1 = 0
    best_classifier = ""

    for name, result in latent_results.items():
        if isinstance(result, dict) and 'f1_score' in result:
            f1 = result['f1_score']
            improvement = f1 - threshold_f1
            print(f"   {name}: F1={f1:.3f} (diff: {improvement:+.3f})")

            if f1 > best_latent_f1:
                best_latent_f1 = f1
                best_classifier = name

    print(f"\nBEST APPROACH:")
    improvement = best_latent_f1 - threshold_f1

    if improvement > 0.005:  # Meaningful improvement
        print(f"   LATENT FEATURES WIN!")
        print(f"   Best: {best_classifier} with F1={best_latent_f1:.3f}")
        print(f"   Improvement: +{improvement:.3f} ({improvement/threshold_f1*100:.1f}%)")
        print(f"   Conclusion: Diffusion internal representations are better features")
    elif improvement > -0.005:  # Essentially tied
        print(f"   ESSENTIALLY TIED")
        print(f"   Threshold F1: {threshold_f1:.3f}")
        print(f"   Latent F1: {best_latent_f1:.3f}")
        print(f"   Conclusion: Both approaches are equally effective")
    else:
        print(f"   THRESHOLD OPTIMIZATION WINS")
        print(f"   Threshold F1: {threshold_f1:.3f}")
        print(f"   Best Latent F1: {best_latent_f1:.3f}")
        print(f"   Conclusion: Simple threshold optimization is better")

def create_latent_results_visualization(latent_results):
    """Create visualization comparing latent feature classifiers"""

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Extract results for plotting
    classifiers = []
    f1_scores = []
    auc_scores = []

    for name, result in latent_results.items():
        if isinstance(result, dict) and 'f1_score' in result:
            classifiers.append(name)
            f1_scores.append(result['f1_score'])
            auc_scores.append(result['auc'])

    # Plot 1: F1 Scores
    bars1 = ax1.bar(classifiers, f1_scores, alpha=0.7, color=['lightblue', 'lightgreen', 'lightcoral'])
    ax1.set_ylabel('F1-Score')
    ax1.set_title(f'F1-Score Comparison\nDiffusion Latent Features')
    ax1.set_ylim(0, 1)
    ax1.tick_params(axis='x', rotation=45)

    # Add value labels on bars
    for bar, score in zip(bars1, f1_scores):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{score:.3f}', ha='center', va='bottom', fontweight='bold')

    # Plot 2: AUC Scores
    bars2 = ax2.bar(classifiers, auc_scores, alpha=0.7, color=['lightblue', 'lightgreen', 'lightcoral'])
    ax2.set_ylabel('AUC Score')
    ax2.set_title(f'AUC Comparison\nDiffusion Latent Features')
    ax2.set_ylim(0, 1)
    ax2.tick_params(axis='x', rotation=45)

    # Add value labels on bars
    for bar, score in zip(bars2, auc_scores):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{score:.3f}', ha='center', va='bottom', fontweight='bold')

    plt.tight_layout()

    # Save plot
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    plot_file = os.path.join(FIGURES_DIR, f'diffusion_latents_comparison_{timestamp}.png')
    plt.savefig(plot_file, dpi=300, bbox_inches='tight')
    plt.show()

    print(f"Visualization saved: {plot_file}")

    return fig

In [None]:
# Main execution functions

def run_full_latent_pipeline(max_samples=None):
    """Run the complete latent feature pipeline"""

    print(f"RUNNING FULL DIFFUSION LATENT PIPELINE")
    print("="*70)

    # Step 1: Extract and save features (if not already done)
    print("STEP 1: Extract and save diffusion latent features")
    features_file = os.path.join(LATENTS_DIR, f'diffusion_latent_features.h5')

    if os.path.exists(features_file):
        print("   Features already extracted, loading from disk...")
        features, metadata_df, attrs = load_saved_diffusion_features()
    else:
        print("   Extracting features from diffusion model...")
        features, metadata_df, summary = extract_and_save_diffusion_latents(max_samples)

    if features is None:
        print("Feature extraction failed")
        return None

    # Step 2: Train classifiers
    print("\nSTEP 2: Train classifiers on latent features")
    latent_results, X_test, y_test, scaler = train_classifiers_on_latents()

    if latent_results is None:
        print("Classifier training failed")
        return None

    # Step 3: Compare with threshold optimization
    print("\nSTEP 3: Compare with threshold optimization")
    compare_with_threshold_optimization(latent_results)

    # Step 4: Create visualizations
    print("\nSTEP 4: Create visualizations")
    fig = create_latent_results_visualization(latent_results)

    print(f"\nFULL LATENT PIPELINE COMPLETE!")
    print(f"Feature dimension: {features.shape[1]}")
    print(f"Total samples: {len(features):,}")

    return latent_results, features, metadata_df

In [None]:
print("EXTRACTION AND SAVING:")
print("   # Extract raw latent vectors")
print("   features, metadata, summary = extract_and_save_diffusion_latents()")
print()
print("LOADING AND CLASSIFICATION:")
print("   # Load saved features and train classifiers")
print("   results = train_classifiers_on_latents()")
print()
print("FULL PIPELINE:")
print("   # Run complete pipeline")
print("   results, features, metadata = run_full_latent_pipeline()")
print()
print("   # Test with limited samples first")
print("   results, features, metadata = run_full_latent_pipeline(max_samples=5000)")

In [None]:
results, features, metadata = run_full_latent_pipeline()