In [None]:
# =============================================================================
# WTKO RNA Velocity + Contrastive Learning Pipeline
# Usage Example Notebook
# 
# SETUP INSTRUCTIONS:
# 1. Update 'project_path' to point to your project directory containing models.py, trainers.py, etc.
# 2. Update 'adata_path' to point to your h5ad data file
# 3. Update 'murk_path' to point to your MURK genes CSV file (optional)
# 4. Ensure all required packages are installed (see requirements.txt)
# =============================================================================

# =============================================================================
# 1. Import Required Libraries
# =============================================================================

# --- Basic Libraries ---
import os
import numpy as np
import pandas as pd
from scipy import sparse
from scipy.stats import binned_statistic_2d
import warnings
warnings.filterwarnings('ignore')

# --- Visualization Libraries ---
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import Normalize
import seaborn as sns

# --- Machine Learning & Deep Learning ---
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim

# --- Single-cell Analysis Libraries ---
import scanpy as sc
import scvelo as scv
import anndata

# --- RNA Velocity Specialized Library ---
from velovi import VELOVI

# --- Dimensionality Reduction & Visualization ---
import umap
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

# --- Distance & Similarity Computation ---
import scipy.spatial.distance as dist
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

# --- Statistical Analysis ---
from scipy.stats import pearsonr, spearmanr
from statsmodels.stats.multitest import multipletests

# --- Progress Bar (Optional) ---
from tqdm import tqdm

# --- Settings ---
# scanpy settings
sc.settings.verbosity = 3  # verbosity level
sc.settings.set_figure_params(dpi=80, facecolor='white')

# scvelo settings  
scv.settings.verbosity = 3
scv.settings.presenter_view = True
scv.set_figure_params('scvelo')

# matplotlib settings
plt.rcParams['figure.figsize'] = (10, 8)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['legend.fontsize'] = 10

# PyTorch settings
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# --- Import Models, Trainers, and Data Loaders ---
import os
import sys
# Add directory containing .py files to Python path
# Replace 'path/to/your/project' with the actual path to your project directory
project_path = 'path/to/your/project'  # Update this path
sys.path.append(project_path)

# Import modules
from models import WTKOContrastiveVAE
from trainers import WTKOTrainer  
from data import create_wt_ko_dataloaders
from utils import plot_latent_space, plot_combined_latent_space

print("=" * 60)
print("WTKO RNA Velocity + Contrastive Learning Pipeline")
print("All libraries imported successfully")
print("=" * 60)

In [None]:
# =============================================================================
# 2. Load and Preprocess Data
# =============================================================================

# Path to the data file
# Replace with your actual data file path
adata_path = "path/to/your/data.h5ad"  # Update this path

# Load AnnData object
adata = sc.read_h5ad(adata_path)

# Display data overview
print("Data overview:")
print(adata)
print("obs columns:", adata.obs.columns.tolist())

# Define blood subclusters to extract
blood_subclusters = [
    "BP1", "BP2", "BP3", "BP4",          # Blood Progenitors
    "Ery1", "Ery2", "Ery3", "Ery4",      # Erythroid cells
    "Haem1", "Haem2", "Haem3", "Haem4",  # Hematopoietic cells
    "Mk", "My", "EC"                      # Megakaryocytes, Myeloid, Endothelial cells
]

# Filter for blood subclusters only
adata_blood = adata[adata.obs["haem_subclust_grouped"].isin(blood_subclusters)].copy()

# Check results
print(f"Number of blood subcluster cells: {adata_blood.n_obs}")
print("haem_subclust_grouped distribution after filtering:")
print(adata_blood.obs["haem_subclust_grouped"].value_counts())

# Extract KO cells (tomato positive)
adata_ko = adata_blood[adata_blood.obs["tomato"] == "pos"].copy()

# Extract WT cells (tomato negative)
adata_wt = adata_blood[adata_blood.obs["tomato"] == "neg"].copy()

# Check cell counts
print(f"KO cell count: {adata_ko.n_obs}")
print(f"WT cell count: {adata_wt.n_obs}")

# === Remove MURK genes ===
# Path to MURK genes file (mitochondrial and ribosomal genes)
# Replace with your actual MURK genes file path
murk_path = "path/to/murk_genes.csv"  # Update this path
murk_genes = pd.read_csv(murk_path, header=None)[0].tolist()

# Remove MURK genes from each group (only consider existing genes)
valid_murk_genes_ko = [g for g in murk_genes if g in adata_ko.var_names]
valid_murk_genes_wt = [g for g in murk_genes if g in adata_wt.var_names]

adata_ko = adata_ko[:, [g for g in adata_ko.var_names if g not in valid_murk_genes_ko]].copy()
adata_wt = adata_wt[:, [g for g in adata_wt.var_names if g not in valid_murk_genes_wt]].copy()

print(f"KO gene count after MURK removal: {adata_ko.n_vars}")
print(f"WT gene count after MURK removal: {adata_wt.n_vars}")

In [None]:
# =============================================================================
# 3. VELOVI Processing Functions
# =============================================================================

def apply_velovi_to_group(adata, group_name, max_epochs=10, min_r2=-10, gamma=-10):
    """
    Apply VELOVI to calculate velocity and return filtered AnnData
    
    Parameters:
    -----------
    adata : AnnData
        Input single-cell data
    group_name : str
        Name of the group (for logging)
    max_epochs : int
        Maximum training epochs for VELOVI
    min_r2 : float
        Minimum R² threshold for gene filtering
    gamma : float
        Minimum gamma threshold for gene filtering
        
    Returns:
    --------
    filtered_adata : AnnData
        Filtered data with velocity estimates
    model : VELOVI
        Trained VELOVI model
    """
    print(f"\n=== Processing {group_name} with VELOVI ===")

    # Create copy to preserve original data
    adata_copy = adata.copy()

    # scVelo preprocessing
    scv.pp.filter_and_normalize(adata_copy, min_shared_counts=20, n_top_genes=40000, enforce=True)
    scv.pp.moments(adata_copy, n_neighbors=30, n_pcs=50, method='umap')

    # Initial velocity estimation (for filtering metrics)
    scv.tl.velocity(
        adata_copy,
        mode="dynamical",
        min_r2=min_r2,
        gamma=gamma,
        filter_genes=False,
        use_highly_variable=False
    )

    # Define gene filtering conditions
    mask = (adata_copy.var.velocity_r2 > min_r2) & (adata_copy.var.velocity_gamma > gamma)

    print(f"Original gene count: {adata_copy.n_vars}")
    print(f"velocity_r2 > {min_r2}: {np.sum(adata_copy.var.velocity_r2 > min_r2)}")
    print(f"gamma > {gamma}: {np.sum(adata_copy.var.velocity_gamma > gamma)}")
    print(f"Genes meeting both criteria: {mask.sum()}")

    # Apply filtering
    filtered_adata = adata_copy[:, mask].copy()

    # VELOVI setup and training
    VELOVI.setup_anndata(filtered_adata, spliced_layer="Ms", unspliced_layer="Mu")
    model = VELOVI(filtered_adata)
    model.train(max_epochs=max_epochs)

    # Store velocity estimates
    velocity_estimates = model.get_velocity()
    filtered_adata.layers["velocity"] = velocity_estimates

    # Calculate velocity graph
    scv.tl.velocity_graph(filtered_adata)

    # Preserve obsm data (UMAP, etc.)
    for key in adata.obsm.keys():
        if key not in filtered_adata.obsm:
            filtered_adata.obsm[key] = adata.obsm[key]

    return filtered_adata, model

def align_gene_sets(adata_wt, adata_ko):
    """
    Align gene sets between WT and KO AnnData objects
    Only use genes that exist in both datasets
    
    Parameters:
    -----------
    adata_wt, adata_ko : AnnData
        WT and KO datasets
        
    Returns:
    --------
    wt_aligned, ko_aligned : AnnData
        Datasets with aligned gene sets
    """
    print("\n=== Aligning Gene Sets ===")
    
    # Get current gene sets
    wt_genes = list(adata_wt.var_names)
    ko_genes = list(adata_ko.var_names)
    
    print(f"WT gene count: {len(wt_genes)}")
    print(f"KO gene count: {len(ko_genes)}")
    
    # Identify common genes
    common_genes = list(set(wt_genes).intersection(set(ko_genes)))
    print(f"Common gene count: {len(common_genes)}")
    
    if len(common_genes) == 0:
        print("Warning: No common genes found between WT and KO")
        # Use a subset for minimal processing
        common_genes = wt_genes[:min(len(wt_genes), 100)]
    
    # Filter to common genes only
    print("Filtering to common genes...")
    wt_aligned = adata_wt[:, common_genes].copy()
    ko_aligned = adata_ko[:, common_genes].copy()
    
    print(f"Aligned WT gene count: {wt_aligned.n_vars}")
    print(f"Aligned KO gene count: {ko_aligned.n_vars}")
    
    return wt_aligned, ko_aligned

def process_and_align_data(adata_wt, adata_ko):
    """
    Main function for VELOVI application and gene set alignment
    
    Parameters:
    -----------
    adata_wt, adata_ko : AnnData
        WT and KO datasets
        
    Returns:
    --------
    adata_wt_aligned, adata_ko_aligned : AnnData
        Processed and aligned datasets
    model_wt, model_ko : VELOVI
        Trained VELOVI models
    """
    # Apply VELOVI
    print("Applying VELOVI to WT group...")
    adata_wt_filtered, model_wt = apply_velovi_to_group(adata_wt.copy(), "WT")
    
    print("Applying VELOVI to KO group...")
    adata_ko_filtered, model_ko = apply_velovi_to_group(adata_ko.copy(), "KO")
    
    # Align gene sets
    print("Aligning gene sets...")
    adata_wt_aligned, adata_ko_aligned = align_gene_sets(
        adata_wt_filtered, adata_ko_filtered
    )
    
    return adata_wt_aligned, adata_ko_aligned, model_wt, model_ko

# =============================================================================
# 4. Process Data with VELOVI
# =============================================================================

# Main processing
adata_wt_aligned, adata_ko_aligned, model_wt, model_ko = process_and_align_data(adata_wt, adata_ko)

In [None]:
# =============================================================================
# 5. Prepare Data for Contrastive VAE
# =============================================================================

def prepare_data_for_contrastive_vae(adata_wt, adata_ko):
    """
    Prepare data for VAE model
    
    Parameters:
    -----------
    adata_wt, adata_ko : AnnData
        WT and KO datasets
        
    Returns:
    --------
    dict : Data dictionary containing arrays and metadata
    """
    print("Preparing data for VAE model...")

    # Use haem_subclust_grouped as labels
    wt_labels_raw = adata_wt.obs['haem_subclust_grouped'].astype(str)
    ko_labels_raw = adata_ko.obs['haem_subclust_grouped'].astype(str)

    # Encode common labels
    cell_types = sorted(set(wt_labels_raw.unique()) | set(ko_labels_raw.unique()))
    cell_type_to_idx = {ct: i for i, ct in enumerate(cell_types)}

    # Prepare WT data
    wt_data = adata_wt.X.toarray() if sparse.issparse(adata_wt.X) else adata_wt.X
    wt_labels = np.array([cell_type_to_idx[ct] for ct in wt_labels_raw])

    # Prepare KO data
    ko_data = adata_ko.X.toarray() if sparse.issparse(adata_ko.X) else adata_ko.X
    ko_labels = np.array([cell_type_to_idx[ct] for ct in ko_labels_raw])

    # Prepare velocity data (if available)
    if 'velocity' in adata_wt.layers and 'velocity' in adata_ko.layers:
        wt_velocity = adata_wt.layers['velocity'].toarray() if sparse.issparse(adata_wt.layers['velocity']) else adata_wt.layers['velocity']
        ko_velocity = adata_ko.layers['velocity'].toarray() if sparse.issparse(adata_ko.layers['velocity']) else adata_ko.layers['velocity']
    else:
        print("Warning: velocity data not found")
        wt_velocity = None
        ko_velocity = None

    print(f"WT data shape: {wt_data.shape}, KO data shape: {ko_data.shape}")
    print(f"Common cell types: {set([cell_types[i] for i in set(wt_labels) & set(ko_labels)])}")

    return {
        'wt_data': wt_data,
        'wt_labels': wt_labels,
        'ko_data': ko_data,
        'ko_labels': ko_labels,
        'wt_velocity': wt_velocity,
        'ko_velocity': ko_velocity,
        'cell_types': cell_types,
        'cell_type_to_idx': cell_type_to_idx
    }

# Prepare data
data_dict = prepare_data_for_contrastive_vae(adata_wt_aligned, adata_ko_aligned)

# =============================================================================
# 6. Create Data Loaders
# =============================================================================

def create_dataloaders(data_dict, batch_size=64):
    """
    Create PyTorch DataLoader objects
    
    Parameters:
    -----------
    data_dict : dict
        Data dictionary from prepare_data_for_contrastive_vae
    batch_size : int
        Batch size for training
        
    Returns:
    --------
    wt_loader, ko_loader : DataLoader
        PyTorch data loaders for WT and KO data
    """
    # Convert NumPy arrays to Tensors
    wt_data = torch.tensor(data_dict['wt_data'], dtype=torch.float32)
    wt_labels = torch.tensor(data_dict['wt_labels'], dtype=torch.long)
    ko_data = torch.tensor(data_dict['ko_data'], dtype=torch.float32)
    ko_labels = torch.tensor(data_dict['ko_labels'], dtype=torch.long)
    
    # Create datasets
    wt_dataset = TensorDataset(wt_data, wt_labels)
    ko_dataset = TensorDataset(ko_data, ko_labels)
    
    # Create data loaders
    wt_loader = DataLoader(
        wt_dataset, 
        batch_size=batch_size,
        shuffle=True,
        drop_last=False
    )
    
    ko_loader = DataLoader(
        ko_dataset, 
        batch_size=batch_size,
        shuffle=True,
        drop_last=False
    )
    
    return wt_loader, ko_loader

# Set batch size (adjust based on dataset size)
batch_size = 256
wt_loader, ko_loader = create_dataloaders(data_dict, batch_size=batch_size)


In [None]:
# =============================================================================
# 7. Initialize and Train Model
# =============================================================================

# Model parameters
input_dim = data_dict['wt_data'].shape[1]  # Number of genes
latent_dim = 10  # Latent space dimensions (adjustable)
hidden_dims = (256, 128, 64)  # Hidden layer dimensions (adjustable)

# Initialize model
model = WTKOContrastiveVAE(
    input_dim=input_dim,
    latent_dim=latent_dim, 
    hidden_dims=hidden_dims,
    tau=0.3,  # Temperature parameter
    lambda_contrast=10,  # Contrastive loss weight
    lambda_align=10,  # Cluster alignment loss weight
    dropout_prob=0.2,  # Dropout probability
    norm_type='batch'  # Normalization type
)

# Initialize trainer
trainer = WTKOTrainer(model)

# Train model
print("Starting model training...")
history = trainer.train(
    wt_loader=wt_loader,
    ko_loader=ko_loader,
    num_epochs=400,  # Number of epochs
    lr=1e-3,  # Learning rate
    weight_decay=1e-3,  # Weight decay
    save_path='./models/wtko_vae',  # Model save path (adjust as needed)
    verbose=True
)

In [None]:
# =============================================================================
# 8. Extract Latent Representations and Visualize
# =============================================================================

# Get latent representations from trained model
wt_latent, wt_labels = trainer.get_latent_representations(wt_loader)
ko_latent, ko_labels = trainer.get_latent_representations(ko_loader)

# Convert to numpy arrays
wt_latent_np = wt_latent.cpu().numpy()
ko_latent_np = ko_latent.cpu().numpy()
wt_labels_np = wt_labels.cpu().numpy()
ko_labels_np = ko_labels.cpu().numpy()

# Project to 2D using UMAP
reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)

# Combine WT and KO data for projection
combined_latent = np.vstack([wt_latent_np, ko_latent_np])
combined_embedding = reducer.fit_transform(combined_latent)

# Split projection results
wt_embedding = combined_embedding[:len(wt_latent_np)]
ko_embedding = combined_embedding[len(wt_latent_np):]

# Cell type names list
cell_type_names = data_dict['cell_types']

# =============================================================================
# 9. Visualization Functions and Plots
# =============================================================================

def plot_latent_space(wt_embedding, ko_embedding, wt_labels, ko_labels, cell_type_names):
    """
    Visualize latent space projections
    
    Parameters:
    -----------
    wt_embedding, ko_embedding : ndarray
        UMAP embeddings for WT and KO cells
    wt_labels, ko_labels : ndarray
        Cell type labels
    cell_type_names : list
        List of cell type names
    """
    # Set up color map
    n_cell_types = len(cell_type_names)
    colors = plt.cm.tab10(np.linspace(0, 1, n_cell_types))
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8))
    
    # Visualize WT cells
    for i, ct in enumerate(range(n_cell_types)):
        mask = wt_labels == ct
        if mask.sum() > 0:
            ax1.scatter(
                wt_embedding[mask, 0], 
                wt_embedding[mask, 1],
                c=[colors[i]],
                label=cell_type_names[ct],
                s=50,
                alpha=0.7
            )
    
    ax1.set_title('WT Cells - Latent Space (UMAP)', fontsize=14)
    ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax1.grid(True, linestyle='--', alpha=0.7)
    
    # Visualize KO cells
    for i, ct in enumerate(range(n_cell_types)):
        mask = ko_labels == ct
        if mask.sum() > 0:
            ax2.scatter(
                ko_embedding[mask, 0], 
                ko_embedding[mask, 1],
                c=[colors[i]],
                label=cell_type_names[ct],
                s=50,
                alpha=0.7
            )
    
    ax2.set_title('KO Cells - Latent Space (UMAP)', fontsize=14)
    ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax2.grid(True, linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    plt.show()

def plot_combined_latent_space(wt_embedding, ko_embedding, wt_labels, ko_labels, cell_type_names):
    """
    Visualize WT and KO cells in combined space (color=cell type, marker=WT/KO)
    
    Parameters:
    -----------
    wt_embedding, ko_embedding : ndarray
        UMAP embeddings for WT and KO cells
    wt_labels, ko_labels : ndarray
        Cell type labels
    cell_type_names : list
        List of cell type names
    """
    n_cell_types = len(cell_type_names)
    colors = plt.cm.tab10(np.linspace(0, 1, n_cell_types))

    # Combine data
    combined_embedding = np.vstack([wt_embedding, ko_embedding])
    combined_labels = np.concatenate([wt_labels, ko_labels])
    combined_group = np.array(['WT'] * len(wt_labels) + ['KO'] * len(ko_labels))
    
    # Plot
    plt.figure(figsize=(10, 8))
    
    for i in range(n_cell_types):
        for group, marker in zip(['WT', 'KO'], ['o', 's']):
            mask = (combined_labels == i) & (combined_group == group)
            if mask.sum() > 0:
                plt.scatter(
                    combined_embedding[mask, 0],
                    combined_embedding[mask, 1],
                    c=[colors[i]],
                    label=f"{cell_type_names[i]} ({group})",
                    s=50,
                    alpha=0.7,
                    marker=marker,
                    edgecolor='k' if group == 'KO' else 'none',
                    linewidths=0.5
                )
    
    plt.title('WT and KO Cells - Combined Latent Space (UMAP)', fontsize=14)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', ncol=1)
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.show()

In [None]:
# Generate visualizations
print("Generating latent space visualizations...")

# Separate WT/KO visualization
plot_latent_space(wt_embedding, ko_embedding, wt_labels_np, ko_labels_np, cell_type_names)

# Combined WT/KO visualization
plot_combined_latent_space(
    wt_embedding=wt_embedding,
    ko_embedding=ko_embedding,
    wt_labels=wt_labels_np,
    ko_labels=ko_labels_np,
    cell_type_names=cell_type_names
)

print("=" * 60)
print("Pipeline execution completed successfully!")
print("=" * 60)