# 02b - Python 批次校正方法

本 notebook 包含 4 种基于 Python 的批次校正方法：**scVI**, **scANVI**, **Scanorama**, **BBKNN**。

**使用方式**：
1. 先在 R notebook (02-Integration.ipynb) 中运行到"导出数据"cell
2. 在对应的 Python 环境中运行本 notebook
3. 运行完后回到 R notebook 导入结果

**输入文件**（由 R 导出）：
- `counts.mtx`, `genes.txt`, `barcodes.txt`
- `metadata.csv`
- `hvg.txt`
- `pca_embedding.csv`

In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
from scipy.io import mmread
from scipy.sparse import csr_matrix
import os
import time

## ========== 关键参数 ==========
input_dir  = "/home/data/tanglei/project/prostate_altas/output/integration"
output_dir = input_dir
batch_key     = "orig.ident"   # <-- 与 R 中的 batch_var 保持一致
celltype_key  = "celltype"     # <-- scANVI 需要
n_latent = 30
n_epochs = None  # None = 自动决定

print(f"输入/输出目录: {input_dir}")

In [None]:
## ========== 加载 R 导出的数据 ==========
print("读取 counts 矩阵...")
counts = mmread(os.path.join(input_dir, "counts.mtx")).T.tocsr()  # cells x genes

genes = open(os.path.join(input_dir, "genes.txt")).read().strip().split("\n")
barcodes = open(os.path.join(input_dir, "barcodes.txt")).read().strip().split("\n")

print("读取 metadata...")
meta = pd.read_csv(os.path.join(input_dir, "metadata.csv"), index_col=0)

print("读取 HVG 列表...")
hvg = open(os.path.join(input_dir, "hvg.txt")).read().strip().split("\n")

print("读取 PCA embedding...")
pca_emb = pd.read_csv(os.path.join(input_dir, "pca_embedding.csv"), index_col=0)

## 构建 AnnData
adata = ad.AnnData(
    X=counts,
    obs=meta.loc[barcodes],
    var=pd.DataFrame(index=genes)
)
adata.obs_names = barcodes
adata.var_names = genes

## 标记 HVG
adata.var["highly_variable"] = adata.var_names.isin(hvg)

## 存入 PCA
adata.obsm["X_pca"] = pca_emb.loc[adata.obs_names].values

## 保存原始 counts 到 raw
adata.layers["counts"] = adata.X.copy()

## 标准化（log1p）供部分方法使用
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

print(f"AnnData: {adata.shape[0]} cells x {adata.shape[1]} genes")
print(f"HVG 数量: {sum(adata.var['highly_variable'])}")
print(f"批次数量: {adata.obs[batch_key].nunique()}")
adata

## 1. scVI

In [None]:
import scvi

print(">>> scVI 开始...")
t0 = time.time()

## scVI 需要原始 counts
adata_scvi = adata.copy()
adata_scvi.X = adata_scvi.layers["counts"].copy()

## 只用 HVG
adata_scvi = adata_scvi[:, adata_scvi.var["highly_variable"]].copy()

## 注册 + 训练
scvi.model.SCVI.setup_anndata(
    adata_scvi,
    layer="counts",
    batch_key=batch_key
)

model_scvi = scvi.model.SCVI(
    adata_scvi,
    n_latent=n_latent,
    n_layers=2,
    gene_likelihood="zinb"
)

model_scvi.train(max_epochs=n_epochs)

## 提取 latent embedding
scvi_emb = model_scvi.get_latent_representation()
scvi_df = pd.DataFrame(
    scvi_emb,
    index=adata_scvi.obs_names,
    columns=[f"scVI_{i+1}" for i in range(scvi_emb.shape[1])]
)

## 保存
scvi_df.to_csv(os.path.join(output_dir, "scvi_embedding.csv"))

elapsed = time.time() - t0
print(f"scVI 完成! 耗时: {elapsed:.1f} 秒")
print(f"Embedding shape: {scvi_df.shape}")

## 存入 adata 用于后续可视化
adata.obsm["X_scVI"] = scvi_emb

## 2. scANVI（监督版 scVI）

scANVI 利用已知的细胞类型标签进行半监督整合。需要 `celltype` 列。  
如果部分细胞没有标签，设为 `"Unknown"` 即可。

In [None]:
print(">>> scANVI 开始...")
t0 = time.time()

## 从已训练的 scVI 模型初始化 scANVI
## 处理缺失标签：NaN/空值 -> "Unknown"
adata_scvi.obs["celltype_scanvi"] = adata_scvi.obs[celltype_key].fillna("Unknown").astype(str)
adata_scvi.obs.loc[adata_scvi.obs["celltype_scanvi"] == "", "celltype_scanvi"] = "Unknown"

model_scanvi = scvi.model.SCANVI.from_scvi_model(
    model_scvi,
    adata=adata_scvi,
    labels_key="celltype_scanvi",
    unlabeled_category="Unknown"
)

model_scanvi.train(max_epochs=20)

## 提取 latent embedding
scanvi_emb = model_scanvi.get_latent_representation()
scanvi_df = pd.DataFrame(
    scanvi_emb,
    index=adata_scvi.obs_names,
    columns=[f"scANVI_{i+1}" for i in range(scanvi_emb.shape[1])]
)

## 保存
scanvi_df.to_csv(os.path.join(output_dir, "scanvi_embedding.csv"))

elapsed = time.time() - t0
print(f"scANVI 完成! 耗时: {elapsed:.1f} 秒")
print(f"Embedding shape: {scanvi_df.shape}")

adata.obsm["X_scANVI"] = scanvi_emb

## 清理 scVI/scANVI 临时对象
del adata_scvi, model_scvi, model_scanvi
import gc; gc.collect()

## 3. Scanorama

In [None]:
import scanorama

print(">>> Scanorama 开始...")
t0 = time.time()

## Scanorama 需要按 batch 分割的 list
adata_sca = adata.copy()
adata_sca = adata_sca[:, adata_sca.var["highly_variable"]].copy()

batches = adata_sca.obs[batch_key].unique().tolist()
adatas_list = [adata_sca[adata_sca.obs[batch_key] == b].copy() for b in batches]

## 运行 Scanorama
scanorama.integrate_scanpy(adatas_list, dimred=n_latent)

## 合并 embedding
scanorama_emb = np.concatenate([a.obsm["X_scanorama"] for a in adatas_list], axis=0)
scanorama_cells = np.concatenate([a.obs_names.tolist() for a in adatas_list])

scanorama_df = pd.DataFrame(
    scanorama_emb,
    index=scanorama_cells,
    columns=[f"Scanorama_{i+1}" for i in range(scanorama_emb.shape[1])]
)
## 按原始 cell 顺序重排
scanorama_df = scanorama_df.loc[adata.obs_names]

## 保存
scanorama_df.to_csv(os.path.join(output_dir, "scanorama_embedding.csv"))

elapsed = time.time() - t0
print(f"Scanorama 完成! 耗时: {elapsed:.1f} 秒")
print(f"Embedding shape: {scanorama_df.shape}")

adata.obsm["X_Scanorama"] = scanorama_df.values

del adata_sca, adatas_list
import gc; gc.collect()

## 4. BBKNN

BBKNN 不产生 corrected embedding，而是直接在 PCA 空间构建 batch-balanced kNN graph。  
因此我们导出基于 BBKNN graph 的 UMAP 坐标。

In [None]:
import bbknn

print(">>> BBKNN 开始...")
t0 = time.time()

adata_bbknn = adata.copy()

## BBKNN 在 PCA 空间上工作
bbknn.bbknn(adata_bbknn, batch_key=batch_key, n_pcs=30)

## 基于 BBKNN graph 计算 UMAP
sc.tl.umap(adata_bbknn)

## 导出 UMAP 坐标（BBKNN 不产生 embedding，graph 无法直接传给 R）
bbknn_umap = pd.DataFrame(
    adata_bbknn.obsm["X_umap"],
    index=adata_bbknn.obs_names,
    columns=["BBKNN_UMAP_1", "BBKNN_UMAP_2"]
)
bbknn_umap.to_csv(os.path.join(output_dir, "bbknn_umap.csv"))

## 同时导出 BBKNN 后的 connectivities 和 distances 以便 R 中进一步分析
## 以及利用 BBKNN graph 做 diffmap 得到更高维 embedding 用于 benchmark
sc.tl.diffmap(adata_bbknn, n_comps=30)
bbknn_emb = pd.DataFrame(
    adata_bbknn.obsm["X_diffmap"][:, :30],
    index=adata_bbknn.obs_names,
    columns=[f"BBKNN_{i+1}" for i in range(30)]
)
bbknn_emb.to_csv(os.path.join(output_dir, "bbknn_embedding.csv"))

elapsed = time.time() - t0
print(f"BBKNN 完成! 耗时: {elapsed:.1f} 秒")
print(f"UMAP shape: {bbknn_umap.shape}")
print(f"Diffmap embedding shape: {bbknn_emb.shape}")

del adata_bbknn
import gc; gc.collect()

## 汇总可视化（可选）

在 Python 中快速查看各方法的 UMAP。

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(20, 16))

methods_obsm = {
    "scVI": "X_scVI",
    "scANVI": "X_scANVI",
    "Scanorama": "X_Scanorama"
}

for ax, (name, key) in zip(axes.flat[:3], methods_obsm.items()):
    ## 对每种 embedding 计算 UMAP
    sc.pp.neighbors(adata, use_rep=key, n_neighbors=30)
    sc.tl.umap(adata)
    sc.pl.umap(adata, color=batch_key, ax=ax, show=False, title=name,
               legend_loc="none", frameon=True, size=1)

## BBKNN 已有 UMAP
ax = axes.flat[3]
bbknn_umap_vals = pd.read_csv(os.path.join(output_dir, "bbknn_umap.csv"), index_col=0)
ax.scatter(bbknn_umap_vals.iloc[:, 0], bbknn_umap_vals.iloc[:, 1],
           c="gray", s=0.5, alpha=0.3)
ax.set_title("BBKNN")
ax.set_xlabel("UMAP1")
ax.set_ylabel("UMAP2")

plt.tight_layout()
plt.savefig(os.path.join(output_dir, "python_methods_umap.png"), dpi=150, bbox_inches="tight")
plt.show()
print("Python 方法全部完成！结果已保存到:", output_dir)