In [13]:
"""
计算 DES, PDS, MAE 以及最终综合得分 S 的函数集合。

依赖：
    pip install anndata scipy statsmodels numpy pandas

用法（示例）：
    scores = compute_vcc_scores(true_adata, pred_adata,
                                groupby='target_gene', control_label='non-targeting',
                                baseline_scores={'DES':0.2,'PDS':0.3,'MAE':0.5})
    print(scores)
"""
import json
import warnings
from typing import Optional, Dict

import numpy as np
import pandas as pd
import scipy.sparse as sp
from scipy.stats import mannwhitneyu
from statsmodels.stats.multitest import multipletests
import anndata as ad

def _get_matrix_row_mean(adata, mask):
    """
    返回 mask（布尔数组）对应细胞的每个基因的平均表达（1D numpy float array）。
    兼容稀疏矩阵。
    """
    X = adata.X
    if sp.issparse(X):
        sub = X[mask]
        # mean(axis=0) 返回 1xG sparse/dense matrix
        mean_vec = np.array(sub.mean(axis=0)).ravel()
    else:
        sub = X[mask, :]
        mean_vec = np.asarray(sub.mean(axis=0)).ravel()
    return mean_vec.astype(float)

def _extract_expr_vectors_for_gene(adata, mask):
    """返回 mask 对应细胞在所有基因上的表达子矩阵（稀疏或密集）"""
    X = adata.X
    if sp.issparse(X):
        sub = X[mask].toarray()
    else:
        sub = X[mask, :]
    return sub  # shape (n_cells, n_genes)

def _wilcoxon_de_genes(adata, group_mask, ntc_mask, gene_names, fdr=0.05, sample_n:int=1000):
    """
    对每个基因做 Wilcoxon rank-sum（Mann-Whitney U）检验：group vs ntc。
    为控制时间，当每组细胞非常多时，会进行随机抽样（sample_n）。
    返回：DataFrame 包含 index=gene_names, columns=['pval','logfc','mean_group','mean_ntc']
    """
    n_group = int(group_mask.sum())
    n_ntc = int(ntc_mask.sum())
    if n_group < 2 or n_ntc < 2:
        # 无法检验
        return pd.DataFrame(index=gene_names, data={
            'pval': np.ones(len(gene_names)),
            'logfc': np.zeros(len(gene_names)),
            'mean_group': np.zeros(len(gene_names)),
            'mean_ntc': np.zeros(len(gene_names))
        })

    # 抽样索引
    rng = np.random.default_rng(0)
    if sample_n is not None:
        if n_group > sample_n:
            g_idx = rng.choice(np.nonzero(group_mask)[0], sample_n, replace=False)
            group_mask_sample = np.zeros_like(group_mask, dtype=bool)
            group_mask_sample[g_idx] = True
        else:
            group_mask_sample = group_mask
        if n_ntc > sample_n:
            n_idx = rng.choice(np.nonzero(ntc_mask)[0], sample_n, replace=False)
            ntc_mask_sample = np.zeros_like(ntc_mask, dtype=bool)
            ntc_mask_sample[n_idx] = True
        else:
            ntc_mask_sample = ntc_mask
    else:
        group_mask_sample = group_mask
        ntc_mask_sample = ntc_mask

    # 取表达矩阵
    group_mat = _extract_expr_vectors_for_gene(adata, group_mask_sample)  # (n_g, G)
    ntc_mat = _extract_expr_vectors_for_gene(adata, ntc_mask_sample)      # (n_n, G)

    G = group_mat.shape[1]
    pvals = np.ones(G)
    mean_group = group_mat.mean(axis=0)
    mean_ntc = ntc_mat.mean(axis=0)
    # logFC 使用 log2(mean_group + 1) - log2(mean_ntc + 1)
    logfc = np.log2(mean_group + 1) - np.log2(mean_ntc + 1)

    # 对每个基因做 Mann-Whitney U（两侧）
    for gi in range(G):
        try:
            u = mannwhitneyu(group_mat[:, gi], ntc_mat[:, gi], alternative='two-sided')
            pvals[gi] = u.pvalue if u.pvalue is not None else 1.0
        except Exception:
            pvals[gi] = 1.0

    # FDR 校正
    reject, pvals_adj, _, _ = multipletests(pvals, alpha=fdr, method='fdr_bh')
    df = pd.DataFrame({
        'pval': pvals,
        'pval_adj': pvals_adj,
        'significant': reject,
        'logfc': logfc,
        'mean_group': mean_group,
        'mean_ntc': mean_ntc
    }, index=gene_names)
    return df

def compute_vcc_scores(true_adata,
                       pred_adata,
                       baseline_scores: Optional[Dict[str,float]] = None,
                       baseline_adata = None,
                       groupby: str = 'target_gene',
                       control_label: str = 'non-targeting',
                       fdr: float = 0.05,
                       sample_n: int = 1000,
                       min_cells_per_group: int = 3):
    """
    计算 DES, PDS, MAE 以及最终 S。
    参数：
        true_adata, pred_adata: AnnData，必须包含相同的 var（基因）或能被对齐。
        baseline_scores: 可选字典 {'DES':float,'PDS':float,'MAE':float} 直接给出基线分数用于缩放。
        baseline_adata: 可选，如果提供则用它与 true_adata 计算 baseline 的三个原始指标。
                       若同时提供 baseline_scores 则以 baseline_scores 为准。
        groupby: 在 adata.obs 中按哪个列分组（通常 'target_gene' 或 'guide_id'）
        control_label: 对应空白/对照分组（例如 'non-targeting'）
        fdr: Benjamini-Hochberg FDR 阈值
        sample_n: 用于 Wilcoxon 的每组最大抽样细胞数（以控制计算量），None 表示不抽样
        min_cells_per_group: 若某 perturbation 或 ntc 细胞数少于此阈值则跳过
    返回：
        dict 包含 DES, PDS, MAE, scaled 各分量，以及最终 S（百分数）
    """

    # --- 对齐基因（var_names） ---
    genes_true = list(true_adata.var_names)
    genes_pred = list(pred_adata.var_names)
    if genes_true != genes_pred:
        # 取交集并重新索引
        common = [g for g in genes_true if g in genes_pred]
        if len(common) == 0:
            raise ValueError("No common genes between true_adata and pred_adata.")
        true_adata = true_adata[:, common]
        pred_adata = pred_adata[:, common]

    gene_names = list(true_adata.var_names)
    G = len(gene_names)

    # --- 确定 perturbation 列与分组 ---
    if groupby not in true_adata.obs.columns:
        raise ValueError(f"groupby '{groupby}' not found in true_adata.obs")
    if groupby not in pred_adata.obs.columns:
        raise ValueError(f"groupby '{groupby}' not found in pred_adata.obs")

    true_groups = true_adata.obs[groupby].astype(str)
    pred_groups = pred_adata.obs[groupby].astype(str)

    # 找到控制(ntc) mask
    ntc_mask_true = (true_groups == control_label).values
    ntc_mask_pred = (pred_groups == control_label).values

    if ntc_mask_true.sum() < min_cells_per_group or ntc_mask_pred.sum() < min_cells_per_group:
        warnings.warn("Control (ntc) cell count is small.")

    # 获取 perturbation 列表（排除 control）
    perturbations = sorted([g for g in true_groups.unique() if g != control_label])
    N = len(perturbations)
    if N == 0:
        raise ValueError("No perturbations found (after excluding control_label).")

    # --- 1) 计算 DES per perturbation ---
    DES_list = []
    for p in perturbations:
        # masks
        g_mask_true = (true_groups == p).values
        g_mask_pred = (pred_groups == p).values

        if g_mask_true.sum() < min_cells_per_group or ntc_mask_true.sum() < min_cells_per_group:
            # 无法做 DE 检验，跳过（返回 NaN，这里我们不算入均值）
            DES_list.append(np.nan)
            continue

        # true DE
        df_true = _wilcoxon_de_genes(true_adata, g_mask_true, ntc_mask_true, gene_names, fdr=fdr, sample_n=sample_n)
        G_true = set(df_true.index[df_true['significant']].tolist())

        # pred DE
        if g_mask_pred.sum() < min_cells_per_group or ntc_mask_pred.sum() < min_cells_per_group:
            df_pred = pd.DataFrame(index=gene_names, data={
                'pval': np.ones(G),
                'pval_adj': np.ones(G),
                'significant': np.zeros(G, dtype=bool),
                'logfc': np.zeros(G),
                'mean_group': np.zeros(G),
                'mean_ntc': np.zeros(G)
            })
        else:
            df_pred = _wilcoxon_de_genes(pred_adata, g_mask_pred, ntc_mask_pred, gene_names, fdr=fdr, sample_n=sample_n)

        G_pred = set(df_pred.index[df_pred['significant']].tolist())

        # 若 |G_pred| <= |G_true|，直接取交集
        if len(G_true) == 0:
            # 真实没有 DE 基因：规范上可跳过（用 NaN），或若两者都为空算 1.0，我们选择跳过（不计入均值）
            DES_list.append(np.nan)
            continue

        if len(G_pred) <= len(G_true):
            inter = len(G_pred & G_true)
            DES_k = inter / len(G_true)
            DES_list.append(DES_k)
        else:
            # 预测过大，按 logFC 的绝对值挑 top |G_true|（使用预测的 logFC）
            df_pred_sorted = df_pred.reindex(gene_names).copy()
            # 以绝对 logFC 排序，取 top
            df_pred_sorted['abs_logfc'] = np.abs(df_pred_sorted['logfc'])
            topk = int(len(G_true))
            top_genes = set(df_pred_sorted.sort_values('abs_logfc', ascending=False).head(topk).index.tolist())
            inter = len(top_genes & G_true)
            DES_k = inter / len(G_true)
            DES_list.append(DES_k)

    DES_array = np.array(DES_list, dtype=float)
    DES_mean = np.nanmean(DES_array)  # 忽略 NaN

    # --- 2) 计算 PDS ---
    # 计算所有 perturbation 的 pseudobulk mean expression (true & pred), 包括 ntc
    # 以 group 名字为 key
    true_pseudobulk = {}
    pred_pseudobulk = {}
    for grp in list(true_groups.unique()):
        mask = (true_groups == grp).values
        true_pseudobulk[grp] = _get_matrix_row_mean(true_adata, mask)
    for grp in list(pred_groups.unique()):
        mask = (pred_groups == grp).values
        pred_pseudobulk[grp] = _get_matrix_row_mean(pred_adata, mask)

    # 若某些 perturbation 在 pred 中缺失，则用 zeros 或跳过（这里我们尽力用 zeros 并发警告）
    for p in perturbations:
        if p not in pred_pseudobulk:
            warnings.warn(f"Perturbation {p} missing in predicted data; using zeros vector.")
            pred_pseudobulk[p] = np.zeros(G)
        if p not in true_pseudobulk:
            warnings.warn(f"Perturbation {p} missing in true data; using zeros vector.")
            true_pseudobulk[p] = np.zeros(G)

    if control_label not in true_pseudobulk or control_label not in pred_pseudobulk:
        raise ValueError("Control label not present in pseudobulk results for true or pred.")

    # 计算 delta = x_k - x_ntc
    true_delta = {p: true_pseudobulk[p] - true_pseudobulk[control_label] for p in perturbations}
    pred_delta = {p: pred_pseudobulk[p] - pred_pseudobulk[control_label] for p in perturbations}

    PDS_vals = []
    for p in perturbations:
        # 计算 pred_delta[p] 与所有 true_delta[q] 的 L1 距离
        dists = np.array([np.sum(np.abs(pred_delta[p] - true_delta[q])) for q in perturbations])
        # 找到真实配对 q==p 的距离的排名（1 是最小）
        # rank = 1 + number of distances strictly smaller than d_pt
        d_pt = np.sum(np.abs(pred_delta[p] - true_delta[p]))
        rank = int((dists < d_pt).sum()) + 1
        # PDS_p = 1 - (rank-1)/N
        PDS_p = 1.0 - (rank - 1) / float(N)
        PDS_vals.append(PDS_p)

    PDS_array = np.array(PDS_vals, dtype=float)
    PDS_mean = PDS_array.mean()

    # --- 3) 计算 MAE ---
    MAE_vals = []
    for p in perturbations:
        y_true = true_pseudobulk[p]
        y_pred = pred_pseudobulk.get(p, np.zeros(G))
        mae_k = np.mean(np.abs(y_pred - y_true))
        MAE_vals.append(mae_k)
    MAE_array = np.array(MAE_vals, dtype=float)
    MAE_mean = MAE_array.mean()

    # --- baseline handling: 若提供 baseline_adata，则计算 baseline 的三项；否则用 baseline_scores dict ---
    if baseline_scores is None and baseline_adata is not None:
        # 递归调用 compute_vcc_scores，将 baseline_scores 作为返回值（但为防循环，这里传 None 且 baseline_adata=None）
        baseline_res = compute_vcc_scores(true_adata, baseline_adata,
                                          baseline_scores={'DES':0.0,'PDS':0.0,'MAE':1.0},
                                          groupby=groupby, control_label=control_label,
                                          fdr=fdr, sample_n=sample_n, min_cells_per_group=min_cells_per_group)
        baseline_scores = {'DES': baseline_res['DES'],
                           'PDS': baseline_res['PDS'],
                           'MAE': baseline_res['MAE']}
    if baseline_scores is None:
        raise ValueError("You must provide either baseline_scores or baseline_adata to compute scaled scores.")

    DES_baseline = float(baseline_scores.get('DES', 0.0))
    PDS_baseline = float(baseline_scores.get('PDS', 0.0))
    MAE_baseline = float(baseline_scores.get('MAE', 1.0))

    # --- 计算 scaled scores ---
    # 对 DES 和 PDS 在 [0,1] 上缩放；对 MAE 用相对减小量
    # 防止分母为 0
    def safe_scale_score(pred, base):
        if base >= 1.0 - 1e-12:
            return 0.0
        return (pred - base) / (1.0 - base)

    DES_scaled = safe_scale_score(DES_mean, DES_baseline)
    PDS_scaled = safe_scale_score(PDS_mean, PDS_baseline)
    # MAE_scaled = (MAE_baseline - MAE_pred) / MAE_baseline
    if MAE_baseline == 0:
        MAE_scaled = 0.0
    else:
        MAE_scaled = (MAE_baseline - MAE_mean) / MAE_baseline

    overall_score = (DES_scaled + PDS_scaled + MAE_scaled) / 3.0 * 100.0

    result = {
        'DES': float(DES_mean),
        'PDS': float(PDS_mean),
        'MAE': float(MAE_mean),
        'DES_scaled': float(DES_scaled),
        'PDS_scaled': float(PDS_scaled),
        'MAE_scaled': float(MAE_scaled),
        'overall_score_percent': float(overall_score),
        # 额外返回每 perturbation 的详细值（可选查看）
        'DES_per_perturbation': pd.Series(DES_array, index=perturbations),
        'PDS_per_perturbation': pd.Series(PDS_array, index=perturbations),
        'MAE_per_perturbation': pd.Series(MAE_array, index=perturbations),
        'perturbations': perturbations
    }
    return result

In [12]:
true_file = "../small_set.h5ad"
pred_file = "../small_set.h5ad"

print("Loading AnnData files...")
true_adata = ad.read_h5ad(true_file)
pred_adata = ad.read_h5ad(pred_file)

baseline_scores = {'DES': 0.106, 'PDS': 0.514, 'MAE': 0.027}

print("Computing scores (this may take time depending on dataset size)...")
scores = compute_vcc_scores(true_adata, pred_adata,
                            baseline_scores=baseline_scores,
                            groupby='target_gene',
                            control_label='non-targeting',
                            fdr=0.05,
                            sample_n=1000,
                            min_cells_per_group=3)

# Print summary
print("\n===== Summary =====")
print(f"DES (mean): {scores['DES']:.6f}")
print(f"PDS (mean): {scores['PDS']:.6f}")
print(f"MAE (mean): {scores['MAE']:.6f}")
print(f"DES_scaled: {scores['DES_scaled']:.6f}")
print(f"PDS_scaled: {scores['PDS_scaled']:.6f}")
print(f"MAE_scaled: {scores['MAE_scaled']:.6f}")
print(f"Overall score (percent): {scores['overall_score_percent']:.4f}%")

# Save overall and per-perturbation results
out_json = {
    'DES': scores['DES'],
    'PDS': scores['PDS'],
    'MAE': scores['MAE'],
    'DES_scaled': scores['DES_scaled'],
    'PDS_scaled': scores['PDS_scaled'],
    'MAE_scaled': scores['MAE_scaled'],
    'overall_score_percent': scores['overall_score_percent']
}
with open('evaluation_results.json', 'w') as f:
    json.dump(out_json, f, indent=2)

# save per-perturbation CSVs
df_des = scores['DES_per_perturbation'].rename("DES").to_frame()
df_pds = scores['PDS_per_perturbation'].rename("PDS").to_frame()
df_mae = scores['MAE_per_perturbation'].rename("MAE").to_frame()
df_all = pd.concat([df_des, df_pds, df_mae], axis=1)
df_all.to_csv('evaluation_per_perturbation.csv')

print("\nSaved evaluation_results.json and evaluation_per_perturbation.csv")

Loading AnnData files...
Computing scores (this may take time depending on dataset size)...

===== Summary =====
DES (mean): 1.000000
PDS (mean): 1.000000
MAE (mean): 0.000000
DES_scaled: 1.000000
PDS_scaled: 1.000000
MAE_scaled: 1.000000
Overall score (percent): 100.0000%

Saved evaluation_results.json and evaluation_per_perturbation.csv


In [14]:
true_file = "../test_set_1119.h5ad"
pred_file = "example.h5ad"

print("Loading AnnData files...")
true_adata = ad.read_h5ad(true_file)
pred_adata = ad.read_h5ad(pred_file)

baseline_scores = {'DES': 0.106, 'PDS': 0.514, 'MAE': 0.027}

print("Computing scores (this may take time depending on dataset size)...")
scores = compute_vcc_scores(true_adata, pred_adata,
                            baseline_scores=baseline_scores,
                            groupby='target_gene',
                            control_label='non-targeting',
                            fdr=0.05,
                            sample_n=1000,
                            min_cells_per_group=3)

# Print summary
print("\n===== Summary =====")
print(f"DES (mean): {scores['DES']:.6f}")
print(f"PDS (mean): {scores['PDS']:.6f}")
print(f"MAE (mean): {scores['MAE']:.6f}")
print(f"DES_scaled: {scores['DES_scaled']:.6f}")
print(f"PDS_scaled: {scores['PDS_scaled']:.6f}")
print(f"MAE_scaled: {scores['MAE_scaled']:.6f}")
print(f"Overall score (percent): {scores['overall_score_percent']:.4f}%")

# Save overall and per-perturbation results
out_json = {
    'DES': scores['DES'],
    'PDS': scores['PDS'],
    'MAE': scores['MAE'],
    'DES_scaled': scores['DES_scaled'],
    'PDS_scaled': scores['PDS_scaled'],
    'MAE_scaled': scores['MAE_scaled'],
    'overall_score_percent': scores['overall_score_percent']
}
with open('eval_outcome/evaluation_results.json', 'w') as f:
    json.dump(out_json, f, indent=2)

# save per-perturbation CSVs
df_des = scores['DES_per_perturbation'].rename("DES").to_frame()
df_pds = scores['PDS_per_perturbation'].rename("PDS").to_frame()
df_mae = scores['MAE_per_perturbation'].rename("MAE").to_frame()
df_all = pd.concat([df_des, df_pds, df_mae], axis=1)
df_all.to_csv('eval_outcome/valuation_per_perturbation.csv')

print("\nSaved evaluation_results.json and evaluation_per_perturbation.csv")

Loading AnnData files...
Computing scores (this may take time depending on dataset size)...

===== Summary =====
DES (mean): 0.042667
PDS (mean): 0.516667
MAE (mean): 7.005270
DES_scaled: -0.070842
PDS_scaled: 0.005487
MAE_scaled: -258.454442
Overall score (percent): -8617.3266%

Saved evaluation_results.json and evaluation_per_perturbation.csv
