# Testing Trained Models on New Single Cell Dataset

This notebook demonstrates how to load new single cell data and test it using the trained RNA-to-ADT transformer mapping models.

## Overview
1. Load new single cell RNA data
2. Load pre-trained models
3. Preprocess new data
4. Extract embeddings and make predictions
5. Evaluate performance (if ground truth available)
6. Visualize results


## 1. Setup and Imports


In [None]:
import sys, os, importlib

# --- Autoreload ---
%load_ext autoreload
%autoreload 2

# --- Paths ---
current_dir = os.getcwd()
if 'Notebooks' in current_dir:
    parent_dir = os.path.dirname(current_dir)
    scripts_path = os.path.join(parent_dir, 'scripts')
else:
    parent_dir = current_dir
    scripts_path = os.path.join(current_dir, 'scripts')

if parent_dir not in sys.path:
    sys.path.append(parent_dir)
if scripts_path not in sys.path:
    sys.path.append(scripts_path)

print("Added to Python path:")
print(f"- Parent directory: {parent_dir}")
print(f"- Scripts directory: {scripts_path}")


In [None]:
# Import required libraries
import torch
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr, spearmanr
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import anndata as ad
from datetime import datetime
import json

# Import custom modules
import scripts.Embeddings_extract as Embeddings_extract
import scripts.GATmodel as GATmodel
import scripts.TransformerMap as TransformerMap

print("All imports successful!")


## 2. Load New Single Cell Data


In [None]:
# Load your new single cell dataset
# Replace with your actual data path
new_rna_path = "path/to/your/new_rna_data.h5ad"
new_adt_path = "path/to/your/new_adt_data.h5ad"  # Optional, for validation

# Example: Load data (uncomment and modify paths as needed)
# new_rna_data = sc.read_h5ad(new_rna_path)
# new_adt_data = sc.read_h5ad(new_adt_path) if new_adt_path else None

# Option 1: Load GSE116256 dataset (AML samples)
# Uncomment the following lines to load the real GSE116256 dataset
"""
import sys
sys.path.append('..')  # Add parent directory to path
from load_gse116256 import load_gse116256_dataset

print("Loading GSE116256 dataset...")
new_rna_data = load_gse116256_dataset(
    data_dir="/projects/vanaja_lab/satya/Datasets/GSE116256",
    output_file="GSE116256_combined.h5ad",
    force_reload=False  # Set to True to reload from raw files
)
new_adt_data = None  # No ADT data available for this dataset
print(f"GSE116256 dataset loaded: {new_rna_data.shape}")
print(f"Sample IDs: {new_rna_data.obs['sample_id'].unique()[:10]}")
"""

# Option 2: For demonstration, create sample data
print("Creating sample data for demonstration...")
np.random.seed(42)

# Create sample RNA data
n_cells = 5000
n_genes = 2000
sample_rna_data = np.random.negative_binomial(5, 0.3, (n_cells, n_genes))

# Create AnnData object
new_rna_data = ad.AnnData(X=sample_rna_data)
new_rna_data.var_names = [f"Gene_{i}" for i in range(n_genes)]
new_rna_data.obs_names = [f"Cell_{i}" for i in range(n_cells)]

# Add some metadata
new_rna_data.obs['sample_id'] = np.random.choice(['Sample_A', 'Sample_B', 'Sample_C'], n_cells)
new_rna_data.obs['cell_type'] = np.random.choice(['T_cell', 'B_cell', 'Monocyte', 'NK_cell'], n_cells)

print(f"Sample RNA data shape: {new_rna_data.shape}")
print(f"Sample metadata: {new_rna_data.obs.columns.tolist()}")

# If you have real data, uncomment the appropriate option above


## 3. Load Pre-trained Models


In [None]:
# Specify the path to your trained models
# Update this path to match your saved model checkpoint
models_dir = "trained_models"

# Find the most recent checkpoint
checkpoint_files = [f for f in os.listdir(models_dir) if f.startswith('rna_adt_transformer_models_') and f.endswith('.pth')]
if checkpoint_files:
    # Sort by timestamp and get the most recent
    checkpoint_files.sort(reverse=True)
    latest_checkpoint = checkpoint_files[0]
    checkpoint_path = os.path.join(models_dir, latest_checkpoint)
    print(f"Loading checkpoint: {latest_checkpoint}")
else:
    print("No checkpoint files found. Please train models first.")
    checkpoint_path = None

# Load the checkpoint
if checkpoint_path and os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    print("✅ Checkpoint loaded successfully")
    
    # Display checkpoint information
    print(f"\n📊 Model Information:")
    print(f"   • RNA input dimension: {checkpoint['rna_input_dim']}")
    print(f"   • ADT output dimension: {checkpoint['adt_output_dim']}")
    print(f"   • RNA classes: {checkpoint['rna_num_classes']}")
    print(f"   • ADT classes: {checkpoint['adt_num_classes']}")
    print(f"   • Training timestamp: {checkpoint['training_config']['timestamp']}")
    
    # Display performance metrics
    perf = checkpoint['performance_metrics']
    print(f"\n🎯 Training Performance:")
    print(f"   • Final test loss: {perf['final_test_loss']:.4f}")
    print(f"   • MSE: {perf['mse']:.4f}")
    print(f"   • R² Score: {perf['r2']:.4f}")
    print(f"   • Mean Pearson correlation: {perf['mean_pearson']:.4f}")
    print(f"   • Mean Spearman correlation: {perf['mean_spearman']:.4f}")
else:
    print("❌ No valid checkpoint found. Please train models first.")
    checkpoint = None


In [None]:
# Initialize models with the loaded configuration
if checkpoint is not None:
    from scripts.GATmodel import SimpleGAT
    from scripts.TransformerMap import TransformerMapping
    
    # Initialize RNA GAT model
    rna_gat_model = SimpleGAT(
        in_channels=checkpoint['rna_input_dim'],
        hidden_channels=64,
        out_channels=35,  # Fixed based on training
        heads=4,
        dropout=0.6
    )
    
    # Initialize ADT GAT model
    adt_gat_model = SimpleGAT(
        in_channels=50,
        hidden_channels=64,
        out_channels=51,  # Fixed based on training
        heads=4,
        dropout=0.6
    )
    
    # Initialize Transformer mapping model
    transformer_model = TransformerMapping(
        input_dim=checkpoint['rna_input_dim'],
        output_dim=checkpoint['adt_output_dim'],
        d_model=256,
        nhead=4,
        num_layers=3
    )
    
    # Load weights
    rna_gat_model.load_state_dict(checkpoint['rna_gat_state_dict'])
    adt_gat_model.load_state_dict(checkpoint['adt_gat_state_dict'])
    transformer_model.load_state_dict(checkpoint['transformer_mapping_state_dict'])
    
    # Set to evaluation mode
    rna_gat_model.eval()
    adt_gat_model.eval()
    transformer_model.eval()
    
    print("✅ All models loaded and set to evaluation mode")
else:
    print("❌ Cannot initialize models without checkpoint")


## 4. Preprocess New Data


In [None]:
# Preprocess new RNA data to match training data format
def preprocess_new_rna_data(adata):
    """
    Preprocess new RNA data to match the training data format
    """
    print(f"Original data shape: {adata.shape}")
    
    # Store raw data
    adata.layers["raw"] = adata.X.copy()
    
    # Basic preprocessing
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    
    # Select highly variable genes
    sc.pp.highly_variable_genes(adata, n_top_genes=2000)
    adata = adata[:, adata.var.highly_variable].copy()
    
    # Scale data
    sc.pp.scale(adata, max_value=10)
    
    # Compute PCA
    sc.tl.pca(adata, n_comps=50, svd_solver="arpack")
    
    print(f"Preprocessed data shape: {adata.shape}")
    print(f"PCA components: {adata.obsm['X_pca'].shape}")
    
    return adata

# Preprocess the new data
if checkpoint is not None:
    processed_rna_data = preprocess_new_rna_data(new_rna_data.copy())
    print("✅ New RNA data preprocessed successfully")
else:
    print("❌ Cannot preprocess data without loaded models")


## 5. Extract Embeddings and Make Predictions


In [None]:
# Build PyTorch Geometric data object and extract embeddings
def predict_adt_embeddings(rna_data, rna_gat_model, transformer_model):
    """
    Extract RNA embeddings and predict ADT embeddings
    """
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Move models to device
    rna_gat_model.to(device)
    transformer_model.to(device)
    
    # Build PyG data object
    print("Building PyTorch Geometric data object...")
    rna_pyg_data, _, config = Embeddings_extract.process_data_with_graphs(
        rna_data, 
        None,  # No ADT data for new dataset
        n_neighbors=20,
        rna_sparse_threshold=10000000,
        rna_max_edges_sparse=75
    )
    
    print(f"PyG data - Nodes: {rna_pyg_data.num_nodes}, Edges: {rna_pyg_data.num_edges}")
    
    # Move data to device
    rna_pyg_data = rna_pyg_data.to(device)
    
    # Extract RNA embeddings
    print("Extracting RNA embeddings...")
    with torch.no_grad():
        rna_embeddings = Embeddings_extract.extract_embeddings(rna_gat_model, rna_pyg_data)
    
    # Predict ADT embeddings
    print("Predicting ADT embeddings...")
    with torch.no_grad():
        predicted_adt_embeddings = transformer_model(rna_embeddings)
    
    # Move to CPU for further processing
    rna_embeddings = rna_embeddings.cpu()
    predicted_adt_embeddings = predicted_adt_embeddings.cpu()
    
    print(f"RNA embeddings shape: {rna_embeddings.shape}")
    print(f"Predicted ADT embeddings shape: {predicted_adt_embeddings.shape}")
    
    return rna_embeddings, predicted_adt_embeddings, rna_pyg_data

# Make predictions
if checkpoint is not None:
    rna_embeddings, predicted_adt_embeddings, rna_pyg_data = predict_adt_embeddings(
        processed_rna_data, rna_gat_model, transformer_model
    )
    print("✅ Predictions completed successfully")
else:
    print("❌ Cannot make predictions without loaded models")


## 6. Evaluate Performance (if ground truth available)


In [None]:
# Evaluate performance if ground truth ADT data is available
def evaluate_predictions(predicted_embeddings, true_embeddings=None, true_adt_data=None):
    """
    Evaluate prediction performance
    """
    results = {}
    
    if true_adt_data is not None:
        print("Ground truth ADT data available - computing performance metrics...")
        
        # Preprocess true ADT data
        processed_adt_data = preprocess_new_rna_data(true_adt_data.copy())  # Reuse RNA preprocessing
        
        # Extract true ADT embeddings
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        adt_gat_model.to(device)
        
        adt_pyg_data, _, _ = Embeddings_extract.process_data_with_graphs(
            None, processed_adt_data,
            n_neighbors=20,
            adt_max_edges_sparse=75
        )
        
        adt_pyg_data = adt_pyg_data.to(device)
        
        with torch.no_grad():
            true_adt_embeddings = Embeddings_extract.extract_embeddings(adt_gat_model, adt_pyg_data)
        
        true_adt_embeddings = true_adt_embeddings.cpu()
        
        # Calculate metrics
        mse = mean_squared_error(true_adt_embeddings, predicted_embeddings)
        r2 = r2_score(true_adt_embeddings, predicted_embeddings)
        
        # Calculate correlation per dimension
        correlations = []
        for i in range(true_adt_embeddings.shape[1]):
            r, _ = pearsonr(true_adt_embeddings[:, i], predicted_embeddings[:, i])
            correlations.append(r)
        
        mean_correlation = np.mean(correlations)
        median_correlation = np.median(correlations)
        
        results = {
            'mse': mse,
            'r2': r2,
            'mean_correlation': mean_correlation,
            'median_correlation': median_correlation,
            'correlations': correlations
        }
        
        print(f"\n📊 Performance Metrics:")
        print(f"   • MSE: {mse:.4f}")
        print(f"   • R² Score: {r2:.4f}")
        print(f"   • Mean Correlation: {mean_correlation:.4f}")
        print(f"   • Median Correlation: {median_correlation:.4f}")
        
    else:
        print("No ground truth ADT data available - skipping performance evaluation")
        results = {'status': 'no_ground_truth'}
    
    return results

# Evaluate predictions (uncomment if you have ground truth data)
# evaluation_results = evaluate_predictions(predicted_adt_embeddings, true_adt_data=new_adt_data)
evaluation_results = evaluate_predictions(predicted_adt_embeddings, true_adt_data=None)
print("✅ Evaluation completed")


## 7. Visualize Results


In [None]:
# Create visualizations of the predicted embeddings
def visualize_predictions(predicted_embeddings, rna_data, metadata_columns=None):
    """
    Create visualizations of predicted ADT embeddings
    """
    # Create AnnData object for predicted embeddings
    pred_adata = ad.AnnData(X=predicted_embeddings.numpy())
    pred_adata.obs = rna_data.obs.copy()
    
    # Compute neighbors and UMAP
    sc.pp.neighbors(pred_adata, n_neighbors=15, use_rep='X')
    sc.tl.umap(pred_adata)
    
    # Create plots
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # UMAP colored by sample
    if 'sample_id' in pred_adata.obs:
        sc.pl.umap(pred_adata, color='sample_id', ax=axes[0,0], show=False, title='Predicted ADT Embeddings - Sample')
    
    # UMAP colored by cell type
    if 'cell_type' in pred_adata.obs:
        sc.pl.umap(pred_adata, color='cell_type', ax=axes[0,1], show=False, title='Predicted ADT Embeddings - Cell Type')
    
    # Distribution of embedding values
    axes[1,0].hist(predicted_embeddings.numpy().flatten(), bins=50, alpha=0.7)
    axes[1,0].set_title('Distribution of Predicted Embedding Values')
    axes[1,0].set_xlabel('Embedding Value')
    axes[1,0].set_ylabel('Frequency')
    
    # Embedding dimension correlation heatmap (first 20 dimensions)
    n_dims = min(20, predicted_embeddings.shape[1])
    corr_matrix = np.corrcoef(predicted_embeddings[:, :n_dims].T)
    im = axes[1,1].imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)
    axes[1,1].set_title(f'Correlation Matrix (First {n_dims} Dimensions)')
    axes[1,1].set_xlabel('Embedding Dimension')
    axes[1,1].set_ylabel('Embedding Dimension')
    plt.colorbar(im, ax=axes[1,1])
    
    plt.tight_layout()
    plt.show()
    
    return pred_adata

# Create visualizations
if checkpoint is not None:
    pred_adata = visualize_predictions(predicted_adt_embeddings, processed_rna_data)
    print("✅ Visualizations created successfully")
else:
    print("❌ Cannot create visualizations without predictions")


## 8. Save Results


In [None]:
# Save results for further analysis
def save_test_results(predicted_embeddings, rna_data, evaluation_results=None, output_dir="test_results"):
    """
    Save test results and predictions
    """
    os.makedirs(output_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Save predicted embeddings
    pred_adata = ad.AnnData(X=predicted_embeddings.numpy())
    pred_adata.obs = rna_data.obs.copy()
    pred_adata.var_names = [f"ADT_Dim_{i}" for i in range(predicted_embeddings.shape[1])]
    
    pred_path = os.path.join(output_dir, f"predicted_adt_embeddings_{timestamp}.h5ad")
    pred_adata.write(pred_path)
    
    # Save results summary
    results_summary = {
        'test_info': {
            'timestamp': timestamp,
            'n_cells': predicted_embeddings.shape[0],
            'n_embedding_dims': predicted_embeddings.shape[1],
            'model_checkpoint': checkpoint_path if checkpoint_path else 'None'
        },
        'evaluation_results': evaluation_results if evaluation_results else {'status': 'no_evaluation'},
        'file_paths': {
            'predicted_embeddings': pred_path
        }
    }
    
    summary_path = os.path.join(output_dir, f"test_summary_{timestamp}.json")
    with open(summary_path, 'w') as f:
        json.dump(results_summary, f, indent=2)
    
    print(f"✅ Results saved to: {output_dir}")
    print(f"   📄 Predicted embeddings: {pred_path}")
    print(f"   📊 Test summary: {summary_path}")
    
    return pred_path, summary_path

# Save results
if checkpoint is not None:
    pred_path, summary_path = save_test_results(
        predicted_adt_embeddings, 
        processed_rna_data, 
        evaluation_results
    )
    print("✅ All results saved successfully")
else:
    print("❌ Cannot save results without predictions")


## 9. Summary and Next Steps


In [None]:
# Display final summary
print("\n" + "="*60)
print("🎯 TESTING COMPLETE - SUMMARY")
print("="*60)

if checkpoint is not None:
    print(f"\n📊 Dataset Information:")
    print(f"   • Total cells tested: {predicted_adt_embeddings.shape[0]:,}")
    print(f"   • Embedding dimensions: {predicted_adt_embeddings.shape[1]}")
    print(f"   • Model checkpoint: {os.path.basename(checkpoint_path)}")
    
    if evaluation_results and 'status' not in evaluation_results:
        print(f"\n🎯 Performance Metrics:")
        print(f"   • MSE: {evaluation_results['mse']:.4f}")
        print(f"   • R² Score: {evaluation_results['r2']:.4f}")
        print(f"   • Mean Correlation: {evaluation_results['mean_correlation']:.4f}")
    else:
        print(f"\n⚠️  No ground truth available for performance evaluation")
    
    print(f"\n📁 Output Files:")
    print(f"   • Predicted embeddings: {pred_path}")
    print(f"   • Test summary: {summary_path}")
    
    print(f"\n🚀 Next Steps:")
    print(f"   1. Load predicted embeddings for downstream analysis")
    print(f"   2. Perform cell type annotation using predicted embeddings")
    print(f"   3. Compare with known cell type markers")
    print(f"   4. Use for cross-modal integration studies")
    
else:
    print("\n❌ Testing failed - no valid models loaded")
    print("   Please ensure you have trained models saved in the 'trained_models' directory")

print("\n" + "="*60)


## 10. Example Usage for Future Testing

To test new datasets in the future, you can use this simplified function:


In [None]:
# Simplified function for testing new datasets
def test_new_dataset(rna_data_path, adt_data_path=None, models_dir="trained_models"):
    """
    Simplified function to test new single cell dataset
    
    Parameters:
    -----------
    rna_data_path : str
        Path to new RNA data (.h5ad file)
    adt_data_path : str, optional
        Path to ground truth ADT data for validation
    models_dir : str
        Directory containing trained models
        
    Returns:
    --------
    dict : Results dictionary with predictions and metrics
    """
    
    # Load data
    rna_data = sc.read_h5ad(rna_data_path)
    adt_data = sc.read_h5ad(adt_data_path) if adt_data_path else None
    
    # Find latest checkpoint
    checkpoint_files = [f for f in os.listdir(models_dir) if f.startswith('rna_adt_transformer_models_')]
    if not checkpoint_files:
        raise FileNotFoundError("No trained models found")
    
    checkpoint_files.sort(reverse=True)
    checkpoint_path = os.path.join(models_dir, checkpoint_files[0])
    
    # Load models
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    # Initialize and load models (same as above)
    # ... (model initialization code)
    
    # Preprocess and predict
    processed_rna = preprocess_new_rna_data(rna_data)
    rna_emb, pred_adt_emb, _ = predict_adt_embeddings(processed_rna, rna_gat_model, transformer_model)
    
    # Evaluate if ground truth available
    eval_results = evaluate_predictions(pred_adt_emb, true_adt_data=adt_data)
    
    # Save results
    pred_path, summary_path = save_test_results(pred_adt_emb, processed_rna, eval_results)
    
    return {
        'predictions': pred_adt_emb,
        'evaluation': eval_results,
        'files': {'predictions': pred_path, 'summary': summary_path}
    }

# Example usage:
# results = test_new_dataset('path/to/new_rna.h5ad', 'path/to/new_adt.h5ad')
print("✅ Simplified testing function defined")
print("   Use: results = test_new_dataset('path/to/rna_data.h5ad', 'path/to/adt_data.h5ad')")
