# MLP Mapping between GAT Embeddings of RNA and ADT

This notebook learns a mapping between GAT embeddings from RNA data and GAT embeddings from ADT data using a Multi-Layer Perceptron (MLP).

In [None]:
# Memory optimization and system check
import torch
import os

# Set memory management environment variables
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Check system resources
print("=== System Resources ===")
if torch.cuda.is_available():
    device = torch.cuda.current_device()
    gpu_props = torch.cuda.get_device_properties(device)
    total_memory = gpu_props.total_memory / (1024**3)  # Convert to GB
    
    print(f"GPU: {gpu_props.name}")
    print(f"Total GPU Memory: {total_memory:.1f} GB")
    print(f"GPU Compute Capability: {gpu_props.major}.{gpu_props.minor}")
    
    # Clear any cached memory
    torch.cuda.empty_cache()
    
    # Check current memory usage
    allocated = torch.cuda.memory_allocated(device) / (1024**3)
    reserved = torch.cuda.memory_reserved(device) / (1024**3)
    
    print(f"Currently allocated: {allocated:.2f} GB")
    print(f"Currently reserved: {reserved:.2f} GB")
    print(f"Available: {total_memory - reserved:.2f} GB")
    
    # Set recommendations based on available memory
    if total_memory < 8:
        print("\n⚠️  WARNING: Low GPU memory detected!")
        print("Recommendations:")
        print("- Use CPU fallback if needed")
        print("- Reduce batch sizes")
        print("- Use graph sparsification")
    elif total_memory < 16:
        print("\n💡 Moderate GPU memory - will use optimized settings")
    else:
        print("\n✅ Sufficient GPU memory available")
        
else:
    print("CUDA not available - will use CPU")
    print("Note: Training will be slower but should work with larger graphs")

print("=" * 50)

In [None]:
%load_ext autoreload
%autoreload 2

# Set environment variables for better memory management
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr, spearmanr
import pandas as pd

import scanpy as sc
import scanpy.external as sce
from scipy import sparse

from DeepOMAPNet.Preprocess import prepare_train_test_anndata

# Set memory management
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"CUDA available: {torch.cuda.is_available()}")
    print(f"GPU count: {torch.cuda.device_count()}")
    print(f"Current GPU: {torch.cuda.current_device()}")
    print(f"GPU name: {torch.cuda.get_device_name()}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.1f} GB")
else:
    print("CUDA not available, using CPU")

## 1. Load and Prepare Data

In [None]:
# Load the preprocessed data
data = prepare_train_test_anndata()
trainGene = data[0]  # RNA data
trainADT = data[2]   # ADT data

print(f"RNA data shape: {trainGene.shape}")
print(f"ADT data shape: {trainADT.shape}")

## 2. Preprocess RNA Data

In [None]:
# RNA preprocessing
sc.pp.normalize_total(trainGene, target_sum=1e4)
sc.pp.log1p(trainGene)
sc.pp.highly_variable_genes(trainGene, n_top_genes=2000, batch_key="samples")
trainGene = trainGene[:, trainGene.var.highly_variable].copy()

sc.pp.scale(trainGene, max_value=10)
sc.tl.pca(trainGene, n_comps=50, svd_solver="arpack")

# Build neighbor graph for RNA
sc.pp.neighbors(trainGene, n_neighbors=15, n_pcs=50)
sc.tl.leiden(trainGene, resolution=1.0)

print(f"RNA data after preprocessing: {trainGene.shape}")
print(f"Number of RNA clusters: {trainGene.obs['leiden'].nunique()}")

## 3. Preprocess ADT Data

In [None]:
# ADT preprocessing
sc.pp.normalize_total(trainADT, target_sum=1e4)
sc.pp.log1p(trainADT)
sc.pp.scale(trainADT, max_value=10)
sc.tl.pca(trainADT, n_comps=50, svd_solver="arpack")

# Build neighbor graph for ADT using BBKNN for batch correction
sce.pp.bbknn(
    trainADT,
    batch_key='samples',
    n_pcs=50,
    neighbors_within_batch=3,
    trim=0
)

sc.tl.leiden(trainADT, resolution=1.0)

print(f"ADT data after preprocessing: {trainADT.shape}")
print(f"Number of ADT clusters: {trainADT.obs['leiden'].nunique()}")

## 4. Build PyTorch Geometric Data Objects

In [None]:
def sparsify_graph(adata, max_edges_per_node=50):
    """Sparsify the graph by keeping only top k neighbors per node"""
    
    A = adata.obsp["connectivities"].tocsr()
    n_nodes = A.shape[0]
    
    # Check if sparsification is needed
    avg_degree = A.nnz / n_nodes
    if avg_degree <= max_edges_per_node:
        print(f"Graph already sparse enough (avg degree: {avg_degree:.1f})")
        return adata
    
    print(f"Sparsifying graph from avg degree {avg_degree:.1f} to max {max_edges_per_node}")
    
    # Create new sparse matrix
    row_indices = []
    col_indices = []
    data_values = []
    
    for i in range(n_nodes):
        # Get neighbors and their weights for node i
        start_idx = A.indptr[i]
        end_idx = A.indptr[i + 1]
        neighbors = A.indices[start_idx:end_idx]
        weights = A.data[start_idx:end_idx]
        
        # Keep only top k neighbors
        if len(neighbors) > max_edges_per_node:
            top_k_indices = np.argpartition(weights, -max_edges_per_node)[-max_edges_per_node:]
            neighbors = neighbors[top_k_indices]
            weights = weights[top_k_indices]
        
        # Add edges
        row_indices.extend([i] * len(neighbors))
        col_indices.extend(neighbors)
        data_values.extend(weights)
    
    # Create new adjacency matrix
    A_sparse = sparse.csr_matrix(
        (data_values, (row_indices, col_indices)), 
        shape=(n_nodes, n_nodes)
    )
    
    # Make symmetric
    A_sparse = (A_sparse + A_sparse.T) / 2
    
    # Update the AnnData object
    adata.obsp["connectivities"] = A_sparse
    
    new_avg_degree = A_sparse.nnz / n_nodes
    print(f"New average degree: {new_avg_degree:.1f}")
    
    return adata

def build_pyg_data(adata, use_pca=True, sparsify_large_graphs=True, max_edges_per_node=50):
    """Build PyTorch Geometric Data object from AnnData"""
    
    # Sparsify if needed
    if sparsify_large_graphs:
        A = adata.obsp["connectivities"]
        avg_degree = A.nnz / A.shape[0]
        if avg_degree > max_edges_per_node:
            print(f"Large graph detected (avg degree: {avg_degree:.1f}), applying sparsification...")
            adata = sparsify_graph(adata, max_edges_per_node)
    
    # Features
    X = adata.obsm["X_pca"] if use_pca else adata.X.toarray()
    
    # Labels (leiden clusters)
    y = adata.obs["leiden"].astype(int).to_numpy()
    
    # Edge index from connectivities
    A = adata.obsp["connectivities"].tocsr()
    A_triu = sparse.triu(A, k=1)
    row, col = A_triu.nonzero()
    edge_index = torch.tensor(np.vstack([row, col]), dtype=torch.long)
    
    # Create PyG Data object
    data = Data(
        x=torch.tensor(X, dtype=torch.float32),
        edge_index=edge_index,
        y=torch.tensor(y, dtype=torch.long),
    )
    
    return data

# Build data objects with memory optimization
print("Building PyG data objects...")

# Check available GPU memory
if torch.cuda.is_available():
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)  # GB
    print(f"Available GPU memory: {gpu_memory:.1f} GB")
    
    # Estimate memory requirements
    rna_edges = trainGene.obsp["connectivities"].nnz
    adt_edges = trainADT.obsp["connectivities"].nnz
    
    print(f"RNA graph edges: {rna_edges:,}")
    print(f"ADT graph edges: {adt_edges:,}")
    
    # Set sparsification based on graph size
    max_edges_rna = 100 if rna_edges > 5000000 else 200
    max_edges_adt = 50 if adt_edges > 10000000 else 100
    
    print(f"Using max edges per node - RNA: {max_edges_rna}, ADT: {max_edges_adt}")
else:
    print("Using CPU - no memory constraints")
    max_edges_rna = 200
    max_edges_adt = 100

# Build data objects
rna_data = build_pyg_data(trainGene, use_pca=True, sparsify_large_graphs=True, max_edges_per_node=max_edges_rna)
adt_data = build_pyg_data(trainADT, use_pca=True, sparsify_large_graphs=True, max_edges_per_node=max_edges_adt)

print(f"RNA PyG data - Nodes: {rna_data.num_nodes}, Edges: {rna_data.num_edges}, Features: {rna_data.num_node_features}")
print(f"ADT PyG data - Nodes: {adt_data.num_nodes}, Edges: {adt_data.num_edges}, Features: {adt_data.num_node_features}")

## 5. Define GAT Model

In [None]:
class SimpleGAT(torch.nn.Module):
    """Simplified GAT for memory-constrained scenarios"""
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4, dropout=0.6):
        super().__init__()
        self.dropout = dropout
        
        # Single GAT layer for memory efficiency
        self.conv1 = GATConv(in_channels, out_channels, heads=heads, dropout=dropout, concat=False)
        
    def forward(self, x, edge_index, return_embeddings=False):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv1(x, edge_index)
        
        if return_embeddings:
            return x
        
        return x

    def get_embeddings(self, x, edge_index):
        """Get embeddings for mapping"""
        return self.forward(x, edge_index, return_embeddings=True)

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=8, dropout=0.6):
        super().__init__()
        self.dropout = dropout
        
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout)
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, dropout=dropout)
        
    def forward(self, x, edge_index, return_embeddings=False):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        if return_embeddings:
            return x  # Return embeddings before final layer
            
        x = self.conv2(x, edge_index)
        return x

    def get_embeddings(self, x, edge_index):
        """Get intermediate embeddings for mapping"""
        return self.forward(x, edge_index, return_embeddings=True)

## 6. Train GAT Models

In [None]:
def train_gat_model(data, model_name="GAT", epochs=200, use_cpu_fallback=False):
    """Train a GAT model and return the trained model"""
    
    device = torch.device('cuda' if torch.cuda.is_available() and not use_cpu_fallback else 'cpu')
    print(f"Using device: {device}")
    
    # Check memory requirements and adjust accordingly
    num_edges = data.num_edges
    num_nodes = data.num_nodes
    
    print(f"Graph stats - Nodes: {num_nodes}, Edges: {num_edges}")
    
    # Memory optimization: reduce model size if too many edges
    use_simple_model = False
    if num_edges > 2000000:  # If more than 2M edges
        print("Very large graph detected, using simplified GAT architecture...")
        hidden_dim = 32
        heads = 4
        use_simple_model = True
    elif num_edges > 1000000:  # If more than 1M edges
        print("Large graph detected, reducing model complexity...")
        hidden_dim = 32
        heads = 4
    else:
        hidden_dim = 64
        heads = 8
    
    # Create train/val/test masks
    N = data.num_nodes
    y_np = data.y.cpu().numpy()
    
    from sklearn.model_selection import StratifiedShuffleSplit
    
    # Split 80/10/10
    sss1 = StratifiedShuffleSplit(n_splits=1, train_size=0.8, random_state=42)
    train_idx, temp_idx = next(sss1.split(np.zeros(N), y_np))
    
    y_temp = y_np[temp_idx]
    sss2 = StratifiedShuffleSplit(n_splits=1, train_size=0.5, random_state=43)
    val_rel, test_rel = next(sss2.split(np.zeros(len(temp_idx)), y_temp))
    val_idx = temp_idx[val_rel]
    test_idx = temp_idx[test_rel]
    
    train_mask = torch.zeros(N, dtype=torch.bool)
    val_mask = torch.zeros(N, dtype=torch.bool)
    test_mask = torch.zeros(N, dtype=torch.bool)
    train_mask[train_idx] = True
    val_mask[val_idx] = True
    test_mask[test_idx] = True
    
    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask
    
    # Initialize model
    in_dim = data.x.size(1)
    n_class = int(data.y.max().item() + 1)
    
    if use_simple_model:
        model = SimpleGAT(in_dim, hidden_dim, n_class, heads=heads).to(device)
        print(f"Using SimpleGAT: {in_dim} -> {n_class} (hidden: {hidden_dim}, heads: {heads})")
    else:
        model = GAT(in_dim, hidden_dim, n_class, heads=heads).to(device)
        print(f"Using GAT: {in_dim} -> {hidden_dim} -> {n_class} (heads: {heads})")
    
    # Move data to device with memory management
    cpu_fallback_triggered = False
    try:
        data = data.to(device)
        print(f"Successfully moved data to {device}")
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            print(f"GPU memory insufficient, falling back to CPU...")
            device = torch.device('cpu')
            model = model.cpu()
            data = data.cpu()
            cpu_fallback_triggered = True
        else:
            raise e
    
    # Initialize optimizer and criterion
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
    criterion = torch.nn.CrossEntropyLoss()
    
    def train():
        nonlocal model, data, optimizer, device, cpu_fallback_triggered
        
        model.train()
        optimizer.zero_grad()
        
        try:
            if device.type == 'cuda':
                torch.cuda.empty_cache()  # Clear cache before forward pass
            
            out = model(data.x, data.edge_index)
            loss = criterion(out[data.train_mask], data.y[data.train_mask])
            loss.backward()
            optimizer.step()
            
            if device.type == 'cuda':
                torch.cuda.empty_cache()  # Clear cache after backward pass
                
            return loss
            
        except RuntimeError as e:
            if "out of memory" in str(e).lower() and not cpu_fallback_triggered:
                print(f"GPU OOM during training, switching to CPU...")
                # Move everything to CPU
                device = torch.device('cpu')
                model = model.cpu()
                data = data.cpu()
                optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
                cpu_fallback_triggered = True
                
                # Retry the forward pass on CPU
                optimizer.zero_grad()
                out = model(data.x, data.edge_index)
                loss = criterion(out[data.train_mask], data.y[data.train_mask])
                loss.backward()
                optimizer.step()
                return loss
            else:
                raise e
    
    def test(mask):
        model.eval()
        with torch.no_grad():
            if device.type == 'cuda':
                torch.cuda.empty_cache()
                
            out = model(data.x, data.edge_index)
            pred = out.argmax(dim=1)
            correct = pred[mask] == data.y[mask]
            acc = int(correct.sum()) / int(mask.sum())
            return acc
    
    print(f"Training {model_name} model...")
    best_val_acc = 0
    best_model_state = None
    
    for epoch in range(1, epochs + 1):
        loss = train()
        
        if epoch % 50 == 0 or epoch == 1:
            val_acc = test(data.val_mask)
            test_acc = test(data.test_mask)
            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_model_state = model.state_dict().copy()
    
    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    final_test_acc = test(data.test_mask)
    print(f"Final {model_name} test accuracy: {final_test_acc:.4f}")
    
    return model, data

In [None]:
# Train GAT models for both RNA and ADT with memory management
print("=== Training RNA GAT ===")
try:
    rna_gat_model, rna_data_with_masks = train_gat_model(rna_data, "RNA GAT", epochs=200)
except RuntimeError as e:
    if "out of memory" in str(e).lower():
        print("GPU memory insufficient for RNA GAT, trying CPU...")
        rna_gat_model, rna_data_with_masks = train_gat_model(rna_data, "RNA GAT", epochs=200, use_cpu_fallback=True)
    else:
        raise e

print("\n=== Training ADT GAT ===")
try:
    adt_gat_model, adt_data_with_masks = train_gat_model(adt_data, "ADT GAT", epochs=200)
except RuntimeError as e:
    if "out of memory" in str(e).lower():
        print("GPU memory insufficient for ADT GAT, trying CPU...")
        adt_gat_model, adt_data_with_masks = train_gat_model(adt_data, "ADT GAT", epochs=200, use_cpu_fallback=True)
    else:
        raise e

print("\n=== GAT Training Complete ===")
print(f"RNA GAT model trained successfully")
print(f"ADT GAT model trained successfully")

## 7. Extract GAT Embeddings

In [None]:
def extract_embeddings(model, data):
    """Extract embeddings from trained GAT model"""
    model.eval()
    
    # Ensure model and data are on the same device
    device = next(model.parameters()).device
    if data.x.device != device:
        print(f"Moving data from {data.x.device} to {device}")
        data = data.to(device)
    
    with torch.no_grad():
        # Clear cache if using GPU
        if device.type == 'cuda':
            torch.cuda.empty_cache()
            
        embeddings = model.get_embeddings(data.x, data.edge_index)
        
        # Move to CPU for further processing
        embeddings = embeddings.cpu()
        
        if device.type == 'cuda':
            torch.cuda.empty_cache()
            
    return embeddings

# Clear any existing cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# Extract embeddings
print("Extracting embeddings...")
rna_embeddings = extract_embeddings(rna_gat_model, rna_data_with_masks)
adt_embeddings = extract_embeddings(adt_gat_model, adt_data_with_masks)

print(f"RNA embeddings shape: {rna_embeddings.shape}")
print(f"ADT embeddings shape: {adt_embeddings.shape}")

# Ensure both embeddings have the same number of cells
assert rna_embeddings.shape[0] == adt_embeddings.shape[0], "Number of cells must match"

## 8. Define MLP Mapping Model

In [None]:
class MLPMapping(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims=[256, 128, 64], dropout=0.3):
        super(MLPMapping, self).__init__()
        
        layers = []
        current_dim = input_dim
        
        # Hidden layers
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(current_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            current_dim = hidden_dim
        
        # Output layer
        layers.append(nn.Linear(current_dim, output_dim))
        
        self.mlp = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.mlp(x)

# Initialize MLP mapping model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_dim = rna_embeddings.shape[1]
output_dim = adt_embeddings.shape[1]

mlp_model = MLPMapping(input_dim, output_dim, hidden_dims=[256, 128, 64]).to(device)

print(f"MLP Model: {input_dim} -> {output_dim}")
print(mlp_model)

## 9. Prepare Training Data for MLP

In [None]:
# Convert embeddings to CPU and numpy
rna_emb_np = rna_embeddings.cpu().numpy()
adt_emb_np = adt_embeddings.cpu().numpy()

# Split data for MLP training (use same train/val/test split as GAT)
train_mask_np = rna_data_with_masks.train_mask.cpu().numpy()
val_mask_np = rna_data_with_masks.val_mask.cpu().numpy()
test_mask_np = rna_data_with_masks.test_mask.cpu().numpy()

# Prepare training data
X_train = torch.tensor(rna_emb_np[train_mask_np], dtype=torch.float32)
y_train = torch.tensor(adt_emb_np[train_mask_np], dtype=torch.float32)

X_val = torch.tensor(rna_emb_np[val_mask_np], dtype=torch.float32)
y_val = torch.tensor(adt_emb_np[val_mask_np], dtype=torch.float32)

X_test = torch.tensor(rna_emb_np[test_mask_np], dtype=torch.float32)
y_test = torch.tensor(adt_emb_np[test_mask_np], dtype=torch.float32)

print(f"Training set: {X_train.shape[0]} samples")
print(f"Validation set: {X_val.shape[0]} samples")
print(f"Test set: {X_test.shape[0]} samples")

# Create data loaders
batch_size = 128
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## 10. Train MLP Mapping Model

In [None]:
# Training parameters
num_epochs = 300
learning_rate = 0.001
weight_decay = 1e-5

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(mlp_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20)

# Training loop
train_losses = []
val_losses = []
best_val_loss = float('inf')
best_model_state = None
patience = 50
patience_counter = 0

print("Training MLP mapping model...")

for epoch in range(num_epochs):
    # Training phase
    mlp_model.train()
    train_loss = 0.0
    
    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        
        optimizer.zero_grad()
        outputs = mlp_model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    train_loss /= len(train_loader)
    train_losses.append(train_loss)
    
    # Validation phase
    mlp_model.eval()
    val_loss = 0.0
    
    with torch.no_grad():
        for batch_x, batch_y in val_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            outputs = mlp_model(batch_x)
            loss = criterion(outputs, batch_y)
            val_loss += loss.item()
    
    val_loss /= len(val_loader)
    val_losses.append(val_loss)
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = mlp_model.state_dict().copy()
        patience_counter = 0
    else:
        patience_counter += 1
    
    if (epoch + 1) % 25 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}')
    
    if patience_counter >= patience:
        print(f'Early stopping at epoch {epoch+1}')
        break

# Load best model
mlp_model.load_state_dict(best_model_state)
print(f'Best validation loss: {best_val_loss:.6f}')

## 11. Evaluate MLP Model

In [None]:
# Test the model
mlp_model.eval()
test_loss = 0.0
predictions = []
ground_truth = []

with torch.no_grad():
    for batch_x, batch_y in test_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        outputs = mlp_model(batch_x)
        loss = criterion(outputs, batch_y)
        test_loss += loss.item()
        
        predictions.append(outputs.cpu().numpy())
        ground_truth.append(batch_y.cpu().numpy())

test_loss /= len(test_loader)
predictions = np.vstack(predictions)
ground_truth = np.vstack(ground_truth)

# Calculate metrics
mse = mean_squared_error(ground_truth, predictions)
r2 = r2_score(ground_truth, predictions)

# Calculate correlation per dimension
pearson_corrs = []
spearman_corrs = []

for i in range(ground_truth.shape[1]):
    pearson_r, _ = pearsonr(ground_truth[:, i], predictions[:, i])
    spearman_r, _ = spearmanr(ground_truth[:, i], predictions[:, i])
    pearson_corrs.append(pearson_r)
    spearman_corrs.append(spearman_r)

mean_pearson = np.mean(pearson_corrs)
mean_spearman = np.mean(spearman_corrs)

print(f"\n=== MLP Mapping Results ===")
print(f"Test Loss (MSE): {test_loss:.6f}")
print(f"MSE: {mse:.6f}")
print(f"R² Score: {r2:.4f}")
print(f"Mean Pearson Correlation: {mean_pearson:.4f}")
print(f"Mean Spearman Correlation: {mean_spearman:.4f}")

## 12. Visualize Results

In [None]:
# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.yscale('log')

plt.subplot(1, 2, 2)
plt.hist(pearson_corrs, bins=20, alpha=0.7, label='Pearson')
plt.hist(spearman_corrs, bins=20, alpha=0.7, label='Spearman')
plt.xlabel('Correlation')
plt.ylabel('Frequency')
plt.title('Per-dimension Correlation Distribution')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# Scatter plots for first few dimensions
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i in range(min(6, ground_truth.shape[1])):
    ax = axes[i]
    ax.scatter(ground_truth[:, i], predictions[:, i], alpha=0.6, s=1)
    
    # Add perfect prediction line
    min_val = min(ground_truth[:, i].min(), predictions[:, i].min())
    max_val = max(ground_truth[:, i].max(), predictions[:, i].max())
    ax.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8)
    
    ax.set_xlabel(f'True ADT Embedding Dim {i+1}')
    ax.set_ylabel(f'Predicted ADT Embedding Dim {i+1}')
    ax.set_title(f'Dim {i+1}: r={pearson_corrs[i]:.3f}')

plt.tight_layout()
plt.show()

## 13. Save Models and Results

In [None]:
# Save trained models
torch.save({
    'rna_gat_state_dict': rna_gat_model.state_dict(),
    'adt_gat_state_dict': adt_gat_model.state_dict(),
    'mlp_mapping_state_dict': mlp_model.state_dict(),
    'rna_input_dim': input_dim,
    'adt_output_dim': output_dim,
    'test_results': {
        'mse': mse,
        'r2': r2,
        'mean_pearson': mean_pearson,
        'mean_spearman': mean_spearman,
        'pearson_corrs': pearson_corrs,
        'spearman_corrs': spearman_corrs
    }
}, 'rna_adt_mapping_models.pth')

print("Models and results saved to 'rna_adt_mapping_models.pth'")

# Save predictions for further analysis
np.savez('mapping_predictions.npz', 
         predictions=predictions, 
         ground_truth=ground_truth,
         pearson_corrs=pearson_corrs,
         spearman_corrs=spearman_corrs)

print("Predictions saved to 'mapping_predictions.npz'")

## Summary

This notebook implements a pipeline to learn mappings between GAT embeddings of RNA and ADT data:

1. **Data Preprocessing**: Both RNA and ADT data are normalized, scaled, and processed to create neighbor graphs
2. **GAT Training**: Separate GAT models are trained on RNA and ADT data for node classification
3. **Embedding Extraction**: Intermediate embeddings are extracted from the trained GAT models
4. **MLP Mapping**: A multi-layer perceptron learns to map RNA embeddings to ADT embeddings
5. **Evaluation**: The mapping quality is assessed using MSE, R², and correlation metrics

The trained models can be used to predict ADT embeddings from RNA data, enabling cross-modal analysis and integration.