
# Single-cell Integration Benchmark
**Methods:** DESC / fastMNN / Harmony / iMAP / LIGER / Scanorama / Seurat / Uncorrect  
**Metrics:** ARI, NMI, ASW_celltype, ASW_batch, iLISI, KL divergence

- **ARI / NMI**：sklearn（以 KMeans 聚类与真值 cell type 对齐）  
- **ASW_celltype / ASW_batch**：silhouette；其中 ASW_batch 使用 `(1 - silhouette)/2 ∈ [0,1]`（越大越混合）  
- **iLISI**：优先使用 R 包 `lisi::compute_lisi`（需 rpy2 + R）；否则回退为 Python 近似（邻域批次比例熵的指数）  
- **KL**：本地邻域批次分布相对**全局**批次分布的 D_KL 均值（越小越好）

> 你只需提供各方法的**嵌入矩阵**（与细胞顺序一致），Notebook 会统一评测与作图。未提供的会跳过。`Uncorrect` 若留空，会用 Scanpy 预处理得到 50 PCs 作为基线。


In [1]:

import os, json, numpy as np, pandas as pd
import scanpy as sc
from anndata import AnnData
import matplotlib.pyplot as plt
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

# iLISI (R) 
R_AVAILABLE = False
try:
    import rpy2.robjects as ro
    from rpy2.robjects import pandas2ri
    from rpy2.robjects.packages import importr
    pandas2ri.activate()
    lisi_pkg = importr("lisi")
    R_AVAILABLE = True
except Exception:
    print("[Info] rpy2 或 R 包 lisi 不可用：iLISI 将使用 Python 近似。")

import matplotlib
matplotlib.rcParams['figure.dpi'] = 120
sc.settings.verbosity = 2


[Info] rpy2 或 R 包 lisi 不可用：iLISI 将使用 Python 近似。


## Parameters（请在此修改）

In [2]:

# === 数据输入（两种方式选其一）===
# 方式1：已合并的 .h5ad（需包含 obs['cell_type'], obs['batch']）
# H5AD = "bct_raw.h5ad" 
H5AD = "mural_raw.h5ad" 
# H5AD = "macaque_raw.h5ad" 
MULTI_H5AD = None
BATCH_KEY  = None
# 方式2：多个 h5ad + 批次列名（若设置 BATCH_KEY，将忽略 H5AD）
# MULTI_H5AD = ["neurips2021_s1d3.h5ad",
#               "neurips2021_s2d1.h5ad",
#               "neurips2021_s3d7.h5ad"]  # 例如 ["s1d3.h5ad", "s2d1.h5ad", "s3d7.h5ad"]
# BATCH_KEY  = "batch"  # 例如 "batch"
# "inputs": [
#     "mural_raw.h5ad",
# ],
# "batch_key": "batch",           
# "cell_type": "cell_type1",  

# obs 列名
# CELLTYPE_COL = "celltype"
# BATCH_COL    = "BATCH"
# CELLTYPE_COL = "cell_type"
# BATCH_COL    = "batch"
CELLTYPE_COL = "cell_type1"
BATCH_COL    = "batch"
# 各方法嵌入文件（行顺序必须与 adata 一致）。支持 .npy/.npz/.csv/.tsv/.txt
EMBEDDING_FILES = {
    "Uncorrect": "embeddings/uncorrect.npy",      # 若 None，将自动计算 50 PCs
    "DESC":      None,
    "fastMNN":   None,
    "Harmony":   "embeddings/harmony.npy",
    "iMAP":      None,
    "LIGER":     None,
    "Scanorama": "embeddings/scanorama.npy",
    "Seurat":    "embeddings/seurat.npy",
    "MyModel":   "outputs/embeddings/MyModel.npy"
    # "scDML":     "embeddings/scdml.npy"
}

SCDML_CELLTYPE = "embeddings/labels_celltype.npy"
SCDML_REASSIGN = "embeddings/labels_reassign.npy"

# 预处理/评测参数
N_TOP_GENES = 2000
N_PCS_BASE  = 50    # Uncorrect 的 PCA 维度
N_NEIGHBORS = 15
K_ILISI     = 90
K_KL        = 50
UMAP_SEED   = 0
OUTDIR      = "benchmark_out"
os.makedirs(OUTDIR, exist_ok=True)

print("OUTPUT DIR:", OUTDIR)


OUTPUT DIR: benchmark_out


## 数据读取与 Uncorrect 基线（50 PCs）

In [3]:

def _load_embedding_file(path: str):
    path = str(path)
    ext = os.path.splitext(path)[1].lower()
    if ext == ".npy":
        return np.load(path)
    if ext == ".npz":
        data = np.load(path)
        if "X" in data: return data["X"]
        for k in data.files: return data[k]
        raise ValueError("npz 文件中未找到数组")
    if ext in [".csv", ".tsv", ".txt"]:
        sep = "," if ext == ".csv" else None
        df = pd.read_csv(path, sep=sep, engine="python", header=None)
        return df.values
    raise ValueError(f"不支持的嵌入文件格式: {ext}")

# 载入 AnnData
if MULTI_H5AD and BATCH_KEY:
    adata_list = [sc.read_h5ad(p) for p in MULTI_H5AD]
    for ad in adata_list:
        if BATCH_KEY not in ad.obs:
            raise ValueError(f"'{BATCH_KEY}' 不在某个 h5ad 的 obs 中")
    # 用 scanpy 合并，并把批次写在 BATCH_COL
    adata = sc.concat(adata_list, join="outer", label=BATCH_COL,
                      keys=[os.path.splitext(os.path.basename(p))[0] for p in MULTI_H5AD])
else:
    if not H5AD:
        raise ValueError("请设置 H5AD，或使用 MULTI_H5AD + BATCH_KEY")
    adata = sc.read_h5ad(H5AD)

# 提取标签
cell_types = adata.obs[CELLTYPE_COL].astype(str).values if CELLTYPE_COL in adata.obs else None
batch_labels = adata.obs[BATCH_COL].astype(str).values if BATCH_COL in adata.obs else None
print("N cells:", adata.n_obs, "| N genes:", adata.n_vars)

# 计算 Uncorrect 基线：50 PCs
adata_unc = adata.copy()
sc.pp.filter_cells(adata_unc, min_genes=200)
sc.pp.filter_genes(adata_unc, min_cells=3)
sc.pp.normalize_total(adata_unc, target_sum=1e4)
sc.pp.log1p(adata_unc)
sc.pp.highly_variable_genes(adata_unc, n_top_genes=N_TOP_GENES, subset=True)
sc.pp.scale(adata_unc, max_value=10)
sc.tl.pca(adata_unc, n_comps=N_PCS_BASE, svd_solver="arpack")
X_uncorrect = adata_unc.obsm["X_pca"]
print("Uncorrect embedding:", X_uncorrect.shape)

# === 关键：将评测基准切换为“过滤后”的细胞，并与 adata_unc 顺序完全一致 ===
# 保存原始细胞顺序（9288）
obs_before = pd.Index(adata.obs_names)

# 过滤后细胞名（顺序与 adata_unc 一致）
cells_new = adata_unc.obs_names.to_numpy()

# 计算过滤后细胞在原始顺序里的位置索引
idx = obs_before.get_indexer(cells_new)   # shape=(8573,)
if (idx < 0).any():
    missing = cells_new[idx < 0]
    raise ValueError(f"在原始 adata 中找不到这些细胞：{missing[:5]} ... 共 {len(missing)} 个")

# 2) 用这个集合对子集化原始 adata，并按 adata_unc 的顺序重排
adata = adata[cells_new].copy()   # 先按名称取交集
adata = adata[adata_unc.obs_names].copy()  # 再确保顺序一致（保险起见）

# 3) 重新提取标签，并把 n_cells 更新为过滤后的数量
cell_types = adata.obs[CELLTYPE_COL].astype(str).values if CELLTYPE_COL in adata.obs else None
batch_labels = adata.obs[BATCH_COL].astype(str).values if BATCH_COL in adata.obs else None
n_cells = adata.n_obs
print("[Info] 评测将基于过滤后细胞数:", n_cells)



N cells: 30302 | N genes: 36162
filtered out 11558 genes that are detected in less than 3 cells
normalizing counts per cell
    finished (0:00:00)
extracting highly variable genes
    finished (0:00:00)


  return dispatch(args[0].__class__)(*args, **kw)


computing PCA
    with n_comps=50
    finished (0:00:34)
Uncorrect embedding: (30302, 50)
[Info] 评测将基于过滤后细胞数: 30302


## 指标函数

In [4]:

def asw_celltype(X, y):
    if y is None or len(np.unique(y)) < 2: return np.nan
    return float(silhouette_score(X, y, metric="euclidean"))

def asw_batch(X, b):
    if b is None or len(np.unique(b)) < 2: return np.nan
    sil = silhouette_score(X, b, metric="euclidean")
    return float((1.0 - sil)/2.0)

def ilisi_python(X, b, k=90):
    if b is None: return np.nan, np.full(X.shape[0], np.nan)
    b = np.asarray(b); n = X.shape[0]; B = len(np.unique(b))
    k_eff = min(k, n)
    nn = NearestNeighbors(n_neighbors=k_eff).fit(X)
    nbrs = nn.kneighbors(return_distance=False)
    _, b_int = np.unique(b, return_inverse=True)
    local_eff = np.zeros(n)
    for i in range(n):
        neigh = nbrs[i]
        counts = np.bincount(b_int[neigh], minlength=B).astype(float)
        p = counts / counts.sum(); p = np.clip(p, 1e-12, 1.0)
        H = -(p * np.log(p)).sum()
        local_eff[i] = np.exp(H)
    return float(local_eff.mean()), local_eff

def ilisi_r(X, b, perplexity=30):
    emb_df = pd.DataFrame(X, columns=[f"dim{i+1}" for i in range(X.shape[1])])
    meta_df = pd.DataFrame({"batch": b})
    emb_r  = ro.conversion.py2rpy(emb_df)
    meta_r = ro.conversion.py2rpy(meta_df)
    res = lisi_pkg.compute_lisi(emb_r, meta_r, ro.StrVector(["batch"]), perplexity=perplexity)
    lis = np.array(res.rx2("batch")).reshape(-1)
    return float(lis.mean()), lis

def local_kl(X, b, k=50):
    if b is None: return np.nan, np.full(X.shape[0], np.nan)
    b = np.asarray(b)
    batches, b_int = np.unique(b, return_inverse=True)
    B = len(batches)
    Pg = np.bincount(b_int, minlength=B).astype(float); Pg /= Pg.sum()
    Pg = np.clip(Pg, 1e-12, 1.0)
    k_eff = min(k, X.shape[0])
    nn = NearestNeighbors(n_neighbors=k_eff).fit(X)
    nbrs = nn.kneighbors(return_distance=False)
    dkl = np.zeros(X.shape[0])
    for i in range(X.shape[0]):
        neigh = nbrs[i]
        Pl = np.bincount(b_int[neigh], minlength=B).astype(float); Pl /= Pl.sum()
        Pl = np.clip(Pl, 1e-12, 1.0)
        dkl[i] = np.sum(Pl * np.log(Pl / Pg))
    return float(dkl.mean()), dkl

def ari_nmi(X, y):
    if y is None or len(np.unique(y)) < 2: return np.nan, np.nan
    n_clusters = len(np.unique(y))
    km = KMeans(n_clusters=n_clusters, n_init=10, random_state=0).fit(X)
    pred = km.labels_
    return float(adjusted_rand_score(y, pred)), float(normalized_mutual_info_score(y, pred))


## 载入各方法嵌入并计算指标

In [5]:
from collections import OrderedDict
import numpy as np

def _maybe_transpose(X, n_cells, name=""):
    """如果发现 X 的第二维等于细胞数，则自动转置"""
    # 转成 numpy，兼容 DataFrame
    Xv = X.values if hasattr(X, "values") else X
    if hasattr(Xv, "shape") and Xv.ndim == 2:
        r, c = Xv.shape
        if r != n_cells and c == n_cells:
            print(f"[Fix] {name}: 检测到形状为({r},{c})，自动转置为({c},{r})")
            Xv = Xv.T
    return Xv

# [+] 新增：鲁棒读取 .npy / .pyr(pickle) / .csv / .tsv / .txt
def _load_label_any(path, n_cells=None):
    arr = None
    try:
        # 先试 npy
        arr = np.load(path, allow_pickle=True)
    except Exception:
        # 再试 pickle（.pyr/.pkl）
        try:
            s = pd.read_pickle(path)
            arr = s.values if hasattr(s, "values") else np.asarray(s)
        except Exception:
            # 最后试文本
            try:
                sep = "\t" if path.endswith(".tsv") else ","
                s = pd.read_csv(path, header=None, sep=sep)
                arr = s.iloc[:, 0].to_numpy()
            except Exception as e:
                print(f"[Warn] 无法读取标签文件 {path}: {e}")
                return None
    arr = np.asarray(arr).reshape(-1)
    arr = arr.astype(str)
    if n_cells is not None and len(arr) != n_cells:
        print(f"[Warn] 标签 {os.path.basename(path)} 长度({len(arr)}) ≠ 细胞数({n_cells})，请确认顺序是否一致")
    return arr

In [6]:
embeddings = OrderedDict()
n_cells = adata.n_obs
labels_len = len(batch_labels) if batch_labels is not None else n_cells
if labels_len != n_cells:
    print(f"[Warn] 标签长度({labels_len}) ≠ 细胞数({n_cells})，请先对齐标签！")

# --- Uncorrect ---
Xu = None
if EMBEDDING_FILES.get("Uncorrect"):
    Xu = _load_embedding_file(EMBEDDING_FILES["Uncorrect"])
    Xu = _maybe_transpose(Xu, n_cells, name="Uncorrect")
else:
    Xu = X_uncorrect
    Xu = _maybe_transpose(Xu, n_cells, name="Uncorrect(fallback)")

if Xu is None or Xu.shape[0] != n_cells:
    print(f"[Warn] Uncorrect 行数({None if Xu is None else Xu.shape[0]}) ≠ 细胞数({n_cells})，已跳过该方法")
else:
    embeddings["Uncorrect"] = Xu

# --- 其它方法 ---
for name in ["MyModel","DESC","fastMNN","Harmony","iMAP","LIGER","Scanorama","Seurat"]:
    p = EMBEDDING_FILES.get(name)
    if not p:
        print(f"[Skip] {name}: 未提供嵌入文件")
        continue
    try:
        Xm = _load_embedding_file(p)
        Xm = _maybe_transpose(Xm, n_cells, name=name)

        # ★ 关键对齐：若矩阵行数等于“原始 9288”，先按 idx 子集到 8573 并保持顺序一致
        if Xm.shape[0] == len(obs_before):
            Xm = Xm[idx]

        if Xm.shape[0] != n_cells:
            print(f"[Warn] {name} 行数({Xm.shape[0]}) ≠ 细胞数({n_cells})，跳过")
        else:
            embeddings[name] = Xm
    except Exception as e:
        print(f"[Warn] 无法加载 {name}: {e}")

[Skip] DESC: 未提供嵌入文件
[Skip] fastMNN: 未提供嵌入文件
[Skip] iMAP: 未提供嵌入文件
[Skip] LIGER: 未提供嵌入文件


In [7]:
# --- 计算指标（再次保护）---
results = []
for method, X in embeddings.items():
    if X.shape[0] != n_cells:
        print(f"[Skip] {method}: X 行数({X.shape[0]}) ≠ 细胞数({n_cells})")
        continue

    print(f"==> {method}: X shape = {X.shape}")
    row = {"method": method}

    ARI, NMI = ari_nmi(X, cell_types) if cell_types is not None else (np.nan, np.nan)
    row["ARI"], row["NMI"] = ARI, NMI

    # ASW
    row["ASW_celltype"] = asw_celltype(X, cell_types) if cell_types is not None else np.nan
    row["ASW_batch"]    = asw_batch(X, batch_labels)

    # iLISI
    if R_AVAILABLE and batch_labels is not None:
        try:
            ilisi_mean, ilisi_dist = ilisi_r(X, batch_labels, perplexity=30)
        except Exception as e:
            print(f"[Warn] iLISI(R) 失败，回退 Python 近似：{e}")
            ilisi_mean, ilisi_dist = ilisi_python(X, batch_labels, k=K_ILISI)
    else:
        ilisi_mean, ilisi_dist = ilisi_python(X, batch_labels, k=K_ILISI)
    row["iLISI"] = ilisi_mean
    row["_ilisi_dist"] = ilisi_dist

    # KL
    kl_mean, kl_dist = local_kl(X, batch_labels, k=K_KL)
    row["KL"] = kl_mean
    row["_kl_dist"] = kl_dist

    results.append(row)

df = pd.DataFrame(results).set_index("method")
display(df[["ARI","NMI","ASW_celltype","ASW_batch","iLISI","KL"]])
df.drop(columns=["_ilisi_dist","_kl_dist"]).to_csv(os.path.join(OUTDIR, "metrics_summary.csv"))
print("保存:", os.path.join(OUTDIR, "metrics_summary.csv"))

==> Uncorrect: X shape = (30302, 50)


==> MyModel: X shape = (30302, 32)
==> Harmony: X shape = (30302, 50)
==> Scanorama: X shape = (30302, 50)
==> Seurat: X shape = (30302, 50)


Unnamed: 0_level_0,ARI,NMI,ASW_celltype,ASW_batch,iLISI,KL
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Uncorrect,0.864393,0.87035,0.180905,0.495052,1.462206,1.059711
MyModel,0.08591,0.100163,-0.043339,0.519164,3.337639,0.197512
Harmony,0.939202,0.937747,0.213231,0.518558,3.291692,0.305639
Scanorama,0.424838,0.572141,0.112822,0.474185,1.341958,1.114245
Seurat,0.117358,0.128822,-0.026501,0.496428,1.972205,0.731611


保存: benchmark_out/metrics_summary.csv


## UMAP（按批次与细胞类型）

In [8]:

def scatter_umap(X, labels, title, out_png):
    X = np.asarray(X)
    assert X.ndim == 2
    if labels is not None:
        labels = np.asarray(labels)
        assert len(labels) == X.shape[0], f"labels({len(labels)}) != n_cells({X.shape[0]})"

    ad = AnnData(X)  # ad.X 不再用于 PCA
    ad.obs["label"] = (labels.astype(str) if labels is not None
                       else pd.Series(["NA"]*X.shape[0], dtype=str).values)

    # 直接复用已有嵌入作为“PCA”的表示
    ad.obsm["X_pca"] = X

    # 用 X_pca 构图，不再调用 sc.tl.pca
    sc.pp.neighbors(ad, n_neighbors=N_NEIGHBORS, use_rep="X_pca")
    sc.tl.umap(ad, random_state=UMAP_SEED)

    old = sc.settings.figdir
    sc.settings.figdir = OUTDIR
    try:
        sc.pl.umap(ad, color="label", title=title, show=False, save=f"_{out_png}")
        auto = os.path.join(OUTDIR, f"umap_{out_png}")
        if os.path.exists(auto): print("保存:", auto)
    finally:
        sc.settings.figdir = old


for method, X in embeddings.items():
    scatter_umap(X, batch_labels, f"{method} — batch",    f"{method}_batch.png")
    if cell_types is not None:
        scatter_umap(X, cell_types, f"{method} — celltype", f"{method}_celltype.png")


computing neighbors


  from .autonotebook import tqdm as notebook_tqdm


    finished (0:00:20)
computing UMAP
    finished (0:00:13)
保存: benchmark_out/umap_Uncorrect_batch.png
computing neighbors
    finished (0:00:02)
computing UMAP
    finished (0:00:13)
保存: benchmark_out/umap_Uncorrect_celltype.png
computing neighbors
    finished (0:00:02)
computing UMAP
    finished (0:00:12)
保存: benchmark_out/umap_MyModel_batch.png
computing neighbors
    finished (0:00:02)
computing UMAP
    finished (0:00:12)
保存: benchmark_out/umap_MyModel_celltype.png
computing neighbors
    finished (0:00:02)
computing UMAP
    finished (0:00:12)
保存: benchmark_out/umap_Harmony_batch.png
computing neighbors
    finished (0:00:02)
computing UMAP
    finished (0:00:12)
保存: benchmark_out/umap_Harmony_celltype.png
computing neighbors
    finished (0:00:02)
computing UMAP
    finished (0:00:12)
保存: benchmark_out/umap_Scanorama_batch.png
computing neighbors
    finished (0:00:02)
computing UMAP
    finished (0:00:12)
保存: benchmark_out/umap_Scanorama_celltype.png
computing neighbors
    