Multimodal Ensemble Training

Fusion of multiple feature modalities for genomic structural variant classification

Models tested:
- Attention Fusion: Multi-head attention across Diffusion + Linear features
- Diffusion Only: Diffusion features baseline
- Linear Only: Linear/scalar features baseline

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import h5py
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score, classification_report
import json
import pickle
from datetime import datetime
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Configuration - GitHub repository structure
DATA_DIR = '../data/processed'
SAVE_DIR = '../data/processed/multimodal_experiments'
MODELS_DIR = '../data/processed/multimodal_experiments/models'
FIGURES_DIR = '../figures'

# Create directories
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(FIGURES_DIR, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")

print(f"Save directory: {SAVE_DIR}")

In [None]:
# Data loading and alignment

def load_all_modalities():
    """Load all feature modalities"""

    print("Loading all feature modalities...")

    modalities = {}

    # Load ResNet features
    resnet_features_path = os.path.join(DATA_DIR, 'resnet_latents/resnet50_features.h5')
    resnet_metadata_path = os.path.join(DATA_DIR, 'resnet_latents/resnet50_metadata.csv')

    if os.path.exists(resnet_features_path) and os.path.exists(resnet_metadata_path):
        with h5py.File(resnet_features_path, 'r') as f:
            modalities['resnet_features'] = f['features'][:]
        modalities['resnet_metadata'] = pd.read_csv(resnet_metadata_path)
        print(f"   ResNet: {modalities['resnet_features'].shape}")

    # Load VICReg features
    vicreg_features_path = os.path.join(DATA_DIR, 'vicreg_latents/vicreg_features.h5')
    vicreg_metadata_path = os.path.join(DATA_DIR, 'vicreg_latents/vicreg_metadata.csv')

    if os.path.exists(vicreg_features_path) and os.path.exists(vicreg_metadata_path):
        with h5py.File(vicreg_features_path, 'r') as f:
            modalities['vicreg_features'] = f['features'][:]
        modalities['vicreg_metadata'] = pd.read_csv(vicreg_metadata_path)
        print(f"   VICReg: {modalities['vicreg_features'].shape}")

    # Load Diffusion features
    diffusion_features_path = os.path.join(DATA_DIR, 'diffusion_latents/diffusion_raw_latents_features.h5')
    diffusion_metadata_path = os.path.join(DATA_DIR, 'diffusion_latents/diffusion_raw_latents_metadata.csv')

    if os.path.exists(diffusion_features_path) and os.path.exists(diffusion_metadata_path):
        with h5py.File(diffusion_features_path, 'r') as f:
            modalities['diffusion_features'] = f['features'][:]
        modalities['diffusion_metadata'] = pd.read_csv(diffusion_metadata_path)
        print(f"   Diffusion: {modalities['diffusion_features'].shape}")

    # Load SAE features
    sae_features_path = os.path.join(DATA_DIR, 'sae_latents/sae_latents_combined.pt')

    if os.path.exists(sae_features_path):
        sae_data = torch.load(sae_features_path, map_location='cpu', weights_only=False)
        modalities['sae_data'] = sae_data
        print(f"   SAE: {sae_data['dense_features'].shape}")

    # Load scalar features
    scalar_features_path = os.path.join(DATA_DIR, 'SV_Features_CLEANED_Dataset.csv')

    if os.path.exists(scalar_features_path):
        modalities['scalar_df'] = pd.read_csv(scalar_features_path)
        print(f"   Scalar: {modalities['scalar_df'].shape}")

    return modalities

def build_coordinate_aligned_dataset(modalities):
    """Build precisely aligned multimodal dataset using SAE-ResNet coordinate matching"""

    print("Building coordinate-aligned dataset...")

    # Extract SAE data
    sae_data = modalities['sae_data']
    sae_variants = sae_data['sv_info']
    sae_features = sae_data['dense_features']
    sae_labels = sae_data['labels']

    # Use ResNet metadata for GRCh38 subset
    resnet_metadata = modalities['resnet_metadata']
    grch38_mask = resnet_metadata['dataset'].str.contains('GRCh38', na=False)
    grch38_resnet_metadata = resnet_metadata[grch38_mask].copy().reset_index(drop=True)

    # Build ResNet lookup with correct filename parsing
    resnet_lookup = {}
    for idx, row in grch38_resnet_metadata.iterrows():
        filename = row['filename']
        parts = filename.split('_')
        if len(parts) >= 8:
            dataset = f"{parts[0]}_{parts[1]}"  # HG002_GRCh38
            label = parts[2]  # TP or FP
            chrom = parts[3]  # chr1
            pos = int(parts[4])  # position
            svtype = parts[6]  # INS/DEL/etc
            end = int(parts[5])  # end position

            key = (dataset, label, chrom, pos, svtype)
            resnet_lookup[key] = {
                'resnet_idx': idx,
                'filename': filename,
                'end': end,
                'row': row
            }

    # Find exact coordinate matches
    matched_pairs = []
    for sae_idx, sae_variant in enumerate(sae_variants):
        # Skip non-GRCh38 SAE variants
        if not sae_variant['dataset'].endswith('GRCh38'):
            continue

        # Extract SAE variant info
        sae_dataset = sae_variant['dataset']
        sae_chrom = sae_variant['chrom']
        sae_pos = sae_variant['pos']
        sae_end = sae_variant['end']
        sae_svtype = sae_variant['svtype']
        sae_label = 'TP' if sae_variant.get('truvari_class') == 'tp_comp_vcf' else 'FP'

        # Look for ResNet match
        key = (sae_dataset, sae_label, sae_chrom, sae_pos, sae_svtype)

        if key in resnet_lookup:
            candidate = resnet_lookup[key]
            # Check coordinate tolerance (±1 bp for end position)
            if abs(candidate['end'] - sae_end) <= 1:
                matched_pairs.append({
                    'sae_idx': sae_idx,
                    'resnet_idx': candidate['resnet_idx'],
                    'sae_variant': sae_variant,
                    'resnet_row': candidate['row']
                })

    n_matches = len(matched_pairs)
    print(f"   Found {n_matches:,} precise coordinate matches")

    if n_matches < 10000:
        print(f"   Warning: Low match count. Expected ~40k matches.")
        return None

    # Extract matched indices
    sae_matched_indices = [pair['sae_idx'] for pair in matched_pairs]
    resnet_matched_indices = [pair['resnet_idx'] for pair in matched_pairs]

    # Build aligned feature dictionary
    aligned_features = {}

    # SAE features
    aligned_features['sae'] = sae_features[sae_matched_indices]

    # ResNet features
    grch38_global_indices = grch38_resnet_metadata.index[resnet_matched_indices].values
    aligned_features['resnet'] = modalities['resnet_features'][grch38_global_indices]

    # VICReg features (if available)
    if 'vicreg_features' in modalities:
        aligned_features['vicreg'] = modalities['vicreg_features'][grch38_global_indices]

    # Diffusion features (if available)
    if 'diffusion_features' in modalities:
        aligned_features['diffusion'] = modalities['diffusion_features'][grch38_global_indices]

    # Linear/Scalar features (coordinate-based matching)
    if 'scalar_df' in modalities:
        scalar_df = modalities['scalar_df']
        aligned_scalar_features = []

        for pair in matched_pairs:
            sae_variant = pair['sae_variant']
            # Find scalar match by coordinate
            scalar_matches = scalar_df[
                (scalar_df['dataset'] == sae_variant['dataset']) &
                (scalar_df['chrom'] == sae_variant['chrom']) &
                (abs(scalar_df['pos'] - sae_variant['pos']) <= 1)
            ]

            if len(scalar_matches) > 0:
                # Extract numeric features
                numeric_cols = scalar_matches.select_dtypes(include=[np.number]).columns
                feature_cols = [col for col in numeric_cols if not any(id_word in col.lower()
                               for id_word in ['pos', 'end', 'chrom'])]
                scalar_row = scalar_matches[feature_cols].iloc[0].fillna(0).values
            else:
                # Fallback - use zeros
                scalar_row = np.zeros(16)  # Standard size including raw svlen

            aligned_scalar_features.append(scalar_row)

        aligned_features['linear'] = np.array(aligned_scalar_features, dtype=np.float32)

    # Extract aligned labels
    aligned_labels = sae_labels[sae_matched_indices]

    print(f"Alignment complete:")
    print(f"   Aligned samples: {len(aligned_labels):,}")
    print(f"   TP: {np.sum(aligned_labels):,} ({np.mean(aligned_labels)*100:.1f}%)")
    print(f"   FP: {len(aligned_labels) - np.sum(aligned_labels):,} ({(1-np.mean(aligned_labels))*100:.1f}%)")
    print(f"   Feature dimensions:")
    for name, features in aligned_features.items():
        print(f"      {name}: {features.shape}")

    return {
        'features': aligned_features,
        'labels': aligned_labels,
        'matched_pairs': matched_pairs,
        'n_matches': n_matches
    }

def fix_linear_features(aligned_data):
    """Fix linear features by removing raw svlen (first column)"""

    print("Fixing linear features: removing raw svlen...")
    print(f"   Before: {aligned_data['features']['linear'].shape}")

    # Remove first column (raw svlen)
    aligned_data['features']['linear'] = aligned_data['features']['linear'][:, 1:]

    print(f"   After: {aligned_data['features']['linear'].shape}")
    print("   Now have 15 clean linear features (removed raw svlen)")

    return aligned_data

In [None]:
# Model architectures

class AttentionFusion(nn.Module):
    """Multi-head attention-based fusion for Diffusion + Linear"""

    def __init__(self, diffusion_dim=2816, linear_dim=15, hidden_dim=256, num_heads=8, num_classes=2):
        super().__init__()

        # Modality encoders
        self.diffusion_encoder = nn.Sequential(
            nn.Linear(diffusion_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        self.linear_encoder = nn.Sequential(
            nn.Linear(linear_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        # Multi-head attention
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            dropout=0.3,
            batch_first=True
        )

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, diffusion_features, linear_features):
        # Encode modalities
        diffusion_encoded = self.diffusion_encoder(diffusion_features)
        linear_encoded = self.linear_encoder(linear_features)

        # Stack for attention
        modality_stack = torch.stack([diffusion_encoded, linear_encoded], dim=1)

        # Attention across modalities
        attended, attention_weights = self.attention(
            modality_stack, modality_stack, modality_stack
        )

        # Classify
        output = self.classifier(attended.reshape(attended.size(0), -1))
        return output, attention_weights

class DiffusionOnly(nn.Module):
    """Diffusion features only baseline"""

    def __init__(self, diffusion_dim=2816, num_classes=2):
        super().__init__()

        self.classifier = nn.Sequential(
            nn.Linear(diffusion_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, diffusion_features, linear_features):
        output = self.classifier(diffusion_features)
        return output, None

class LinearOnly(nn.Module):
    """Linear features only baseline"""

    def __init__(self, linear_dim=15, num_classes=2):
        super().__init__()

        self.classifier = nn.Sequential(
            nn.Linear(linear_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, num_classes)
        )

    def forward(self, diffusion_features, linear_features):
        output = self.classifier(linear_features)
        return output, None


In [None]:
# Dataset and training

class MultimodalDataset(Dataset):
    def __init__(self, diffusion_features, linear_features, labels):
        self.diffusion_features = diffusion_features
        self.linear_features = linear_features
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return (
            self.diffusion_features[idx],
            self.linear_features[idx],
            self.labels[idx]
        )

def train_neural_model(model, train_loader, val_loader, epochs=50):
    """Train a neural model with early stopping"""

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', patience=5, factor=0.5, verbose=False
    )

    best_val_acc = 0
    best_val_auc = 0
    patience = 15
    patience_counter = 0

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        for diffusion_batch, linear_batch, labels_batch in train_loader:
            diffusion_batch = diffusion_batch.to(device)
            linear_batch = linear_batch.to(device)
            labels_batch = labels_batch.to(device)

            optimizer.zero_grad()
            outputs, _ = model(diffusion_batch, linear_batch)
            loss = criterion(outputs, labels_batch)
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()
            train_loss += loss.item()

        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        val_probs = []
        val_targets = []

        with torch.no_grad():
            for diffusion_batch, linear_batch, labels_batch in val_loader:
                diffusion_batch = diffusion_batch.to(device)
                linear_batch = linear_batch.to(device)
                labels_batch = labels_batch.to(device)

                outputs, _ = model(diffusion_batch, linear_batch)
                _, predicted = outputs.max(1)
                val_total += labels_batch.size(0)
                val_correct += predicted.eq(labels_batch).sum().item()

                probs = F.softmax(outputs, dim=1)
                val_probs.extend(probs[:, 1].cpu().numpy())
                val_targets.extend(labels_batch.cpu().numpy())

        val_acc = val_correct / val_total
        val_auc = roc_auc_score(val_targets, val_probs)

        # Learning rate scheduling
        scheduler.step(val_acc)
        current_lr = optimizer.param_groups[0]['lr']

        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_val_auc = val_auc
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience or current_lr < 1e-7:
            break

    return best_val_acc, best_val_auc

def calculate_comprehensive_metrics(y_true, y_pred, y_proba):
    """Calculate all metrics including precision, recall, F1"""

    metrics = {}

    # Basic metrics
    metrics['accuracy'] = accuracy_score(y_true, y_pred)
    metrics['precision'] = precision_score(y_true, y_pred, zero_division=0)
    metrics['recall'] = recall_score(y_true, y_pred, zero_division=0)
    metrics['f1'] = f1_score(y_true, y_pred, zero_division=0)

    # AUC (requires probabilities)
    if y_proba is not None and len(np.unique(y_true)) > 1:
        metrics['auc'] = roc_auc_score(y_true, y_proba)
    else:
        metrics['auc'] = 0.5

    return metrics

In [None]:
# Main experiment function

def run_multimodal_experiment(aligned_data):
    """Run comprehensive multimodal architecture comparison"""

    print("Running multimodal experiment...")

    if aligned_data is None:
        print("No aligned data provided!")
        return None

    features_dict = aligned_data['features']
    labels = aligned_data['labels']

    # Extract Diffusion and Linear features
    if 'diffusion' not in features_dict or 'linear' not in features_dict:
        print("Missing Diffusion or Linear features!")
        return None

    diffusion_features = features_dict['diffusion']
    linear_features = features_dict['linear']

    print(f"Dataset: {len(labels):,} samples")
    print(f"Diffusion features: {diffusion_features.shape[1]:,}")
    print(f"Linear features: {linear_features.shape[1]:,}")
    print(f"Class distribution: TP={np.sum(labels==1)}, FP={np.sum(labels==0)}")

    # Architecture configurations
    architectures = {
        'Attention_Fusion': lambda: AttentionFusion(diffusion_features.shape[1], linear_features.shape[1]),
        'Diffusion_Only': lambda: DiffusionOnly(diffusion_features.shape[1]),
        'Linear_Only': lambda: LinearOnly(linear_features.shape[1])
    }

    # Cross-validation setup
    seeds = [42, 123, 456, 789, 999]
    n_splits = 5

    all_results = []

    # Test neural architectures
    for arch_name, arch_factory in architectures.items():
        print(f"\nTesting {arch_name}")
        print("-" * 50)

        arch_results = []

        for seed_idx, seed in enumerate(seeds):
            print(f"   Seed {seed} ({seed_idx+1}/{len(seeds)})...")

            # Set seeds
            torch.manual_seed(seed)
            np.random.seed(seed)

            # Cross-validation
            skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
            fold_results = []

            for fold, (train_idx, val_idx) in enumerate(skf.split(diffusion_features, labels)):

                # Checkpoint path
                checkpoint_path = os.path.join(
                    MODELS_DIR,
                    f"multimodal_{arch_name}_seed{seed}_fold{fold}.pt"
                )

                # Check if already computed
                if os.path.exists(checkpoint_path):
                    checkpoint = torch.load(checkpoint_path, map_location='cpu')
                    fold_results.append({
                        'accuracy': checkpoint['val_accuracy'],
                        'auc': checkpoint['val_auc'],
                        'precision': checkpoint.get('precision', 0),
                        'recall': checkpoint.get('recall', 0),
                        'f1': checkpoint.get('f1', 0)
                    })
                    continue

                # Create model
                model = arch_factory().to(device)

                # Prepare data
                X_diffusion_train, X_diffusion_val = diffusion_features[train_idx], diffusion_features[val_idx]
                X_linear_train, X_linear_val = linear_features[train_idx], linear_features[val_idx]
                y_train, y_val = labels[train_idx], labels[val_idx]

                # Scale features
                diffusion_scaler = StandardScaler()
                linear_scaler = StandardScaler()

                X_diffusion_train_scaled = diffusion_scaler.fit_transform(X_diffusion_train)
                X_diffusion_val_scaled = diffusion_scaler.transform(X_diffusion_val)

                X_linear_train_scaled = linear_scaler.fit_transform(X_linear_train)
                X_linear_val_scaled = linear_scaler.transform(X_linear_val)

                # Convert to tensors
                X_diffusion_train_tensor = torch.FloatTensor(X_diffusion_train_scaled)
                X_diffusion_val_tensor = torch.FloatTensor(X_diffusion_val_scaled)
                X_linear_train_tensor = torch.FloatTensor(X_linear_train_scaled)
                X_linear_val_tensor = torch.FloatTensor(X_linear_val_scaled)
                y_train_tensor = torch.LongTensor(y_train)
                y_val_tensor = torch.LongTensor(y_val)

                # Create data loaders
                train_dataset = MultimodalDataset(X_diffusion_train_tensor, X_linear_train_tensor, y_train_tensor)
                val_dataset = MultimodalDataset(X_diffusion_val_tensor, X_linear_val_tensor, y_val_tensor)

                train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
                val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

                # Train model
                val_acc, val_auc = train_neural_model(model, train_loader, val_loader, epochs=50)

                # Calculate additional metrics
                model.eval()
                all_preds = []
                all_probs = []
                all_targets = []

                with torch.no_grad():
                    for diffusion_batch, linear_batch, labels_batch in val_loader:
                        diffusion_batch = diffusion_batch.to(device)
                        linear_batch = linear_batch.to(device)

                        outputs, _ = model(diffusion_batch, linear_batch)
                        _, predicted = outputs.max(1)
                        probs = F.softmax(outputs, dim=1)

                        all_preds.extend(predicted.cpu().numpy())
                        all_probs.extend(probs[:, 1].cpu().numpy())
                        all_targets.extend(labels_batch.numpy())

                # Calculate comprehensive metrics
                metrics = calculate_comprehensive_metrics(
                    np.array(all_targets),
                    np.array(all_preds),
                    np.array(all_probs)
                )

                # Save checkpoint
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'val_accuracy': val_acc,
                    'val_auc': val_auc,
                    'precision': metrics['precision'],
                    'recall': metrics['recall'],
                    'f1': metrics['f1'],
                    'diffusion_scaler': diffusion_scaler,
                    'linear_scaler': linear_scaler,
                    'arch_name': arch_name,
                    'seed': seed,
                    'fold': fold
                }, checkpoint_path)

                fold_results.append({
                    'accuracy': val_acc,
                    'auc': val_auc,
                    'precision': metrics['precision'],
                    'recall': metrics['recall'],
                    'f1': metrics['f1']
                })

            # Average across folds for this seed
            seed_metrics = {}
            for metric in ['accuracy', 'auc', 'precision', 'recall', 'f1']:
                seed_metrics[metric] = np.mean([r[metric] for r in fold_results])

            arch_results.append(seed_metrics)

            print(f"      Average: {seed_metrics['accuracy']:.4f} acc, {seed_metrics['f1']:.4f} F1")

        # Compile results for this architecture
        final_metrics = {}
        for metric in ['accuracy', 'auc', 'precision', 'recall', 'f1']:
            values = [r[metric] for r in arch_results]
            final_metrics[f'{metric}_mean'] = np.mean(values)
            final_metrics[f'{metric}_std'] = np.std(values)

        all_results.append({
            'Architecture': arch_name,
            'Type': 'Neural',
            **final_metrics,
            'Raw_Results': arch_results
        })

        print(f"   Final: {final_metrics['f1_mean']:.4f} ± {final_metrics['f1_std']:.4f} F1")

    # Save results
    results_path = os.path.join(SAVE_DIR, "multimodal_results.pkl")
    with open(results_path, 'wb') as f:
        pickle.dump(all_results, f)

    print(f"\nMultimodal results saved to {results_path}")

    return all_results

def analyze_multimodal_results(results):
    """Analyze and display multimodal results"""

    if not results:
        print("No results to analyze!")
        return

    print("\nMULTIMODAL EXPERIMENT RESULTS")
    print("="*60)

    # Sort by F1 score
    results_sorted = sorted(results, key=lambda x: x['f1_mean'], reverse=True)

    for i, result in enumerate(results_sorted):
        print(f"\n{i+1}. {result['Architecture']}")
        print(f"   F1 Score:    {result['f1_mean']:.4f} ± {result['f1_std']:.4f}")
        print(f"   Precision:   {result['precision_mean']:.4f} ± {result['precision_std']:.4f}")
        print(f"   Recall:      {result['recall_mean']:.4f} ± {result['recall_std']:.4f}")
        print(f"   Accuracy:    {result['accuracy_mean']:.4f} ± {result['accuracy_std']:.4f}")
        print(f"   AUC:         {result['auc_mean']:.4f} ± {result['auc_std']:.4f}")

    # Best model
    best_model = results_sorted[0]
    print(f"\nBEST MODEL: {best_model['Architecture']}")
    print(f"   F1 Score: {best_model['f1_mean']:.4f} ± {best_model['f1_std']:.4f}")

    # Determine if fusion helps
    attention_fusion = next((r for r in results if r['Architecture'] == 'Attention_Fusion'), None)
    diffusion_only = next((r for r in results if r['Architecture'] == 'Diffusion_Only'), None)
    linear_only = next((r for r in results if r['Architecture'] == 'Linear_Only'), None)

    if attention_fusion and diffusion_only and linear_only:
        print(f"\nFUSION ANALYSIS:")
        fusion_f1 = attention_fusion['f1_mean']
        diffusion_f1 = diffusion_only['f1_mean']
        linear_f1 = linear_only['f1_mean']

        best_single = max(diffusion_f1, linear_f1)
        improvement = fusion_f1 - best_single

        print(f"   Best single modality: {best_single:.4f}")
        print(f"   Attention Fusion:     {fusion_f1:.4f}")
        print(f"   Improvement:          {improvement:+.4f}")

        if improvement > 0.01:  # 1% improvement threshold
            print(f"   Multimodal fusion provides meaningful improvement!")
        else:
            print(f"   Single modality performance is competitive")

    return results_sorted

In [None]:
print("MAIN FUNCTIONS:")
print("   modalities = load_all_modalities()")
print("   aligned_data = build_coordinate_aligned_dataset(modalities)")
print("   aligned_data = fix_linear_features(aligned_data)")
print("   results = run_multimodal_experiment(aligned_data)")
print("   analyze_multimodal_results(results)")
print()
print("MODELS TESTED:")
print("   Attention Fusion: Multi-head attention across Diffusion + Linear")
print("   Diffusion Only: Diffusion features baseline")
print("   Linear Only: Linear/scalar features baseline")
print()
print("OUTPUT:")
print("   Cross-validated performance with precision/recall/F1")
print("   Saved models for best checkpoints")
print("   Analysis of fusion vs single-modality performance")

# To run the complete pipeline:
# modalities = load_all_modalities()
# aligned_data = build_coordinate_aligned_dataset(modalities)
# aligned_data = fix_linear_features(aligned_data)
# results = run_multimodal_experiment(aligned_data)
# analyze_multimodal_results(results)

In [None]:
modalities = load_all_modalities()
aligned_data = build_coordinate_aligned_dataset(modalities)
aligned_data = fix_linear_features(aligned_data)
results = run_multimodal_experiment(aligned_data)
analyze_multimodal_results(results)