## Centroid rank accuracy (Norman) — all genes

This notebook computes the **centroid rank accuracy** metric across OOD perturbations for Norman, using **all genes**.

### Metric definition
For each perturbation \(g\) and method \(m\), let \(\mu_g\) be the ground-truth mean expression vector (centroid) of stimulated cells for perturbation \(g\), and \(\hat\mu_{g,m}\) be the method’s predicted mean.

Define Euclidean distances to all GT centroids:

\[
 d_{g\to h}^{(m)} = \lVert \hat\mu_{g,m} - \mu_h \rVert_2
\]

Then the **centroid rank accuracy** for \(g\) is:

\[
 \mathrm{CRA}_g^{(m)} = \frac{1}{K-1}\sum_{h\neq g} \mathbf{1}\left[d_{g\to h}^{(m)} > d_{g\to g}^{(m)}\right]
\]

This is in \([0,1]\): 1.0 means the prediction for \(g\) is closest to \(\mu_g\) among all \(\{\mu_h\}\).

### Notes
- Uses **all genes** (fixed gene panel), so the metric is well-defined across perturbations.
- By default, GT centroids are computed from **OOD split only** (matching your evaluation protocol).


In [1]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
import scanpy as sc
from scipy.spatial.distance import cdist


def _to_dense(X):
    if isinstance(X, np.ndarray):
        return X
    return X.toarray()


def centroid_rank_accuracy(true_means: pd.DataFrame, pred_means: pd.DataFrame) -> pd.Series:
    """Compute CRA per perturbation.

    true_means: index=pert, columns=genes
    pred_means: index=pert, columns=genes (same order)

    Returns: pd.Series index=pert with values in [0,1].
    """
    if list(true_means.columns) != list(pred_means.columns):
        raise ValueError("Gene order mismatch between true_means and pred_means")
    if list(true_means.index) != list(pred_means.index):
        raise ValueError("Perturbation order mismatch between true_means and pred_means")

    D = cdist(pred_means.to_numpy(), true_means.to_numpy(), metric="euclidean")  # (K,K)
    self_d = np.diag(D)
    K = D.shape[0]
    if K < 2:
        raise ValueError("Need at least 2 perturbations to compute CRA")

    # strict > : ties count as 0 (matches your original code)
    score = (D > self_d[:, None]).sum(axis=1) / (K - 1)
    return pd.Series(score, index=true_means.index, name="CRA")


def list_available_perts(pred_dir: Path, method: str) -> list[str]:
    """Return perts inferred from filenames in a method prediction folder."""
    files = sorted(pred_dir.glob("*.h5ad"))
    perts = []
    for fp in files:
        stem = fp.stem
        if method.upper() == "GEARS":
            perts.append(stem.replace("_", "+"))
        else:
            perts.append(stem)
    return perts


def load_pred_mean_from_file(fp: Path, pert: str, method: str) -> np.ndarray:
    """Load a single prediction file and return the predicted mean vector for `pert`.

    Handles GEARS mean stored in `.uns[...]` when available; otherwise uses mean of `.X`.
    """
    ad = sc.read_h5ad(fp)

    ad = ad[ad.obs['cond_harm_pred'] == pert]

    # fallback: mean across all cells in ad.X
    return _to_dense(ad.X).mean(axis=0).reshape(-1)



In [2]:
# --- Configuration ---
# Choose scenario: 'combinatorially_seen' or 'single_only'
scenario = 'combinatorially_seen'

# Paths (relative to repo root)
adata_path = Path('../../Datasets/preprocessed_datasets/norman.h5ad')

pred_dirs = {
    'scDisentangle': Path(f'../../Benchmarks/SCDISENTANGLE/Norman/predictions/{scenario}'),
    'CPA': Path(f'../../Benchmarks/CPA/Norman/predictions/{scenario}'),
    'GEARS': Path(f'../../Benchmarks/GEARS/Norman/predictions/{scenario}'),
}

# Which obs fields to use for GT
cond_key = 'condition'   # contains 'A+B' perturbation names

# NOTE: `Datasets/preprocessed_datasets/norman.h5ad` has no split column by default.
# If you have a split column in a different AnnData, set split_key accordingly.
split_key = None         # e.g., 'split'
use_ood_only_for_gt = False

# Output
out_csv = Path(f'centroid_rank_accuracy__{scenario}__all_genes.csv')

# Methods to include (keys must exist in pred_dirs)
methods = list(pred_dirs.keys())
methods



['scDisentangle', 'CPA', 'GEARS']

In [3]:
# --- Discover perturbations available across all methods ---
available = {}
for m, d in pred_dirs.items():
    if not d.is_dir():
        raise FileNotFoundError(f"Missing prediction directory for {m}: {d.resolve()}")
    available[m] = set(list_available_perts(d, m))

common_perts = sorted(set.intersection(*[available[m] for m in methods]))
print('Common perts across methods:', len(common_perts))
print('Example:', common_perts[:10])

# --- Load and normalize ground truth AnnData (match evaluation normalization) ---
adata = sc.read_h5ad(adata_path)

# normalize_total + log1p to match your compute_metrics_norman pipeline
sc.pp.normalize_total(adata, target_sum=adata.uns['single_perts_median'])
sc.pp.log1p(adata)

# --- Compute GT centroids for each pert (all genes) ---
true_means = []
for pert in common_perts:
    mask = (adata.obs[cond_key] == pert)

    # Optional: if you have a split column and want OOD-only GT.
    if use_ood_only_for_gt:
        if not split_key:
            raise ValueError("use_ood_only_for_gt=True requires setting split_key to a valid adata.obs column")
        if split_key not in adata.obs:
            raise KeyError(f"split_key '{split_key}' not found in adata.obs")
        mask = mask & (adata.obs[split_key] == 'ood')

    ad_true = adata[mask]
    if ad_true.n_obs == 0:
        raise ValueError(f"No GT cells found for pert={pert} (use_ood_only_for_gt={use_ood_only_for_gt})")

    mu = _to_dense(ad_true.X).mean(axis=0).reshape(-1)
    true_means.append(mu)

true_means = pd.DataFrame(true_means, index=common_perts, columns=adata.var_names)
true_means.shape



Common perts across methods: 128
Example: ['AHR+FEV', 'AHR+KLF1', 'BCL2L11+BAK1', 'BCL2L11+TGFBR2', 'BPGM+SAMD1', 'BPGM+ZBTB1', 'CBL+CNN1', 'CBL+PTPN12', 'CBL+PTPN9', 'CBL+TGFBR2']


adata.X seems to be already log-transformed.


(128, 5446)

In [None]:
# --- Load predicted centroids for each method ---
pred_means_by_method = {}

for method in methods:
    pred_dir = pred_dirs[method]
    mus = []
    for pert in common_perts:
        fname = pert if method.upper() != 'GEARS' else pert.replace('+', '_')
        fp = pred_dir / f"{fname}.h5ad"
        if not fp.is_file():
            raise FileNotFoundError(f"Missing prediction file: {fp}")

        mu_hat = load_pred_mean_from_file(fp, pert=pert, method=method)
        mus.append(mu_hat)

    pred_means = pd.DataFrame(mus, index=common_perts, columns=true_means.columns)
    pred_means_by_method[method] = pred_means

# --- Compute CRA per pert and summarize ---
scores = {}
for method, pred_means in pred_means_by_method.items():
    scores[method] = centroid_rank_accuracy(true_means, pred_means)

scores_df = pd.DataFrame(scores)
scores_df_percent = scores_df * 100.0

display(scores_df.head())
print('Mean CRA (percent):')
print(scores_df_percent.mean(axis=0).sort_values(ascending=False))

# Save
scores_df_percent.to_csv(out_csv)
print('Saved:', out_csv.resolve())

