In [None]:
import sys
sys.path.insert(0, "/gpfs0/bgu-ofircohen/users/likhtepi/proj/scvi-tools/")

In [11]:
import scanpy as sc
import scvelo as scv
import numpy as np
from pathlib import Path
from scvi.external.velovi import VELOVI
from scvi.experimental.velovi_improvements.transformer_velocity import (
    TransformerConfig,
    refine_velocities_with_transformer,
    _filter_neighbors_by_distance,
    _VelocitySequenceDataset,
    VelocityTransformer
)
import torch

In [3]:
# --------------------------
# 1. Paths / configuration
# --------------------------
DATA_PATH = Path("/gpfs0/bgu-ofircohen/users/likhtepi/CellRanger/files_adata/merged_2311_labeled.h5ad")
GROUP_KEY = "cluster"        # obs column with values {"H","S","R"}
COLOR_KEY = "celltype"       # obs column for coloring stream plots

BASELINE_CKPT = Path(
    "/gpfs0/bgu-ofircohen/users/likhtepi/proj/scvi-tools/checkpoints/"
    "adata_combined_nopre_nh256_nl1_nz10_bs256_ep400_encmlp_baseline"
)
TRANSFORMER_PT = Path(
    "/gpfs0/bgu-ofircohen/users/likhtepi/proj/scvi-tools/results/"
    "velovi_adata_combined_basic_10/velovi_adata_combined_transformer.pt"
)

In [4]:
transformer_cfg = TransformerConfig(
    n_layers=3,
    n_heads=8,
    hidden_dim=256,
    dropout=0.1,
    batch_size=128,
    epochs=20,
    learning_rate=1e-3,
    weight_alignment=1.0,
    weight_smooth=0.02,
    weight_smooth_same=0.02,
    weight_boundary_align=0.2,
    weight_boundary_contrast=0.05,
    weight_direction=0.7,
    weight_celltype=0.5,
    weight_celltype_dir=0.35,
    weight_celltype_mag=0.0,
    aux_cluster_loss_weight=0.2,
    neighbor_max_distance=4.0,
    max_neighbors=15,
    residual_to_baseline=True,
)

In [5]:
# --------------------------
# 2. Load AnnData + subsets
# --------------------------
adata_full = sc.read_h5ad(DATA_PATH)
assert "X_umap" in adata_full.obsm_keys(), "The AnnData must already contain X_umap."

subsets = {
    "all": adata_full,
    # "H": adata_full[adata_full.obs[GROUP_KEY] == "H"].copy(),
    # "S": adata_full[adata_full.obs[GROUP_KEY] == "S"].copy(),
    # "R": adata_full[adata_full.obs[GROUP_KEY] == "R"].copy(),
}


In [6]:
def load_pretrained_baseline(adata):
    """Load pretrained VELOVI baseline and infer baseline velocity + latent."""
    VELOVI.setup_anndata(
        adata,
        spliced_layer="Ms",
        unspliced_layer="Mu",
    )
    model = VELOVI.load(BASELINE_CKPT, adata=adata)
    velocity = model.get_velocity(adata=adata, return_numpy=True)
    latent = model.get_latent_representation(adata=adata)
    return model, velocity, latent

In [7]:
def build_neighbor_indices(adata, k=30):
    """Use existing connectivities on the subset; no UMAP recomputation."""
    graph = adata.obsp["connectivities"].tocsr()
    n = adata.n_obs
    idx = np.zeros((n, k), dtype=np.int64)
    for i in range(n):
        row = graph[i].toarray().ravel()
        row[i] = 0.0
        if np.allclose(row, 0):
            idx[i] = i
        else:
            idx[i] = np.argsort(-row)[:k]
    return idx


In [8]:
def apply_pretrained_transformer(adata, latent, baseline_velocity, cfg, state_path):
    """Run inference through the pretrained transformer."""
    embedding = np.asarray(adata.obsm["X_umap"], dtype=np.float32)
    neighbor_indices = build_neighbor_indices(adata, k=cfg.max_neighbors or 30)
    if cfg.neighbor_max_distance is not None:
        neighbor_indices = _filter_neighbors_by_distance(
            neighbor_indices,
            embedding,
            cfg.neighbor_max_distance,
        )

    dataset = _VelocitySequenceDataset(
        latent=latent,
        embedding=embedding,
        baseline_velocity=baseline_velocity,
        neighbor_indices=neighbor_indices,
        velocity_components=None,
        projection=None,
        config=cfg,
        cell_type_ids=None,
        type_means=None,
        cluster_labels=None,
        cluster_edge_list=None,
        pseudotime=None,
        alignment_vectors=None,
        supervised_target=None,
        supervised_weight=None,
    )
    seq_len = dataset.neighbor_indices.shape[1] + 1
    input_dim = dataset.feature_matrix.shape[1]
    output_dim = baseline_velocity.shape[1]
    direction_dim = dataset.embedding.shape[1]
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = VelocityTransformer(
        input_dim=input_dim,
        output_dim=output_dim,
        direction_dim=direction_dim,
        config=cfg,
        seq_len=seq_len,
        num_clusters=dataset.n_clusters,
    ).to(device)
    state = torch.load(state_path, map_location=device)
    model.load_state_dict(state, strict=False)
    model.eval()

    preds = []
    loader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, shuffle=False)
    with torch.no_grad():
        for batch in loader:
            feats = batch["features"].to(device)
            seq_vel = batch["sequence_velocity"].to(device)
            seq_dir = batch["sequence_direction"].to(device)
            token_type = batch["token_type"].to(device)
            pred, _, _ = model(feats, seq_vel, seq_dir, token_type)
            preds.append(pred.cpu().numpy())
    refined = np.vstack(preds).astype(np.float32)
    return refined

In [9]:
def plot_baseline_vs_transformer(adata, color_key, title):
    fig, axes = plt.subplots(1, 2, figsize=(10, 4), dpi=150)
    scv.tl.velocity_graph(adata, vkey="velocity_baseline", n_jobs=32, mode_neighbors="distances")
    scv.pl.velocity_embedding_stream(
        adata,
        basis="umap",
        vkey="velocity_baseline",
        color=color_key,
        colorbar=True,
        ax=axes[0],
        title="Baseline",
        show=False,
    )
    scv.tl.velocity_graph(adata, vkey="velocity_transformer", n_jobs=32, mode_neighbors="distances")
    scv.pl.velocity_embedding_stream(
        adata,
        basis="umap",
        vkey="velocity_transformer",
        color=color_key,
        colorbar=True,
        ax=axes[1],
        title="Transformer refiner",
        show=False,
    )
    fig.suptitle(title, fontsize=13)
    plt.tight_layout()
    plt.show()

In [12]:
trained = {}
for label, subset in subsets.items():
    print(f"\n=== {label.upper()} subset ({subset.n_obs} cells) ===")
    _, vel_base, latent = load_pretrained_baseline(subset)
    subset.layers["velocity_baseline"] = vel_base
    vel_trans = apply_pretrained_transformer(subset, latent, vel_base, transformer_cfg, TRANSFORMER_PT)
    subset.layers["velocity_transformer"] = vel_trans
    trained[label] = subset



=== ALL subset (30802 cells) ===
[34mINFO    [0m File                                                                                                      
         [35m/gpfs0/bgu-ofircohen/users/likhtepi/proj/scvi-tools/checkpoints/adata_combined_nopre_nh256_nl1_nz10_bs256_[0m
         [35mep400_encmlp_baseline/[0m[95mmodel.pt[0m already downloaded                                                         


In [13]:
import pandas as pd

results = []

for subset_label, subset in trained.items():
    # Iterate through cells
    for cell_idx in range(subset.n_obs):
        cluster = subset.obs.iloc[cell_idx]["cluster"]
        celltype = subset.obs.iloc[cell_idx]["celltype"]
        
        # Get velocity values for this cell across all genes
        vel_base = subset.layers["velocity_baseline"][cell_idx, :]
        vel_trans = subset.layers["velocity_transformer"][cell_idx, :]
        
        # Create a row for each gene
        for gene_idx, gene_name in enumerate(subset.var_names):
            results.append({
                "subset_type": cluster,
                "celltype": celltype,
                "gene_name": gene_name,
                "velocity_baseline": vel_base[gene_idx],
                "velocity_transformer": vel_trans[gene_idx]
            })

# Create DataFrame and save to CSV
df = pd.DataFrame(results)
df.to_csv("velocity_results.csv", index=False)

print(f"Saved {len(df)} rows to velocity_results.csv")

Saved 61604000 rows to velocity_results.csv


In [31]:
import pandas as pd
import gseapy as gp
import matplotlib.pyplot as plt
import numpy as np

# ============= LOAD AND FILTER DATA =============

print("Loading data from trained subsets...")

# Create results dataframe with ALL genes
results = []

for subset_label, subset in trained.items():
    for cell_idx in range(subset.n_obs):
        cluster = subset.obs.iloc[cell_idx]["cluster"]
        celltype = subset.obs.iloc[cell_idx]["celltype"]
        
        vel_base = subset.layers["velocity_baseline"][cell_idx, :]
        vel_trans = subset.layers["velocity_transformer"][cell_idx, :]
        
        for gene_idx, gene_name in enumerate(subset.var_names):
            results.append({
                "subset_type": cluster,
                "celltype": celltype,
                "gene_name": gene_name,
                "velocity_baseline": vel_base[gene_idx],
                "velocity_transformer": vel_trans[gene_idx],
            })

df = pd.DataFrame(results)

# Filter to specific celltypes
celltypes_of_interest = ["aHSC", "dHSC", "cyclHSC"]
df = df[df["celltype"].isin(celltypes_of_interest)]

print(f"Filtered to celltypes: {celltypes_of_interest}")
print(f"Total cells after filtering: {len(df)}\n")

# Get unique subset types
subset_types = sorted(df["subset_type"].unique())
print(f"Available subset types: {subset_types}\n")

# ============= AGGREGATE BY GENE FOR EACH SUBSET TYPE =============

subset_data = {}
for subset_type in subset_types:
    subset_df = df[df["subset_type"] == subset_type].groupby("gene_name").agg({
        "velocity_baseline": "mean",
        "velocity_transformer": "mean",
    }).reset_index()
    subset_data[subset_type] = subset_df
    print(f"{subset_type}: {len(subset_df)} genes")

# ============= ENRICHMENT FUNCTION =============

def create_nes_like_score(p_val, overlap_size, total_size):
    """Calculate a normalized enrichment score-like metric from enrichment results."""
    return overlap_size / total_size * -np.log10(p_val)

def run_enrichment(gene_velocity_df, subset_name, method_name, comparison_name, organism="Mouse"):
    """
    Run pathway enrichment using WikiPathways and Reactome.
    Returns dataframe with results.
    """
    # Sort genes by absolute velocity score (most significant first)
    ranked_genes = gene_velocity_df.reindex(
        gene_velocity_df["score"].abs().sort_values(ascending=False).index
    )
    
    print(f"  {subset_name} ({method_name}): {len(ranked_genes)} genes")
    
    try:
        # Run enrichment on WikiPathways and Reactome (Mouse databases)
        results = gp.enrichr(
            gene_list=ranked_genes["gene_name"].tolist(),
            organism=organism,
            gene_sets=["WikiPathways_2024_Mouse", "Reactome_2022"]
        )
        
        if len(results.results) == 0:
            print(f"    No pathways found")
            return None
        
        # Process results
        plot_df = results.results.copy()
        
        # Calculate NES (Normalized Enrichment Score)
        plot_df["NES"] = plot_df.apply(
            lambda row: create_nes_like_score(
                float(row["Adjusted P-value"]), 
                int(row["Overlap"].split("/")[0]),
                int(row["Overlap"].split("/")[1])
            ), 
            axis=1
        )
        
        # Extract p-adjusted value and overlap size
        plot_df["padj"] = plot_df["Adjusted P-value"].astype(float)
        plot_df["size"] = plot_df["Overlap"].apply(lambda x: int(x.split("/")[0]))
        
        plot_df = plot_df.sort_values("NES")
        
        # Save enrichment results
        fname = f"enrichment_{comparison_name}_{subset_name}_{method_name}.csv"
        plot_df[["Term", "NES", "padj", "size", "Overlap"]].to_csv(fname, index=False)
        print(f"    ✓ Saved {fname} ({len(plot_df)} pathways)")
        
        # Create scatter plot
        fig, ax = plt.subplots(figsize=(11, 10))
        
        scatter = ax.scatter(
            plot_df["NES"], 
            range(len(plot_df)), 
            c=-np.log10(plot_df["padj"]), 
            s=plot_df["size"] * 5,
            cmap="RdYlBu_r", 
            alpha=0.7, 
            edgecolors='black', 
            linewidth=0.5
        )
        
        ax.set_yticks(range(len(plot_df)))
        ax.set_yticklabels([name[:50] for name in plot_df["Term"]], fontsize=8)
        ax.set_xlabel("NES", fontsize=11, fontweight='bold')
        ax.set_ylabel("Pathway", fontsize=11, fontweight='bold')
        ax.set_title(f"{comparison_name.upper()} - {subset_name} ({method_name})", 
                    fontsize=12, fontweight='bold')
        ax.axvline(x=0, color='black', linestyle='--', alpha=0.5, linewidth=1)
        
        cbar = plt.colorbar(scatter, ax=ax)
        cbar.set_label("-log10(padj)", fontsize=10, fontweight='bold')
        ax.grid(True, alpha=0.3, axis='x')
        
        plt.tight_layout()
        pngname = f"dotplot_{comparison_name}_{subset_name}_{method_name}.png"
        plt.savefig(pngname, dpi=300, bbox_inches='tight')
        plt.close()
        
        return plot_df
        
    except Exception as e:
        print(f"    ✗ Error: {e}")
        return None

# ============= RUN ALL COMPARISONS =============

comparisons = [("S", "R"), ("S", "H"), ("R", "H")]

for subset1, subset2 in comparisons:
    comparison_name = f"{subset1.lower()}vs{subset2.lower()}"
    
    print(f"\n{'='*70}")
    print(f"COMPARISON: {subset1} vs {subset2}")
    print(f"{'='*70}")
    
    df1 = subset_data[subset1]
    df2 = subset_data[subset2]
    
    # ---- BASELINE METHOD ----
    print(f"\nBaseline velocity:")
    
    baseline_1 = df1[["gene_name", "velocity_baseline"]].copy()
    baseline_1.columns = ["gene_name", "score"]
    run_enrichment(baseline_1, subset1.lower(), "baseline", comparison_name)
    
    baseline_2 = df2[["gene_name", "velocity_baseline"]].copy()
    baseline_2.columns = ["gene_name", "score"]
    run_enrichment(baseline_2, subset2.lower(), "baseline", comparison_name)
    
    # ---- TRANSFORMER METHOD ----
    print(f"\nTransformer velocity:")
    
    transformer_1 = df1[["gene_name", "velocity_transformer"]].copy()
    transformer_1.columns = ["gene_name", "score"]
    run_enrichment(transformer_1, subset1.lower(), "transformer", comparison_name)
    
    transformer_2 = df2[["gene_name", "velocity_transformer"]].copy()
    transformer_2.columns = ["gene_name", "score"]
    run_enrichment(transformer_2, subset2.lower(), "transformer", comparison_name)

# ============= FILTER TO CHOSEN PATHWAYS =============

chosen_pathways = {
    "svsr": ["WP_TYPE_II_INTERFERON_SIGNALING_IFNG", "REACTOME_DNA_STRAND_ELONGATION", 
             "REACTOME_SCF_SKP2_MEDIATED_DEGRADATION_OF_P27_P21", "REACTOME_ANTIGEN_PROCESSING_CROSS_PRESENTATION",
             "REACTOME_DNA_REPLICATION", "REACTOME_LAGGING_STRAND_SYNTHESIS", "REACTOME_G1_S_TRANSITION",
             "REACTOME_G2_M_CHECKPOINTS", "REACTOME_PYROPTOSIS"],
    "svsh": ["WP_TYPE_II_INTERFERON_SIGNALING_IFNG", "REACTOME_DNA_STRAND_ELONGATION",
             "REACTOME_SCF_SKP2_MEDIATED_DEGRADATION_OF_P27_P21", "REACTOME_ANTIGEN_PROCESSING_CROSS_PRESENTATION",
             "REACTOME_DNA_REPLICATION", "REACTOME_LAGGING_STRAND_SYNTHESIS", "REACTOME_G1_S_TRANSITION",
             "REACTOME_G2_M_CHECKPOINTS", "REACTOME_PYROPTOSIS"],
    "rvsh": ["WP_TYPE_II_INTERFERON_SIGNALING_IFNG", "REACTOME_GENERATION_OF_SECOND_MESSENGER_MOLECULES",
             "REACTOME_PHOSPHORYLATION_OF_CD3_AND_TCR_ZETA_CHAINS", "REACTOME_TRANSLOCATION_OF_ZAP_70_TO_IMMUNOLOGICAL_SYNAPSE",
             "REACTOME_MHC_CLASS_II_ANTIGEN_PRESENTATION", "REACTOME_MITOCHONDRIAL_TRANSLATION",
             "BIOCARTA_MONOCYTE_PATHWAY", "REACTOME_PD_1_SIGNALING"]
}

print(f"\n\n{'='*70}")
print("FILTERING TO CHOSEN PATHWAYS")
print(f"{'='*70}\n")

def filter_and_plot_chosen(comparison_name, chosen_list, subset1, subset2):
    """Filter enrichment results to chosen pathways and create comparison plots."""
    print(f"\n{comparison_name.upper()}:")
    
    all_results = []
    
    for subset in [subset1.lower(), subset2.lower()]:
        for method in ["baseline", "transformer"]:
            fname = f"enrichment_{comparison_name}_{subset}_{method}.csv"
            try:
                df = pd.read_csv(fname)
                print(f"\n  {subset.upper()} - {method}:")
                print(f"    Total pathways: {len(df)}")
                
                # Try to match pathways
                filtered = df[df["Term"].str.contains("|".join(chosen_list), case=False, na=False)]
                
                if len(filtered) > 0:
                    filtered["subset"] = subset.upper()
                    filtered["method"] = method
                    all_results.append(filtered)
                    print(f"    ✓ Matched {len(filtered)} chosen pathways")
                else:
                    print(f"    ✗ No matches (showing first 5 available pathways):")
                    for idx, term in enumerate(df["Term"].head(5)):
                        print(f"       - {term}")
                    
            except FileNotFoundError:
                print(f"  {subset.upper()} - {method}: File not found")
    
    if all_results:
        combined = pd.concat(all_results, ignore_index=True)
        combined = combined.sort_values("NES")
        combined.to_csv(f"chosen_pathways_{comparison_name}.csv", index=False)
        
        # Create 4-panel comparison plot
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        fig.suptitle(f"Chosen Pathways: {comparison_name.upper()}", fontsize=14, fontweight='bold')
        
        for row, subset in enumerate([subset1.lower(), subset2.lower()]):
            for col, method in enumerate(["baseline", "transformer"]):
                ax = axes[row, col]
                subset_data_plot = combined[(combined["subset"] == subset.upper()) & (combined["method"] == method)]
                
                if len(subset_data_plot) > 0:
                    subset_data_plot = subset_data_plot.sort_values("NES")
                    
                    scatter = ax.scatter(
                        subset_data_plot["NES"], 
                        range(len(subset_data_plot)), 
                        c=-np.log10(subset_data_plot["padj"]), 
                        s=subset_data_plot["size"] * 8,
                        cmap="RdYlBu_r", 
                        alpha=0.7, 
                        edgecolors='black', 
                        linewidth=1
                    )
                    
                    ax.set_yticks(range(len(subset_data_plot)))
                    ax.set_yticklabels([name[:45] for name in subset_data_plot["Term"]], fontsize=9)
                    ax.set_xlabel("NES", fontsize=10, fontweight='bold')
                    ax.set_title(f"{subset.upper()} - {method}", fontsize=11, fontweight='bold')
                    ax.axvline(x=0, color='black', linestyle='--', alpha=0.5, linewidth=1)
                    ax.grid(True, alpha=0.3, axis='x')
                    
                    cbar = plt.colorbar(scatter, ax=ax)
                    cbar.set_label("-log10(padj)", fontsize=9)
        
        plt.tight_layout()
        plt.savefig(f"chosen_pathways_{comparison_name}_comparison.png", dpi=300, bbox_inches='tight')
        plt.close()
        print(f"\n  ✓ Saved chosen_pathways_{comparison_name}.csv")
        print(f"  ✓ Saved chosen_pathways_{comparison_name}_comparison.png")
    else:
        print(f"\n  ✗ No matching pathways found")

# Run filtering for all comparisons
filter_and_plot_chosen("svsr", chosen_pathways["svsr"], "S", "R")
filter_and_plot_chosen("svsh", chosen_pathways["svsh"], "S", "H")
filter_and_plot_chosen("rvsh", chosen_pathways["rvsh"], "R", "H")

print(f"\n\n{'='*70}")
print("✓ ANALYSIS COMPLETE!")
print(f"{'='*70}")
print("\nOutput files:")
print("  - enrichment_*.csv: All pathway enrichment results")
print("  - dotplot_*.png: Individual enrichment scatter plots")
print("  - chosen_pathways_*.csv: Filtered to your chosen pathways")
print("  - chosen_pathways_*_comparison.png: 2x2 comparison plots")

Loading data from trained subsets...
Filtered to celltypes: ['aHSC', 'dHSC', 'cyclHSC']
Total cells after filtering: 6754000

Available subset types: ['H', 'R', 'S']

H: 2000 genes
R: 2000 genes
S: 2000 genes

COMPARISON: S vs R

Baseline velocity:
  s (baseline): 2000 genes
    ✓ Saved enrichment_svsr_s_baseline.csv (2970 pathways)
  r (baseline): 2000 genes
    ✓ Saved enrichment_svsr_r_baseline.csv (2970 pathways)

Transformer velocity:
  s (transformer): 2000 genes
    ✓ Saved enrichment_svsr_s_transformer.csv (2970 pathways)
  r (transformer): 2000 genes
    ✓ Saved enrichment_svsr_r_transformer.csv (2970 pathways)

COMPARISON: S vs H

Baseline velocity:
  s (baseline): 2000 genes
    ✓ Saved enrichment_svsh_s_baseline.csv (2970 pathways)
  h (baseline): 2000 genes
    ✓ Saved enrichment_svsh_h_baseline.csv (2970 pathways)

Transformer velocity:
  s (transformer): 2000 genes
    ✓ Saved enrichment_svsh_s_transformer.csv (2970 pathways)
  h (transformer): 2000 genes
    ✓ Saved enr

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

print("="*70)
print("RANK IMPROVEMENT: BASELINE vs TRANSFORMER")
print("="*70 + "\n")

# Load data
df_baseline = pd.read_csv("/gpfs0/bgu-ofircohen/users/likhtepi/proj/scvi-tools/results/top_100_genes_per_cluster_baseline.csv")
df_transformer = pd.read_csv("/gpfs0/bgu-ofircohen/users/likhtepi/proj/scvi-tools/results/top_100_genes_per_cluster_transformer.csv")

print(f"Baseline data shape: {df_baseline.shape}")
print(f"Transformer data shape: {df_transformer.shape}")
print(f"Columns: {df_baseline.columns.tolist()}\n")

# Filter to HSCs and subset_type=='all'
celltypes = ['aHSC', 'dHSC', 'cyclHSC']
df_baseline = df_baseline[(df_baseline['celltype'].isin(celltypes)) & (df_baseline['subset_type'] == 'all')]
df_transformer = df_transformer[(df_transformer['celltype'].isin(celltypes)) & (df_transformer['subset_type'] == 'all')]

print(f"After filtering to {celltypes} with subset_type=='all':")
print(f"  Baseline: {len(df_baseline)} entries")
print(f"  Transformer: {len(df_transformer)} entries\n")

# Define genes of interest
genes_of_interest = {
    "Myeloid Priming": ["S100a9", "S100a8"],
    "Atypical Effectors": ["Nkg7", "Rhoh"],
    "Novel Cycling": ["Top2a", "Prtn3"],
    "Housekeeping": ["Malat1", "Tmsb4x", "Actb"],
    "Myeloid Priming": ["Elane"]
}

all_target_genes = [g for genes in genes_of_interest.values() for g in genes]

print(f"Tracking {len(all_target_genes)} genes of interest:\n")
for category, genes in genes_of_interest.items():
    print(f"  {category}: {genes}")

# ============= RANK GENES BY VELOCITY =============

# Aggregate velocity by gene (mean across cells)
baseline_agg = df_baseline.groupby("gene_name")["velocity"].mean().reset_index()
baseline_agg.columns = ["gene", "velocity"]
baseline_agg["rank_baseline"] = baseline_agg["velocity"].rank(ascending=False, method='min')

transformer_agg = df_transformer.groupby("gene_name")["velocity"].mean().reset_index()
transformer_agg.columns = ["gene", "velocity"]
transformer_agg["rank_transformer"] = transformer_agg["velocity"].rank(ascending=False, method='min')

print(f"\nTotal genes ranked:")
print(f"  Baseline: {len(baseline_agg)}")
print(f"  Transformer: {len(transformer_agg)}\n")

# ============= EXTRACT TARGET GENES =============

comparison_data = []

for gene in all_target_genes:
    rank_base = baseline_agg[baseline_agg["gene"] == gene]["rank_baseline"].values
    rank_trans = transformer_agg[transformer_agg["gene"] == gene]["rank_transformer"].values
    
    vel_base = baseline_agg[baseline_agg["gene"] == gene]["velocity"].values
    vel_trans = transformer_agg[transformer_agg["gene"] == gene]["velocity"].values
    
    rank_base = rank_base[0] if len(rank_base) > 0 else np.nan
    rank_trans = rank_trans[0] if len(rank_trans) > 0 else np.nan
    vel_base = vel_base[0] if len(vel_base) > 0 else np.nan
    vel_trans = vel_trans[0] if len(vel_trans) > 0 else np.nan
    
    category = [cat for cat, genes in genes_of_interest.items() if gene in genes][0]
    
    comparison_data.append({
        "gene": gene,
        "category": category,
        "rank_baseline": rank_base,
        "rank_transformer": rank_trans,
        "velocity_baseline": vel_base,
        "velocity_transformer": vel_trans,
        "rank_change": rank_base - rank_trans if not np.isnan(rank_base) and not np.isnan(rank_trans) else np.nan
    })

df_comparison = pd.DataFrame(comparison_data)
df_comparison = df_comparison.sort_values("rank_baseline")

print(df_comparison[["gene", "category", "rank_baseline", "rank_transformer", "rank_change"]].to_string(index=False))
print()

# Save comparison
df_comparison.to_csv("rank_improvement_baseline_vs_transformer.csv", index=False)
print(f"✓ Saved: rank_improvement_baseline_vs_transformer.csv\n")

# ============= VISUALIZATIONS =============

print("Generating visualizations...\n")

# Color map for categories
category_colors = {
    "Myeloid Priming": "#e74c3c",
    "Atypical Effectors": "#3498db", 
    "Novel Cycling": "#2ecc71",
    "Housekeeping": "#95a5a6"
}

# 1. Grouped bar plot
fig, ax = plt.subplots(figsize=(14, 8))

genes = df_comparison["gene"].values
x = np.arange(len(genes))
width = 0.35

gene_colors = [category_colors[cat] for cat in df_comparison["category"].values]

bars1 = ax.bar(x - width/2, df_comparison["rank_baseline"].values, width, 
              label='Baseline VeloVI', color=gene_colors, alpha=0.7, edgecolor='black', linewidth=1.5)
bars2 = ax.bar(x + width/2, df_comparison["rank_transformer"].values, width,
              label='Transformer VeloVI', color=gene_colors, alpha=1.0, edgecolor='black', linewidth=1.5, hatch='//')

# Add rank numbers
for bar1, bar2 in zip(bars1, bars2):
    h1 = bar1.get_height()
    h2 = bar2.get_height()
    if not np.isnan(h1):
        ax.text(bar1.get_x() + bar1.get_width()/2., h1 + 1,
               f'{int(h1)}', ha='center', va='bottom', fontsize=9, fontweight='bold')
    if not np.isnan(h2):
        ax.text(bar2.get_x() + bar2.get_width()/2., h2 + 1,
               f'{int(h2)}', ha='center', va='bottom', fontsize=9, fontweight='bold')

ax.set_xlabel("Gene", fontsize=12, fontweight='bold')
ax.set_ylabel("Rank", fontsize=12, fontweight='bold')
ax.set_title("Gene Rank Comparison: Baseline vs Transformer VeloVI\n(HSCs: aHSC + dHSC + cyclHSC, subset_type='all')", 
            fontsize=13, fontweight='bold', pad=20)
ax.set_xticks(x)
ax.set_xticklabels(genes, fontsize=11, fontweight='bold', rotation=45, ha='right')
ax.invert_yaxis()
ax.legend(fontsize=11, loc='lower left')
ax.grid(True, alpha=0.3, axis='y', linestyle='--')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.savefig("rank_comparison_barplot.png", dpi=300, bbox_inches='tight')
plt.close()
print("✓ Saved: rank_comparison_barplot.png")

# 2. Scatter plot: Baseline vs Transformer rank
fig, ax = plt.subplots(figsize=(10, 8))

for category in genes_of_interest.keys():
    cat_data = df_comparison[df_comparison["category"] == category]
    ax.scatter(cat_data["rank_baseline"], cat_data["rank_transformer"],
              s=300, color=category_colors[category], alpha=0.7, edgecolors='black', linewidth=1.5,
              label=category)

# Diagonal line (no change)
max_rank = max(df_comparison["rank_baseline"].max(), df_comparison["rank_transformer"].max())
ax.plot([0, max_rank], [0, max_rank], 'k--', alpha=0.5, linewidth=2, label='No change')

# Add gene labels
for idx, row in df_comparison.iterrows():
    ax.annotate(row["gene"], (row["rank_baseline"], row["rank_transformer"]),
               fontsize=9, ha='center', va='bottom')

ax.set_xlabel("Rank (Baseline VeloVI)", fontsize=12, fontweight='bold')
ax.set_ylabel("Rank (Transformer VeloVI)", fontsize=12, fontweight='bold')
ax.set_title("Baseline vs Transformer Ranks\n(Points below diagonal = Promoted)", 
            fontsize=13, fontweight='bold')
ax.invert_yaxis()
ax.legend(fontsize=10, loc='upper left')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("scatter_rank_baseline_vs_transformer.png", dpi=300, bbox_inches='tight')
plt.close()
print("✓ Saved: scatter_rank_baseline_vs_transformer.png")

# 3. Rank change bar plot
fig, ax = plt.subplots(figsize=(10, 8))

df_sort = df_comparison.sort_values("rank_change", ascending=False)
colors = ['green' if x > 0 else 'red' for x in df_sort["rank_change"].fillna(0)]

bars = ax.barh(range(len(df_sort)), df_sort["rank_change"].fillna(0), 
              color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)

ax.set_yticks(range(len(df_sort)))
ax.set_yticklabels(df_sort["gene"].values, fontsize=11, fontweight='bold')
ax.set_xlabel("Rank Change (Baseline - Transformer)", fontsize=12, fontweight='bold')
ax.set_title("Rank Improvements: Baseline → Transformer\n(Positive = Promoted, Negative = Demoted)", 
            fontsize=13, fontweight='bold')
ax.axvline(x=0, color='black', linestyle='-', linewidth=1.5)
ax.grid(True, alpha=0.3, axis='x')

# Add values
for i, (bar, val) in enumerate(zip(bars, df_sort["rank_change"].fillna(0).values)):
    ax.text(val, i, f' {int(val)}', va='center', 
           ha='left' if val > 0 else 'right', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig("barplot_rank_changes.png", dpi=300, bbox_inches='tight')
plt.close()
print("✓ Saved: barplot_rank_changes.png")

# ============= SUMMARY =============

print("\n" + "="*70)
print("SUMMARY BY CATEGORY")
print("="*70 + "\n")

for category in genes_of_interest.keys():
    cat_data = df_comparison[df_comparison["category"] == category]
    promoted = (cat_data['rank_change'] > 0).sum()
    demoted = (cat_data['rank_change'] < 0).sum()
    print(f"{category}:")
    print(f"  Mean rank change: {cat_data['rank_change'].mean():.1f}")
    print(f"  Promoted: {promoted}/{len(cat_data)}")
    print(f"  Demoted: {demoted}/{len(cat_data)}")
    print()

print("="*70)
print("✓ ANALYSIS COMPLETE!")
print("="*70)

RANK IMPROVEMENT: BASELINE vs TRANSFORMER

Baseline data shape: (6400, 5)
Transformer data shape: (6400, 5)
Columns: ['Unnamed: 0', 'subset_type', 'celltype', 'gene_name', 'velocity']

After filtering to ['aHSC', 'dHSC', 'cyclHSC'] with subset_type=='all':
  Baseline: 200 entries
  Transformer: 200 entries

Tracking 8 genes of interest:

  Myeloid Priming: ['Elane']
  Atypical Effectors: ['Nkg7', 'Rhoh']
  Novel Cycling: ['Top2a', 'Prtn3']
  Housekeeping: ['Malat1', 'Tmsb4x', 'Actb']

Total genes ranked:
  Baseline: 106
  Transformer: 107

  gene           category  rank_baseline  rank_transformer  rank_change
 Prtn3      Novel Cycling            2.0               2.0          0.0
Malat1       Housekeeping            3.0               6.0         -3.0
Tmsb4x       Housekeeping            4.0               3.0          1.0
  Actb       Housekeeping           12.0              13.0         -1.0
 Elane    Myeloid Priming           53.0              31.0         22.0
  Rhoh Atypical Effect

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

print("="*70)
print("FINDING TOP 100 DIFFERENTIALLY EXPRESSED GENES")
print("(Baseline VeloVI vs Transformer VeloVI)")
print("="*70 + "\n")

# Collect all velocity data
all_data = []

for subset_label, subset in trained.items():
    for cell_idx in range(subset.n_obs):
        vel_base = subset.layers["velocity_baseline"][cell_idx, :]
        vel_trans = subset.layers["velocity_transformer"][cell_idx, :]
        
        for gene_idx, gene_name in enumerate(subset.var_names):
            all_data.append({
                "gene": gene_name,
                "velocity_baseline": vel_base[gene_idx],
                "velocity_transformer": vel_trans[gene_idx]
            })

df = pd.DataFrame(all_data)

# Aggregate by gene (mean velocity across all cells)
df_agg = df.groupby("gene")[["velocity_baseline", "velocity_transformer"]].mean().reset_index()

# Calculate differential expression metrics
df_agg["velocity_delta"] = df_agg["velocity_transformer"] - df_agg["velocity_baseline"]
df_agg["abs_delta"] = np.abs(df_agg["velocity_delta"])
df_agg["fold_change"] = df_agg["velocity_transformer"] / (np.abs(df_agg["velocity_baseline"]) + 1e-6)

# Get top 100 by absolute change
top100 = df_agg.nlargest(100, "abs_delta").copy()
top100 = top100.sort_values("velocity_delta", ascending=False).reset_index(drop=True)
top100["rank"] = range(1, len(top100) + 1)

print(f"Total genes analyzed: {len(df_agg)}")
print(f"Top 100 genes with largest velocity changes:\n")
print(top100[["rank", "gene", "velocity_baseline", "velocity_transformer", "velocity_delta"]].to_string())

# Save results
top100.to_csv("top100_differentially_expressed_genes.csv", index=False)
print(f"\n✓ Saved: top100_differentially_expressed_genes.csv")

# ============= VISUALIZATIONS =============

# 1. Scatter plot: Baseline vs Transformer
print("\nGenerating visualizations...")

fig, ax = plt.subplots(figsize=(12, 10))

# Plot all genes (light gray)
scatter_all = ax.scatter(df_agg["velocity_baseline"], df_agg["velocity_transformer"], 
                         alpha=0.1, s=20, color='gray', label=f'All genes (n={len(df_agg)})')

# Highlight top 100
scatter_top = ax.scatter(top100["velocity_baseline"], top100["velocity_transformer"],
                        alpha=0.7, s=100, color='red', edgecolors='black', linewidth=1, 
                        label='Top 100 DE genes')

# Diagonal line (no change)
min_val = min(df_agg["velocity_baseline"].min(), df_agg["velocity_transformer"].min())
max_val = max(df_agg["velocity_baseline"].max(), df_agg["velocity_transformer"].max())
ax.plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.5, linewidth=2, label='No change')

ax.set_xlabel("Velocity (Baseline VeloVI)", fontsize=12, fontweight='bold')
ax.set_ylabel("Velocity (Transformer VeloVI)", fontsize=12, fontweight='bold')
ax.set_title("Top 100 Differentially Expressed Genes\nBaseline vs Transformer VeloVI", 
            fontsize=13, fontweight='bold')
ax.legend(fontsize=10, loc='upper left')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("scatter_baseline_vs_transformer_top100.png", dpi=300, bbox_inches='tight')
plt.close()
print("✓ Saved: scatter_baseline_vs_transformer_top100.png")

# 2. Volcano plot: Velocity Delta vs Magnitude
fig, ax = plt.subplots(figsize=(12, 10))

# Calculate magnitude
df_agg["magnitude_transformer"] = df_agg["velocity_transformer"] ** 2
df_agg["magnitude_baseline"] = df_agg["velocity_baseline"] ** 2
df_agg["log_magnitude_ratio"] = np.log2((df_agg["magnitude_transformer"] + 1e-6) / 
                                       (df_agg["magnitude_baseline"] + 1e-6))

# Scatter plot
scatter_all = ax.scatter(df_agg["velocity_delta"], df_agg["log_magnitude_ratio"],
                        alpha=0.2, s=30, color='gray', label=f'All genes (n={len(df_agg)})')

scatter_top = ax.scatter(top100["velocity_delta"], 
                        np.log2((top100["velocity_transformer"]**2 + 1e-6) / 
                               (top100["velocity_baseline"]**2 + 1e-6)),
                        alpha=0.7, s=100, color='red', edgecolors='black', linewidth=1,
                        label='Top 100 DE genes')

# Add reference lines
ax.axvline(x=0, color='black', linestyle='--', alpha=0.5, linewidth=1.5)
ax.axhline(y=0, color='black', linestyle='--', alpha=0.5, linewidth=1.5)

ax.set_xlabel("Velocity Delta (Transformer - Baseline)", fontsize=12, fontweight='bold')
ax.set_ylabel("Log2(Magnitude Ratio: Transformer/Baseline)", fontsize=12, fontweight='bold')
ax.set_title("Volcano-like Plot: Velocity Changes in Top 100 Genes", fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("volcano_top100_de_genes.png", dpi=300, bbox_inches='tight')
plt.close()
print("✓ Saved: volcano_top100_de_genes.png")

# 3. Bar plot: Top 20 genes by velocity delta
top20 = top100.head(20).copy()

fig, ax = plt.subplots(figsize=(12, 8))

colors = ['green' if x > 0 else 'red' for x in top20["velocity_delta"]]
bars = ax.barh(range(len(top20)), top20["velocity_delta"], color=colors, alpha=0.7, edgecolor='black', linewidth=1)

ax.set_yticks(range(len(top20)))
ax.set_yticklabels(top20["gene"].values, fontsize=10, fontweight='bold')
ax.set_xlabel("Velocity Delta (Transformer - Baseline)", fontsize=12, fontweight='bold')
ax.set_title("Top 20 Differentially Expressed Genes by Velocity Change", fontsize=13, fontweight='bold')
ax.axvline(x=0, color='black', linestyle='-', linewidth=1.5)
ax.grid(True, alpha=0.3, axis='x')

# Add values on bars
for i, (bar, val) in enumerate(zip(bars, top20["velocity_delta"].values)):
    ax.text(val, i, f' {val:.4f}', va='center', ha='left' if val > 0 else 'right', fontsize=9, fontweight='bold')

plt.tight_layout()
plt.savefig("barplot_top20_velocity_delta.png", dpi=300, bbox_inches='tight')
plt.close()
print("✓ Saved: barplot_top20_velocity_delta.png")

# ============= SUMMARY STATISTICS =============

print("\n" + "="*70)
print("SUMMARY STATISTICS (Top 100 Genes)")
print("="*70 + "\n")

print(f"Mean velocity change: {top100['velocity_delta'].mean():.6f}")
print(f"Median velocity change: {top100['velocity_delta'].median():.6f}")
print(f"Std deviation: {top100['velocity_delta'].std():.6f}")
print(f"\nGenes promoted (Transformer > Baseline): {(top100['velocity_delta'] > 0).sum()}")
print(f"Genes demoted (Transformer < Baseline): {(top100['velocity_delta'] < 0).sum()}")

print(f"\nTop 10 promoted genes:")
print(top100[top100["velocity_delta"] > 0][["rank", "gene", "velocity_baseline", "velocity_transformer", "velocity_delta"]].head(10).to_string(index=False))

print(f"\nTop 10 demoted genes:")
print(top100[top100["velocity_delta"] < 0][["rank", "gene", "velocity_baseline", "velocity_transformer", "velocity_delta"]].tail(10).to_string(index=False))

print(f"\n✓ ANALYSIS COMPLETE!")

FINDING TOP 100 DIFFERENTIALLY EXPRESSED GENES
(Baseline VeloVI vs Transformer VeloVI)



In [78]:
for g1, g2 in pairs:
    res = pre.res2d
    sig = res[res["FDR q-val"] < 0.1].sort_values("NES", ascending=False)
    sig.to_csv(f"gsea_{g1}_vs_{g2}_mice/significant_fdr1.csv", index=False)

In [79]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def plot_gsea_dotplot(group, csv_path, top_n=15, fdr_cutoff=0.05):
    df = pd.read_csv(csv_path)
    df = df[df["FDR q-val"] < fdr_cutoff].copy()
    if df.empty:
        print(f"{group}: no terms passed FDR<{fdr_cutoff}")
        return

    df = df.sort_values("NES", ascending=False).head(top_n)
    df["-log10FDR"] = -np.log10(df["FDR q-val"].clip(lower=1e-300))
    top_term = df.iloc[0]["Term"]

    plt.figure(figsize=(6, 0.35 * len(df) + 1))
    ax = sns.scatterplot(
        data=df,
        x="NES",
        y="Term",
        size="-log10FDR",
        hue="NES",
        palette="coolwarm",
        sizes=(40, 220),
        edgecolor="none",
        legend="brief",
    )

    # Emphasize the highest NES term
    top_df = df[df["Term"] == top_term]
    ax.scatter(
        top_df["NES"],
        top_df["Term"],
        s=260,
        facecolors="none",
        edgecolors="black",
        linewidths=1.5,
        zorder=10,
        label="Top NES",
    )

    ax.set_title(f"GSEA DotPlot – {group}")
    ax.set_xlabel("NES")
    ax.set_ylabel("")
    plt.tight_layout()
    plt.show()


