# ST-MIL 评估（使用已保存结果）

本 Notebook 只读取 `results/st_mil_hest/{split}/pred_bag.npy` 与 `true_bag.npy` 做评估。
不再重新跑模型预测，适合你已经完成训练和推理的场景。


In [1]:
from pathlib import Path
import os
import json
import numpy as np

# ====== 关键路径（按你的机器修改）======
ROOT = Path(os.environ.get('MORPHO_VC_ROOT', '..')).expanduser().resolve(strict=False)
results_dir = ROOT / 'results' / 'st_mil_hest'

# 选择评估的 split（'val' 或 'test'）
split_name = 'test'
pred_path = results_dir / split_name / 'pred_bag.npy'
true_path = results_dir / split_name / 'true_bag.npy'

# common_genes 优先读取 data/spatial_data，其次读取 checkpoints
common_genes = None
common_txt = ROOT / 'data' / 'spatial_data' / 'common_genes.txt'
ckpt_genes = ROOT / 'checkpoints' / 'st_mil_hest' / 'common_genes.json'

if common_txt.exists():
    common_genes = [g.strip() for g in common_txt.read_text().splitlines() if g.strip()]
elif ckpt_genes.exists():
    common_genes = json.loads(ckpt_genes.read_text())
else:
    raise FileNotFoundError('找不到 common_genes.txt 或 common_genes.json')

if not pred_path.exists() or not true_path.exists():
    raise FileNotFoundError(f'结果文件不存在: {pred_path} 或 {true_path}')

pred_bag = np.load(pred_path)
true_bag = np.load(true_path)

print('split:', split_name)
print('pred_bag:', pred_bag.shape)
print('true_bag:', true_bag.shape)
print('common_genes:', len(common_genes))
if pred_bag.shape[1] != len(common_genes):
    print('警告：pred_bag 列数与 common_genes 长度不一致')


split: test
pred_bag: (3836, 17512)
true_bag: (3836, 17512)
common_genes: 17512


In [2]:
import numpy as np

def pearson_corr(a, b):
    a = a - a.mean()
    b = b - b.mean()
    denom = np.sqrt((a * a).sum()) * np.sqrt((b * b).sum())
    if denom == 0:
        return np.nan
    return float((a * b).sum() / denom)

mae = float(np.mean(np.abs(pred_bag - true_bag)))
rmse = float(np.sqrt(np.mean((pred_bag - true_bag) ** 2)))

gene_corrs = []
for i in range(pred_bag.shape[1]):
    corr = pearson_corr(pred_bag[:, i], true_bag[:, i])
    if not np.isnan(corr):
        gene_corrs.append(corr)
mean_gene_corr = float(np.mean(gene_corrs)) if gene_corrs else float('nan')

print(f'MAE: {mae:.4f}')
print(f'RMSE: {rmse:.4f}')
print(f'平均 Pearson(按基因): {mean_gene_corr:.4f}')


MAE: 0.6539
RMSE: 4.6864
平均 Pearson(按基因): 0.3592


In [3]:
import numpy as np

# ====== 真实高变基因评估 ======
k = 1000  # 你可以改 500 / 1000 / 2000

true_x = np.log1p(true_bag)
pred_x = np.log1p(pred_bag)

true_var = true_x.var(axis=0)
hvg_idx = np.argsort(true_var)[-k:]

corrs = []
for i in hvg_idx:
    c = pearson_corr(pred_x[:, i], true_x[:, i])
    if not np.isnan(c):
        corrs.append(c)

print(f'真实HVG({k})上的平均Pearson: {float(np.mean(corrs)):.4f}')

# 预测HVG 与 真实HVG 重叠度
pred_var = pred_x.var(axis=0)
pred_hvg_idx = np.argsort(pred_var)[-k:]
overlap = len(set(hvg_idx) & set(pred_hvg_idx)) / k
print(f'预测HVG与真实HVG重叠比例: {overlap:.4f}')

# 输出相关性最高的基因
gene_corr_pairs = []
for i in hvg_idx:
    c = pearson_corr(pred_x[:, i], true_x[:, i])
    if not np.isnan(c):
        gene_corr_pairs.append((common_genes[i], c))
gene_corr_pairs.sort(key=lambda x: x[1], reverse=True)
print('相关性最高的基因（HVG内 Top10）：')
for g, c in gene_corr_pairs[:10]:
    print(f'{g}: {c:.4f}')


真实HVG(1000)上的平均Pearson: 0.8625
预测HVG与真实HVG重叠比例: 0.7480
相关性最高的基因（HVG内 Top10）：
KLK2: 0.9809
KLK3: 0.9793
EEF2: 0.9764
PPDPF: 0.9683
EEF1G: 0.9663
SERF2: 0.9642
UBA52: 0.9605
UBC: 0.9586
PTMS: 0.9573
PABPC1: 0.9572


In [4]:
# ====== 指定基因相关性 ======
gene_name = 'KLK3'  # 改成你想看的基因

if gene_name in common_genes:
    idx = common_genes.index(gene_name)
    corr = pearson_corr(pred_bag[:, idx], true_bag[:, idx])
    print(f'{gene_name} Pearson: {corr:.4f}')
else:
    print('gene_name 不在 common_genes 中:', gene_name)


KLK3 Pearson: 0.9889


In [5]:
# ====== 预测塌缩诊断（可选）======
print('pred 每个基因标准差均值:', float(pred_bag.std(axis=0).mean()))
print('pred 每个基因标准差最小:', float(pred_bag.std(axis=0).min()))
print('true 每个基因标准差均值:', float(true_bag.std(axis=0).mean()))
print('true 每个基因标准差最小:', float(true_bag.std(axis=0).min()))

if pred_bag.shape[1] >= 2:
    g1, g2 = 0, 1
    corr = pearson_corr(pred_bag[:, g1], pred_bag[:, g2])
    print(f'预测基因 {common_genes[g1]} vs {common_genes[g2]} Pearson: {corr:.4f}')


pred 每个基因标准差均值: 0.6540340185165405
pred 每个基因标准差最小: 2.6412854126078855e-08
true 每个基因标准差均值: 1.2009499073028564
true 每个基因标准差最小: 0.0
预测基因 A1BG vs A1CF Pearson: 0.3824
