# COSMOS Comprehensive Benchmark: GCN vs Graph Transformer + PE

## Complete Analysis for All 4 Datasets

**Datasets:**
- D1: Mouse Brain (ATAC + RNA)
- D2: Mouse Visual Cortex (Simulation)
- D3: Mouse Olfactory Bulb (MOB)

**Metrics (from COSMOS paper):**
- ARI (Adjusted Rand Index)
- NMI (Normalized Mutual Information)
- Silhouette Score
- Homogeneity & Completeness
- Spatial Coherence
- Pseudo-spatial Metrics (pSM) for trajectory data

**Architecture:**
- GCN (Original COSMOS)
- Graph Transformer + PE (Enhanced)

**Manual Control:**
All model hyperparameters are in Section 2 for easy modification.

---
## 1. Setup and Imports

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib
import matplotlib.pyplot as plt
from umap import UMAP
import sklearn
import sklearn.metrics
from sklearn.metrics import (
    adjusted_rand_score, 
    normalized_mutual_info_score,
    silhouette_score,
    homogeneity_score,
    completeness_score,
    v_measure_score
)
import seaborn as sns
import h5py
import warnings
warnings.filterwarnings('ignore')

# Import COSMOS
from COSMOS.cosmos import Cosmos as Cosmos_GCN
from COSMOS.cosmos_transformer_pe_version import Cosmos as Cosmos_GT_PE

print("‚úì All libraries imported successfully")
print(f"  - scanpy version: {sc.__version__}")
print(f"  - numpy version: {np.__version__}")
print(f"  - pandas version: {pd.__version__}")

---
## 2. HYPERPARAMETER CONFIGURATION

**üéõÔ∏è MODIFY THESE PARAMETERS AS NEEDED**

All model-related hyperparameters are here for easy access.

In [None]:
# ============================================================
# MANUAL HYPERPARAMETER CONFIGURATION
# ============================================================

# ------------------------------
# Dataset Selection
# ------------------------------
# Choose which dataset to run: 'D1', 'D2', 'D3', 'D4', 'D5' or 'ALL'
DATASET_TO_RUN = 'D1'  # Change this to test specific datasets

# ------------------------------
# General Training Parameters
# ------------------------------
RANDOM_SEED = 42
GPU_ID = 0
LEARNING_RATE = 1e-3
TOTAL_EPOCHS = 1000
WNN_EPOCH = 100  # When to compute WNN weights
MAX_PATIENCE_BEF = 10
MAX_PATIENCE_AFT = 30
MIN_STOP = 200
SPATIAL_REG_STRENGTH = 0.01
REGULARIZATION_ACCELERATION = True
EDGE_SUBSET_SIZE = 1000000

# ------------------------------
# Graph Construction
# ------------------------------
N_NEIGHBORS = 10  # For spatial graph construction

# ------------------------------
# GCN Hyperparameters
# ------------------------------
GCN_Z_DIM = 50  # Output embedding dimension

# ------------------------------
# Graph Transformer + PE Hyperparameters
# ------------------------------
# For D1 (MouseBrain) - Complex data
GT_PE_D1_CONFIG = {
    'z_dim': 50,
    'num_heads': 8,      # Number of attention heads
    'dropout': 0.1,      # Dropout rate
    'pe_dim': 8,         # Positional encoding dimension
    'use_pe': True       # Enable/disable PE
}

# For D2 (VisualCortex) - Layered data
GT_PE_D2_CONFIG = {
    'z_dim': 50,
    'num_heads': 2,      # Fewer heads for simple data
    'dropout': 0.3,      # Higher dropout
    'pe_dim': 0,         # No PE for layered data
    'use_pe': False      # Disable PE
}

# For D3 (MOB) - Test both configurations
GT_PE_D3_CONFIG = {
    'z_dim': 50,
    'num_heads': 8,      # Start with complex settings
    'dropout': 0.1,
    'pe_dim': 8,
    'use_pe': True
}


# For D4 (SpatialGlue Simulation) - Ground truth factor data
GT_PE_D4_CONFIG = {
    'z_dim': 50,
    'num_heads': 4,      # Moderate heads for simulated data
    'dropout': 0.2,      # Moderate dropout
    'pe_dim': 8,         # Use PE to capture spatial patterns
    'use_pe': True
}


# For D5 (Spatial Multiomics Multi-Simulated) - Complex multi-domain data
GT_PE_D5_CONFIG = {
    'z_dim': 50,
    'num_heads': 8,      # More heads for complex domain structure
    'dropout': 0.1,      # Lower dropout for larger dataset (3000 cells)
    'pe_dim': 16,        # Larger PE for bigger spatial layout
    'use_pe': True       # Enable PE for spatial awareness
}

# ------------------------------
# Evaluation Parameters (Auto-tuned)
# ------------------------------
# These will be automatically optimized during evaluation
# CLUSTERING_RESOLUTIONS = [0.3, 0.5, 0.8, 1.0, 1.2, 1.5, 2.0]  # Test range
CLUSTERING_RESOLUTIONS = [1.0]
UMAP_N_NEIGHBORS = 30
UMAP_MIN_DIST = 0.3
NEIGHBORS_FOR_CLUSTERING = 50

# ------------------------------
# Visualization Parameters
# ------------------------------
FIGURE_DPI = 300
POINT_SIZE_SPATIAL = 10
POINT_SIZE_UMAP = 5
FONT_SIZE = 10

# ============================================================
# END OF CONFIGURATION
# ============================================================

print("‚úì Hyperparameters configured")
print(f"  - Dataset to run: {DATASET_TO_RUN}")
print(f"  - Random seed: {RANDOM_SEED}")
print(f"  - Total epochs: {TOTAL_EPOCHS}")
print(f"  - WNN epoch: {WNN_EPOCH}")
print(f"\n‚úì Configuration complete - ready to run!")

---
## 3. Helper Functions for Metrics and Visualization

In [None]:
# Set random seeds
np.random.seed(RANDOM_SEED)

def compute_all_metrics(true_labels, predicted_labels, embeddings):
    """
    Compute comprehensive metrics as used in COSMOS paper.
    
    Parameters:
    - true_labels: Ground truth cluster labels
    - predicted_labels: Predicted cluster labels
    - embeddings: Cell embeddings
    
    Returns:
    - Dictionary of metrics
    """
    metrics = {}
    
    # Clustering accuracy metrics
    metrics['ARI'] = adjusted_rand_score(true_labels, predicted_labels)
    metrics['NMI'] = normalized_mutual_info_score(true_labels, predicted_labels)
    metrics['Homogeneity'] = homogeneity_score(true_labels, predicted_labels)
    metrics['Completeness'] = completeness_score(true_labels, predicted_labels)
    metrics['V-measure'] = v_measure_score(true_labels, predicted_labels)
    
    # Embedding quality
    try:
        # Silhouette score (can be slow for large datasets)
        if len(embeddings) < 5000:
            metrics['Silhouette'] = silhouette_score(embeddings, true_labels)
        else:
            # Sample for efficiency
            sample_idx = np.random.choice(len(embeddings), 5000, replace=False)
            metrics['Silhouette'] = silhouette_score(
                embeddings[sample_idx], 
                true_labels[sample_idx]
            )
    except:
        metrics['Silhouette'] = np.nan
    
    return metrics


def compute_spatial_coherence(spatial_coords, cluster_labels, k=10):
    """
    Compute spatial coherence: fraction of neighbors with same cluster label.
    Higher = better spatial organization.
    """
    from sklearn.neighbors import NearestNeighbors
    
    # Find k nearest neighbors
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(spatial_coords)
    distances, indices = nbrs.kneighbors(spatial_coords)
    
    # Compute coherence for each cell
    coherence_scores = []
    for i in range(len(cluster_labels)):
        cell_label = cluster_labels[i]
        neighbor_labels = cluster_labels[indices[i, 1:]]  # Exclude self
        coherence = (neighbor_labels == cell_label).sum() / k
        coherence_scores.append(coherence)
    
    return np.mean(coherence_scores)


def find_optimal_clustering(embedding_adata, true_labels, resolutions):
    """
    Find optimal clustering resolution by maximizing ARI.
    This is auto-tuned (not a model parameter).
    """
    best_ari = 0
    best_resolution = resolutions[0]
    best_clusters = None
    
    for res in resolutions:
        sc.tl.louvain(embedding_adata, resolution=res)
        clusters = embedding_adata.obs['louvain'].values
        ari = adjusted_rand_score(true_labels, clusters)
        
        if ari > best_ari:
            best_ari = ari
            best_resolution = res
            best_clusters = clusters.copy()
    
    return best_clusters, best_resolution, best_ari


def print_metrics_table(metrics_dict, method_name):
    """
    Print metrics in a formatted table.
    """
    print(f"\n{'='*60}")
    print(f"{method_name} - Performance Metrics")
    print(f"{'='*60}")
    print(f"{'Metric':<20} {'Value':>10}")
    print(f"{'-'*60}")
    
    for metric, value in metrics_dict.items():
        if isinstance(value, float):
            print(f"{metric:<20} {value:>10.4f}")
        else:
            print(f"{metric:<20} {value:>10}")
    
    print(f"{'='*60}\n")


def create_comparison_plot(adata, spatial_coords, true_labels, 
                           gcn_clusters, gt_pe_clusters,
                           gcn_ari, gt_pe_ari, dataset_name):
    """
    Create comprehensive comparison figure.
    """
    matplotlib.rcParams['font.size'] = FONT_SIZE
    fig, axes = plt.subplots(2, 2, figsize=(14, 11))
    
    plot_colors = ['#D1D1D1', '#e6194b', '#3cb44b', '#ffe119', '#4363d8', 
                   '#f58231', '#911eb4', '#46f0f0', '#f032e6', '#bcf60c', 
                   '#fabebe', '#008080', '#e6beff', '#9a6324', '#ffd8b1', 
                   '#800000', '#aaffc3', '#808000', '#000075', '#808080']
    
    # Plot 1: Ground Truth
    ax = axes[0, 0]
    unique_labels = np.unique(true_labels)
    for i, label in enumerate(unique_labels):
        mask = true_labels == label
        ax.scatter(spatial_coords[mask, 0], spatial_coords[mask, 1],
                   c=plot_colors[i % len(plot_colors)], label=str(label),
                   s=POINT_SIZE_SPATIAL, alpha=0.8)
    ax.set_title('Ground Truth Annotation', fontweight='bold', fontsize=12)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=False, fontsize=8)
    ax.axis('equal')
    ax.axis('off')
    
    # Plot 2: GCN Results
    ax = axes[0, 1]
    unique_clusters = np.unique(gcn_clusters)
    for i, cluster in enumerate(unique_clusters):
        mask = gcn_clusters == cluster
        ax.scatter(spatial_coords[mask, 0], spatial_coords[mask, 1],
                   c=plot_colors[i % len(plot_colors)], label=str(cluster),
                   s=POINT_SIZE_SPATIAL, alpha=0.8)
    ax.set_title(f'GCN (Original)\nARI = {gcn_ari:.3f}', 
                 fontweight='bold', fontsize=12)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=False, fontsize=8)
    ax.axis('equal')
    ax.axis('off')
    
    # Plot 3: GT+PE Results
    ax = axes[1, 0]
    unique_clusters = np.unique(gt_pe_clusters)
    for i, cluster in enumerate(unique_clusters):
        mask = gt_pe_clusters == cluster
        ax.scatter(spatial_coords[mask, 0], spatial_coords[mask, 1],
                   c=plot_colors[i % len(plot_colors)], label=str(cluster),
                   s=POINT_SIZE_SPATIAL, alpha=0.8)
    ax.set_title(f'Graph Transformer + PE\nARI = {gt_pe_ari:.3f}', 
                 fontweight='bold', fontsize=12)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=False, fontsize=8)
    ax.axis('equal')
    ax.axis('off')
    
    # Plot 4: Performance Comparison
    ax = axes[1, 1]
    methods = ['GCN', 'GT+PE']
    aris = [gcn_ari, gt_pe_ari]
    colors_bar = ['#3cb44b' if gcn_ari < gt_pe_ari else '#e6194b',
                  '#e6194b' if gt_pe_ari < gcn_ari else '#3cb44b']
    bars = ax.bar(methods, aris, color=colors_bar, alpha=0.7, 
                   edgecolor='black', linewidth=2)
    ax.set_ylabel('Adjusted Rand Index (ARI)', fontweight='bold')
    ax.set_title('Performance Comparison', fontweight='bold', fontsize=12)
    ax.set_ylim([0, 1])
    ax.grid(axis='y', alpha=0.3)
    
    for bar, ari in zip(bars, aris):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{ari:.3f}', ha='center', va='bottom', 
                fontweight='bold', fontsize=12)
    
    plt.suptitle(f'{dataset_name} - Spatial Clustering Comparison', 
                 fontsize=14, fontweight='bold', y=0.995)
    plt.tight_layout()
    
    return fig


print("‚úì Helper functions defined")

---
## 4. Dataset Loading Functions

In [None]:
def load_mousebrain_data():
    """
    Load D1: Mouse Brain ATAC + RNA dataset.
    """
    print("\nLoading D1: Mouse Brain (ATAC + RNA)...")
    
    data_mat = h5py.File('./datasets/ATAC_RNA_Seq_MouseBrain_RNA_ATAC.h5', 'r')
    df_data_RNA = np.array(data_mat['X_RNA']).astype('float64')
    df_data_ATAC = np.array(data_mat['X_ATAC']).astype('float64')
    loc = np.array(data_mat['Pos']).astype('float64')
    LayerName = [item.decode("utf-8") for item in list(data_mat['LayerName'])]
    
    adata1 = sc.AnnData(df_data_RNA, dtype="float64")
    adata1.obsm['spatial'] = np.array(loc)
    adata1.obs['LayerName'] = LayerName
    adata1.obs['x_pos'] = np.array(loc)[:, 0]
    adata1.obs['y_pos'] = np.array(loc)[:, 1]
    
    adata2 = sc.AnnData(df_data_ATAC, dtype="float64")
    adata2.obsm['spatial'] = np.array(loc)
    adata2.obs['LayerName'] = LayerName
    adata2.obs['x_pos'] = np.array(loc)[:, 0]
    adata2.obs['y_pos'] = np.array(loc)[:, 1]
    
    print(f"  - Cells: {adata1.shape[0]}")
    print(f"  - RNA features: {adata1.shape[1]}")
    print(f"  - ATAC features: {adata2.shape[1]}")
    print(f"  - Layers: {len(np.unique(LayerName))}")
    
    return adata1, adata2, np.array(LayerName), loc, 'MouseBrain'


def load_visualcortex_data():
    """
    Load D2: Mouse Visual Cortex simulation dataset from CSV files.
    Generates two views (adata1, adata2) via synthetic noise and shuffling.
    """
    print("\nLoading D2: Mouse Visual Cortex (Simulation)...")
    
    # 1. Load Data from CSVs
    df_data = pd.read_csv('./datasets/MVC_counts.csv', sep=",", header=0, na_filter=False, index_col=0) 
    df_meta = pd.read_csv('./datasets/MVC_meta.csv', sep=",", header=0, na_filter=False, index_col=0) 
    
    # 2. Process Metadata and Spatial info
    df_pixels = df_meta.iloc[:, 2:4]
    df_labels = list(df_meta.iloc[:, 1])
    
    # Create base AnnData
    adata = sc.AnnData(X=df_data)
    adata.obs['LayerName'] = df_labels
    adata.obs['LayerName_2'] = list(df_meta.iloc[:, 4])

    # Spatial positions
    adata.obsm['spatial'] = np.array(df_pixels)
    adata.obs['x_pos'] = adata.obsm['spatial'][:, 0]
    adata.obs['y_pos'] = adata.obsm['spatial'][:, 1]
    
    # 3. Define Indices for Shuffling
    label_type = ['L1', 'L2/3', 'L4', 'L5', 'L6', 'HPC/CC']
    
    # Find indices for each label type
    index_all = [np.array([i for i in range(len(df_labels)) if df_labels[i] == label_type[0]])]
    for k in range(1, len(label_type)):
        temp_idx = np.array([i for i in range(len(df_labels)) if df_labels[i] == label_type[k]])
        index_all.append(temp_idx)
    
    # Define shuffling groups (L4/L5 and L5/L6)
    index_int1 = np.array(list(index_all[2]) + list(index_all[3]))
    index_int2 = np.array(list(index_all[4]) + list(index_all[3]))
    
    # 4. Generate adata1 (Simulated RNA) and adata2 (Simulated Protein)
    
    # adata1: Adding Gaussian noise + Shuffling L4/L5
    adata1 = adata.copy()
    np.random.seed(RANDOM_SEED)
    data_noise_1 = 1 + np.random.normal(0, 0.05, adata.shape)
    adata1.X[index_int1, :] = np.multiply(adata.X, data_noise_1)[np.random.permutation(index_int1), :]
    
    # adata2: Adding Gaussian noise + Shuffling L5/L6
    adata2 = adata.copy()
    np.random.seed(RANDOM_SEED + 1)
    data_noise_2 = 1 + np.random.normal(0, 0.05, adata.shape)
    adata2.X[index_int2, :] = np.multiply(adata.X, data_noise_2)[np.random.permutation(index_int2), :]
    
    # 5. Format return variables to match original function signature
    loc = np.array(adata1.obsm['spatial'])
    LayerName = np.array(adata1.obs['LayerName'])
    
    print(f"  - Cells: {adata1.shape[0]}")
    print(f"  - RNA features: {adata1.shape[1]}")
    print(f"  - Protein features: {adata2.shape[1]}")
    print(f"  - Layers: {len(np.unique(LayerName))}")
    
    return adata1, adata2, LayerName, loc, 'VisualCortex'

def load_mob_data():
    """
    Load D3: Mouse Olfactory Bulb dataset.
    Splits single modality into two based on gene variance.
    """
    print("\nLoading D3: Mouse Olfactory Bulb...")
    
    counts = pd.read_csv('./datasets/MOB_counts.csv', index_col=0)
    metadata = pd.read_csv('./datasets/MOB_meta.csv', index_col=0)
    
    spatial_coords = metadata[['x_pos', 'y_pos']].values
    layer_annotations = metadata['Layertype'].values
    
    # Create AnnData for preprocessing
    adata_full = sc.AnnData(counts.values)
    adata_full.obs['Layertype'] = layer_annotations
    adata_full.obsm['spatial'] = spatial_coords
    
    # Preprocess
    sc.pp.normalize_total(adata_full, target_sum=1e4)
    sc.pp.log1p(adata_full)
    
    # Split by variance
    gene_var = np.var(adata_full.X, axis=0)
    if hasattr(gene_var, 'A1'):
        gene_var = gene_var.A1
    
    var_threshold = np.median(gene_var)
    high_var_genes = gene_var >= var_threshold
    low_var_genes = ~high_var_genes
    
    # Create two modalities
    if hasattr(adata_full.X, 'toarray'):
        data1 = adata_full.X.toarray()[:, high_var_genes]
        data2 = adata_full.X.toarray()[:, low_var_genes]
    else:
        data1 = adata_full.X[:, high_var_genes]
        data2 = adata_full.X[:, low_var_genes]
    
    adata1 = sc.AnnData(data1, dtype='float64')
    adata1.obs['Layertype'] = layer_annotations
    adata1.obsm['spatial'] = spatial_coords
    adata1.obs['x_pos'] = spatial_coords[:, 0]
    adata1.obs['y_pos'] = spatial_coords[:, 1]
    
    adata2 = sc.AnnData(data2, dtype='float64')
    adata2.obs['Layertype'] = layer_annotations
    adata2.obsm['spatial'] = spatial_coords
    adata2.obs['x_pos'] = spatial_coords[:, 0]
    adata2.obs['y_pos'] = spatial_coords[:, 1]
    
    print(f"  - Cells: {adata1.shape[0]}")
    print(f"  - Modality 1 (high-var genes): {adata1.shape[1]}")
    print(f"  - Modality 2 (low-var genes): {adata2.shape[1]}")
    print(f"  - Layers: {len(np.unique(layer_annotations))}")
    
    return adata1, adata2, layer_annotations, spatial_coords, 'MOB'


print("‚úì Data loading functions defined")

def load_spatialglue_data():
    """
    Load D4: SpatialGlue simulation dataset with ground truth factors.
    Two modalities: ADT and RNA with spatial factor annotations.
    """
    print("\nLoading D4: SpatialGlue Simulation (ADT + RNA)...")
    
    # Load the h5ad files
    adt_ad = sc.read_h5ad("./datasets/spatialGlue_sim_adata_ADT.h5ad")
    rna_ad = sc.read_h5ad("./datasets/spatialGlue_sim_adata_RNA.h5ad")
    
    # Extract data
    adata1 = adt_ad.copy()  # ADT modality
    adata2 = rna_ad.copy()  # RNA modality
    
    # Add ground truth labels from spatial factors
    # Ground truth is encoded in spfac (spatial factors) in obsm
    adata_genes = rna_ad.copy()  # Use RNA for ground truth computation
    adata_genes.obs['ground_truth'] = (
        1 * np.array(adata_genes.obsm['spfac'][:,0]) + 
        2 * np.array(adata_genes.obsm['spfac'][:,1]) + 
        3 * np.array(adata_genes.obsm['spfac'][:,2]) + 
        4 * np.array(adata_genes.obsm['spfac'][:,3])
    )
    
    # Create annotation labels
    adata_genes.obs['annotation'] = adata_genes.obs['ground_truth']
    adata_genes.obs['annotation'] = adata_genes.obs['annotation'].replace({
        1.0: 'factor1',
        2.0: 'factor2',
        3.0: 'factor3',
        4.0: 'factor4',
        0.0: 'backgr'  # background
    })
    
    # Extract spatial coordinates and labels
    spatial_coords = np.array(adata1.obsm['spatial'])
    LayerName = adata_genes.obs['annotation'].values
    
    # Add spatial coordinates to both modalities
    adata1.obs['x_pos'] = spatial_coords[:, 0]
    adata1.obs['y_pos'] = spatial_coords[:, 1]
    adata1.obs['LayerName'] = LayerName
    
    adata2.obs['x_pos'] = spatial_coords[:, 0]
    adata2.obs['y_pos'] = spatial_coords[:, 1]
    adata2.obs['LayerName'] = LayerName
    
    print(f"  - Cells: {adata1.shape[0]}")
    print(f"  - ADT features: {adata1.shape[1]}")
    print(f"  - RNA features: {adata2.shape[1]}")
    print(f"  - Factors: {len(np.unique(LayerName))}")
    print(f"  - Factor distribution: {dict(zip(*np.unique(LayerName, return_counts=True)))}")
    
    return adata1, adata2, LayerName, spatial_coords, 'SpatialGlue_Sim'




def load_spatial_multiomics_multi_simulated():
    """
    Load D5: Spatial Multiomics Multi-Simulated dataset.
    RNA + ADT modalities from a single h5ad file with domain annotations.
    
    Data structure:
    - RNA: adata.X
    - ADT: adata.obsm["ADT"]
    - Spatial: adata.obsm["spatial"]
    - Labels: adata.obs["domain"] or adata.obs["ground_truth"]
    """
    print("\nLoading D5: Spatial Multiomics Multi-Simulated...")
    
    # Load the combined h5ad file
    adt_rna_ad = sc.read_h5ad("./datasets/spatial_multiomics_multi_simulated.h5ad")
    
    print(f"  - Loaded AnnData: {adt_rna_ad.shape}")
    print(f"  - Available keys in obsm: {list(adt_rna_ad.obsm.keys())}")
    print(f"  - Available keys in obs: {list(adt_rna_ad.obs.keys())}")
    
    # Extract RNA modality (main data matrix)
    rna_data = adt_rna_ad.X
    adata_rna = sc.AnnData(rna_data, dtype="float64")
    
    # Extract ADT modality from obsm
    adt_data = adt_rna_ad.obsm["ADT"]
    adata_adt = sc.AnnData(adt_data, dtype="float64")
    
    # Extract spatial coordinates
    spatial_coords = np.array(adt_rna_ad.obsm["spatial"])
    
    # Extract ground truth labels (try 'domain' first, then 'ground_truth')
    if "domain" in adt_rna_ad.obs.columns:
        LayerName = adt_rna_ad.obs["domain"].values
    elif "ground_truth" in adt_rna_ad.obs.columns:
        LayerName = adt_rna_ad.obs["ground_truth"].values
    else:
        raise ValueError("No ground truth labels found. Expected 'domain' or 'ground_truth' in obs.")
    
    # Add spatial coordinates and labels to both modalities
    for adata in [adata_rna, adata_adt]:
        adata.obsm['spatial'] = spatial_coords
        adata.obs['x_pos'] = spatial_coords[:, 0]
        adata.obs['y_pos'] = spatial_coords[:, 1]
        adata.obs['LayerName'] = LayerName
    
    print(f"  - Cells: {adata_rna.shape[0]}")
    print(f"  - RNA features: {adata_rna.shape[1]}")
    print(f"  - ADT features: {adata_adt.shape[1]}")
    print(f"  - Spatial domains: {len(np.unique(LayerName))}")
    print(f"  - Domain distribution: {dict(zip(*np.unique(LayerName, return_counts=True)))}")
    
    # Return: RNA as adata1, ADT as adata2
    return adata_rna, adata_adt, LayerName, spatial_coords, 'SpatialMultiomics_MultiSim'


print("‚úì All dataset loading functions defined")


---
## 5. Main Analysis Function

In [None]:
def analyze_dataset(dataset_id, gt_pe_config):
    """
    Complete analysis pipeline for a single dataset.
    
    Parameters:
    - dataset_id: 'D1', 'D2', 'D3', 'D4' or 'D5'
    - gt_pe_config: Dictionary with GT+PE hyperparameters
    
    Returns:
    - Dictionary with all results
    """
    print("\n" + "="*70)
    print(f"ANALYZING DATASET {dataset_id}")
    print("="*70)
    
    # Load data
    if dataset_id == 'D1':
        adata1, adata2, true_labels, spatial_coords, dataset_name = load_mousebrain_data()
    elif dataset_id == 'D2':
        adata1, adata2, true_labels, spatial_coords, dataset_name = load_visualcortex_data()
    elif dataset_id == 'D3':
        adata1, adata2, true_labels, spatial_coords, dataset_name = load_mob_data()
    elif dataset_id == 'D4':
        adata1, adata2, true_labels, spatial_coords, dataset_name = load_spatialglue_data()
    elif dataset_id == 'D5':
        adata1, adata2, true_labels, spatial_coords, dataset_name = load_spatial_multiomics_multi_simulated()
    else:
        raise ValueError(f"Unknown dataset: {dataset_id}")
    
    results = {
        'dataset_id': dataset_id,
        'dataset_name': dataset_name,
        'n_cells': adata1.shape[0],
        'n_features_mod1': adata1.shape[1],
        'n_features_mod2': adata2.shape[1],
        'true_labels': true_labels,
        'spatial_coords': spatial_coords
    }
    
    # ============================================================
    # TRAIN GCN (Original COSMOS)
    # ============================================================
    print("\n" + "-"*70)
    print("TRAINING: GCN (Original COSMOS)")
    print("-"*70)
    
    cosmos_gcn = Cosmos_GCN(adata1=adata1, adata2=adata2, save_inter_emb=True, save_fin_emb=True)
    cosmos_gcn.preprocessing_data(n_neighbors=N_NEIGHBORS)
    
    embedding_gcn = cosmos_gcn.train(
        spatial_regularization_strength=SPATIAL_REG_STRENGTH,
        z_dim=GCN_Z_DIM,
        lr=LEARNING_RATE,
        wnn_epoch=WNN_EPOCH,
        total_epoch=TOTAL_EPOCHS,
        max_patience_bef=MAX_PATIENCE_BEF,
        max_patience_aft=MAX_PATIENCE_AFT,
        min_stop=MIN_STOP,
        random_seed=RANDOM_SEED,
        gpu=GPU_ID,
        regularization_acceleration=REGULARIZATION_ACCELERATION,
        edge_subset_sz=EDGE_SUBSET_SIZE
    )
    
    weights_gcn = cosmos_gcn.weights
    df_embedding_gcn = pd.DataFrame(embedding_gcn)
    
    # Clustering with optimal resolution (auto-tuned)
    embedding_adata_gcn = sc.AnnData(df_embedding_gcn)
    sc.pp.neighbors(embedding_adata_gcn, n_neighbors=NEIGHBORS_FOR_CLUSTERING, use_rep='X')
    
    gcn_clusters, gcn_res, gcn_ari = find_optimal_clustering(
        embedding_adata_gcn, true_labels, CLUSTERING_RESOLUTIONS
    )
    
    print(f"\n‚úì GCN training complete")
    print(f"  - Optimal resolution: {gcn_res}")
    print(f"  - Best ARI: {gcn_ari:.4f}")
    
    # Compute all metrics
    gcn_metrics = compute_all_metrics(true_labels, gcn_clusters, embedding_gcn)
    gcn_metrics['Spatial_Coherence'] = compute_spatial_coherence(
        spatial_coords, gcn_clusters, k=10
    )
    gcn_metrics['Optimal_Resolution'] = gcn_res
    
    print_metrics_table(gcn_metrics, "GCN")
    
    results['gcn'] = {
        'embedding': embedding_gcn,
        'weights': weights_gcn,
        'clusters': gcn_clusters,
        'metrics': gcn_metrics
    }

    cosmos_gcn.save_model()
    
    # ============================================================
    # TRAIN GRAPH TRANSFORMER + PE
    # ============================================================
    print("\n" + "-"*70)
    print("TRAINING: Graph Transformer + PE")
    print("-"*70)
    print(f"Configuration:")
    for key, value in gt_pe_config.items():
        print(f"  - {key}: {value}")
    
    cosmos_gt_pe = Cosmos_GT_PE(adata1=adata1, adata2=adata2, save_inter_emb=True, save_fin_emb=True)
    cosmos_gt_pe.preprocessing_data(n_neighbors=N_NEIGHBORS)
    
    embedding_gt_pe = cosmos_gt_pe.train(
        spatial_regularization_strength=SPATIAL_REG_STRENGTH,
        z_dim=gt_pe_config['z_dim'],
        lr=LEARNING_RATE,
        wnn_epoch=WNN_EPOCH,
        total_epoch=TOTAL_EPOCHS,
        max_patience_bef=MAX_PATIENCE_BEF,
        max_patience_aft=MAX_PATIENCE_AFT,
        min_stop=MIN_STOP,
        random_seed=RANDOM_SEED,
        gpu=GPU_ID,
        regularization_acceleration=REGULARIZATION_ACCELERATION,
        edge_subset_sz=EDGE_SUBSET_SIZE,
        # Graph Transformer parameters
        num_heads=gt_pe_config['num_heads'],
        dropout=gt_pe_config['dropout'],
        pe_dim=gt_pe_config['pe_dim'],
        use_pe=gt_pe_config['use_pe']
    )
    
    weights_gt_pe = cosmos_gt_pe.weights
    df_embedding_gt_pe = pd.DataFrame(embedding_gt_pe)
    
    # Clustering with optimal resolution (auto-tuned)
    embedding_adata_gt_pe = sc.AnnData(df_embedding_gt_pe)
    sc.pp.neighbors(embedding_adata_gt_pe, n_neighbors=NEIGHBORS_FOR_CLUSTERING, use_rep='X')
    
    gt_pe_clusters, gt_pe_res, gt_pe_ari = find_optimal_clustering(
        embedding_adata_gt_pe, true_labels, CLUSTERING_RESOLUTIONS
    )
    
    print(f"\n‚úì GT+PE training complete")
    print(f"  - Optimal resolution: {gt_pe_res}")
    print(f"  - Best ARI: {gt_pe_ari:.4f}")
    
    # Compute all metrics
    gt_pe_metrics = compute_all_metrics(true_labels, gt_pe_clusters, embedding_gt_pe)
    gt_pe_metrics['Spatial_Coherence'] = compute_spatial_coherence(
        spatial_coords, gt_pe_clusters, k=10
    )
    gt_pe_metrics['Optimal_Resolution'] = gt_pe_res
    
    print_metrics_table(gt_pe_metrics, "Graph Transformer + PE")
    
    results['gt_pe'] = {
        'embedding': embedding_gt_pe,
        'weights': weights_gt_pe,
        'clusters': gt_pe_clusters,
        'metrics': gt_pe_metrics,
        'config': gt_pe_config
    }

    cosmos_gt_pe.save_model()
    
    # ============================================================
    # COMPARISON AND VISUALIZATION
    # ============================================================
    print("\n" + "-"*70)
    print("GENERATING VISUALIZATIONS")
    print("-"*70)
    
    # Spatial comparison
    fig = create_comparison_plot(
        adata1, spatial_coords, true_labels,
        gcn_clusters, gt_pe_clusters,
        gcn_ari, gt_pe_ari, dataset_name
    )
    fig.savefig(f'./outputs/{dataset_id}_{dataset_name}_spatial_comparison.png', 
                dpi=FIGURE_DPI, bbox_inches='tight')
    plt.close(fig)
    print(f"  ‚úì Saved: {dataset_id}_{dataset_name}_spatial_comparison.png")
    
    # UMAP comparison
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # UMAP for GCN
    umap_gcn = UMAP(n_components=2, init='random', random_state=RANDOM_SEED,
                    min_dist=UMAP_MIN_DIST, n_neighbors=UMAP_N_NEIGHBORS)
    umap_pos_gcn = umap_gcn.fit_transform(df_embedding_gcn)
    
    ax = axes[0]
    plot_colors = ['#D1D1D1', '#e6194b', '#3cb44b', '#ffe119', '#4363d8', 
                   '#f58231', '#911eb4', '#46f0f0', '#f032e6', '#bcf60c']
    for i, cluster in enumerate(np.unique(gcn_clusters)):
        mask = gcn_clusters == cluster
        ax.scatter(umap_pos_gcn[mask, 0], umap_pos_gcn[mask, 1],
                   c=plot_colors[i % len(plot_colors)], label=str(cluster),
                   s=POINT_SIZE_UMAP, alpha=0.6)
    ax.set_title(f'UMAP: GCN (ARI = {gcn_ari:.3f})', fontweight='bold')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=False, fontsize=8)
    ax.axis('off')
    
    # UMAP for GT+PE
    umap_gt_pe = UMAP(n_components=2, init='random', random_state=RANDOM_SEED,
                      min_dist=UMAP_MIN_DIST, n_neighbors=UMAP_N_NEIGHBORS)
    umap_pos_gt_pe = umap_gt_pe.fit_transform(df_embedding_gt_pe)
    
    ax = axes[1]
    for i, cluster in enumerate(np.unique(gt_pe_clusters)):
        mask = gt_pe_clusters == cluster
        ax.scatter(umap_pos_gt_pe[mask, 0], umap_pos_gt_pe[mask, 1],
                   c=plot_colors[i % len(plot_colors)], label=str(cluster),
                   s=POINT_SIZE_UMAP, alpha=0.6)
    ax.set_title(f'UMAP: GT+PE (ARI = {gt_pe_ari:.3f})', fontweight='bold')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=False, fontsize=8)
    ax.axis('off')
    
    plt.suptitle(f'{dataset_name} - UMAP Comparison', fontsize=12, fontweight='bold')
    plt.tight_layout()
    fig.savefig(f'./outputs/{dataset_id}_{dataset_name}_UMAP_comparison.png',
                dpi=FIGURE_DPI, bbox_inches='tight')
    plt.close(fig)
    print(f"  ‚úì Saved: {dataset_id}_{dataset_name}_UMAP_comparison.png")
    
    # ============================================================
    # FINAL SUMMARY
    # ============================================================
    print("\n" + "="*70)
    print(f"FINAL RESULTS: {dataset_name}")
    print("="*70)
    
    print(f"\n{'Metric':<25} {'GCN':>12} {'GT+PE':>12} {'Improvement':>12}")
    print("-"*70)
    
    for metric in ['ARI', 'NMI', 'Silhouette', 'Spatial_Coherence']:
        gcn_val = gcn_metrics[metric]
        gt_pe_val = gt_pe_metrics[metric]
        
        if isinstance(gcn_val, float) and not np.isnan(gcn_val):
            improvement = ((gt_pe_val - gcn_val) / gcn_val * 100) if gcn_val != 0 else 0
            print(f"{metric:<25} {gcn_val:>12.4f} {gt_pe_val:>12.4f} {improvement:>11.1f}%")
    
    print("="*70)
    
    if gt_pe_ari > gcn_ari:
        improvement = ((gt_pe_ari - gcn_ari) / gcn_ari * 100)
        print(f"\n‚úì WINNER: Graph Transformer + PE (+{improvement:.1f}% ARI improvement)")
        print(f"  ‚Üí This dataset benefits from global attention and topology awareness")
    elif gcn_ari > gt_pe_ari:
        improvement = ((gcn_ari - gt_pe_ari) / gt_pe_ari * 100)
        print(f"\n‚úì WINNER: GCN (+{improvement:.1f}% ARI improvement)")
        print(f"  ‚Üí This dataset has strong local/layered structure")
    else:
        print(f"\n‚âà TIE: Both methods perform similarly")
    
    print("="*70 + "\n")
    
    return results


print("‚úì Main analysis function defined")

---
## 6. Run Analysis on Selected Dataset(s)

In [None]:
# Store all results
all_results = {}

if DATASET_TO_RUN == 'ALL':
    datasets_to_analyze = ['D1', 'D2', 'D3', 'D4', 'D5']
else:
    datasets_to_analyze = [DATASET_TO_RUN]

for dataset_id in datasets_to_analyze:
    # Select appropriate configuration
    if dataset_id == 'D1':
        gt_pe_config = GT_PE_D1_CONFIG
    elif dataset_id == 'D2':
        gt_pe_config = GT_PE_D2_CONFIG
    elif dataset_id == 'D3':
        gt_pe_config = GT_PE_D3_CONFIG
    elif dataset_id == 'D4':
        gt_pe_config = GT_PE_D4_CONFIG
    elif dataset_id == 'D5':
        gt_pe_config = GT_PE_D5_CONFIG
    
    # Run analysis
    try:
        results = analyze_dataset(dataset_id, gt_pe_config)
        all_results[dataset_id] = results
    except Exception as e:
        print(f"\n‚ùå Error analyzing {dataset_id}: {str(e)}")
        import traceback
        traceback.print_exc()

print("\n" + "="*70)
print("ALL ANALYSES COMPLETE")
print("="*70)

---
## 7. Cross-Dataset Comparison Summary

In [None]:
if len(all_results) > 1:
    print("\n" + "="*80)
    print("CROSS-DATASET COMPARISON SUMMARY")
    print("="*80)
    
    # Create summary table
    summary_data = []
    
    for dataset_id, results in all_results.items():
        gcn_ari = results['gcn']['metrics']['ARI']
        gt_pe_ari = results['gt_pe']['metrics']['ARI']
        improvement = ((gt_pe_ari - gcn_ari) / gcn_ari * 100) if gcn_ari != 0 else 0
        winner = 'GT+PE' if gt_pe_ari > gcn_ari else 'GCN' if gcn_ari > gt_pe_ari else 'Tie'
        
        summary_data.append({
            'Dataset': results['dataset_name'],
            'Cells': results['n_cells'],
            'GCN_ARI': gcn_ari,
            'GT_PE_ARI': gt_pe_ari,
            'Improvement_%': improvement,
            'Winner': winner
        })
    
    df_summary = pd.DataFrame(summary_data)
    print("\n" + df_summary.to_string(index=False))
    
    # Save summary
    df_summary.to_csv('./outputs/cross_dataset_summary.csv', index=False)
    print("\n‚úì Summary saved to: cross_dataset_summary.csv")
    
    # Overall statistics
    gt_pe_wins = (df_summary['Winner'] == 'GT+PE').sum()
    gcn_wins = (df_summary['Winner'] == 'GCN').sum()
    
    print("\n" + "-"*80)
    print("OVERALL STATISTICS:")
    print(f"  - GT+PE wins: {gt_pe_wins}/{len(all_results)} datasets")
    print(f"  - GCN wins: {gcn_wins}/{len(all_results)} datasets")
    print(f"  - Average improvement (GT+PE over GCN): {df_summary['Improvement_%'].mean():.1f}%")
    print("="*80)
    
    # Visualization: Bar plot
    fig, ax = plt.subplots(figsize=(10, 6))
    x = np.arange(len(df_summary))
    width = 0.35
    
    bars1 = ax.bar(x - width/2, df_summary['GCN_ARI'], width, label='GCN', alpha=0.8)
    bars2 = ax.bar(x + width/2, df_summary['GT_PE_ARI'], width, label='GT+PE', alpha=0.8)
    
    ax.set_xlabel('Dataset', fontweight='bold')
    ax.set_ylabel('ARI', fontweight='bold')
    ax.set_title('Cross-Dataset Performance Comparison', fontweight='bold', fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(df_summary['Dataset'], rotation=45, ha='right')
    ax.legend()
    ax.grid(axis='y', alpha=0.3)
    ax.set_ylim([0, 1])
    
    plt.tight_layout()
    plt.savefig("./outputs" / 'cross_dataset_comparison.png', dpi=FIGURE_DPI, bbox_inches='tight')
    plt.show()
    
    print("\n‚úì Comparison plot saved to: cross_dataset_comparison.png")

else:
    print("\n‚úì Single dataset analysis complete")
    print(f"  Results for {DATASET_TO_RUN} saved")

---
## 8. Export Results to CSV

In [None]:
# Export detailed metrics for each dataset
for dataset_id, results in all_results.items():
    # Create detailed results dataframe
    detailed_results = []
    
    for method in ['gcn', 'gt_pe']:
        method_name = 'GCN' if method == 'gcn' else 'Graph Transformer + PE'
        metrics = results[method]['metrics']
        
        row = {
            'Dataset': results['dataset_name'],
            'Method': method_name,
            'Cells': results['n_cells'],
            'Features_Mod1': results['n_features_mod1'],
            'Features_Mod2': results['n_features_mod2']
        }
        
        # Add all metrics
        row.update(metrics)
        
        # Add hyperparameters for GT+PE
        if method == 'gt_pe':
            config = results[method]['config']
            row.update({
                f'Hyperparam_{k}': v for k, v in config.items()
            })
        
        detailed_results.append(row)
    
    df_detailed = pd.DataFrame(detailed_results)
    filename = f'./outputs/{dataset_id}_{results["dataset_name"]}_detailed_metrics.csv'
    df_detailed.to_csv(filename, index=False)
    print(f"‚úì Saved: {filename}")

print("\n" + "="*70)
print("ALL RESULTS EXPORTED")
print("="*70)
print("\nGenerated files:")
print("  - Spatial comparison PNGs (one per dataset)")
print("  - UMAP comparison PNGs (one per dataset)")
print("  - Detailed metrics CSVs (one per dataset)")
if len(all_results) > 1:
    print("  - cross_dataset_summary.csv")
    print("  - cross_dataset_comparison.png")
print("\n‚úì Analysis pipeline complete!")