# Transformer Mapping between GAT Embeddings of RNA and ADT

This notebook learns a mapping between GAT embeddings from RNA data and GAT embeddings from ADT data using a Transformer Encoder architecture.

In [None]:
# Memory optimization and system check
import torch
import os

# Set memory management environment variables
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Check system resources
print("=== System Resources ===")
if torch.cuda.is_available():
    device = torch.cuda.current_device()
    gpu_props = torch.cuda.get_device_properties(device)
    total_memory = gpu_props.total_memory / (1024**3)  # Convert to GB
    
    print(f"GPU: {gpu_props.name}")
    print(f"Total GPU Memory: {total_memory:.1f} GB")
    print(f"GPU Compute Capability: {gpu_props.major}.{gpu_props.minor}")
    
    # Clear any cached memory
    torch.cuda.empty_cache()
    
    # Check current memory usage
    allocated = torch.cuda.memory_allocated(device) / (1024**3)
    reserved = torch.cuda.memory_reserved(device) / (1024**3)
    
    print(f"Currently allocated: {allocated:.2f} GB")
    print(f"Currently reserved: {reserved:.2f} GB")
    print(f"Available: {total_memory - reserved:.2f} GB")
    
    # Set recommendations based on available memory
    if total_memory < 8:
        print("\n⚠️  WARNING: Low GPU memory detected!")
        print("Recommendations:")
        print("- Use CPU fallback if needed")
        print("- Reduce batch sizes")
        print("- Use graph sparsification")
    elif total_memory < 16:
        print("\n💡 Moderate GPU memory - will use optimized settings")
    else:
        print("\n✅ Sufficient GPU memory available")
        
else:
    print("CUDA not available - will use CPU")
    print("Note: Training will be slower but should work with larger graphs")

print("=" * 50)

In [None]:
%load_ext autoreload
%autoreload 2

# Set environment variables for better memory management
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr, spearmanr
import pandas as pd

import scanpy as sc
import scanpy.external as sce
from scipy import sparse

from DeepOMAPNet.Preprocess import prepare_train_test_anndata

# Set memory management
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"CUDA available: {torch.cuda.is_available()}")
    print(f"GPU count: {torch.cuda.device_count()}")
    print(f"Current GPU: {torch.cuda.current_device()}")
    print(f"GPU name: {torch.cuda.get_device_name()}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.1f} GB")
else:
    print("CUDA not available, using CPU")

## 1. Load and Prepare Data

In [None]:
# Load the preprocessed data
data = prepare_train_test_anndata()
trainGene = data[0]  # RNA data
trainADT = data[2]   # ADT data

print(f"RNA data shape: {trainGene.shape}")
print(f"ADT data shape: {trainADT.shape}")

## 2. Preprocess RNA Data

In [None]:
# RNA preprocessing
sc.pp.normalize_total(trainGene, target_sum=1e4)
sc.pp.log1p(trainGene)
sc.pp.highly_variable_genes(trainGene, n_top_genes=2000, batch_key="samples")
trainGene = trainGene[:, trainGene.var.highly_variable].copy()

sc.pp.scale(trainGene, max_value=10)
sc.tl.pca(trainGene, n_comps=50, svd_solver="arpack")

# Build neighbor graph for RNA
sc.pp.neighbors(trainGene, n_neighbors=15, n_pcs=50)
sc.tl.leiden(trainGene, resolution=1.0)

print(f"RNA data after preprocessing: {trainGene.shape}")
print(f"Number of RNA clusters: {trainGene.obs['leiden'].nunique()}")

## 3. Preprocess ADT Data

## 3. Centered Log-Ratio (CLR) Normalization for ADT Data

Before we apply standard preprocessing steps, we'll perform Centered Log-Ratio (CLR) normalization on the ADT data. CLR normalization is particularly suited for ADT/CITE-seq data because:

1. It handles the compositional nature of the data
2. It preserves relative differences between markers
3. It reduces technical noise while maintaining biological signal

The CLR transformation is defined as:

$$\text{CLR}(x) = \log(x) - \frac{1}{D}\sum_{i=1}^{D}\log(x_i)$$

Where $D$ is the number of features (ADT markers).

In [None]:
import scipy
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def clr_normalize(adata, axis=1, pseudo_count=1):
    """
    Apply centered log-ratio normalization to the data.
    
    Parameters:
    -----------
    adata : AnnData
        AnnData object with raw counts
    axis : int, default=1
        0 = normalize features (columns), 1 = normalize cells (rows)
    pseudo_count : float, default=1
        Value to add to counts before log transform to avoid log(0)
        
    Returns:
    --------
    AnnData with CLR-normalized values in .X
    """
    print("Applying CLR normalization to ADT data...")
    
    # Make a copy to avoid modifying the original
    adata_clr = adata.copy()
    
    # Get raw counts (densify if sparse)
    X = adata_clr.X.toarray() if scipy.sparse.issparse(adata_clr.X) else adata_clr.X.copy()
    
    # Add pseudo count
    X += pseudo_count
    
    # Calculate geometric mean of each cell (row) or feature (column)
    if axis == 1:  # across features (for each cell)
        # Get geometric mean for each cell
        geometric_means = np.exp(np.mean(np.log(X), axis=1, keepdims=True))
        # CLR transformation
        X_clr = np.log(X / geometric_means)
    else:  # across cells (for each feature)
        # Get geometric mean for each feature
        geometric_means = np.exp(np.mean(np.log(X), axis=0, keepdims=True))
        # CLR transformation
        X_clr = np.log(X / geometric_means)
    
    # Update data
    adata_clr.X = X_clr
    
    # Store original data in raw slot
    adata_clr.raw = adata
    
    print(f"CLR normalization complete. Shape: {adata_clr.X.shape}")
    return adata_clr

# Apply CLR normalization to ADT data
trainADT_clr = clr_normalize(trainADT)

# Basic quality check - visualize distribution before and after normalization
fig, ax = plt.subplots(1, 2, figsize=(14, 5))

# Original data distribution
if scipy.sparse.issparse(trainADT.X):
    sample_values = trainADT.X.data[:10000] if len(trainADT.X.data) > 10000 else trainADT.X.data
else:
    sample_values = trainADT.X.flatten()[:10000]
    
sns.histplot(sample_values, bins=50, kde=True, ax=ax[0])
ax[0].set_title("Original ADT Values")
ax[0].set_xlabel("Value")

# CLR-normalized data distribution
sample_values_clr = trainADT_clr.X.flatten()[:10000]
sns.histplot(sample_values_clr, bins=50, kde=True, ax=ax[1])
ax[1].set_title("CLR-Normalized ADT Values")
ax[1].set_xlabel("Value")

plt.tight_layout()
plt.show()

# Replace original ADT data with CLR-normalized data for further processing
trainADT = trainADT_clr

print("ADT data now uses CLR normalization")
print(f"ADT data shape: {trainADT.shape}")

In [None]:
## 4. Additional ADT Preprocessing

In [None]:
# Additional preprocessing for ADT data
print("Computing PCA and neighbor graph for ADT data...")

# Check if PCA has been computed already
if 'X_pca' not in trainADT.obsm:
    print("Computing PCA for ADT data...")
    sc.tl.pca(trainADT, n_comps=50, svd_solver="arpack")

# Compute neighbors for ADT data
print("Computing neighbor graph for ADT data...")
sc.pp.neighbors(trainADT, n_neighbors=15, n_pcs=50)

# Run leiden clustering if not already run
if 'leiden' not in trainADT.obs:
    print("Running leiden clustering on ADT data...")
    sc.tl.leiden(trainADT, resolution=1.0)

# Verify neighbor graph was computed successfully
if 'connectivities' in trainADT.obsp:
    print(f"Neighbor graph computed successfully")
    print(f"Graph size: {trainADT.shape[0]} nodes, {trainADT.obsp['connectivities'].nnz} edges")
    print(f"Average degree: {trainADT.obsp['connectivities'].nnz / trainADT.shape[0]:.1f}")
else:
    print("ERROR: Failed to compute neighbor graph")

# You can also run UMAP for visualization
if 'X_umap' not in trainADT.obsm:
    print("Computing UMAP embedding for ADT data...")
    sc.tl.umap(trainADT)
    
print("ADT preprocessing complete.")

## 5. Build PyTorch Geometric Data Objects

In [None]:
def sparsify_graph(adata, max_edges_per_node=50):
    """Sparsify the graph by keeping only top k neighbors per node"""
    
    # Check if connectivities exists
    if "connectivities" not in adata.obsp:
        print("No connectivity graph found. Computing neighbors first...")
        sc.pp.neighbors(adata, n_neighbors=15, n_pcs=50)
    
    A = adata.obsp["connectivities"].tocsr()
    n_nodes = A.shape[0]
    
    # Check if sparsification is needed
    avg_degree = A.nnz / n_nodes
    if avg_degree <= max_edges_per_node:
        print(f"Graph already sparse enough (avg degree: {avg_degree:.1f})")
        return adata
    
    print(f"Sparsifying graph from avg degree {avg_degree:.1f} to max {max_edges_per_node}")
    
    # Create new sparse matrix
    row_indices = []
    col_indices = []
    data_values = []
    
    for i in range(n_nodes):
        # Get neighbors and their weights for node i
        start_idx = A.indptr[i]
        end_idx = A.indptr[i + 1]
        neighbors = A.indices[start_idx:end_idx]
        weights = A.data[start_idx:end_idx]
        
        # Keep only top k neighbors
        if len(neighbors) > max_edges_per_node:
            top_k_indices = np.argpartition(weights, -max_edges_per_node)[-max_edges_per_node:]
            neighbors = neighbors[top_k_indices]
            weights = weights[top_k_indices]
        
        # Add edges
        row_indices.extend([i] * len(neighbors))
        col_indices.extend(neighbors)
        data_values.extend(weights)
    
    # Create new adjacency matrix
    A_sparse = sparse.csr_matrix(
        (data_values, (row_indices, col_indices)), 
        shape=(n_nodes, n_nodes)
    )
    
    # Make symmetric
    A_sparse = (A_sparse + A_sparse.T) / 2
    
    # Update the AnnData object
    adata.obsp["connectivities"] = A_sparse
    
    new_avg_degree = A_sparse.nnz / n_nodes
    print(f"New average degree: {new_avg_degree:.1f}")
    
    return adata

def build_pyg_data(adata, use_pca=True, sparsify_large_graphs=True, max_edges_per_node=50):
    """Build PyTorch Geometric Data object from AnnData"""
    
    # Ensure PCA and neighbor graph are computed
    if use_pca and "X_pca" not in adata.obsm:
        print("Computing PCA first...")
        sc.tl.pca(adata, n_comps=50, svd_solver="arpack")
    
    if "connectivities" not in adata.obsp:
        print("Computing neighbor graph first...")
        sc.pp.neighbors(adata, n_neighbors=15, n_pcs=50 if use_pca else None)
    
    # Sparsify if needed
    if sparsify_large_graphs:
        A = adata.obsp["connectivities"]
        avg_degree = A.nnz / A.shape[0]
        if avg_degree > max_edges_per_node:
            print(f"Large graph detected (avg degree: {avg_degree:.1f}), applying sparsification...")
            adata = sparsify_graph(adata, max_edges_per_node)
    
    # Features
    if use_pca:
        X = adata.obsm["X_pca"]
    else:
        X = adata.X.toarray() if scipy.sparse.issparse(adata.X) else adata.X.copy()
    
    # Labels (leiden clusters)
    if "leiden" not in adata.obs:
        print("Computing leiden clusters first...")
        sc.tl.leiden(adata, resolution=1.0)
    
    y = adata.obs["leiden"].astype(int).to_numpy()
    
    # Edge index from connectivities
    A = adata.obsp["connectivities"].tocsr()
    A_triu = sparse.triu(A, k=1)
    row, col = A_triu.nonzero()
    edge_index = torch.tensor(np.vstack([row, col]), dtype=torch.long)
    
    # Create PyG Data object
    data = Data(
        x=torch.tensor(X, dtype=torch.float32),
        edge_index=edge_index,
        y=torch.tensor(y, dtype=torch.long),
    )
    
    return data

# Build data objects with memory optimization
print("Building PyG data objects...")

# Check available GPU memory
if torch.cuda.is_available():
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)  # GB
    print(f"Available GPU memory: {gpu_memory:.1f} GB")
    
    # Ensure neighbors are computed first for both datasets
    if "connectivities" not in trainGene.obsp:
        print("Computing neighbors for RNA data...")
        sc.pp.neighbors(trainGene, n_neighbors=15, n_pcs=50)
        
    if "connectivities" not in trainADT.obsp:
        print("Computing neighbors for ADT data...")
        sc.pp.neighbors(trainADT, n_neighbors=15, n_pcs=50)
    
    # Estimate memory requirements
    rna_edges = trainGene.obsp["connectivities"].nnz
    adt_edges = trainADT.obsp["connectivities"].nnz
    
    print(f"RNA graph edges: {rna_edges:,}")
    print(f"ADT graph edges: {adt_edges:,}")
    
    # Set sparsification based on graph size
    max_edges_rna = 100 if rna_edges > 5000000 else 200
    max_edges_adt = 50 if adt_edges > 10000000 else 100
    
    print(f"Using max edges per node - RNA: {max_edges_rna}, ADT: {max_edges_adt}")
else:
    print("Using CPU - no memory constraints")
    max_edges_rna = 200
    max_edges_adt = 100

# Build data objects
rna_data = build_pyg_data(trainGene, use_pca=True, sparsify_large_graphs=True, max_edges_per_node=max_edges_rna)
adt_data = build_pyg_data(trainADT, use_pca=True, sparsify_large_graphs=True, max_edges_per_node=max_edges_adt)

print(f"RNA PyG data - Nodes: {rna_data.num_nodes}, Edges: {rna_data.num_edges}, Features: {rna_data.num_node_features}")
print(f"ADT PyG data - Nodes: {adt_data.num_nodes}, Edges: {adt_data.num_edges}, Features: {adt_data.num_node_features}")

## 6. Define GAT Model

In [None]:
class SimpleGAT(torch.nn.Module):
    """Simplified GAT for memory-constrained scenarios"""
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4, dropout=0.6):
        super().__init__()
        self.dropout = dropout
        
        # Single GAT layer for memory efficiency
        self.conv1 = GATConv(in_channels, out_channels, heads=heads, dropout=dropout, concat=False)
        
    def forward(self, x, edge_index, return_embeddings=False):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv1(x, edge_index)
        
        if return_embeddings:
            return x
        
        return x

    def get_embeddings(self, x, edge_index):
        """Get embeddings for mapping"""
        return self.forward(x, edge_index, return_embeddings=True)

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=8, dropout=0.6):
        super().__init__()
        self.dropout = dropout
        
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout)
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, dropout=dropout)
        
    def forward(self, x, edge_index, return_embeddings=False):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        if return_embeddings:
            return x  # Return embeddings before final layer
            
        x = self.conv2(x, edge_index)
        return x

    def get_embeddings(self, x, edge_index):
        """Get intermediate embeddings for mapping"""
        return self.forward(x, edge_index, return_embeddings=True)

## 7. Train GAT Models

In [None]:
def train_gat_model(data, model_name="GAT", epochs=200, use_cpu_fallback=False):
    """Train a GAT model and return the trained model"""
    
    device = torch.device('cuda' if torch.cuda.is_available() and not use_cpu_fallback else 'cpu')
    print(f"Using device: {device}")
    
    # Check memory requirements and adjust accordingly
    num_edges = data.num_edges
    num_nodes = data.num_nodes
    
    print(f"Graph stats - Nodes: {num_nodes}, Edges: {num_edges}")
    
    # Memory optimization: reduce model size if too many edges
    use_simple_model = False
    if num_edges > 2000000:  # If more than 2M edges
        print("Very large graph detected, using simplified GAT architecture...")
        hidden_dim = 32
        heads = 4
        use_simple_model = True
    elif num_edges > 1000000:  # If more than 1M edges
        print("Large graph detected, reducing model complexity...")
        hidden_dim = 32
        heads = 4
    else:
        hidden_dim = 64
        heads = 8
    
    # Create train/val/test masks
    N = data.num_nodes
    y_np = data.y.cpu().numpy()
    
    from sklearn.model_selection import StratifiedShuffleSplit
    
    # Split 80/10/10
    sss1 = StratifiedShuffleSplit(n_splits=1, train_size=0.8, random_state=42)
    train_idx, temp_idx = next(sss1.split(np.zeros(N), y_np))
    
    y_temp = y_np[temp_idx]
    sss2 = StratifiedShuffleSplit(n_splits=1, train_size=0.5, random_state=43)
    val_rel, test_rel = next(sss2.split(np.zeros(len(temp_idx)), y_temp))
    val_idx = temp_idx[val_rel]
    test_idx = temp_idx[test_rel]
    
    train_mask = torch.zeros(N, dtype=torch.bool)
    val_mask = torch.zeros(N, dtype=torch.bool)
    test_mask = torch.zeros(N, dtype=torch.bool)
    train_mask[train_idx] = True
    val_mask[val_idx] = True
    test_mask[test_idx] = True
    
    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask
    
    # Initialize model
    in_dim = data.x.size(1)
    n_class = int(data.y.max().item() + 1)
    
    if use_simple_model:
        model = SimpleGAT(in_dim, hidden_dim, n_class, heads=heads).to(device)
        print(f"Using SimpleGAT: {in_dim} -> {n_class} (hidden: {hidden_dim}, heads: {heads})")
    else:
        model = GAT(in_dim, hidden_dim, n_class, heads=heads).to(device)
        print(f"Using GAT: {in_dim} -> {hidden_dim} -> {n_class} (heads: {heads})")
    
    # Move data to device with memory management
    cpu_fallback_triggered = False
    try:
        data = data.to(device)
        print(f"Successfully moved data to {device}")
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            print(f"GPU memory insufficient, falling back to CPU...")
            device = torch.device('cpu')
            model = model.cpu()
            data = data.cpu()
            cpu_fallback_triggered = True
        else:
            raise e
    
    # Initialize optimizer and criterion
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
    criterion = torch.nn.CrossEntropyLoss()
    
    def train():
        nonlocal model, data, optimizer, device, cpu_fallback_triggered
        
        model.train()
        optimizer.zero_grad()
        
        try:
            if device.type == 'cuda':
                torch.cuda.empty_cache()  # Clear cache before forward pass
            
            out = model(data.x, data.edge_index)
            loss = criterion(out[data.train_mask], data.y[data.train_mask])
            loss.backward()
            optimizer.step()
            
            if device.type == 'cuda':
                torch.cuda.empty_cache()  # Clear cache after backward pass
                
            return loss
            
        except RuntimeError as e:
            if "out of memory" in str(e).lower() and not cpu_fallback_triggered:
                print(f"GPU OOM during training, switching to CPU...")
                # Move everything to CPU
                device = torch.device('cpu')
                model = model.cpu()
                data = data.cpu()
                optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
                cpu_fallback_triggered = True
                
                # Retry the forward pass on CPU
                optimizer.zero_grad()
                out = model(data.x, data.edge_index)
                loss = criterion(out[data.train_mask], data.y[data.train_mask])
                loss.backward()
                optimizer.step()
                return loss
            else:
                raise e
    
    def test(mask):
        model.eval()
        with torch.no_grad():
            if device.type == 'cuda':
                torch.cuda.empty_cache()
                
            out = model(data.x, data.edge_index)
            pred = out.argmax(dim=1)
            correct = pred[mask] == data.y[mask]
            acc = int(correct.sum()) / int(mask.sum())
            return acc
    
    print(f"Training {model_name} model...")
    best_val_acc = 0
    best_model_state = None
    
    for epoch in range(1, epochs + 1):
        loss = train()
        
        if epoch % 50 == 0 or epoch == 1:
            val_acc = test(data.val_mask)
            test_acc = test(data.test_mask)
            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_model_state = model.state_dict().copy()
    
    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    final_test_acc = test(data.test_mask)
    print(f"Final {model_name} test accuracy: {final_test_acc:.4f}")
    
    return model, data

In [None]:
# Train GAT models for both RNA and ADT with memory management
print("=== Training RNA GAT ===")
try:
    rna_gat_model, rna_data_with_masks = train_gat_model(rna_data, "RNA GAT", epochs=200)
except RuntimeError as e:
    if "out of memory" in str(e).lower():
        print("GPU memory insufficient for RNA GAT, trying CPU...")
        rna_gat_model, rna_data_with_masks = train_gat_model(rna_data, "RNA GAT", epochs=200, use_cpu_fallback=True)
    else:
        raise e

print("\n=== Training ADT GAT ===")
try:
    adt_gat_model, adt_data_with_masks = train_gat_model(adt_data, "ADT GAT", epochs=200)
except RuntimeError as e:
    if "out of memory" in str(e).lower():
        print("GPU memory insufficient for ADT GAT, trying CPU...")
        adt_gat_model, adt_data_with_masks = train_gat_model(adt_data, "ADT GAT", epochs=200, use_cpu_fallback=True)
    else:
        raise e

print("\n=== GAT Training Complete ===")
print(f"RNA GAT model trained successfully")
print(f"ADT GAT model trained successfully")

## 8. Extract GAT Embeddings

In [None]:
def extract_embeddings(model, data):
    """Extract embeddings from trained GAT model"""
    model.eval()
    
    # Ensure model and data are on the same device
    device = next(model.parameters()).device
    if data.x.device != device:
        print(f"Moving data from {data.x.device} to {device}")
        data = data.to(device)
    
    with torch.no_grad():
        # Clear cache if using GPU
        if device.type == 'cuda':
            torch.cuda.empty_cache()
            
        embeddings = model.get_embeddings(data.x, data.edge_index)
        
        # Move to CPU for further processing
        embeddings = embeddings.cpu()
        
        if device.type == 'cuda':
            torch.cuda.empty_cache()
            
    return embeddings

# Clear any existing cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# Extract embeddings
print("Extracting embeddings...")
rna_embeddings = extract_embeddings(rna_gat_model, rna_data_with_masks)
adt_embeddings = extract_embeddings(adt_gat_model, adt_data_with_masks)

print(f"RNA embeddings shape: {rna_embeddings.shape}")
print(f"ADT embeddings shape: {adt_embeddings.shape}")

# Ensure both embeddings have the same number of cells
assert rna_embeddings.shape[0] == adt_embeddings.shape[0], "Number of cells must match"

## 9. Define Transformer Encoder Mapping Model

In [None]:
class PositionalEncoding(nn.Module):
    """Positional encoding for transformer models"""
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class TransformerMapping(nn.Module):
    def __init__(self, input_dim, output_dim, d_model=256, nhead=4, num_layers=3, dropout=0.1):
        super(TransformerMapping, self).__init__()
        
        # Input projection
        self.input_proj = nn.Linear(input_dim, d_model)
        
        # Transformer encoder
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model*4,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
        
        # Output projection
        self.output_proj = nn.Linear(d_model, output_dim)
        
    def forward(self, x):
        # Project input to transformer dimensions
        x = self.input_proj(x)
        
        # Add batch dimension if not present
        if len(x.shape) == 2:
            x = x.unsqueeze(1)  # [batch_size, 1, d_model]
            
        # Pass through transformer encoder
        x = self.transformer_encoder(x)
        
        # Project to output dimensions
        x = self.output_proj(x.squeeze(1))
        
        return x

# Initialize transformer mapping model
import math  # For the positional encoding
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_dim = rna_embeddings.shape[1]
output_dim = adt_embeddings.shape[1]

transformer_model = TransformerMapping(
    input_dim=input_dim, 
    output_dim=output_dim, 
    d_model=256,
    nhead=4,
    num_layers=3
).to(device)

print(f"Transformer Model: {input_dim} -> {output_dim}")
print(transformer_model)

## 9. Prepare Training Data for Transformer

In [None]:
# Convert embeddings to CPU and numpy
rna_emb_np = rna_embeddings.cpu().numpy()
adt_emb_np = adt_embeddings.cpu().numpy()

# Split data for training (use same train/val/test split as GAT)
train_mask_np = rna_data_with_masks.train_mask.cpu().numpy()
val_mask_np = rna_data_with_masks.val_mask.cpu().numpy()
test_mask_np = rna_data_with_masks.test_mask.cpu().numpy()

# Prepare training data
X_train = torch.tensor(rna_emb_np[train_mask_np], dtype=torch.float32)
y_train = torch.tensor(adt_emb_np[train_mask_np], dtype=torch.float32)

X_val = torch.tensor(rna_emb_np[val_mask_np], dtype=torch.float32)
y_val = torch.tensor(adt_emb_np[val_mask_np], dtype=torch.float32)

X_test = torch.tensor(rna_emb_np[test_mask_np], dtype=torch.float32)
y_test = torch.tensor(adt_emb_np[test_mask_np], dtype=torch.float32)

print(f"Training set: {X_train.shape[0]} samples")
print(f"Validation set: {X_val.shape[0]} samples")
print(f"Test set: {X_test.shape[0]} samples")

# Create data loaders
batch_size = 64  # Smaller batch size for transformer
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## 10. Train Transformer Mapping Model

In [None]:
# Training parameters
num_epochs = 300
learning_rate = 0.0005  # Lower learning rate for transformer
weight_decay = 1e-4

# Warmup scheduler
from torch.optim.lr_scheduler import LambdaLR

def get_lr_scheduler(optimizer, warmup_steps=1000, max_steps=10000):
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return max(
            0.0, float(max_steps - current_step) / float(max(1, max_steps - warmup_steps))
        )
    return LambdaLR(optimizer, lr_lambda)

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(transformer_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)

# Training loop
train_losses = []
val_losses = []
best_val_loss = float('inf')
best_model_state = None
patience = 30
patience_counter = 0

print("Training Transformer mapping model...")

for epoch in range(num_epochs):
    # Training phase
    transformer_model.train()
    train_loss = 0.0
    
    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        
        optimizer.zero_grad()
        outputs = transformer_model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(transformer_model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        train_loss += loss.item()
    
    train_loss /= len(train_loader)
    train_losses.append(train_loss)
    
    # Validation phase
    transformer_model.eval()
    val_loss = 0.0
    
    with torch.no_grad():
        for batch_x, batch_y in val_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            outputs = transformer_model(batch_x)
            loss = criterion(outputs, batch_y)
            val_loss += loss.item()
    
    val_loss /= len(val_loader)
    val_losses.append(val_loss)
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = transformer_model.state_dict().copy()
        patience_counter = 0
    else:
        patience_counter += 1
    
    if (epoch + 1) % 25 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}')
    
    if patience_counter >= patience:
        print(f'Early stopping at epoch {epoch+1}')
        break

# Load best model
transformer_model.load_state_dict(best_model_state)
print(f'Best validation loss: {best_val_loss:.6f}')

## 11. Evaluate Transformer Model

In [None]:
# Test the model
transformer_model.eval()
test_loss = 0.0
predictions = []
ground_truth = []

with torch.no_grad():
    for batch_x, batch_y in test_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        outputs = transformer_model(batch_x)
        loss = criterion(outputs, batch_y)
        test_loss += loss.item()
        
        predictions.append(outputs.cpu().numpy())
        ground_truth.append(batch_y.cpu().numpy())

test_loss /= len(test_loader)
predictions = np.vstack(predictions)
ground_truth = np.vstack(ground_truth)

# Calculate metrics
mse = mean_squared_error(ground_truth, predictions)
r2 = r2_score(ground_truth, predictions)

# Calculate correlation per dimension
pearson_corrs = []
spearman_corrs = []

for i in range(ground_truth.shape[1]):
    pearson_r, _ = pearsonr(ground_truth[:, i], predictions[:, i])
    spearman_r, _ = spearmanr(ground_truth[:, i], predictions[:, i])
    pearson_corrs.append(pearson_r)
    spearman_corrs.append(spearman_r)

mean_pearson = np.mean(pearson_corrs)
mean_spearman = np.mean(spearman_corrs)

print(f"\n=== Transformer Mapping Results ===")
print(f"Test Loss (MSE): {test_loss:.6f}")
print(f"MSE: {mse:.6f}")
print(f"R² Score: {r2:.4f}")
print(f"Mean Pearson Correlation: {mean_pearson:.4f}")
print(f"Mean Spearman Correlation: {mean_spearman:.4f}")

## 12. Visualize Results

In [None]:
# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.yscale('log')

plt.subplot(1, 2, 2)
plt.hist(pearson_corrs, bins=20, alpha=0.7, label='Pearson')
plt.hist(spearman_corrs, bins=20, alpha=0.7, label='Spearman')
plt.xlabel('Correlation')
plt.ylabel('Frequency')
plt.title('Per-dimension Correlation Distribution')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# Scatter plots for first few dimensions
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i in range(min(6, ground_truth.shape[1])):
    ax = axes[i]
    ax.scatter(ground_truth[:, i], predictions[:, i], alpha=0.6, s=1)
    
    # Add perfect prediction line
    min_val = min(ground_truth[:, i].min(), predictions[:, i].min())
    max_val = max(ground_truth[:, i].max(), predictions[:, i].max())
    ax.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8)
    
    ax.set_xlabel(f'True ADT Embedding Dim {i+1}')
    ax.set_ylabel(f'Predicted ADT Embedding Dim {i+1}')
    ax.set_title(f'Dim {i+1}: r={pearson_corrs[i]:.3f}')

plt.tight_layout()
plt.show()

## 13. Compare with MLP Model (if available)

In [None]:
# Try to load MLP model results if available
try:
    mlp_results = np.load('mapping_predictions.npz')
    mlp_predictions = mlp_results['predictions']
    mlp_ground_truth = mlp_results['ground_truth']
    mlp_pearson = mlp_results['pearson_corrs']
    mlp_spearman = mlp_results['spearman_corrs']
    
    # Calculate metrics
    mlp_mse = mean_squared_error(mlp_ground_truth, mlp_predictions)
    mlp_r2 = r2_score(mlp_ground_truth, mlp_predictions)
    mlp_mean_pearson = np.mean(mlp_pearson)
    mlp_mean_spearman = np.mean(mlp_spearman)
    
    # Compare results
    print("\n=== Transformer vs MLP Comparison ===")
    print(f"MSE: Transformer: {mse:.6f}, MLP: {mlp_mse:.6f}, Improvement: {(mlp_mse-mse)/mlp_mse*100:.2f}%")
    print(f"R²: Transformer: {r2:.4f}, MLP: {mlp_r2:.4f}, Improvement: {(r2-mlp_r2)/mlp_r2*100:.2f}%")
    print(f"Pearson: Transformer: {mean_pearson:.4f}, MLP: {mlp_mean_pearson:.4f}, Improvement: {(mean_pearson-mlp_mean_pearson)/mlp_mean_pearson*100:.2f}%")
    print(f"Spearman: Transformer: {mean_spearman:.4f}, MLP: {mlp_mean_spearman:.4f}, Improvement: {(mean_spearman-mlp_mean_spearman)/mlp_mean_spearman*100:.2f}%")
    
    # Visualize correlation distribution comparison
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.hist(pearson_corrs, bins=20, alpha=0.7, label='Transformer')
    plt.hist(mlp_pearson, bins=20, alpha=0.7, label='MLP')
    plt.xlabel('Pearson Correlation')
    plt.ylabel('Frequency')
    plt.title('Pearson Correlation Comparison')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.hist(spearman_corrs, bins=20, alpha=0.7, label='Transformer')
    plt.hist(mlp_spearman, bins=20, alpha=0.7, label='MLP')
    plt.xlabel('Spearman Correlation')
    plt.ylabel('Frequency')
    plt.title('Spearman Correlation Comparison')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
except FileNotFoundError:
    print("MLP results file not found. Cannot compare models.")

## 14. Save Models and Results

In [None]:
# Save trained models
torch.save({
    'rna_gat_state_dict': rna_gat_model.state_dict(),
    'adt_gat_state_dict': adt_gat_model.state_dict(),
    'transformer_mapping_state_dict': transformer_model.state_dict(),
    'rna_input_dim': input_dim,
    'adt_output_dim': output_dim,
    'test_results': {
        'mse': mse,
        'r2': r2,
        'mean_pearson': mean_pearson,
        'mean_spearman': mean_spearman,
        'pearson_corrs': pearson_corrs,
        'spearman_corrs': spearman_corrs
    }
}, 'rna_adt_transformer_mapping_models.pth')

print("Models and results saved to 'rna_adt_transformer_mapping_models.pth'")

# Save predictions for further analysis
np.savez('transformer_mapping_predictions.npz', 
         predictions=predictions, 
         ground_truth=ground_truth,
         pearson_corrs=pearson_corrs,
         spearman_corrs=spearman_corrs)

print("Predictions saved to 'transformer_mapping_predictions.npz'")

## 15. Evaluate Cluster Preservation

To further evaluate the quality of the transformer encoder mapping, we can check if the predicted ADT embeddings preserve the cluster structure of the original ADT data. We'll apply Leiden clustering to both the true ADT embeddings and predicted ADT embeddings, then measure the agreement between these cluster assignments.

In [None]:
# Create AnnData objects for true and predicted ADT embeddings
import anndata as ad
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
import scanpy as sc

print("Evaluating cluster preservation in predicted embeddings...")

# Create AnnData objects
true_adt_adata = ad.AnnData(X=ground_truth)
pred_adt_adata = ad.AnnData(X=predictions)

# Process both datasets the same way
for adata in [true_adt_adata, pred_adt_adata]:
    sc.pp.neighbors(adata, n_neighbors=15, use_rep='X')
    
# Run Leiden clustering with multiple resolutions
resolutions = [0.2, 0.5, 0.8, 1.0, 1.5, 2.0]
results = []

for res in resolutions:
    # Cluster true ADT embeddings
    sc.tl.leiden(true_adt_adata, resolution=res, key_added=f'leiden_res{res}')
    
    # Cluster predicted ADT embeddings
    sc.tl.leiden(pred_adt_adata, resolution=res, key_added=f'leiden_res{res}')
    
    # Get cluster labels
    true_labels = true_adt_adata.obs[f'leiden_res{res}'].astype(int).values
    pred_labels = pred_adt_adata.obs[f'leiden_res{res}'].astype(int).values
    
    # Calculate metrics
    ari = adjusted_rand_score(true_labels, pred_labels)
    nmi = normalized_mutual_info_score(true_labels, pred_labels, average_method='arithmetic')
    
    # Number of clusters
    true_n_clusters = len(np.unique(true_labels))
    pred_n_clusters = len(np.unique(pred_labels))
    
    # Store results
    results.append({
        'Resolution': res,
        'True Clusters': true_n_clusters,
        'Predicted Clusters': pred_n_clusters,
        'ARI': ari,
        'NMI': nmi
    })

# Create results dataframe
results_df = pd.DataFrame(results)
print(results_df)

# Visualize results
plt.figure(figsize=(14, 5))

plt.subplot(1, 2, 1)
plt.plot(results_df['Resolution'], results_df['ARI'], 'o-', label='ARI')
plt.plot(results_df['Resolution'], results_df['NMI'], 'o-', label='NMI')
plt.xlabel('Leiden Resolution')
plt.ylabel('Score')
plt.title('Clustering Agreement Metrics')
plt.legend()
plt.grid(alpha=0.3)

plt.subplot(1, 2, 2)
bar_width = 0.35
x = np.arange(len(resolutions))
plt.bar(x - bar_width/2, results_df['True Clusters'], bar_width, label='True ADT')
plt.bar(x + bar_width/2, results_df['Predicted Clusters'], bar_width, label='Predicted ADT')
plt.xlabel('Leiden Resolution')
plt.ylabel('Number of Clusters')
plt.title('Cluster Counts')
plt.xticks(x, resolutions)
plt.legend()
plt.grid(alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# Use resolution with best ARI score for UMAP visualization
best_res_idx = results_df['ARI'].idxmax()
best_res = results_df.loc[best_res_idx, 'Resolution']
print(f"\nBest resolution: {best_res} (ARI: {results_df.loc[best_res_idx, 'ARI']:.4f})")

# UMAP visualization of both embeddings with cluster labels
for adata in [true_adt_adata, pred_adt_adata]:
    sc.tl.umap(adata)

# Create a figure for UMAP visualization
plt.figure(figsize=(16, 7))

plt.subplot(1, 2, 1)
sc.pl.umap(true_adt_adata, color=f'leiden_res{best_res}', title='True ADT Embeddings', show=False, legend_loc='on data')
plt.axis('on')

plt.subplot(1, 2, 2)
sc.pl.umap(pred_adt_adata, color=f'leiden_res{best_res}', title='Predicted ADT Embeddings', show=False, legend_loc='on data')
plt.axis('on')

plt.tight_layout()
plt.show()

In [None]:
# Create a confusion matrix to see how well clusters match between true and predicted embeddings
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Get cluster labels at the best resolution
best_res = results_df.loc[results_df['ARI'].idxmax(), 'Resolution']
true_labels = true_adt_adata.obs[f'leiden_res{best_res}'].astype(int)
pred_labels = pred_adt_adata.obs[f'leiden_res{best_res}'].astype(int)

# Calculate confusion matrix
conf_matrix = confusion_matrix(true_labels, pred_labels, normalize='true')

# Get the number of clusters for proper visualization
n_clusters_true = len(np.unique(true_labels))
n_clusters_pred = len(np.unique(pred_labels))

# Plot confusion matrix
plt.figure(figsize=(12, 10))
sns.heatmap(conf_matrix, cmap="YlGnBu", annot=True if n_clusters_true <= 20 else False,
            fmt='.2f', xticklabels=range(n_clusters_pred), yticklabels=range(n_clusters_true))
plt.xlabel('Predicted Clusters')
plt.ylabel('True Clusters')
plt.title(f'Confusion Matrix (Normalized) - Resolution {best_res}')
plt.tight_layout()
plt.show()

# Find the most preserved and least preserved clusters
cluster_preservation = np.max(conf_matrix, axis=1)
most_preserved_idx = np.argmax(cluster_preservation)
least_preserved_idx = np.argmin(cluster_preservation)

print(f"Most preserved cluster: {most_preserved_idx} with {cluster_preservation[most_preserved_idx]:.2%} preservation")
print(f"Least preserved cluster: {least_preserved_idx} with {cluster_preservation[least_preserved_idx]:.2%} preservation")

# Add cluster labels to the original data for further analysis
adata_combined = ad.AnnData(
    X=np.concatenate([ground_truth, predictions]),
    obs=pd.DataFrame({
        'embedding_type': ['True ADT'] * ground_truth.shape[0] + ['Predicted ADT'] * predictions.shape[0],
    })
)

# Calculate UMAP for the combined embedding space
sc.pp.neighbors(adata_combined, n_neighbors=15, use_rep='X')
sc.tl.umap(adata_combined)

# Plot combined UMAP
plt.figure(figsize=(12, 10))
sc.pl.umap(adata_combined, color='embedding_type', title='Combined UMAP - True vs Predicted ADT Embeddings',
           palette={'True ADT': 'blue', 'Predicted ADT': 'red'}, alpha=0.7, size=30, show=False)
plt.legend(loc='upper right', frameon=True)
plt.show()

# Quantify global structure preservation
from scipy.spatial import procrustes

# Perform Procrustes analysis on UMAP coordinates
true_coords = true_adt_adata.obsm['X_umap']
pred_coords = pred_adt_adata.obsm['X_umap']

# Procrustes analysis scales, rotates and translates the predicted coordinates to best match the true coordinates
mtx1, mtx2, disparity = procrustes(true_coords, pred_coords)

print(f"\nProcrustes analysis disparity (lower is better): {disparity:.4f}")
print("This value quantifies how well the global structure is preserved after optimal alignment.")

# Calculate silhouette scores to measure cluster quality
from sklearn.metrics import silhouette_score

try:
    true_silhouette = silhouette_score(true_adt_adata.X, true_labels)
    pred_silhouette = silhouette_score(pred_adt_adata.X, pred_labels)
    
    print(f"\nSilhouette score for true clusters: {true_silhouette:.4f}")
    print(f"Silhouette score for predicted clusters: {pred_silhouette:.4f}")
    print(f"Ratio (pred/true): {pred_silhouette/true_silhouette:.4f}")
    if pred_silhouette >= true_silhouette:
        print("The predicted embeddings have equally good or better defined clusters than the true embeddings.")
    else:
        print("The true embeddings have better defined clusters than the predicted embeddings.")
except:
    print("Could not calculate silhouette scores, possibly due to cluster numbers or sample size.")

## 17. Comprehensive Performance Visualization

In this section, we'll create additional visualizations to better understand the model performance and assess the quality of the RNA to ADT mapping:

1. Detailed accuracy metrics and correlation heatmaps
2. Dimension reduction visualizations comparing true vs. predicted embeddings
3. Feature importance analysis
4. Distribution of prediction errors
5. Performance across different cell types

In [None]:
# 1. Detailed Accuracy Metrics and Correlation Heatmap

# Calculate per-dimension metrics
dim_metrics = []
for i in range(ground_truth.shape[1]):
    true = ground_truth[:, i]
    pred = predictions[:, i]
    
    # Calculate metrics
    dim_mse = mean_squared_error(true, pred)
    dim_r2 = r2_score(true, pred)
    dim_pearson, _ = pearsonr(true, pred)
    dim_spearman, _ = spearmanr(true, pred)
    
    # Calculate relative error
    mean_abs_error = np.mean(np.abs(true - pred))
    mean_true = np.mean(np.abs(true))
    rel_error = mean_abs_error / mean_true if mean_true > 0 else float('inf')
    
    dim_metrics.append({
        'Dimension': i + 1,
        'MSE': dim_mse,
        'R²': dim_r2,
        'Pearson r': dim_pearson,
        'Spearman r': dim_spearman,
        'Relative Error': rel_error
    })

# Create DataFrame for better visualization
metrics_df = pd.DataFrame(dim_metrics)

# Sort by correlation to identify best and worst predicted dimensions
metrics_df_sorted = metrics_df.sort_values(by='Pearson r', ascending=False)

# Display summary statistics
print("Summary Statistics for Dimension-wise Performance:")
print(metrics_df.describe())

# Plot distribution of metrics
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# MSE distribution
sns.histplot(metrics_df['MSE'], kde=True, ax=axes[0, 0])
axes[0, 0].set_title('MSE Distribution Across Dimensions')
axes[0, 0].set_xlabel('Mean Squared Error')

# R² distribution
sns.histplot(metrics_df['R²'], kde=True, ax=axes[0, 1])
axes[0, 1].set_title('R² Distribution Across Dimensions')
axes[0, 1].set_xlabel('R² Score')

# Pearson correlation distribution
sns.histplot(metrics_df['Pearson r'], kde=True, ax=axes[1, 0])
axes[1, 0].set_title('Pearson Correlation Distribution')
axes[1, 0].set_xlabel('Pearson r')

# Relative error distribution
sns.histplot(metrics_df['Relative Error'].clip(upper=2), kde=True, ax=axes[1, 1])
axes[1, 1].set_title('Relative Error Distribution (clipped at 2)')
axes[1, 1].set_xlabel('Relative Error')

plt.tight_layout()
plt.show()

# Display top 10 best and worst predicted dimensions
print("\nTop 10 Best Predicted Dimensions:")
display(metrics_df_sorted.head(10))

print("\nTop 10 Worst Predicted Dimensions:")
display(metrics_df_sorted.tail(10))

# Calculate correlation between predicted and true dimensions
corr_matrix = np.zeros((ground_truth.shape[1], ground_truth.shape[1]))
for i in range(ground_truth.shape[1]):
    for j in range(ground_truth.shape[1]):
        corr_matrix[i, j], _ = pearsonr(ground_truth[:, i], predictions[:, j])

# Plot correlation heatmap
plt.figure(figsize=(12, 10))
sns.heatmap(corr_matrix, cmap='coolwarm', center=0, 
            xticklabels=range(1, ground_truth.shape[1]+1, 5), 
            yticklabels=range(1, ground_truth.shape[1]+1, 5))
plt.title('Correlation Between True and Predicted Dimensions')
plt.xlabel('Predicted Dimension')
plt.ylabel('True Dimension')
plt.tight_layout()
plt.show()

# Plot diagonal correlation strength
plt.figure(figsize=(10, 6))
plt.plot(range(1, ground_truth.shape[1]+1), np.diag(corr_matrix), 'o-')
plt.axhline(y=0.5, color='r', linestyle='--', alpha=0.7)
plt.axhline(y=0.7, color='g', linestyle='--', alpha=0.7)
plt.title('Diagonal Correlation Strength')
plt.xlabel('Dimension')
plt.ylabel('Correlation (r)')
plt.ylim(-0.1, 1.1)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# 2. Dimension Reduction Visualizations

# Combine true and predicted embeddings for unified visualization
combined_data = np.vstack([ground_truth, predictions])
data_labels = np.array(['True ADT'] * ground_truth.shape[0] + ['Predicted ADT'] * predictions.shape[0])

# Create color palette
palette = {'True ADT': '#1f77b4', 'Predicted ADT': '#ff7f0e'}

# PCA
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
pca_result = pca.fit_transform(combined_data)

# Extract coordinates for each group
pca_df = pd.DataFrame({
    'PC1': pca_result[:, 0],
    'PC2': pca_result[:, 1],
    'Type': data_labels
})

plt.figure(figsize=(12, 10))

# PCA plot
plt.subplot(2, 2, 1)
sns.scatterplot(
    data=pca_df, x='PC1', y='PC2', hue='Type', 
    palette=palette, s=10, alpha=0.7
)
plt.title(f'PCA Projection\nExplained Variance: {pca.explained_variance_ratio_.sum():.2f}')

# t-SNE
from sklearn.manifold import TSNE

# This can be slow on large datasets, so we'll use a sample if needed
max_samples = 5000
if combined_data.shape[0] > max_samples:
    sample_idx = np.random.choice(combined_data.shape[0], max_samples, replace=False)
    sample_data = combined_data[sample_idx]
    sample_labels = data_labels[sample_idx]
else:
    sample_data = combined_data
    sample_labels = data_labels

try:
    tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=42)
    tsne_result = tsne.fit_transform(sample_data)
    
    # Extract coordinates for each group
    tsne_df = pd.DataFrame({
        'tSNE1': tsne_result[:, 0],
        'tSNE2': tsne_result[:, 1],
        'Type': sample_labels
    })
    
    plt.subplot(2, 2, 2)
    sns.scatterplot(
        data=tsne_df, x='tSNE1', y='tSNE2', hue='Type', 
        palette=palette, s=10, alpha=0.7
    )
    plt.title('t-SNE Projection')
except Exception as e:
    print(f"t-SNE failed: {e}")
    plt.subplot(2, 2, 2)
    plt.text(0.5, 0.5, f"t-SNE failed: {str(e)}", 
             ha='center', va='center', transform=plt.gca().transAxes)

# UMAP (if available)
try:
    import umap
    
    reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42)
    umap_result = reducer.fit_transform(sample_data)
    
    # Extract coordinates for each group
    umap_df = pd.DataFrame({
        'UMAP1': umap_result[:, 0],
        'UMAP2': umap_result[:, 1],
        'Type': sample_labels
    })
    
    plt.subplot(2, 2, 3)
    sns.scatterplot(
        data=umap_df, x='UMAP1', y='UMAP2', hue='Type', 
        palette=palette, s=10, alpha=0.7
    )
    plt.title('UMAP Projection')
except ImportError:
    print("UMAP not available. Install with 'pip install umap-learn'")
    plt.subplot(2, 2, 3)
    plt.text(0.5, 0.5, "UMAP not installed\nInstall with 'pip install umap-learn'", 
             ha='center', va='center', transform=plt.gca().transAxes)
except Exception as e:
    print(f"UMAP failed: {e}")
    plt.subplot(2, 2, 3)
    plt.text(0.5, 0.5, f"UMAP failed: {str(e)}", 
             ha='center', va='center', transform=plt.gca().transAxes)

# Procrustes analysis - align predicted to true embeddings
from scipy.spatial import procrustes

# Use sample if data is large
if ground_truth.shape[0] > max_samples:
    sample_idx = np.random.choice(ground_truth.shape[0], max_samples, replace=False)
    true_sample = ground_truth[sample_idx]
    pred_sample = predictions[sample_idx]
else:
    true_sample = ground_truth
    pred_sample = predictions

# First reduce dimensions with PCA for visualization
pca = PCA(n_components=2)
true_pca = pca.fit_transform(true_sample)
pred_pca = pca.transform(pred_sample)

# Perform Procrustes alignment
mtx1, mtx2, disparity = procrustes(true_pca, pred_pca)

# Create dataframe for plotting
procrustes_df = pd.DataFrame({
    'X_True': mtx1[:, 0],
    'Y_True': mtx1[:, 1],
    'X_Pred': mtx2[:, 0],
    'Y_Pred': mtx2[:, 1],
})

plt.subplot(2, 2, 4)
# Plot true points
plt.scatter(procrustes_df['X_True'], procrustes_df['Y_True'], 
            color=palette['True ADT'], s=10, alpha=0.7, label='True ADT')
# Plot predicted points
plt.scatter(procrustes_df['X_Pred'], procrustes_df['Y_Pred'], 
            color=palette['Predicted ADT'], s=10, alpha=0.7, label='Predicted ADT')

# Draw lines connecting corresponding points
for i in range(len(procrustes_df)):
    plt.plot([procrustes_df.iloc[i, 0], procrustes_df.iloc[i, 2]],
             [procrustes_df.iloc[i, 1], procrustes_df.iloc[i, 3]],
             'gray', alpha=0.2)

plt.legend()
plt.title(f'Procrustes Analysis\nDisparity: {disparity:.4f}')

plt.tight_layout()
plt.show()

In [None]:
# 3. Feature Importance Analysis

# 3.1 Dimension-wise error analysis
error_magnitude = np.mean(np.abs(ground_truth - predictions), axis=0)

# Plot dimensions sorted by prediction error
plt.figure(figsize=(12, 6))
sorted_indices = np.argsort(error_magnitude)
plt.bar(range(len(error_magnitude)), error_magnitude[sorted_indices], alpha=0.7)
plt.xlabel('Dimension Index (Sorted by Error)')
plt.ylabel('Mean Absolute Error')
plt.title('Dimension-wise Prediction Error')
plt.grid(alpha=0.3, axis='y')
plt.tight_layout()
plt.show()

# 3.2 Variance explained analysis
# Calculate variance explained by each dimension
true_variance = np.var(ground_truth, axis=0)
pred_variance = np.var(predictions, axis=0)
variance_ratio = pred_variance / true_variance

# Create DataFrame for better visualization
var_df = pd.DataFrame({
    'Dimension': range(1, len(true_variance) + 1),
    'True Variance': true_variance,
    'Predicted Variance': pred_variance,
    'Variance Ratio': variance_ratio,
    'Error': error_magnitude
})

# Plot variance comparison
plt.figure(figsize=(14, 8))

plt.subplot(2, 1, 1)
plt.scatter(var_df['True Variance'], var_df['Predicted Variance'], alpha=0.7)
plt.plot([0, var_df['True Variance'].max()], [0, var_df['True Variance'].max()], 'r--')
plt.xlabel('True Variance')
plt.ylabel('Predicted Variance')
plt.title('Variance Preservation in Predicted Embeddings')
plt.grid(alpha=0.3)

plt.subplot(2, 1, 2)
sns.histplot(var_df['Variance Ratio'].clip(0, 2), bins=30, kde=True)
plt.axvline(x=1, color='r', linestyle='--')
plt.xlabel('Variance Ratio (Predicted/True)')
plt.ylabel('Count')
plt.title('Distribution of Variance Ratios (clipped at 2)')

plt.tight_layout()
plt.show()

# 3.3 Analyze relationship between variance and prediction error
plt.figure(figsize=(10, 6))
plt.scatter(var_df['True Variance'], var_df['Error'], alpha=0.7)
plt.xlabel('True Variance')
plt.ylabel('Mean Absolute Error')
plt.title('Relationship Between Dimension Variance and Prediction Error')
plt.grid(alpha=0.3)

# Add best fit line
from scipy.stats import linregress
slope, intercept, r_value, p_value, std_err = linregress(var_df['True Variance'], var_df['Error'])
x_line = np.linspace(var_df['True Variance'].min(), var_df['True Variance'].max(), 100)
y_line = slope * x_line + intercept
plt.plot(x_line, y_line, 'r--', 
         label=f'Slope: {slope:.4f}, R²: {r_value**2:.4f}, p: {p_value:.4f}')
plt.legend()

plt.tight_layout()
plt.show()

# Display top dimensions by variance
print("\nTop 10 Dimensions by Variance:")
display(var_df.sort_values(by='True Variance', ascending=False).head(10))

In [None]:
# 4. Distribution of Prediction Errors

# Calculate errors
absolute_errors = np.abs(ground_truth - predictions)
relative_errors = np.zeros_like(absolute_errors)
non_zero_mask = (ground_truth != 0)
relative_errors[non_zero_mask] = absolute_errors[non_zero_mask] / np.abs(ground_truth)[non_zero_mask]

# Aggregate errors per sample
sample_mae = np.mean(absolute_errors, axis=1)  # Mean absolute error per sample
sample_max_error = np.max(absolute_errors, axis=1)  # Max error per sample
sample_mean_rel_error = np.mean(np.clip(relative_errors, 0, 2), axis=1)  # Mean relative error, clipped at 2

# Create DataFrame for visualization
error_df = pd.DataFrame({
    'Mean Absolute Error': sample_mae,
    'Max Absolute Error': sample_max_error,
    'Mean Relative Error': sample_mean_rel_error
})

# Visualize error distributions
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Mean absolute error histogram
sns.histplot(error_df['Mean Absolute Error'], kde=True, ax=axes[0, 0])
axes[0, 0].set_title('Distribution of Mean Absolute Errors per Sample')
axes[0, 0].set_xlabel('Mean Absolute Error')

# Max absolute error histogram
sns.histplot(error_df['Max Absolute Error'], kde=True, ax=axes[0, 1])
axes[0, 1].set_title('Distribution of Maximum Absolute Errors per Sample')
axes[0, 1].set_xlabel('Maximum Absolute Error')

# Mean relative error histogram
sns.histplot(error_df['Mean Relative Error'], kde=True, ax=axes[1, 0])
axes[1, 0].set_title('Distribution of Mean Relative Errors per Sample')
axes[1, 0].set_xlabel('Mean Relative Error (clipped at 2)')

# Scatterplot of mean vs max error
sns.scatterplot(x='Mean Absolute Error', y='Max Absolute Error', data=error_df, ax=axes[1, 1], alpha=0.5)
axes[1, 1].set_title('Relationship Between Mean and Max Errors')

plt.tight_layout()
plt.show()

# 4.2 Analyze error distribution by dimension
# Create box plot showing error distribution for top dimensions
# Select a subset of dimensions for clarity
n_dims_to_show = min(20, ground_truth.shape[1])
top_var_dims = var_df.sort_values(by='True Variance', ascending=False).head(n_dims_to_show)['Dimension'].values - 1

# Extract errors for these dimensions
selected_errors = absolute_errors[:, top_var_dims]

# Create a DataFrame for plotting
boxplot_data = []
for i, dim_idx in enumerate(top_var_dims):
    boxplot_data.append(pd.DataFrame({
        'Absolute Error': absolute_errors[:, dim_idx],
        'Dimension': f"Dim {dim_idx+1}"
    }))

boxplot_df = pd.concat(boxplot_data)

# Create box plot
plt.figure(figsize=(14, 6))
sns.boxplot(x='Dimension', y='Absolute Error', data=boxplot_df)
plt.title('Error Distribution for Top High-Variance Dimensions')
plt.xticks(rotation=45)
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

# 4.3 Identify and analyze outlier predictions
# Find samples with the highest prediction errors
worst_samples_idx = np.argsort(sample_mae)[-20:]  # Get indices of 20 worst predictions
worst_samples_true = ground_truth[worst_samples_idx]
worst_samples_pred = predictions[worst_samples_idx]

# Display mean error for worst samples
print(f"Mean absolute error for 20 worst predictions: {sample_mae[worst_samples_idx].mean():.4f}")
print(f"(Compared to overall mean absolute error: {sample_mae.mean():.4f})")

# Plot some of the worst sample predictions
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
fig.suptitle("Examples of Worst Predictions", fontsize=16)

axes = axes.flatten()
for i, idx in enumerate(worst_samples_idx[-4:]):  # Show 4 worst samples
    true = ground_truth[idx]
    pred = predictions[idx]
    dims = min(50, len(true))  # Show first 50 dimensions or fewer
    
    ax = axes[i]
    x = range(dims)
    ax.plot(x, true[:dims], 'b-', label='True', alpha=0.7)
    ax.plot(x, pred[:dims], 'r-', label='Predicted', alpha=0.7)
    ax.set_title(f"Sample #{idx}: MAE={sample_mae[idx]:.4f}")
    ax.set_xlabel('Dimension')
    ax.set_ylabel('Value')
    ax.legend()
    ax.grid(alpha=0.3)

plt.tight_layout()
plt.subplots_adjust(top=0.9)  # Adjust for the suptitle
plt.show()

In [None]:
# 5. Performance Across Different Cell Types (Clusters)

# Assuming we have leiden clusters from earlier analysis
# We'll use these to analyze prediction performance by cell type

# 5.1 Re-create predictions and ground truth with cell type information
try:
    # Get cluster labels if available
    # First check if we can use leiden labels from trainADT
    if 'leiden' in trainADT.obs:
        # Need to map back to the test indices
        all_indices = np.arange(len(trainADT))
        cluster_labels = trainADT.obs['leiden'].values[all_indices[test_mask_np]]
        
        # Create DataFrame with predictions, ground truth, and cluster info
        cell_type_df = pd.DataFrame({
            'Cluster': cluster_labels,
            'Mean Absolute Error': sample_mae
        })
        
        # Calculate average error by cell type
        cluster_performance = cell_type_df.groupby('Cluster')['Mean Absolute Error'].agg(['mean', 'std', 'count']).reset_index()
        cluster_performance = cluster_performance.sort_values(by='mean')
        
        # Plot performance by cell type
        plt.figure(figsize=(14, 6))
        sns.barplot(x='Cluster', y='mean', data=cluster_performance, 
                    order=cluster_performance['Cluster'])
        plt.errorbar(x=range(len(cluster_performance)), 
                    y=cluster_performance['mean'], 
                    yerr=cluster_performance['std'],
                    fmt='none', ecolor='black', capsize=3)
        plt.xlabel('Cell Type (Cluster)')
        plt.ylabel('Mean Absolute Error')
        plt.title('Model Performance by Cell Type')
        plt.xticks(rotation=45)
        plt.grid(axis='y', alpha=0.3)
        plt.tight_layout()
        plt.show()
        
        # Show detailed performance by cluster
        print("Performance by Cell Type:")
        display(cluster_performance.sort_values(by='mean'))
        
        # Create a UMAP visualization with error overlay
        if 'X_umap' in trainADT.obsm:
            # Get UMAP coordinates for test data
            test_umap = trainADT.obsm['X_umap'][all_indices[test_mask_np]]
            
            # Create DataFrame for plotting
            umap_df = pd.DataFrame({
                'UMAP1': test_umap[:, 0],
                'UMAP2': test_umap[:, 1],
                'Cluster': cluster_labels,
                'Error': sample_mae
            })
            
            # Plot UMAP colored by cluster
            plt.figure(figsize=(16, 7))
            
            plt.subplot(1, 2, 1)
            sns.scatterplot(x='UMAP1', y='UMAP2', hue='Cluster', data=umap_df, 
                            palette='tab20', s=10, alpha=0.7)
            plt.title('UMAP Projection by Cell Type')
            
            # Plot UMAP colored by prediction error
            plt.subplot(1, 2, 2)
            scatter = plt.scatter(umap_df['UMAP1'], umap_df['UMAP2'], 
                                 c=umap_df['Error'], cmap='viridis', 
                                 s=10, alpha=0.7)
            plt.colorbar(scatter, label='Mean Absolute Error')
            plt.title('UMAP Projection by Prediction Error')
            
            plt.tight_layout()
            plt.show()
            
            # Statistical test for differences between clusters
            from scipy.stats import kruskal
            
            try:
                # Kruskal-Wallis H-test for independent samples
                h_stat, p_value = kruskal(*[
                    cell_type_df[cell_type_df['Cluster'] == c]['Mean Absolute Error'].values 
                    for c in cell_type_df['Cluster'].unique()
                ])
                
                print(f"\nKruskal-Wallis test for differences between clusters:")
                print(f"H-statistic: {h_stat:.4f}, p-value: {p_value:.4e}")
                
                if p_value < 0.05:
                    print("There are statistically significant differences in performance between cell types")
                else:
                    print("No statistically significant differences in performance between cell types")
            except:
                print("Could not perform statistical test - possibly due to insufficient data in some clusters")
    else:
        print("No cluster information available - cannot analyze performance by cell type")
        
except Exception as e:
    print(f"Could not analyze performance by cell type: {e}")
    import traceback
    traceback.print_exc()

# 5.2 Visualize combined data with error overlay using PCA
# Create PCA projection of the test data
pca = PCA(n_components=2)
test_pca = pca.fit_transform(np.vstack([ground_truth, predictions]))

# Split back into true and predicted
n_test = ground_truth.shape[0]
true_pca = test_pca[:n_test]
pred_pca = test_pca[n_test:]

# Create a scatter plot with lines connecting true and predicted points
plt.figure(figsize=(10, 8))
for i in range(n_test):
    # Draw line from true to predicted
    plt.plot([true_pca[i, 0], pred_pca[i, 0]], 
             [true_pca[i, 1], pred_pca[i, 1]], 
             'gray', alpha=0.1)

# Plot points with error coloring
scatter = plt.scatter(true_pca[:, 0], true_pca[:, 1], c=sample_mae, 
                     cmap='viridis', s=30, alpha=0.7)
plt.colorbar(scatter, label='Mean Absolute Error')

# Add predicted points in red
plt.scatter(pred_pca[:, 0], pred_pca[:, 1], c='red', s=5, alpha=0.5)

plt.title('PCA Projection with Error Overlay and Prediction Lines')
plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
plt.tight_layout()
plt.show()

# 5.3 Create a summary performance table
performance_summary = pd.DataFrame({
    'Metric': ['Overall MSE', 'R² Score', 'Mean Pearson Correlation', 'Mean Spearman Correlation'],
    'Value': [mse, r2, mean_pearson, mean_spearman]
})

# Add percentiles of error
percentiles = [10, 25, 50, 75, 90]
for p in percentiles:
    performance_summary = performance_summary.append({
        'Metric': f'MAE {p}th Percentile',
        'Value': np.percentile(sample_mae, p)
    }, ignore_index=True)

print("Overall Performance Summary:")
display(performance_summary)

In [None]:
# 6. Create Interactive Visualizations (using Plotly if available)

try:
    import plotly.express as px
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    
    # Create interactive PCA plot
    pca = PCA(n_components=3)
    pca_result = pca.fit_transform(combined_data)
    
    # Extract coordinates for each group
    pca_df = pd.DataFrame({
        'PC1': pca_result[:, 0],
        'PC2': pca_result[:, 1],
        'PC3': pca_result[:, 2],
        'Type': data_labels
    })
    
    # Add error information to predicted points
    n_test = ground_truth.shape[0]
    for i, mae in enumerate(sample_mae):
        idx = n_test + i  # Index in the combined data
        pca_df.loc[idx, 'Error'] = mae
    
    # Create interactive 3D PCA plot
    fig = px.scatter_3d(
        pca_df, x='PC1', y='PC2', z='PC3',
        color='Type', symbol='Type',
        opacity=0.7,
        labels={'PC1': f'PC1 ({pca.explained_variance_ratio_[0]:.2%})',
                'PC2': f'PC2 ({pca.explained_variance_ratio_[1]:.2%})',
                'PC3': f'PC3 ({pca.explained_variance_ratio_[2]:.2%})'},
        title='Interactive 3D PCA Projection',
        height=700
    )
    fig.update_traces(marker=dict(size=3))
    fig.show()
    
    # Create interactive error heatmap for dimensions
    # Calculate correlation matrix between predicted and true dimensions
    interactive_corr_matrix = np.zeros((min(30, ground_truth.shape[1]), min(30, ground_truth.shape[1])))
    for i in range(interactive_corr_matrix.shape[0]):
        for j in range(interactive_corr_matrix.shape[1]):
            interactive_corr_matrix[i, j], _ = pearsonr(ground_truth[:, i], predictions[:, j])
    
    fig = go.Figure(data=go.Heatmap(
        z=interactive_corr_matrix,
        x=[f'Pred {i+1}' for i in range(interactive_corr_matrix.shape[1])],
        y=[f'True {i+1}' for i in range(interactive_corr_matrix.shape[0])],
        colorscale='RdBu_r',
        zmid=0,
        colorbar=dict(title='Correlation')
    ))
    
    fig.update_layout(
        title='Correlation Between True and Predicted Dimensions (First 30)',
        xaxis_title='Predicted Dimension',
        yaxis_title='True Dimension',
        height=600,
        width=700
    )
    fig.show()
    
    # Interactive scatter plot showing prediction quality
    # Take a sample of points for better performance
    max_points = 2000
    if ground_truth.shape[0] > max_points:
        sample_idx = np.random.choice(ground_truth.shape[0], max_points, replace=False)
        gt_sample = ground_truth[sample_idx]
        pred_sample = predictions[sample_idx]
        error_sample = sample_mae[sample_idx]
    else:
        gt_sample = ground_truth
        pred_sample = predictions
        error_sample = sample_mae
    
    # Get first two PCA components
    pca = PCA(n_components=2)
    gt_pca = pca.fit_transform(gt_sample)
    pred_pca = pca.transform(pred_sample)
    
    # Create DataFrame for plotting
    scatter_df = pd.DataFrame({
        'True_PC1': gt_pca[:, 0],
        'True_PC2': gt_pca[:, 1],
        'Pred_PC1': pred_pca[:, 0],
        'Pred_PC2': pred_pca[:, 1],
        'Error': error_sample
    })
    
    # Create subplot with two scatterplots
    fig = make_subplots(rows=1, cols=2, 
                        subplot_titles=('True Embeddings', 'Predicted Embeddings'))
    
    # Add scatter traces
    fig.add_trace(
        go.Scatter(
            x=scatter_df['True_PC1'], y=scatter_df['True_PC2'],
            mode='markers',
            marker=dict(
                size=5,
                color=scatter_df['Error'],
                colorscale='Viridis',
                colorbar=dict(title='Error'),
                showscale=True
            ),
            name='True'
        ),
        row=1, col=1
    )
    
    fig.add_trace(
        go.Scatter(
            x=scatter_df['Pred_PC1'], y=scatter_df['Pred_PC2'],
            mode='markers',
            marker=dict(
                size=5,
                color=scatter_df['Error'],
                colorscale='Viridis',
                showscale=False
            ),
            name='Predicted'
        ),
        row=1, col=2
    )
    
    fig.update_layout(
        title_text="PCA Projection with Error Coloring",
        height=500,
        width=900
    )
    
    fig.show()
    
except ImportError:
    print("Plotly is not installed. Install with 'pip install plotly' for interactive visualizations.")
except Exception as e:
    print(f"Error creating interactive visualizations: {e}")
    import traceback
    traceback.print_exc()

## 18. Model Performance Summary

The comprehensive visualizations above provide detailed insights into the performance of our Transformer Encoder model for mapping RNA embeddings to ADT embeddings:

### Key Findings:

1. **Overall Performance**:
   - MSE (Mean Squared Error): Quantifies the average squared difference between predicted and true values
   - R² Score: Indicates how much variance in the true values is captured by the model
   - Pearson/Spearman Correlations: Measure the linear and rank correlation between predicted and true values

2. **Dimension-wise Analysis**:
   - Some embedding dimensions are predicted with higher accuracy than others
   - High-variance dimensions tend to be more challenging to predict
   - The diagonal correlation heatmap shows how well each individual dimension is preserved

3. **Embedding Space Structure**:
   - Dimension reduction visualizations (PCA, t-SNE, UMAP) show how well the global structure is preserved
   - Procrustes analysis quantifies the alignment between true and predicted embedding spaces
   - Connected point visualizations reveal how individual points move in the embedding space

4. **Error Analysis**:
   - Distribution of errors across samples identifies potential outliers
   - Feature importance analysis reveals which dimensions contribute most to errors
   - Error patterns across different cell types help identify model weaknesses

5. **Cell Type Performance**:
   - Performance varies across different cell populations (clusters)
   - Some cell types may be easier to predict than others
   - Error distribution in UMAP space reveals biological patterns in prediction accuracy

### Interpretation:

- **High R² and Correlation**: Indicates the model captures most of the variance in the ADT embeddings
- **Low MSE**: Suggests precise numerical prediction of embedding values
- **Preserved Cluster Structure**: Demonstrates the model maintains the biological meaning in the mapping
- **Consistent Performance Across Cell Types**: Shows robustness across diverse biological contexts

These visualizations provide a comprehensive assessment of the model's strengths and weaknesses, helping to understand not just how well it performs overall, but where it excels or needs improvement.

## 16. Interpretation of Cluster Preservation Results

The above analysis helps us understand how well our Transformer Encoder model preserves biological cell types when mapping from RNA to ADT embeddings. Here's how to interpret these results:

### Key Metrics:
- **Adjusted Rand Index (ARI)**: Measures the similarity between the true and predicted cluster assignments, adjusted for chance. Values range from -1 to 1, where 1 indicates perfect agreement, and values near 0 indicate random clustering.
- **Normalized Mutual Information (NMI)**: Quantifies the shared information between the two clusterings. Values range from 0 to 1, with 1 indicating perfect agreement.
- **Procrustes Disparity**: Measures how well the global structure is preserved after optimal alignment (lower is better).
- **Silhouette Scores**: Measures how well-defined the clusters are (higher is better).

### Interpretation Guidelines:

1. **Strong Cluster Preservation** (Good model performance):
   - High ARI (> 0.7) and NMI (> 0.8)
   - Similar number of clusters between true and predicted embeddings
   - Low Procrustes disparity
   - Similar silhouette scores between true and predicted embeddings
   - Clear diagonal pattern in confusion matrix

2. **Moderate Cluster Preservation** (Acceptable model performance):
   - Moderate ARI (0.4-0.7) and NMI (0.5-0.8)
   - Some differences in cluster numbers
   - Some off-diagonal elements in confusion matrix, but still showing structure
   - Visible separation of clusters in UMAP visualizations, though not identical

3. **Poor Cluster Preservation** (Model needs improvement):
   - Low ARI (< 0.4) and NMI (< 0.5)
   - Very different number of clusters
   - No clear pattern in confusion matrix
   - Poor separation in UMAP visualizations

### Biological Significance:
If the model demonstrates good cluster preservation, it suggests that:
1. The RNA expression data contains sufficient information to predict cell types as defined by surface protein markers
2. The transformer model has successfully learned to map between these two modalities
3. The predicted ADT embeddings could potentially be used for downstream analyses in place of actual ADT measurements

This evaluation framework provides a comprehensive assessment of how well the model maintains biological cell type structure when mapping between RNA and ADT modalities.

## Summary

This notebook implements a pipeline to learn mappings between GAT embeddings of RNA and ADT data:

1. **Data Preprocessing**: Both RNA and ADT data are normalized, scaled, and processed to create neighbor graphs
2. **GAT Training**: Separate GAT models are trained on RNA and ADT data for node classification
3. **Embedding Extraction**: Intermediate embeddings are extracted from the trained GAT models
4. **Transformer Mapping**: 
   - A Transformer Encoder architecture learns to map RNA embeddings to ADT embeddings
   - Self-attention mechanisms capture complex dependencies between embedding dimensions
   - Multi-headed attention allows the model to focus on different aspects of the data
5. **Evaluation**: The mapping quality is assessed using MSE, R², and correlation metrics

The Transformer architecture offers several advantages over MLPs:
- **Context awareness**: Self-attention mechanism captures global dependencies in the embeddings
- **Parameter efficiency**: Shared parameters across layers for better generalization
- **Representation power**: Better handling of complex relationships between features

The trained models can be used to predict ADT embeddings from RNA data, enabling cross-modal analysis and integration.

In [None]:
import anndata as ad
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

print("Evaluating cluster preservation in predicted embeddings...")

# Create AnnData objects
true_adt_adata = ad.AnnData(X=ground_truth)
pred_adt_adata = ad.AnnData(X=predictions)

# Process both datasets the same way
for adata in [true_adt_adata, pred_adt_adata]:
    sc.pp.neighbors(adata, n_neighbors=15, use_rep='X')
    
# Run Leiden clustering with multiple resolutions
resolutions = [0.2, 0.5, 0.8, 1.0, 1.5, 2.0]
results = []

for res in resolutions:
    # Cluster true ADT embeddings
    sc.tl.leiden(true_adt_adata, resolution=res, key_added=f'leiden_res{res}')
    
    # Cluster predicted ADT embeddings
    sc.tl.leiden(pred_adt_adata, resolution=res, key_added=f'leiden_res{res}')
    
    # Get cluster labels
    true_labels = true_adt_adata.obs[f'leiden_res{res}'].astype(int).values
    pred_labels = pred_adt_adata.obs[f'leiden_res{res}'].astype(int).values
    
    # Calculate metrics
    ari = adjusted_rand_score(true_labels, pred_labels)
    nmi = normalized_mutual_info_score(true_labels, pred_labels, average_method='arithmetic')
    
    # Number of clusters
    true_n_clusters = len(np.unique(true_labels))
    pred_n_clusters = len(np.unique(pred_labels))
    
    # Store results
    results.append({
        'Resolution': res,
        'True Clusters': true_n_clusters,
        'Predicted Clusters': pred_n_clusters,
        'ARI': ari,
        'NMI': nmi
    })

# Create results dataframe
results_df = pd.DataFrame(results)
print(results_df)

# Visualize results - make sure to set matplotlib backend explicitly
plt.figure(figsize=(14, 5))

plt.subplot(1, 2, 1)
plt.plot(results_df['Resolution'], results_df['ARI'], 'o-', label='ARI')
plt.plot(results_df['Resolution'], results_df['NMI'], 'o-', label='NMI')
plt.xlabel('Leiden Resolution')
plt.ylabel('Score')
plt.title('Clustering Agreement Metrics')
plt.legend()
plt.grid(alpha=0.3)

plt.subplot(1, 2, 2)
bar_width = 0.35
x = np.arange(len(resolutions))
plt.bar(x - bar_width/2, results_df['True Clusters'], bar_width, label='True ADT')
plt.bar(x + bar_width/2, results_df['Predicted Clusters'], bar_width, label='Predicted ADT')
plt.xlabel('Leiden Resolution')
plt.ylabel('Number of Clusters')
plt.title('Cluster Counts')
plt.xticks(x, resolutions)
plt.legend()
plt.grid(alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('metrics_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Use resolution with best ARI score for UMAP visualization
best_res_idx = results_df['ARI'].idxmax()
best_res = results_df.loc[best_res_idx, 'Resolution']
print(f"\nBest resolution: {best_res} (ARI: {results_df.loc[best_res_idx, 'ARI']:.4f})")

# UMAP visualization of both embeddings with cluster labels
for adata in [true_adt_adata, pred_adt_adata]:
    sc.tl.umap(adata)

# Create UMAP plots separately using matplotlib rather than scanpy's built-in plotting
# This gives us more control over the figure size and ensures the plots aren't truncated

# Generate cluster colors for consistency between plots
n_clusters = max(
    len(np.unique(true_adt_adata.obs[f'leiden_res{best_res}'])),
    len(np.unique(pred_adt_adata.obs[f'leiden_res{best_res}']))
)
cluster_colors = plt.cm.tab20(np.linspace(0, 1, n_clusters))

# Set up the figure
plt.figure(figsize=(20, 10))

# True ADT UMAP
plt.subplot(1, 2, 1)
true_clusters = true_adt_adata.obs[f'leiden_res{best_res}'].astype(int).values
for i in range(n_clusters):
    mask = true_clusters == i
    if np.any(mask):
        plt.scatter(
            true_adt_adata.obsm['X_umap'][mask, 0],
            true_adt_adata.obsm['X_umap'][mask, 1],
            s=10,
            color=cluster_colors[i],
            label=str(i)
        )
        
        # Add cluster labels
        centroid_x = np.mean(true_adt_adata.obsm['X_umap'][mask, 0])
        centroid_y = np.mean(true_adt_adata.obsm['X_umap'][mask, 1])
        plt.text(centroid_x, centroid_y, str(i), fontweight='bold', fontsize=12, ha='center')

plt.title(f'True ADT Embeddings (Resolution: {best_res})')
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
# Uncomment the next line if you want a traditional legend
# plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', ncol=1)

# Predicted ADT UMAP
plt.subplot(1, 2, 2)
pred_clusters = pred_adt_adata.obs[f'leiden_res{best_res}'].astype(int).values
for i in range(n_clusters):
    mask = pred_clusters == i
    if np.any(mask):
        plt.scatter(
            pred_adt_adata.obsm['X_umap'][mask, 0],
            pred_adt_adata.obsm['X_umap'][mask, 1],
            s=10,
            color=cluster_colors[i],
            label=str(i)
        )
        
        # Add cluster labels
        centroid_x = np.mean(pred_adt_adata.obsm['X_umap'][mask, 0])
        centroid_y = np.mean(pred_adt_adata.obsm['X_umap'][mask, 1])
        plt.text(centroid_x, centroid_y, str(i), fontweight='bold', fontsize=12, ha='center')

plt.title(f'Predicted ADT Embeddings (Resolution: {best_res})')
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
# Uncomment the next line if you want a traditional legend
# plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', ncol=1)

plt.tight_layout()
plt.savefig('umap_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
import anndata as ad
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl

# Increase the figure size and DPI for better visualization
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 300

print("Evaluating cluster preservation in predicted embeddings...")

# Create AnnData objects
true_adt_adata = ad.AnnData(X=ground_truth)
pred_adt_adata = ad.AnnData(X=predictions)

# Process both datasets the same way
for adata in [true_adt_adata, pred_adt_adata]:
    sc.pp.neighbors(adata, n_neighbors=15, use_rep='X')
    
# Run Leiden clustering with multiple resolutions
resolutions = [0.2, 0.5, 0.8, 1.0, 1.5, 2.0]
results = []

for res in resolutions:
    # Cluster true ADT embeddings
    sc.tl.leiden(true_adt_adata, resolution=res, key_added=f'leiden_res{res}')
    
    # Cluster predicted ADT embeddings
    sc.tl.leiden(pred_adt_adata, resolution=res, key_added=f'leiden_res{res}')
    
    # Get cluster labels
    true_labels = true_adt_adata.obs[f'leiden_res{res}'].astype(int).values
    pred_labels = pred_adt_adata.obs[f'leiden_res{res}'].astype(int).values
    
    # Calculate metrics
    ari = adjusted_rand_score(true_labels, pred_labels)
    nmi = normalized_mutual_info_score(true_labels, pred_labels, average_method='arithmetic')
    
    # Number of clusters
    true_n_clusters = len(np.unique(true_labels))
    pred_n_clusters = len(np.unique(pred_labels))
    
    # Store results
    results.append({
        'Resolution': res,
        'True Clusters': true_n_clusters,
        'Predicted Clusters': pred_n_clusters,
        'ARI': ari,
        'NMI': nmi
    })

# Create results dataframe
results_df = pd.DataFrame(results)
print(results_df.to_string(index=False, float_format=lambda x: f"{x:.4f}"))

# Visualize metrics and cluster counts
plt.figure(figsize=(15, 6))

plt.subplot(1, 2, 1)
plt.plot(results_df['Resolution'], results_df['ARI'], 'o-', linewidth=2, markersize=8, label='ARI')
plt.plot(results_df['Resolution'], results_df['NMI'], 'o-', linewidth=2, markersize=8, label='NMI')
plt.xlabel('Leiden Resolution', fontsize=12)
plt.ylabel('Score', fontsize=12)
plt.title('Clustering Agreement Metrics', fontsize=14)
plt.legend(fontsize=12)
plt.grid(alpha=0.3)
plt.xticks(resolutions)

plt.subplot(1, 2, 2)
bar_width = 0.35
x = np.arange(len(resolutions))
plt.bar(x - bar_width/2, results_df['True Clusters'], bar_width, label='True ADT')
plt.bar(x + bar_width/2, results_df['Predicted Clusters'], bar_width, label='Predicted ADT')
plt.xlabel('Leiden Resolution', fontsize=12)
plt.ylabel('Number of Clusters', fontsize=12)
plt.title('Cluster Counts', fontsize=14)
plt.xticks(x, resolutions)
plt.legend(fontsize=12)
plt.grid(alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('cluster_metrics_comparison.png', bbox_inches='tight')
plt.show()

# Use resolution with best ARI score for UMAP visualization
best_res_idx = results_df['ARI'].idxmax()
best_res = results_df.loc[best_res_idx, 'Resolution']
print(f"\nBest resolution: {best_res} (ARI: {results_df.loc[best_res_idx, 'ARI']:.4f})")

# UMAP visualization of both embeddings with cluster labels
for adata in [true_adt_adata, pred_adt_adata]:
    # Run UMAP with explicit random state and spread for reproducibility
    sc.tl.umap(adata, random_state=42, min_dist=0.3, spread=1.0)

# Save the original scanpy plotting parameters
original_settings = dict(sc.settings.figdir)

# Set scanpy figure parameters to save high-quality figures
sc.settings.set_figure_params(dpi=150, dpi_save=300, frameon=True, figsize=(7, 7))

# Create a figure for UMAP visualization
fig, axes = plt.subplots(1, 2, figsize=(18, 8))

# Plot true ADT UMAP
sc.pl.umap(true_adt_adata, color=f'leiden_res{best_res}', title=f'True ADT Embeddings (res={best_res})', 
           show=False, legend_loc='on data', ax=axes[0], legend_fontsize=10)
axes[0].set_aspect('equal')
axes[0].set_xlabel('UMAP1', fontsize=12)
axes[0].set_ylabel('UMAP2', fontsize=12)
axes[0].grid(False)

# Plot predicted ADT UMAP
sc.pl.umap(pred_adt_adata, color=f'leiden_res{best_res}', title=f'Predicted ADT Embeddings (res={best_res})', 
           show=False, legend_loc='on data', ax=axes[1], legend_fontsize=10)
axes[1].set_aspect('equal')
axes[1].set_xlabel('UMAP1', fontsize=12)
axes[1].set_ylabel('UMAP2', fontsize=12)
axes[1].grid(False)

plt.tight_layout()
plt.savefig(f'umap_comparison_res{best_res}.png', bbox_inches='tight', dpi=300)
plt.show()

# Restore original scanpy settings
sc.settings.figdir = original_settings

# Optional: Save the results as CSV for further analysis
results_df.to_csv('clustering_metrics.csv', index=False)

print("Analysis complete. Figures saved to disk.")

In [None]:
# Cell Type Performance Metrics: Accuracy, Sensitivity, and Precision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import precision_score, recall_score, accuracy_score, confusion_matrix
from scipy.stats import pearsonr

print("Calculating cell type performance metrics...")

# First, ensure we have cluster labels from both RNA and ADT data
# We'll use leiden clustering results at the best resolution
best_res = results_df.loc[results_df['ARI'].idxmax(), 'Resolution']

# Get cluster labels for true and predicted data
true_labels = true_adt_adata.obs[f'leiden_res{best_res}'].astype(int).values
pred_labels = pred_adt_adata.obs[f'leiden_res{best_res}'].astype(int).values

# 1. Build a mapping between predicted and true clusters using maximum overlap
cm = confusion_matrix(true_labels, pred_labels)
cm_norm = cm / cm.sum(axis=1, keepdims=True)  # Normalize by row (true clusters)

# For each true cluster, find the predicted cluster with maximum overlap
cluster_mapping = {}
for i in range(cm.shape[0]):
    max_j = np.argmax(cm[i])
    cluster_mapping[i] = max_j

print(f"Cluster mapping from true to predicted: {cluster_mapping}")

# 2. Calculate metrics per cell type (cluster)
metrics = []
unique_clusters = np.unique(true_labels)

for cluster in unique_clusters:
    # Create binary classification problem for this cluster
    true_binary = (true_labels == cluster).astype(int)
    pred_binary = (pred_labels == cluster_mapping[cluster]).astype(int)
    
    # Calculate metrics
    accuracy = accuracy_score(true_binary, pred_binary)
    precision = precision_score(true_binary, pred_binary, zero_division=0)
    recall = recall_score(true_binary, pred_binary, zero_division=0)
    
    # Calculate correlation for cells in this cluster
    cluster_mask = (true_labels == cluster)
    if np.sum(cluster_mask) > 5:  # Ensure enough samples for correlation
        # Get true and predicted embeddings for this cluster
        true_emb_cluster = ground_truth[cluster_mask]
        pred_emb_cluster = predictions[cluster_mask]
        
        # Calculate average correlation across dimensions
        corr_values = []
        for i in range(true_emb_cluster.shape[1]):
            r, _ = pearsonr(true_emb_cluster[:, i], pred_emb_cluster[:, i])
            corr_values.append(r)
        avg_corr = np.mean(corr_values)
    else:
        avg_corr = np.nan
    
    # Count cells in this cluster
    cell_count = np.sum(cluster_mask)
    
    metrics.append({
        'Cluster': cluster,
        'Cell Count': cell_count,
        'Accuracy': accuracy,
        'Precision': precision,
        'Recall': recall,  # Same as Sensitivity
        'Correlation': avg_corr,
        'Mapped To': cluster_mapping[cluster]
    })

# Create DataFrame for visualization
metrics_df = pd.DataFrame(metrics)
metrics_df = metrics_df.sort_values('Cell Count', ascending=False)

# 3. Plot metrics
plt.figure(figsize=(18, 10))

# Plot metrics side by side with custom colors
ax1 = plt.subplot(2, 1, 1)
metrics_subset = metrics_df.sort_values('Cell Count', ascending=False).head(15)  # Top 15 clusters by size

bar_width = 0.25
x = np.arange(len(metrics_subset))

# Plot bars
ax1.bar(x - bar_width, metrics_subset['Accuracy'], bar_width, label='Accuracy', color='#3498db')
ax1.bar(x, metrics_subset['Precision'], bar_width, label='Precision', color='#2ecc71')
ax1.bar(x + bar_width, metrics_subset['Recall'], bar_width, label='Sensitivity', color='#e74c3c')

# Add labels and formatting
ax1.set_xlabel('Cell Cluster', fontsize=12)
ax1.set_ylabel('Score', fontsize=12)
ax1.set_title('Accuracy, Precision, and Sensitivity by Cell Cluster', fontsize=14)
ax1.set_xticks(x)
ax1.set_xticklabels([f"C{c}\n({n})" for c, n in zip(metrics_subset['Cluster'], metrics_subset['Cell Count'])])
ax1.legend(fontsize=12)
ax1.grid(axis='y', alpha=0.3)
ax1.set_ylim(0, 1)

# Add value labels above bars
for i, metric in enumerate(['Accuracy', 'Precision', 'Recall']):
    for j, value in enumerate(metrics_subset[metric]):
        ax1.text(j + (i-1)*bar_width, value + 0.02, f"{value:.2f}", 
                ha='center', va='bottom', fontsize=9, rotation=0)

# Create a heatmap of all metrics
ax2 = plt.subplot(2, 1, 2)
metrics_plot = metrics_df.sort_values('Cell Count', ascending=False).head(20)
metrics_heatmap = metrics_plot[['Accuracy', 'Precision', 'Recall', 'Correlation']].copy()
metrics_heatmap.index = [f"C{c} ({n})" for c, n in zip(metrics_plot['Cluster'], metrics_plot['Cell Count'])]

sns.heatmap(metrics_heatmap.T, annot=True, fmt='.2f', cmap='viridis', 
            linewidths=0.5, ax=ax2, cbar_kws={'label': 'Score'})
ax2.set_title('Performance Metrics Heatmap (Top 20 Clusters by Cell Count)', fontsize=14)
ax2.set_ylabel('Metric', fontsize=12)
ax2.set_xlabel('Cell Cluster (cell count)', fontsize=12)

plt.tight_layout()
plt.savefig('celltype_performance_metrics.png', dpi=300, bbox_inches='tight')
plt.show()

# 4. Create scatter plot showing relationship between cluster size and performance
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Create colormap based on correlation values
cmap = plt.cm.viridis
norm = plt.Normalize(metrics_df['Correlation'].min(), metrics_df['Correlation'].max())

# Plot Accuracy vs Cluster Size
axes[0, 0].scatter(metrics_df['Cell Count'], metrics_df['Accuracy'],
                   c=metrics_df['Correlation'], cmap=cmap, s=100, alpha=0.7)
axes[0, 0].set_xlabel('Cluster Size (Cell Count)', fontsize=12)
axes[0, 0].set_ylabel('Accuracy', fontsize=12)
axes[0, 0].set_title('Accuracy vs Cluster Size', fontsize=14)
axes[0, 0].grid(alpha=0.3)

# Plot Precision vs Cluster Size
axes[0, 1].scatter(metrics_df['Cell Count'], metrics_df['Precision'],
                   c=metrics_df['Correlation'], cmap=cmap, s=100, alpha=0.7)
axes[0, 1].set_xlabel('Cluster Size (Cell Count)', fontsize=12)
axes[0, 1].set_ylabel('Precision', fontsize=12)
axes[0, 1].set_title('Precision vs Cluster Size', fontsize=14)
axes[0, 1].grid(alpha=0.3)

# Plot Sensitivity vs Cluster Size
axes[1, 0].scatter(metrics_df['Cell Count'], metrics_df['Recall'],
                   c=metrics_df['Correlation'], cmap=cmap, s=100, alpha=0.7)
axes[1, 0].set_xlabel('Cluster Size (Cell Count)', fontsize=12)
axes[1, 0].set_ylabel('Sensitivity (Recall)', fontsize=12)
axes[1, 0].set_title('Sensitivity vs Cluster Size', fontsize=14)
axes[1, 0].grid(alpha=0.3)

# Plot Correlation vs Accuracy
scatter = axes[1, 1].scatter(metrics_df['Accuracy'], metrics_df['Correlation'],
                             c=metrics_df['Cell Count'], cmap='plasma', s=100, alpha=0.7)
axes[1, 1].set_xlabel('Accuracy', fontsize=12)
axes[1, 1].set_ylabel('Average Correlation', fontsize=12)
axes[1, 1].set_title('Correlation vs Accuracy', fontsize=14)
axes[1, 1].grid(alpha=0.3)

# Add a colorbar for correlation in the first 3 plots
cbar_ax1 = fig.add_axes([0.92, 0.55, 0.02, 0.3])  # [left, bottom, width, height]
cb1 = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cbar_ax1)
cb1.set_label('Correlation', fontsize=12)

# Add a colorbar for cluster size in the last plot
cbar_ax2 = fig.add_axes([0.92, 0.15, 0.02, 0.3])  # [left, bottom, width, height]
cb2 = plt.colorbar(scatter, cax=cbar_ax2)
cb2.set_label('Cluster Size', fontsize=12)

plt.tight_layout(rect=[0, 0, 0.9, 1])  # Adjust layout to make room for colorbars
plt.savefig('celltype_metrics_scatter.png', dpi=300, bbox_inches='tight')
plt.show()

# 5. Summary statistics
print("\nSummary Statistics for Cell Type Performance:")
print(f"Overall Mean Accuracy: {metrics_df['Accuracy'].mean():.4f}")
print(f"Overall Mean Precision: {metrics_df['Precision'].mean():.4f}")
print(f"Overall Mean Sensitivity: {metrics_df['Recall'].mean():.4f}")
print(f"Overall Mean Correlation: {metrics_df['Correlation'].mean():.4f}")

# Best and worst performing clusters
best_cluster = metrics_df.loc[metrics_df['Accuracy'].idxmax()]
worst_cluster = metrics_df.loc[metrics_df['Accuracy'].idxmin()]

print(f"\nBest performing cluster: Cluster {best_cluster['Cluster']} with {best_cluster['Cell Count']} cells")
print(f"  Accuracy: {best_cluster['Accuracy']:.4f}, Precision: {best_cluster['Precision']:.4f}, Sensitivity: {best_cluster['Recall']:.4f}")

print(f"\nWorst performing cluster: Cluster {worst_cluster['Cluster']} with {worst_cluster['Cell Count']} cells")
print(f"  Accuracy: {worst_cluster['Accuracy']:.4f}, Precision: {worst_cluster['Precision']:.4f}, Sensitivity: {worst_cluster['Recall']:.4f}")

# Save metrics to CSV
metrics_df.to_csv('celltype_performance_metrics.csv', index=False)
print("\nMetrics saved to 'celltype_performance_metrics.csv'")