In [28]:
# -*- coding: utf-8 -*-
"""
单文件版主脚本（仅支持一个 .h5ad 输入）
- 读取单个 h5ad，并按 ARGS["batch_key"] 列拆分为多个批次（adata_list）
- 预处理 -> 训练 DisentAE -> MNN 迭代对齐 -> 解码 -> UMAP 可视化
- UMAP 可视化对 cell type 做了稳健处理（候选列名 + 缺失填充 + 统一列名 "cell_type_std"）

依赖：
    - scanpy, anndata, numpy, torch
    - 项目内模块：model.py / file.py（包含 preprocess_batches, DisentAE,
      apply_mnn_correction_iterative, decode_corrected_expression, load_manifest, save_manifest 等）
"""

import os
import time
import numpy as np
import scanpy as sc
from anndata import AnnData
import torch

In [29]:
# ===== 项目内模块 =====
# 这里假设你的项目提供了所需函数；若没有，将抛出清晰的错误。
try:
    from model import (
        preprocess_batches,
        DisentAE,
        apply_mnn_correction_iterative,
        decode_corrected_expression,
        train_disentae,
    )
except Exception as e:
    raise ImportError(
        "未找到所需的项目内模块/函数，请确认 model.py 存在并包含 preprocess_batches/DisentAE/"
        "apply_mnn_correction_iterative/decode_corrected_expression/train_disentae 等。"
    ) from e

try:
    from file import load_manifest, save_manifest
except Exception:
    # 提供安全降级：没有 manifest 功能也能跑
    def load_manifest(_outdir: str):
        return None

    def save_manifest(_outdir: str, _manifest: dict):
        pass

print(f"[Info] Using torch {torch.__version__}, CUDA available: {torch.cuda.is_available()}")
sc.settings.verbosity = 2

[Info] Using torch 2.8.0+cu128, CUDA available: False


In [30]:
# =========== 参数 ============
# "bct_raw.h5ad""BATCH""celltype"
#  maral_raw.h5ad 'batch' 'cell_type1',

# "inputs": [
#     "bct_raw.h5ad",
# ],
# "batch_key": "BATCH",           
# "cell_type": "celltype",  

    # "inputs": [
    #     "mural_raw.h5ad",
    # ],
    # "batch_key": "batch",           
    # "cell_type": "cell_type1",  

ARGS = {
    "inputs": [
        "mural_raw.h5ad",
    ],
    "batch_key": "batch",           
    "cell_type": "cell_type1",  
    "outdir": "outputs",

    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "zc_dim": 32,
    "zb_dim": 16,
    "epochs": 35,
    "batch_size": 256,
    "lr": 1e-2,
    "grl_lambda": 1.0,
    "inject_noise_std": 0.1,

    "n_top_genes": 2000,
    "k_mnn": 20,

    # 训练阶段：默认启用缓存；若 force=True 则无视缓存强制重训
    "resume": True,
    "force": False,
}

os.makedirs(ARGS["outdir"], exist_ok=True)
print("[Params]")
for k, v in ARGS.items():
    if k != "inputs":
        print(f"  {k}: {v}")
print(f"  inputs: {len(ARGS['inputs'])} file -> {ARGS['inputs']}")

# 统一设置图像保存目录到 outputs/figs，避免默认 figures 目录分散
sc.settings.figdir = os.path.join(ARGS["outdir"], "figs")
os.makedirs(sc.settings.figdir, exist_ok=True)

[Params]
  batch_key: batch
  cell_type: cell_type1
  outdir: outputs
  device: cpu
  zc_dim: 32
  zb_dim: 16
  epochs: 35
  batch_size: 256
  lr: 0.01
  grl_lambda: 1.0
  inject_noise_std: 0.1
  n_top_genes: 2000
  k_mnn: 20
  resume: True
  force: False
  inputs: 1 file -> ['mural_raw.h5ad']


In [31]:
#============ 读取数据为按批次的 `AnnData` 列表（仅单文件） ============
adata_list = []

# 只允许一个 .h5ad 文件
if len(ARGS["inputs"]) != 1:
    raise ValueError("此脚本现在只支持一个 .h5ad 文件，请在 ARGS['inputs'] 里仅保留一个路径。")

if ARGS["batch_key"] is None:
    raise ValueError("请在 ARGS['batch_key'] 中提供批次列名，例如 'batch' 或 'BATCH'。")

in_path = ARGS["inputs"][0]
if not os.path.exists(in_path):
    raise FileNotFoundError(f"找不到输入文件：{in_path}")

ad = sc.read_h5ad(in_path)
if ARGS["batch_key"] not in ad.obs:
    raise ValueError(
        f"batch_key '{ARGS['batch_key']}' 不在 adata.obs 中，可用列有：{list(ad.obs.columns)}"
    )

# Debug 信息（可按需关闭）
PRINT_DEBUG = True
if PRINT_DEBUG:
    print("\n[Debug] 细胞名（obs_names）示例:")
    print(ad.obs_names[:10])
    print("\n[Debug] 基因名（var_names）示例:")
    print(ad.var_names[:10])
    print("\n[Debug] adata.obs 列：")
    print(ad.obs.columns)

# 按批次列拆分为多个 AnnData（每个批次一个）
for b in ad.obs[ARGS["batch_key"]].astype(str).unique():
    ad_b = ad[ad.obs[ARGS["batch_key"]].astype(str) == b].copy()
    adata_list.append(ad_b)

print(f"[Info] 从单个 h5ad 依据 '{ARGS['batch_key']}' 拆出 {len(adata_list)} 个批次")



[Debug] 细胞名（obs_names）示例:
Index(['D28.1_1', 'D28.1_13', 'D28.1_15', 'D28.1_17', 'D28.1_2', 'D28.1_26',
       'D28.1_29', 'D28.1_3', 'D28.1_30', 'D28.1_37'],
      dtype='object')

[Debug] 基因名（var_names）示例:
Index(['A1BG-AS1', 'A1BG', 'A1CF', 'A2M-AS1', 'A2ML1', 'A2M', 'A4GALT',
       'A4GNT', 'AAAS', 'AACSP1'],
      dtype='object')

[Debug] adata.obs 列：
Index(['batch', 'cell_ontology_class', 'cell_ontology_id', 'cell_type1',
       'dataset_name', 'donor', 'organ', 'organism', 'platform'],
      dtype='object')
[Info] 从单个 h5ad 依据 'batch' 拆出 8 个批次


In [32]:
# ======================= 预处理 =======================
X, batches, hvg_list = preprocess_batches(
    adata_list, n_top_genes=ARGS["n_top_genes"]
)
print(f"[Info] X shape: {X.shape}, batches shape: {batches.shape}, |HVG∩|={len(hvg_list)}")

np.save(os.path.join(ARGS["outdir"], "X_preprocessed.npy"), X)
np.save(os.path.join(ARGS["outdir"], "batches.npy"), batches)
with open(os.path.join(ARGS["outdir"], "HVG_intersection.tsv"), "w", encoding="utf-8") as f:
    f.write("\n".join(map(str, hvg_list)))

filtered out 3990 genes that are detected in less than 3 cells
filtered out 3514 genes that are detected in less than 3 cells


filtered out 3873 genes that are detected in less than 3 cells
filtered out 3720 genes that are detected in less than 3 cells
filtered out 4491 genes that are detected in less than 3 cells
filtered out 4312 genes that are detected in less than 3 cells
filtered out 4199 genes that are detected in less than 3 cells
filtered out 4284 genes that are detected in less than 3 cells
normalizing counts per cell
    finished (0:00:00)
extracting highly variable genes
    finished (0:00:00)
normalizing counts per cell
    finished (0:00:00)
extracting highly variable genes
    finished (0:00:00)
normalizing counts per cell
    finished (0:00:00)
extracting highly variable genes
    finished (0:00:00)
normalizing counts per cell
    finished (0:00:00)
extracting highly variable genes
    finished (0:00:00)
normalizing counts per cell
    finished (0:00:00)
extracting highly variable genes
    finished (0:00:00)
normalizing counts per cell
    finished (0:00:00)
extracting highly variable genes
   

  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)


In [33]:
#================ 训练 DisentAE ==========================
pth_model = os.path.join(ARGS["outdir"], "model.pt")
pth_Zc = os.path.join(ARGS["outdir"], "Zc.npy")
pth_Zb = os.path.join(ARGS["outdir"], "Zb.npy")

torch.manual_seed(0)
np.random.seed(0)

device = ARGS["device"]
print(f"[Info] Using device: {device}")

model, Zc, Zb = train_disentae(
    X, batches,
    zc_dim=ARGS["zc_dim"], zb_dim=ARGS["zb_dim"],
    batch_size=ARGS["batch_size"], epochs=ARGS["epochs"],
    lr=ARGS["lr"], device=device,
    inject_noise_std=ARGS["inject_noise_std"]
)
print(f"[Info] Zc shape: {Zc.shape}, Zb shape: {Zb.shape}")
torch.save(model.state_dict(), pth_model)
np.save(pth_Zc, Zc)
np.save(pth_Zb, Zb)
try:
    save_manifest(ARGS["outdir"], curr_manifest)
except Exception:
    pass

[Info] Using device: cpu
Epoch 1/35 | Recon 3.0636 | Adv(z_c) 2.0757 | Batch(z_b) 2.0761 | XCov 0.6576 | Align 0.1511 | Center 2.0523
Epoch 2/35 | Recon 1.1502 | Adv(z_c) 2.0709 | Batch(z_b) 2.0731 | XCov 0.4526 | Align 0.3600 | Center 0.4592
Epoch 3/35 | Recon 0.8476 | Adv(z_c) 2.0742 | Batch(z_b) 2.0754 | XCov 0.1781 | Align 0.2717 | Center 0.2692
Epoch 4/35 | Recon 0.7755 | Adv(z_c) 2.0718 | Batch(z_b) 2.0737 | XCov 0.0295 | Align 0.1245 | Center 0.1726
Epoch 5/35 | Recon 0.7251 | Adv(z_c) 2.0745 | Batch(z_b) 2.0760 | XCov 0.0189 | Align 0.0762 | Center 0.1515
Epoch 6/35 | Recon 0.6780 | Adv(z_c) 2.0769 | Batch(z_b) 2.0763 | XCov 0.0146 | Align 0.0698 | Center 0.1451
Epoch 7/35 | Recon 0.6187 | Adv(z_c) 2.0731 | Batch(z_b) 2.0742 | XCov 0.0123 | Align 0.0693 | Center 0.1314
Epoch 8/35 | Recon 0.6263 | Adv(z_c) 2.0711 | Batch(z_b) 2.0712 | XCov 0.0108 | Align 0.0669 | Center 0.1170
Epoch 9/35 | Recon 0.5962 | Adv(z_c) 2.0740 | Batch(z_b) 2.0725 | XCov 0.0071 | Align 0.0394 | Center 0

In [34]:
#============= MNN 迭代对齐 ===============
Zc_list = [Zc[batches == i] for i in range(int(batches.max()) + 1)]
Zc_corr_list = apply_mnn_correction_iterative(Zc_list, k=ARGS["k_mnn"])

Zc_corr = np.vstack(Zc_corr_list)
# embedding = umap.UMAP(metric="cosine", n_neighbors=50, min_dist=0.3, random_state=0).fit_transform(Zc_corr)
print(f"[Info] Zc_corr shape: {Zc_corr.shape}")
np.save(os.path.join(ARGS["outdir"], "Zc_corrected.npy"), Zc_corr)

# 也输出 TSV（可选，文件会较大）
np.savetxt(os.path.join(ARGS["outdir"], "Zc_corrected.tsv"), Zc_corr, delimiter="\t")


[Info] Zc_corr shape: (2122, 32)


In [35]:
#========== 解码到“去批次”的表达矩阵 ============
Xcorr_batches = decode_corrected_expression(model, Zc_corr_list, device=device)
Xcorr = np.vstack(Xcorr_batches)
print(f"[Info] Xcorr shape: {Xcorr.shape}")
np.save(os.path.join(ARGS["outdir"], "X_corrected.npy"), Xcorr)
np.savetxt(os.path.join(ARGS["outdir"], "X_corrected.tsv"), Xcorr, delimiter="\t")


[Info] Xcorr shape: (2122, 811)


In [36]:
#========= UMAP 可视化（Before/After） =========
try:
    # === 与 X/Zc_corr 行顺序对齐的标签拼接 ===
    batch_labels = np.concatenate([
        ad.obs[ARGS['batch_key']].astype(str).values for ad in adata_list
    ])

    # —— 稳健提取 cell type：候选列名 + 缺失填充 ——
    _CTYPE_CANDIDATES = [ARGS.get("cell_type"), "cell_type", "celltype",
                         "CellType", "cell_types", "annotation", "cell_label"]

    def _pick_col(df, cands):
        for c in cands:
            if c and (c in df.columns):
                return c
        return None

    _celltype_arrays = []
    for ad_sub in adata_list:
        key = _pick_col(ad_sub.obs, _CTYPE_CANDIDATES)
        if key is None:
            _celltype_arrays.append(np.array(["Unknown"] * ad_sub.n_obs))
        else:
            _celltype_arrays.append(ad_sub.obs[key].astype(str).values)
    cell_types = np.concatenate(_celltype_arrays)

    # ========== X 原始空间（Before） ==========
    adata_X = AnnData(X)
    adata_X.obs[ARGS['batch_key']] = batch_labels
    adata_X.obs["cell_type_std"] = cell_types

    sc.pp.scale(adata_X, max_value=10)
    sc.tl.pca(adata_X)
    sc.pp.neighbors(adata_X)
    sc.tl.umap(adata_X)
    sc.pl.umap(
        adata_X,
        color=ARGS['batch_key'],
        # title="After correction (by batch)",
        title="MyModel —— batch",
        save="_before_batch.png",
    )
    if len(set(cell_types)) > 1:
        sc.pl.umap(
            adata_X,
            color="cell_type_std",
            # title="After correction (by cell type)",
            title="MyModel —— celltype",
            save="_before_celltype.png",
        )
    else:
        print("[Info] 未检测到有效的多类细胞类型标签，已跳过按 cell type 上色（Before）。")

    # ========== Zc_corr 空间（After） ==========
    adata_Zc = AnnData(Zc_corr)
    adata_Zc.obs[ARGS['batch_key']] = batch_labels
    adata_Zc.obs["cell_type_std"] = cell_types

    # 另外保存嵌入到 outputs/embeddings
    emb_dir = os.path.join(ARGS["outdir"], "embeddings")
    os.makedirs(emb_dir, exist_ok=True)
    emb_path = os.path.join(emb_dir, "MyModel.npy")
    Zc_to_save = np.asarray(Zc_corr)
    np.save(emb_path, Zc_to_save)
    print(f"[OK] Embedding saved -> {emb_path} shape={Zc_to_save.shape}")

    sc.pp.scale(adata_Zc, max_value=10)
    sc.tl.pca(adata_Zc)
    sc.pp.neighbors(adata_Zc)
    sc.tl.umap(adata_Zc)
    sc.pl.umap(
        adata_Zc,
        color=ARGS['batch_key'],
        title="After correction (by batch)",
        save="_after_batch_z.png",
    )
    if len(set(cell_types)) > 1:
        sc.pl.umap(
            adata_Zc,
            color="cell_type_std",
            title="After correction (by cell type)",
            save="_after_celltype_z.png",
        )
    else:
        print("[Info] 未检测到有效的多类细胞类型标签，已跳过按 cell type 上色（After）。")

    print("[Info] 可视化保存至:", sc.settings.figdir)

    # # ========== X_corr 空间（After） ==========
    # adata_Zc = AnnData(Xcorr)
    # adata_Zc.obs[ARGS['batch_key']] = batch_labels
    # adata_Zc.obs["cell_type_std"] = cell_types

    # # 另外保存嵌入到 outputs/embeddings
    # emb_dir = os.path.join(ARGS["outdir"], "embeddings")
    # os.makedirs(emb_dir, exist_ok=True)
    # emb_path = os.path.join(emb_dir, "MyModel.npy")
    # Zc_to_save = np.asarray(Zc_corr)
    # np.save(emb_path, Zc_to_save)
    # print(f"[OK] Embedding saved -> {emb_path} shape={Zc_to_save.shape}")

    # sc.pp.scale(adata_Zc, max_value=10)
    # sc.tl.pca(adata_Zc)
    # sc.pp.neighbors(adata_Zc)
    # sc.tl.umap(adata_Zc)
    # sc.pl.umap(
    #     adata_Zc,
    #     color=ARGS['batch_key'],
    #     title="After correction (by batch)",
    #     save="_after_batch_x.png",
    # )
    # if len(set(cell_types)) > 1:
    #     sc.pl.umap(
    #         adata_Zc,
    #         color="cell_type_std",
    #         title="After correction (by cell type)",
    #         save="_after_celltype_x.png",
    #     )
    # else:
    #     print("[Info] 未检测到有效的多类细胞类型标签，已跳过按 cell type 上色（After）。")

    # print("[Info] 可视化保存至:", sc.settings.figdir)
except Exception as e:
    print(f"[Warn] 可视化失败，已跳过。错误信息：{e}")

computing PCA


    with n_comps=50
    finished (0:00:08)
computing neighbors
    using 'X_pca' with n_pcs = 50
    finished (0:00:00)
computing UMAP
    finished (0:00:01)
[OK] Embedding saved -> outputs/embeddings/MyModel.npy shape=(2122, 32)
computing PCA
    with n_comps=31
    finished (0:00:00)
computing neighbors
    using data matrix X directly


  plt.show()
  plt.show()


    finished (0:00:00)
computing UMAP
    finished (0:00:01)
[Info] 可视化保存至: outputs/figs


  plt.show()
  plt.show()


In [37]:
#============== finished ============
print("[Done] All artifacts are saved in:", ARGS["outdir"])


[Done] All artifacts are saved in: outputs
