
# DisentAE Batch Correction Pipeline (Notebook)

这个笔记本是将你的 `main.py` 转换成的 `ipynb` 版本。  
与命令行参数不同，这里用一个 **Parameters** 单元格来设置参数（等价于 `argparse` 的各项）。

> 依赖：`numpy`, `scanpy`, `anndata`, `torch`，以及你本地的 `model.py`、`file.py`（包含 `DisentAE`、`preprocess_batches`、`train_disentae`、`apply_mnn_correction_iterative`、`decode_corrected_expression`、`inputs_signature`、`load_manifest`、`save_manifest` 等）。


In [1]:

import os, time
import numpy as np
import scanpy as sc
from anndata import AnnData
import torch

# 你的项目内模块
from model import *
from file import *  # 假设包含 inputs_signature / load_manifest / save_manifest 等

print(f"[Info] Using torch {torch.__version__}, CUDA available: {torch.cuda.is_available()}")
sc.settings.verbosity = 2


  from .autonotebook import tqdm as notebook_tqdm


[Info] Using torch 2.8.0+cu128, CUDA available: False


In [17]:
adata = sc.read_h5ad('bct_raw.h5ad')

# 输出行名 (Observed names)，即细胞名
print("细胞/样本名 (行名):")
print(adata.obs_names)

# 输出列名 (Variable names)，即基因名
print("\n基因名 (列名):")
print(adata.var_names)

print("adata.obs 中的所有列（元数据）:")
print(adata.obs.columns)

# # 查看 batch 列的前几行
# print("batch 列的前5个值:")
# print(adata.obs['batch'].head())

# # 统计每个批次有多少个细胞
# print("每个批次的细胞数量:")
# print(adata.obs['batch'].value_counts())

细胞/样本名 (行名):
Index(['vis1', 'vis2', 'vis3', 'vis4', 'vis5', 'vis6', 'vis7', 'vis8', 'vis9',
       'vis10',
       ...
       'wal4363', 'wal4364', 'wal4365', 'wal4366', 'wal4367', 'wal4370',
       'wal4371', 'wal4372', 'wal4374', 'wal4375'],
      dtype='object', length=9288)

基因名 (列名):
Index(['ENSMUSG00000092341', 'ENSMUSG00000029580', 'ENSMUSG00000023043',
       'ENSMUSG00000064341', 'ENSMUSG00000031765', 'ENSMUSG00000017009',
       'ENSMUSG00000016559', 'ENSMUSG00000064370', 'ENSMUSG00000024661',
       'ENSMUSG00000049382',
       ...
       'ENSMUSG00000036580', 'ENSMUSG00000010663', 'ENSMUSG00000040605',
       'ENSMUSG00000039431', 'ENSMUSG00000028634', 'ENSMUSG00000022218',
       'ENSMUSG00000020593', 'ENSMUSG00000030905', 'ENSMUSG00000022949',
       'ENSMUSG00000028545'],
      dtype='object', length=1222)
adata.obs 中的所有列（元数据）:
Index(['orig.ident', 'nCount_originalexp', 'nFeature_originalexp', 'study',
       'cell.class', 'library_size', 'detected_genes', 'BATCH', 'cell

## Parameters（请在此处修改）

In [2]:

# 等价于命令行：
# python main.py --inputs neurips2021_s1d3.h5ad neurips2021_s2d1.h5ad neurips2021_s3d7.h5ad --outdir out
# 若只传一个 .h5ad，可配合 --batch_key 使用（例如 'batch'）

ARGS = {
    "inputs": [
        # 在这里填入你的 .h5ad 路径，支持多个；或只填一个并设置 batch_key
        # "neurips2021_s1d3.h5ad",
        # "neurips2021_s2d1.h5ad",
        # "neurips2021_s3d7.h5ad",
        "bct_raw.h5ad",
    ],
    "batch_key": None,              # 例如: "batch" 或 "replicate"
    "outdir": "outputs",

    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "zc_dim": 32,
    "zb_dim": 16,
    "epochs": 50,
    "batch_size": 256,
    "lr": 1e-3,
    "grl_lambda": 1.0,
    "inject_noise_std": 0.1,

    "n_top_genes": 2000,
    "k_mnn": 20,

    # 训练阶段：默认启用缓存（仅训练阶段）；若 force=True 则无视缓存强制重训
    "resume": True,
    "force": False,
}

os.makedirs(ARGS["outdir"], exist_ok=True)
print("[Params]")
for k, v in ARGS.items():
    if k != "inputs":
        print(f"  {k}: {v}")
print(f"  inputs: {len(ARGS['inputs'])} files -> {ARGS['inputs']}")


[Params]
  batch_key: None
  outdir: outputs
  device: cpu
  zc_dim: 32
  zb_dim: 16
  epochs: 50
  batch_size: 256
  lr: 0.001
  grl_lambda: 1.0
  inject_noise_std: 0.1
  n_top_genes: 2000
  k_mnn: 20
  resume: True
  force: False
  inputs: 1 files -> ['bct_raw.h5ad']


## 构建运行签名（用于训练缓存判定）

In [3]:

curr_manifest = {
    "time": time.strftime("%Y-%m-%d %H:%M:%S"),
    "inputs": inputs_signature(ARGS["inputs"]),
    "params": {
        "zc_dim": ARGS["zc_dim"],
        "zb_dim": ARGS["zb_dim"],
        "epochs": ARGS["epochs"],
        "batch_size": ARGS["batch_size"],
        "lr": ARGS["lr"],
        "grl_lambda": ARGS["grl_lambda"],
        "inject_noise_std": ARGS["inject_noise_std"],
        "n_top_genes": ARGS["n_top_genes"],
        "k_mnn": ARGS["k_mnn"],
        "batch_key": ARGS["batch_key"],
    }
}
old_manifest = load_manifest(ARGS["outdir"])
curr_manifest


{'time': '2025-08-31 22:45:42',
 'inputs': [{'path': '/home/luke/project/dachuang/bct_raw.h5ad',
   'size': 18279891,
   'mtime': 1660820490,
   'sha1': '583ece94c4f1ad8b8ac282604b88d0f654dbcb1d'}],
 'params': {'zc_dim': 32,
  'zb_dim': 16,
  'epochs': 50,
  'batch_size': 256,
  'lr': 0.001,
  'grl_lambda': 1.0,
  'inject_noise_std': 0.1,
  'n_top_genes': 2000,
  'k_mnn': 20,
  'batch_key': None}}

## 读取数据为按批次的 `AnnData` 列表

In [4]:

adata_list = []
if len(ARGS["inputs"]) == 0:
    raise ValueError("请在 ARGS['inputs'] 中填写至少一个 .h5ad 路径")

if len(ARGS["inputs"]) == 1 and ARGS["batch_key"] is not None:
    ad = sc.read_h5ad(ARGS["inputs"][0])
    if ARGS["batch_key"] not in ad.obs:
        raise ValueError(f"batch_key '{ARGS['batch_key']}' 不在 adata.obs 中")
    for b in ad.obs[ARGS["batch_key"]].astype(str).unique():
        ad_b = ad[ad.obs[ARGS["batch_key"]].astype(str) == b].copy()
        adata_list.append(ad_b)
    print(f"[Info] 从单个 h5ad 依据 '{ARGS['batch_key']}' 拆出 {len(adata_list)} 个批次")
else:
    for p in ARGS["inputs"]:
        ad = sc.read_h5ad(p)
        adata_list.append(ad)
    print(f"[Info] 读取多个批次文件，共 {len(adata_list)} 个")

len(adata_list)


[Info] 读取多个批次文件，共 1 个


1

## 预处理（每次执行，不跳过）

In [5]:

X, batches, hvg_list = preprocess_batches(
    adata_list, n_top_genes=ARGS["n_top_genes"]
)
print(f"[Info] X shape: {X.shape}, batches shape: {batches.shape}, |HVG∩|={len(hvg_list)}")

np.save(os.path.join(ARGS["outdir"], "X_preprocessed.npy"), X)
np.save(os.path.join(ARGS["outdir"], "batches.npy"), batches)
with open(os.path.join(ARGS["outdir"], "HVG_intersection.tsv"), "w") as f:
    f.write("\n".join(hvg_list))

# === 将整数批次索引 -> 友好批次名称 ===
""" batch_labels_str 的长度与样本数一致，顺序与 batches 相同；之后凡是画 UMAP 或算“批次相关”指标时，都用它替代裸的 batches。 """
# if len(ARGS["inputs"]) == 1 and ARGS.get("batch_key"):
#     # 单文件 + batch_key：用 obs[batch_key] 的实际取值
#     batch_names = []
#     for ad in adata_list:
#         vals = ad.obs[ARGS["batch_key"]].astype(str).unique().tolist()
#         batch_names.append(vals[0] if len(vals) else "batch")
# else:
#     # 多文件：用文件名（去后缀）
#     batch_names = [os.path.splitext(os.path.basename(p))[0] for p in ARGS["inputs"]]

# batch_labels_str = np.array([batch_names[int(i)] for i in batches], dtype=str)
# batch_labels_str = adata.obs["batch"].astype(str).values



filtered out 715 cells that have less than 200 genes expressed


normalizing counts per cell
    finished (0:00:01)
extracting highly variable genes
`n_top_genes` > `adata.n_var`, returning all genes.
    finished (0:00:00)
[Info] X shape: (8573, 1222), batches shape: (8573,), |HVG∩|=1222


  return dispatch(args[0].__class__)(*args, **kw)


' batch_labels_str 的长度与样本数一致，顺序与 batches 相同；之后凡是画 UMAP 或算“批次相关”指标时，都用它替代裸的 batches。 '

## 训练 DisentAE（可跳过/继续）

In [6]:

pth_model = os.path.join(ARGS["outdir"], "model.pt")
pth_Zc = os.path.join(ARGS["outdir"], "Zc.npy")
pth_Zb = os.path.join(ARGS["outdir"], "Zb.npy")

can_skip_train = (
    ARGS["resume"]
    and old_manifest is not None
    and old_manifest.get("inputs") == curr_manifest["inputs"]
    and old_manifest.get("params") == curr_manifest["params"]
    and os.path.exists(pth_model) and os.path.exists(pth_Zc) and os.path.exists(pth_Zb)
)

torch.manual_seed(0)
np.random.seed(0)

device = ARGS["device"]
print(f"[Info] Using device: {device}")

if can_skip_train and not ARGS["force"]:
    print("[Cache] 发现训练产物，直接加载")
    n_batch = int(batches.max()) + 1
    model = DisentAE(X.shape[1], n_batch, ARGS["zc_dim"], ARGS["zb_dim"]).to(device)
    model.load_state_dict(torch.load(pth_model, map_location=device))
    model.eval()
    Zc = np.load(pth_Zc)
    Zb = np.load(pth_Zb)
else:
    model, Zc, Zb = train_disentae(
        X, batches,
        zc_dim=ARGS["zc_dim"], zb_dim=ARGS["zb_dim"],
        batch_size=ARGS["batch_size"], epochs=ARGS["epochs"],
        lr=ARGS["lr"], device=device,
        grl_lambda=ARGS["grl_lambda"], inject_noise_std=ARGS["inject_noise_std"]
    )
    print(f"[Info] Zc shape: {Zc.shape}, Zb shape: {Zb.shape}")
    torch.save(model.state_dict(), pth_model)
    np.save(pth_Zc, Zc)
    np.save(pth_Zb, Zb)
    save_manifest(ARGS["outdir"], curr_manifest)

Zc.shape, Zb.shape


[Info] Using device: cpu
Epoch 1/50 | Recon 29.1623 | Content(adv) 0.0000 | Batch(z_b) 0.0000
Epoch 2/50 | Recon 27.1003 | Content(adv) 0.0000 | Batch(z_b) 0.0000
Epoch 3/50 | Recon 26.6052 | Content(adv) 0.0000 | Batch(z_b) 0.0000
Epoch 4/50 | Recon 26.1605 | Content(adv) 0.0000 | Batch(z_b) 0.0000
Epoch 5/50 | Recon 25.8523 | Content(adv) 0.0000 | Batch(z_b) 0.0000
Epoch 6/50 | Recon 25.4821 | Content(adv) 0.0000 | Batch(z_b) 0.0000
Epoch 7/50 | Recon 25.1949 | Content(adv) 0.0000 | Batch(z_b) 0.0000
Epoch 8/50 | Recon 24.8978 | Content(adv) 0.0000 | Batch(z_b) 0.0000
Epoch 9/50 | Recon 24.6191 | Content(adv) 0.0000 | Batch(z_b) 0.0000
Epoch 10/50 | Recon 24.3467 | Content(adv) 0.0000 | Batch(z_b) 0.0000
Epoch 11/50 | Recon 24.0790 | Content(adv) 0.0000 | Batch(z_b) 0.0000
Epoch 12/50 | Recon 23.8284 | Content(adv) 0.0000 | Batch(z_b) 0.0000
Epoch 13/50 | Recon 23.5908 | Content(adv) 0.0000 | Batch(z_b) 0.0000
Epoch 14/50 | Recon 23.3410 | Content(adv) 0.0000 | Batch(z_b) 0.0000
Epoc

((8573, 32), (8573, 16))

## MNN 迭代对齐（每次执行）

In [7]:

Zc_list = [Zc[batches == i] for i in range(int(batches.max()) + 1)]
Zc_corr_list = apply_mnn_correction_iterative(Zc_list, k=ARGS["k_mnn"])

Zc_corr = np.vstack(Zc_corr_list)
print(f"[Info] Zc_corr shape: {Zc_corr.shape}")
np.save(os.path.join(ARGS["outdir"], "Zc_corrected.npy"), Zc_corr)
Zc_corr.shape


[Info] Zc_corr shape: (8573, 32)


(8573, 32)

## 解码到“去批次”的表达矩阵（每次执行）

In [8]:

Xcorr_batches = decode_corrected_expression(model, Zc_corr_list, device=device)
Xcorr = np.vstack(Xcorr_batches)
print(f"[Info] Xcorr shape: {Xcorr.shape}")
np.save(os.path.join(ARGS["outdir"], "X_corrected.npy"), Xcorr)

# 也输出 TSV（大文件可能较慢/较大，请按需保留或注释掉）
np.savetxt(os.path.join(ARGS["outdir"], "Zc_corrected.tsv"), Zc_corr, delimiter="\t")
np.savetxt(os.path.join(ARGS["outdir"], "X_corrected.tsv"), Xcorr, delimiter="\t")
Xcorr.shape


[Info] Xcorr shape: (8573, 1222)


(8573, 1222)

## 可选 UMAP 可视化（Before/After）

In [9]:
adata_X = AnnData(X)
print("adata.obs 中的所有列（元数据）:")
print(adata_X.obs.columns)

adata.obs 中的所有列（元数据）:
Index([], dtype='object')


In [10]:
try:
    # === 先从原数据里取出标签（顺序需与 X / Zc_corr 保持一致）===
    # 如果你的 preprocess_batches 是按 adata_list 的顺序拼的行，
    # 下面这样拼接就与 X/Zc_corr 对齐：
    batch_labels = np.concatenate([ad.obs["batch"].astype(str).values for ad in adata_list])
    cell_types   = np.concatenate([ad.obs["cell_type"].astype(str).values for ad in adata_list])

    # ========== X 原始空间（Before）==========
    adata_X = AnnData(X)
    adata_X.obs["batch"]     = batch_labels
    adata_X.obs["cell_type"] = cell_types

    sc.pp.scale(adata_X, max_value=10)
    sc.tl.pca(adata_X)
    sc.pp.neighbors(adata_X)
    sc.tl.umap(adata_X)
    sc.pl.umap(adata_X, color="batch", title="Before correction (by batch)", save="_before_batch.png")
    sc.pl.umap(adata_X, color="cell_type", title="Before correction (by cell type)", save="_before_celltype.png")

    # ========== Zc_corr 空间（After）==========
    adata_Zc = AnnData(Zc_corr)
    adata_Zc.obs["batch"]     = batch_labels
    adata_Zc.obs["cell_type"] = cell_types

    outdir = "embeddings"
    os.makedirs(outdir, exist_ok=True)
    out_path = os.path.join(outdir, "my.npy")
    X = np.asarray(Zc_corr)
    np.save(out_path, X)
    print(f"[OK] Harmony -> {out_path} shape={X.shape}")

    sc.pp.scale(adata_Zc, max_value=10)
    sc.tl.pca(adata_Zc)
    sc.pp.neighbors(adata_Zc)
    sc.tl.umap(adata_Zc)
    sc.pl.umap(adata_Zc, color="batch", title="After correction (by batch)", save="_after_batch.png")
    sc.pl.umap(adata_Zc, color="cell_type", title="After correction (by cell type)", save="_after_celltype.png")

    print("[Info] 可视化保存至当前工作目录（或 scanpy 配置的 figure 目录）")
except Exception as e:
    print(f"[Warn] 可视化失败，已跳过。错误信息：{e}")


[Warn] visualize_correction 调用失败，已跳过。错误信息：'batch'


## 完成

In [11]:

print("[Done] All artifacts are saved in:", ARGS["outdir"])


[Done] All artifacts are saved in: outputs



## 评估与可视化（ARI / NMI / ASW_celltype / ASW_batch / iLISI / KL）

如果你有**细胞类型标签**（与样本顺序一致），请在下方把 `CELL_TYPES` 设为：
- 一个 `list/ndarray`（长度 N），或
- 指向一个包含 N 行、每行一个标签的 `TSV/CSV/TXT` 文件路径（自动读取第一列）。

若未提供细胞类型，代码将自动跳过 ARI/NMI 与 ASW_celltype，仅计算批次相关指标（ASW_batch、iLISI、KL）。


In [12]:

# 安装/导入评估工具（已生成在 /mnt/data/）
import sys, os, numpy as np
sys.path.append("/mnt/data")
from sc_integration_metrics import (
    compute_ari_nmi, compute_asw_celltype, compute_asw_batch,
    compute_ilisi, compute_local_kl, evaluate_all_with_plots,
    scatter_2d, bar_metrics, histogram
)

# === 参数：提供细胞类型标签（可选） ===
# CELL_TYPES = None  # 1) 直接给一个与样本顺序一致的list/ndarray；或 2) 给路径字符串（见下）


# 如果给了路径，尝试读取第一列作为标签
def _maybe_load_labels(x):
    if x is None:
        return None
    if isinstance(x, (list, tuple, np.ndarray)):
        return np.asarray(x)
    if isinstance(x, str) and os.path.exists(x):
        # 读第一列
        try:
            import pandas as pd
            s = pd.read_csv(x, sep=None, engine="python", header=None).iloc[:,0].astype(str).values
            return s
        except Exception:
            pass
        try:
            s = np.loadtxt(x, dtype=str)
            return s
        except Exception:
            pass
    return None

# cell_types = _maybe_load_labels(CELL_TYPES)
# print("[Eval] cell_types provided:", "YES" if cell_types is not None else "NO")

# === 选择要评估的嵌入 ===
# 你可以选择 Zc_before / Zc_after(Zc_corr) / UMAP坐标 等
X_EMB = Zc_corr  # 默认使用校正后的 Zc


In [13]:

# === 计算并出图 ===
if cell_types is None:
    # 仅批次相关的指标
    # asw_b = compute_asw_batch(X_EMB, batches)
    # ilisi, ilisi_norm, ilisi_dist = compute_ilisi(X_EMB, batches, k=90, return_distribution=True)
    # mean_kl, dkl_dist = compute_local_kl(X_EMB, batches, k=50, return_distribution=True)

    asw_b = compute_asw_batch(X_EMB, batch_labels_str)
    ilisi, ilisi_norm, ilisi_dist = compute_ilisi(X_EMB, batch_labels_str, k=90, return_distribution=True)
    mean_kl, dkl_dist = compute_local_kl(X_EMB, batch_labels_str, k=50, return_distribution=True)

    scatter_2d(X_EMB, batches, title="Embedding colored by Batch", assume_2d=False)
    bar_metrics(
        {
            "ASW_batch": asw_b,
            "iLISI_norm": ilisi_norm,
            "KL_local": mean_kl,
        },
        title="Batch-mixing Metrics Summary",
    )
    if ilisi_dist is not None:
        histogram(ilisi_dist, title="iLISI (per-cell effective #batches)", xlabel="effective batches")
    if dkl_dist is not None:
        histogram(dkl_dist, title="Local KL divergence (neighborhood vs global)", xlabel="D_KL")

    metrics = {
        "ARI": np.nan,
        "NMI": np.nan,
        "ASW_celltype": np.nan,
        "ASW_batch": asw_b,
        "iLISI": ilisi,
        "iLISI_norm": ilisi_norm,
        "KL_local_mean": mean_kl,
    }
else:
    # 计算全部六项指标并可视化
    results = evaluate_all_with_plots(
        X_emb=X_EMB,
        celltype_labels=cell_types,
        batch_labels=batches,
        assume_2d=False,
        k_ilisi=90,
        k_kl=50,
        clustering="kmeans",
        n_clusters=None,  # 默认=细胞类型的数量
    )
    metrics = {k: results[k] for k in [
        "ARI","NMI","ASW_celltype","ASW_batch","iLISI","iLISI_norm","KL_local_mean"
    ]}

print("\n[Metrics]")
for k, v in metrics.items():
    print(f"{k:15s} : {v}")


NameError: name 'cell_types' is not defined


> 备注：
> - `ASW_batch` 已将 Silhouette 映射到 [0,1]（越大批次越混合）。
> - `iLISI` 与 `iLISI_norm` 分别是“有效批次数”的均值与其按真实批次数归一化的版本。
> - `KL_local_mean` 是每个点邻域批次分布相对全局分布的平均 KL 散度（越小越好）。
> - ARI / NMI / ASW_celltype 需要真实细胞类型标签。
