<a href="https://colab.research.google.com/github/Bio-MingChen/DL_practice_by_Colab/blob/main/Raman2single_cell_refinement.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from IPython.display import clear_output

# Call clear_output() to clear the output of the current cell
! pip install scanpy scikit-learn
clear_output()

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os
import sys
sys.path.extend([".", ".."])

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils as U
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

import numpy as np
import scipy as sp
import pandas as pd

# -------------------------
# helpers
# -------------------------
def turn_on_model(model: nn.Module):
    for p in model.parameters():
        p.requires_grad = True

def turn_off_model(model: nn.Module):
    for p in model.parameters():
        p.requires_grad = False


# -------------------------
# Encoder / Decoder
# -------------------------
class StandardEncoder(nn.Module):
    """
    与原文相同的多层全连接 + BN + ReLU 堆叠，最后输出 mean / logvar。
    hidden_dim 默认 512，可改。
    """
    def __init__(self, input_dim: int, latent_dim: int, hidden_dim: int = 512):
        super().__init__()
        self.part1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),

            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),

            # extra block (保持与原文一致)
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),

            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),

            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
        )
        self.to_mean   = nn.Linear(hidden_dim, latent_dim)
        self.to_logvar = nn.Linear(hidden_dim, latent_dim)

        self.latent_dim = latent_dim

    def forward(self, x):
        h = self.part1(x)
        return self.to_mean(h), self.to_logvar(h)


class StandardDecoder(nn.Module):
    """
    与原文一致，区别仅在于用显式数字替代 1<<k。
    no_final_relu=True 时不做末层 ReLU（常用于重构原值回归）。
    """
    def __init__(self, input_dim: int, latent_dim: int, hidden_dim: int = 512, no_final_relu: bool = False):
        super().__init__()

        layers = [
            nn.Linear(latent_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),

            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),

            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),

            # extra block
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),

            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),

            nn.Linear(hidden_dim, input_dim),
        ]
        if not no_final_relu:
            layers.append(nn.ReLU(inplace=True))

        self.net = nn.Sequential(*layers)
        self.latent_dim = latent_dim

    def forward(self, z):
        return self.net(z)


class Discriminator(nn.Module):
    """
    判别器与原文相同：可选 spectral norm，输出 end_dim=2 个 logits（源/目标 one-hot）。
    层宽 64 -> 32 -> 32 -> end_dim，均显式数字，不再用位移写法。
    """
    def __init__(self, latent_dim: int, spectral: bool = True, end_dim: int = 2):
        super().__init__()
        def linear(n_in, n_out):
            layer = nn.Linear(n_in, n_out)
            return U.spectral_norm(layer) if spectral else layer

        self.net = nn.Sequential(
            linear(latent_dim, 64),
            nn.ReLU(inplace=True),
            linear(64, 32),
            nn.ReLU(inplace=True),
            linear(32, 32),
            nn.ReLU(inplace=True),
            linear(32, end_dim),   # logits（不用 Sigmoid/Softmax，后面配 BCEWithLogits）
        )

    def forward(self, z):
        return self.net(z)


# -------------------------
# VAE wrapper
# -------------------------
class VAE(nn.Module):
    """
    与原逻辑保持一致：
      - is_vae=True 时用 reparam；False 则直接用 mean。
      - latent_norm: 'batch'（默认）或 'none'；原代码等价于一直做 BatchNorm。
      - forward: 先 reparam 再（可选）做 latent 归一化（去掉了原代码中永不触发的 if 0 分支）。
    """
    def __init__(
        self,
        encoder: nn.Module,
        decoder: nn.Module,
        is_vae: bool = True,
        latent_norm: str = 'batch',   # 'batch' or 'none'
    ):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.is_vae = is_vae

        if latent_norm == 'batch':
            self.latent_normalizer = nn.BatchNorm1d(self.encoder.latent_dim)
            self.use_latent_norm = True
        elif latent_norm == 'none':
            self.latent_normalizer = None
            self.use_latent_norm = False
        else:
            raise ValueError("latent_norm must be 'batch' or 'none'")

    @staticmethod
    def reparam_trick(mean, logvar, is_vae: bool):
        if not is_vae:
            return mean
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def get_latent(self, x, latent_noise_std: float = 0.0):
        mean, logvar = self.encoder(x)
        z = self.reparam_trick(mean, logvar, self.is_vae)
        if self.use_latent_norm:
            z = self.latent_normalizer(z)
        if latent_noise_std and latent_noise_std > 0:
            z = z + latent_noise_std * torch.randn_like(z)
        return z

    def forward(self, x, latent_noise_std: float = 0.0):
        mean, logvar = self.encoder(x)
        z = self.reparam_trick(mean, logvar, self.is_vae)
        if self.use_latent_norm:
            z = self.latent_normalizer(z)
        if latent_noise_std and latent_noise_std > 0:
            z = z + latent_noise_std * torch.randn_like(z)
        recon_x = self.decoder(z)
        return recon_x, mean, logvar, z


# -------------------------
# Losses（保持原接口/语义）
# -------------------------
def old_mse_loss(x, recon_x, weights=None):
    """
    原始实现是 MSE；去掉了会把 loss 放大的常数乘子。
    """
    if weights is None:
        return F.mse_loss(recon_x, x)
    else:
        # 与 weighted_mse 保持一致
        return weighted_mse(recon_x, x, weights=weights)

def weighted_mse(a, b, weights=None):
    if weights is None:
        return F.mse_loss(a, b)
    return torch.sum(((a - b) ** 2) * weights)

def old_vae_loss(x, recon_x, mean, logvar, weights=None, this_lambda=0.0):
    """
    VAE 重构 + KL；当 this_lambda=0 时退化为纯 MSE（原文常用）。
    """
    if weights is None:
        bce = F.mse_loss(recon_x, x)
    else:
        bce = weighted_mse(recon_x, x, weights=weights)
    kl_div = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    return bce + this_lambda * kl_div

def discrim_loss(pred_logits, true_onehot):
    """
    与原代码一致：使用二分类的 two-logits + one-hot 目标。
    pred_logits: (B, 2)
    true_onehot: (B, 2)  例如 [1,0] or [0,1]
    """
    return F.binary_cross_entropy_with_logits(pred_logits, true_onehot)

def adv_vae_loss(
    x, recon_x,
    mean, logvar, discrim_preds,
    alpha: float, beta: float, device=None, weights=None,
):
    """
    组合损失：alpha * VAE重构(+KL) + beta * 判别器 BCE。
    判别器监督标签默认使用 [1,0]（source）作为“真”标签。
    """
    vae_part = old_vae_loss(x, recon_x, mean, logvar, weights=weights)

    # 构造 one-hot 标签（与原文一致）
    source_label = torch.tensor([1.0, 0.0], device=device)
    discrim_labels = source_label.expand(x.shape[0], 2)

    disc_part = F.binary_cross_entropy_with_logits(discrim_preds, discrim_labels)
    total = alpha * vae_part + beta * disc_part
    return total, vae_part, disc_part


In [5]:
# ================================
# Preprocess → Balanced split → DataLoaders
# for RNA / ATAC(gene activity) / Raman
# ================================
import numpy as np
import scanpy as sc
import torch
import scipy.sparse as sp
from torch.utils.data import TensorDataset, DataLoader
from sklearn.utils import check_random_state

# --------- 配置 ---------
BATCH_SIZE = 64
TEST_FRACTION = 0.20
RANDOM_SEED = 0

# RNA
RNA_USE_X_AS_IS = True     # True: 使用 adata_n.X（你说已是 scaled）; False: 用 layers['data'] 或 ['counts'] 再做 log1p
RNA_USE_HVG = True         # 若 adata_n.var['highly_variable'] 存在，优先只取 HVG
RNA_N_TOP_HVG = 2000       # 若需要重新挑HVV时可用
RNA_ZSCORE = False         # 通常不需要（你已 scaled）

# ATAC（gene activity）
ATAC_USE_X_AS_IS = True    # True: 用 adata_a.X（已 normalized 的基因活性）
ATAC_ZSCORE = False        # 可按基因再做 z-score（通常不需要，怕和上游norm重复）

# Raman
RAMAN_ZSCORE = True        # 对每个波数做 z-score（常见，对模型更友好）

# 跨模态（RNA vs ATAC）
USE_GENE_INTERSECTION = False  # 若要在RNA/ATAC上做对齐/共享空间，建议 True
# -----------------------

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device =', device)

# def to_dense_if_sparse(X):
#     if hasattr(X, "A"):  # scipy.sparse matrix
#         return X.A.astype(np.float32)
#     return np.asarray(X, dtype=np.float32)

def to_dense_if_sparse(X):
    """稳健地把 AnnData.X 或 layer 转成 float32 的 dense ndarray。"""
    if sp.issparse(X):                 # 关键：用 issparse 判断
        return X.toarray().astype(np.float32)
    # 某些环境里会是 np.matrix，需要转成 ndarray
    if isinstance(X, np.matrix):
        return np.asarray(X, dtype=np.float32)
    return np.asarray(X, dtype=np.float32)

def zscore_features(X):
    mu = X.mean(axis=0, keepdims=True)
    sd = X.std(axis=0, keepdims=True) + 1e-6
    return (X - mu) / sd

def encode_labels(celltypes):
    classes = np.unique(celltypes)
    label2id = {c:i for i,c in enumerate(classes)}
    id2label = {i:c for c,i in label2id.items()}
    y = np.array([label2id[c] for c in celltypes], dtype=np.int64)
    return y, label2id, id2label

def balanced_stratified_indices(labels, test_fraction=0.2, seed=0):
    rng = check_random_state(seed)
    labels = np.asarray(labels)
    classes, counts = np.unique(labels, return_counts=True)
    n_per_class = counts.min()
    train_idx, test_idx = [], []
    for c in classes:
        idx_c = np.where(labels == c)[0]
        idx_c = rng.permutation(idx_c)[:n_per_class]
        n_test = int(round(test_fraction * n_per_class))
        test_idx.extend(idx_c[:n_test])
        train_idx.extend(idx_c[n_test:])
    return np.array(train_idx), np.array(test_idx)

def make_loaders(X, y, batch_size=64, shuffle_train=True):
    Xt = torch.from_numpy(X).float()
    yt = torch.from_numpy(y).long()
    ds = TensorDataset(Xt, yt)
    train = shuffle_train
    return DataLoader(ds, batch_size=batch_size, shuffle=train), ds

# ---------- 读你的 AnnData ----------
adata_a = sc.read_h5ad('/content/drive/MyDrive/Colab Notebooks/data/Raman_single_cell/adata_a.h5ad') # gene activity
adata_n = sc.read_h5ad('/content/drive/MyDrive/Colab Notebooks/data/Raman_single_cell/adata_n.h5ad') # RNA
adata_r = sc.read_h5ad('/content/drive/MyDrive/Colab Notebooks/data/Raman_single_cell/adata_r.h5ad') # Raman

assert 'cell_type' in adata_a.obs.columns
assert 'cell_type' in adata_n.obs.columns
assert 'cell_type' in adata_r.obs.columns

# =========================
# 1) RNA
# =========================
if RNA_USE_X_AS_IS:
    Xn = to_dense_if_sparse(adata_n.X)
else:
    # 如需改用 layers['data'] 或 counts：示例
    if 'data' in adata_n.layers:
        Xn = to_dense_if_sparse(adata_n.layers['data'])
    elif 'counts' in adata_n.layers:
        Xn = to_dense_if_sparse(adata_n.layers['counts'])
        Xn = np.log1p(Xn)
    else:
        Xn = to_dense_if_sparse(adata_n.X)

gene_names_rna = np.array(adata_n.var_names)
# 只取 HVG（若提供）
if RNA_USE_HVG and ('highly_variable' in adata_n.var.columns):
    hvg_mask = adata_n.var['highly_variable'].values
    Xn = Xn[:, hvg_mask]
    gene_names_rna = gene_names_rna[hvg_mask]

if RNA_ZSCORE:
    Xn = zscore_features(Xn)

yn, n_label2id, n_id2label = encode_labels(adata_n.obs['cell_type'].values)

# =========================
# 2) ATAC (gene activity)
# =========================
if ATAC_USE_X_AS_IS:
    Xa = to_dense_if_sparse(adata_a.X)
else:
    Xa = to_dense_if_sparse(adata_a.X)  # 这里留钩子：如需换层/再log，可改这里

gene_names_atac = np.array(adata_a.var_names)

if ATAC_ZSCORE:
    Xa = zscore_features(Xa)

ya, a_label2id, a_id2label = encode_labels(adata_a.obs['cell_type'].values)

# =========================
# 3) 跨模态基因集合对齐（可选）
# =========================
if USE_GENE_INTERSECTION:
    common_genes, rn_idx, aa_idx = np.intersect1d(gene_names_rna, gene_names_atac, return_indices=True)
    if common_genes.size == 0:
        raise ValueError("RNA 与 ATAC 的基因集合没有交集，请检查 var_names。")
    Xn = Xn[:, rn_idx]
    Xa = Xa[:, aa_idx]
    gene_names_rna = common_genes
    gene_names_atac = common_genes
    print(f"[Gene Intersection] common genes: {len(common_genes)}")
else:
    print("[Gene Intersection] disabled: RNA/ATAC 特征不强制一致")

# =========================
# 4) Raman
# =========================
Xr = to_dense_if_sparse(adata_r.X)  # normalized spectra
if RAMAN_ZSCORE:
    Xr = zscore_features(Xr)

yr, r_label2id, r_id2label = encode_labels(adata_r.obs['cell_type'].values)

# =========================
# 5) 分层均衡 + 划分 + DataLoaders
# =========================
# RNA
n_train_idx, n_test_idx = balanced_stratified_indices(yn, TEST_FRACTION, RANDOM_SEED)
Xn_train, Xn_test = Xn[n_train_idx], Xn[n_test_idx]
yn_train, yn_test = yn[n_train_idx], yn[n_test_idx]
rna_train_loader, rna_train_ds = make_loaders(Xn_train, yn_train, BATCH_SIZE, shuffle_train=True)
rna_test_loader,  rna_test_ds  = make_loaders(Xn_test,  yn_test,  BATCH_SIZE, shuffle_train=False)
print(f"[RNA] X:{Xn.shape}  train:{len(rna_train_ds)}  test:{len(rna_test_ds)}  classes:{len(n_label2id)}")

# ATAC
a_train_idx, a_test_idx = balanced_stratified_indices(ya, TEST_FRACTION, RANDOM_SEED)
Xa_train, Xa_test = Xa[a_train_idx], Xa[a_test_idx]
ya_train, ya_test = ya[a_train_idx], ya[a_test_idx]
atac_train_loader, atac_train_ds = make_loaders(Xa_train, ya_train, BATCH_SIZE, shuffle_train=True)
atac_test_loader,  atac_test_ds  = make_loaders(Xa_test,  ya_test,  BATCH_SIZE, shuffle_train=False)
print(f"[ATAC(gene activity)] X:{Xa.shape}  train:{len(atac_train_ds)}  test:{len(atac_test_ds)}  classes:{len(a_label2id)}")

# Raman
r_train_idx, r_test_idx = balanced_stratified_indices(yr, TEST_FRACTION, RANDOM_SEED)
Xr_train, Xr_test = Xr[r_train_idx], Xr[r_test_idx]
yr_train, yr_test = yr[r_train_idx], yr[r_test_idx]
raman_train_loader, raman_train_ds = make_loaders(Xr_train, yr_train, BATCH_SIZE, shuffle_train=True)
raman_test_loader,  raman_test_ds  = make_loaders(Xr_test,  yr_test,  BATCH_SIZE, shuffle_train=False)
print(f"[RAMAN] X:{Xr.shape}  train:{len(raman_train_ds)}  test:{len(raman_test_ds)}  classes:{len(r_label2id)}")

# =========================
# 6) 打包输出（方便后续使用）
# =========================
preprocessed = {
    "rna": {
        "train_loader": rna_train_loader,
        "test_loader":  rna_test_loader,
        "y_map": {"label2id": n_label2id, "id2label": n_id2label},
        "feature_names": gene_names_rna,
        "shape": Xn.shape,
    },
    "atac": {
        "train_loader": atac_train_loader,
        "test_loader":  atac_test_loader,
        "y_map": {"label2id": a_label2id, "id2label": a_id2label},
        "feature_names": gene_names_atac,
        "shape": Xa.shape,
    },
    "raman": {
        "train_loader": raman_train_loader,
        "test_loader":  raman_test_loader,
        "y_map": {"label2id": r_label2id, "id2label": r_id2label},
        "feature_names": np.array(adata_r.var_names, dtype=str),
        "shape": Xr.shape,
    }
}

print(preprocessed)

device = cpu
[Gene Intersection] disabled: RNA/ATAC 特征不强制一致
[RNA] X:(938, 2000)  train:492  test:124  classes:4
[ATAC(gene activity)] X:(10738, 19039)  train:2660  test:664  classes:4
[RAMAN] X:(1415, 432)  train:1040  test:260  classes:4
{'rna': {'train_loader': <torch.utils.data.dataloader.DataLoader object at 0x794c7c079d10>, 'test_loader': <torch.utils.data.dataloader.DataLoader object at 0x794cb39c1490>, 'y_map': {'label2id': {'HSC': 0, 'Naive B': 1, 'Pre B': 2, 'Pro B': 3}, 'id2label': {0: 'HSC', 1: 'Naive B', 2: 'Pre B', 3: 'Pro B'}}, 'feature_names': array(['FAM132A', 'ACAP3', 'DVL1', ..., 'C21orf58', 'DIP2A', 'AC145212.2'],
      dtype=object), 'shape': (938, 2000)}, 'atac': {'train_loader': <torch.utils.data.dataloader.DataLoader object at 0x794c7c07ce50>, 'test_loader': <torch.utils.data.dataloader.DataLoader object at 0x794c7c072890>, 'y_map': {'label2id': {'HSC': 0, 'Naive B': 1, 'Pre B': 2, 'Pro B': 3}, 'id2label': {0: 'HSC', 1: 'Naive B', 2: 'Pre B', 3: 'Pro B'}}, 'featu

In [None]:
# =========================
# Train ref_vae for RNA & ATAC (gene activity)
# =========================
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.utils import check_random_state

# ---- 引入你之前整理好的模型/损失（保持与文献一致的结构） ----
# 请确保以下类/函数已在环境中（来自我们上一条“基础模型代码”）:
# StandardEncoder, StandardDecoder, VAE, old_vae_loss
# 如果不在，请把那段模型代码粘过来。

# -----------------------
# 可调超参
# -----------------------
LATENT_DIM  = 128
HIDDEN_DIM  = 2048
EPOCHS      = 30
BATCH_SIZE  = 128
LR          = 1e-4          # 文献里用过 1e-5/1e-4，先用 1e-4 更快收敛
KL_WEIGHT   = 1e-4          # KL 的权重（0 时退化为纯 AE）
NUM_WORKERS = 0
PIN_MEMORY  = True

USE_RNA_HVG = True          # 若 adata_n.var['highly_variable'] 存在，优先只取 HVG
RANDOM_SEED = 0
TEST_FRACTION = 0.2         # 8:2 划分

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device =', device)

# -----------------------
# 小工具
# -----------------------
def to_dense_if_sparse(X):
    if sp.issparse(X):
        return X.toarray().astype(np.float32)
    # 兼容 np.matrix
    return np.asarray(X, dtype=np.float32)

def balanced_stratified_indices(labels, test_fraction=0.2, seed=0):
    rng = check_random_state(seed)
    labels = np.asarray(labels)
    classes, counts = np.unique(labels, return_counts=True)
    n_per = counts.min()
    tr, te = [], []
    for c in classes:
        idx = np.where(labels == c)[0]
        idx = rng.permutation(idx)[:n_per]
        n_test = int(round(test_fraction * n_per))
        te.extend(idx[:n_test]); tr.extend(idx[n_test:])
    return np.array(tr), np.array(te)

def make_loader_X(X, batch_size=128, shuffle=True):
    Xt = torch.from_numpy(X).float()
    ds = TensorDataset(Xt)  # 只有 X，没有 y
    loader = DataLoader(ds, batch_size=batch_size, shuffle=shuffle,
                        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    return loader, ds

def build_ref_vae(input_dim, latent_dim=128, hidden_dim=2048, kl_weight=1e-4, lr=1e-4):
    vae = VAE(
        encoder=StandardEncoder(input_dim, latent_dim, hidden_dim=hidden_dim),
        decoder=StandardDecoder(input_dim, latent_dim, hidden_dim=hidden_dim, no_final_relu=True),
        is_vae=True,                 # 开启 VAE（reparam + KL）
        latent_norm='batch',         # 与原始代码一致：latent 做 BN（可改 'none'）
    ).to(device)
    opt = torch.optim.Adam(vae.parameters(), lr=lr)
    return vae, opt

def train_ref_vae(name, X_train, X_test, epochs=30, kl_weight=1e-4, lr=1e-4):
    in_dim = X_train.shape[1]
    vae, opt = build_ref_vae(in_dim, LATENT_DIM, HIDDEN_DIM, kl_weight, lr)

    train_loader, _ = make_loader_X(X_train, BATCH_SIZE, shuffle=True)
    test_loader,  _ = make_loader_X(X_test,  BATCH_SIZE, shuffle=False)

    print(f"\n=== Train {name}_ref_vae ===")
    for ep in range(1, epochs+1):
        vae.train(); loss_sum = 0.0; n = 0
        for (xb,) in train_loader:
            xb = xb.to(device)
            opt.zero_grad()
            recon, mean, logvar, _ = vae(xb)
            loss = old_vae_loss(xb, recon, mean, logvar, weights=None, this_lambda=kl_weight)
            loss.backward(); opt.step()
            loss_sum += loss.item(); n += xb.size(0)
        tr_loss = loss_sum / n

        # 简单的 test 重构评估（不回传）
        vae.eval(); te_sum = 0.0; m = 0
        with torch.no_grad():
            for (xb,) in test_loader:
                xb = xb.to(device)
                recon, mean, logvar, _ = vae(xb)
                loss = old_vae_loss(xb, recon, mean, logvar, weights=None, this_lambda=kl_weight)
                te_sum += loss.item(); m += xb.size(0)
        te_loss = te_sum / m
        print(f"[{name} ep {ep:02d}] train_loss={tr_loss:.6f}  test_loss={te_loss:.6f}")

    return vae

# -----------------------
# 1) RNA → rna_ref_vae
# -----------------------
# X：优先使用 adata_n.X（你说是 scaled data）
Xn = to_dense_if_sparse(adata_n.X)
if USE_RNA_HVG and ('highly_variable' in adata_n.var.columns):
    hvg = adata_n.var['highly_variable'].values
    if hvg.sum() > 0:
        Xn = Xn[:, hvg]
        print(f"[RNA] use HVG: {hvg.sum()} genes")

# 分层均衡划分（只用 label 来划分，不参与训练）
yn = np.array(adata_n.obs['cell_type'])
# 映射到 int 以便分层
classes = {c:i for i,c in enumerate(np.unique(yn))}
yn_id = np.array([classes[c] for c in yn], dtype=np.int64)
n_tr_idx, n_te_idx = balanced_stratified_indices(yn_id, TEST_FRACTION, RANDOM_SEED)
Xn_train, Xn_test = Xn[n_tr_idx], Xn[n_te_idx]

rna_ref_vae = train_ref_vae("RNA", Xn_train, Xn_test,
                            epochs=EPOCHS, kl_weight=KL_WEIGHT, lr=LR)

# 可选：保存
torch.save(rna_ref_vae.state_dict(), "rna_ref_vae.pt")
print("Saved rna_ref_vae.pt")

# -----------------------
# 2) ATAC(gene activity) → atac_ref_vae
# -----------------------
Xa = to_dense_if_sparse(adata_a.X)  # gene activity，特征是基因
ya = np.array(adata_a.obs['cell_type'])
classes_a = {c:i for i,c in enumerate(np.unique(ya))}
ya_id = np.array([classes_a[c] for c in ya], dtype=np.int64)
a_tr_idx, a_te_idx = balanced_stratified_indices(ya_id, TEST_FRACTION, RANDOM_SEED)
Xa_train, Xa_test = Xa[a_tr_idx], Xa[a_te_idx]

atac_ref_vae = train_ref_vae("ATAC", Xa_train, Xa_test,
                             epochs=EPOCHS, kl_weight=KL_WEIGHT, lr=LR)

torch.save(atac_ref_vae.state_dict(), "atac_ref_vae.pt")
print("Saved atac_ref_vae.pt")

# -----------------------
# 用法示例：取 latent（与文献接口一致）
# -----------------------
rna_ref_vae.eval(); atac_ref_vae.eval()
with torch.no_grad():
    # RNA latent
    rna_latent = rna_ref_vae.get_latent(torch.from_numpy(Xn_test).to(device)).cpu().numpy()
    # ATAC latent
    atac_latent = atac_ref_vae.get_latent(torch.from_numpy(Xa_test).to(device)).cpu().numpy()
print("Latent shapes -> RNA:", rna_latent.shape, " ATAC:", atac_latent.shape)


In [None]:
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1) 准备 RNA 的潜变量与标签（用之前的 rna_ref_vae、Xn_train/Xn_test）
rna_ref_vae.eval()

def get_latent_batches(model, X, batch=256):
    Z = []
    with torch.no_grad():
        for i in range(0, len(X), batch):
            xb = torch.from_numpy(X[i:i+batch]).float().to(device)
            z = model.get_latent(xb)
            Z.append(z.cpu().numpy())
    return np.vstack(Z)

Z_train = get_latent_batches(rna_ref_vae, Xn_train)
Z_test  = get_latent_batches(rna_ref_vae, Xn_test)

# 标签（用你之前 balanced split 的 yn_train/yn_test；如果没有，就用 adata_n.obs['cell_type'] 映射）
y_train = yn_train
y_test  = yn_test
n_classes = int(max(yn_train.max(), yn_test.max()) + 1)

# 2) 一个小 MLP 分类器（也可以用我们之前的 Discriminator(end_dim=n_classes)）
class LatentClassifier(nn.Module):
    def __init__(self, in_dim, n_classes, hidden=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(inplace=True),
            nn.Linear(hidden, hidden//2), nn.ReLU(inplace=True),
            nn.Linear(hidden//2, n_classes)   # 直接输出 logits
        )
    def forward(self, z): return self.net(z)

clf = LatentClassifier(in_dim=Z_train.shape[1], n_classes=n_classes).to(device)
opt = torch.optim.Adam(clf.parameters(), lr=1e-3)

# 类别加权（可选）
class_counts = np.bincount(y_train, minlength=n_classes)
weights = (len(y_train) / (class_counts + 1e-6)).astype(np.float32)
w_t = torch.tensor(weights, dtype=torch.float32, device=device)
criterion = nn.CrossEntropyLoss(weight=w_t)

# 3) DataLoader
train_loader = DataLoader(TensorDataset(torch.from_numpy(Z_train).float(),
                                        torch.from_numpy(y_train).long()),
                          batch_size=128, shuffle=True)
test_loader  = DataLoader(TensorDataset(torch.from_numpy(Z_test).float(),
                                        torch.from_numpy(y_test).long()),
                          batch_size=256, shuffle=False)

# 4) 训练
EPOCHS = 20
for ep in range(1, EPOCHS+1):
    clf.train(); losses=[]
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad()
        logits = clf(xb)
        loss = criterion(logits, yb)
        loss.backward(); opt.step()
        losses.append(loss.item())
    print(f"[clf ep {ep:02d}] loss={np.mean(losses):.4f}")

# 5) 评估
def eval_loader(model, loader):
    model.eval(); all_y=[]; all_p=[]
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            pred = model(xb).argmax(dim=1).cpu().numpy()
            all_p.append(pred); all_y.append(yb.numpy())
    y = np.concatenate(all_y); p = np.concatenate(all_p)
    return accuracy_score(y, p), f1_score(y, p, average='macro'), confusion_matrix(y, p)

acc_tr, f1_tr, _ = eval_loader(clf, train_loader)
acc_te, f1_te, cm = eval_loader(clf, test_loader)
print(f"Train acc={acc_tr:.3f}  macroF1={f1_tr:.3f}")
print(f"Test  acc={acc_te:.3f}  macroF1={f1_te:.3f}")
print("Confusion matrix:\n", cm)


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix

# ===== 小工具：批量取 latent =====
@torch.no_grad()
def get_latent_batches(ref_vae, X, device=None, batch_size=256):
    device = device or (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
    ref_vae.eval()
    Z = []
    for i in range(0, len(X), batch_size):
        xb = torch.from_numpy(X[i:i+batch_size]).float().to(device)
        z = ref_vae.get_latent(xb)
        Z.append(z.detach().cpu().numpy())
    return np.vstack(Z)

# ===== 默认分类头（MLP）；也可传自定义/Discriminator =====
class LatentClassifier(nn.Module):
    def __init__(self, in_dim, n_classes, hidden=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(inplace=True),
            nn.Linear(hidden, hidden//2), nn.ReLU(inplace=True),
            nn.Linear(hidden//2, n_classes)  # logits
        )
    def forward(self, z): return self.net(z)

def _build_classifier(in_dim, n_classes, hidden=256, use_discriminator=False, discriminator_class=None):
    if use_discriminator:
        assert discriminator_class is not None, "use_discriminator=True 需要传 discriminator_class（如你之前的 Discriminator）"
        # 你的 Discriminator(latent_dim, end_dim=n_classes)
        return discriminator_class(in_dim, end_dim=n_classes)
    else:
        return LatentClassifier(in_dim, n_classes, hidden=hidden)

# ===== 训练 + 评估（可复用到 RNA / ATAC）=====
def train_latent_classifier(
    ref_vae,
    X_train, y_train,
    X_test,  y_test,
    *,
    device=None,
    epochs=20,
    batch_size=128,
    lr=1e-3,
    hidden=256,
    class_weighting=True,
    use_discriminator=False,
    discriminator_class=None,   # 例如传你之前的 Discriminator 类
    return_latents=False,
):
    """
    在 ref_vae 潜空间上训练细胞类型分类器。
    - ref_vae: 已训练好的 VAE/AE，需实现 .get_latent(x)
    - X_train/X_test: numpy array (N, F)
    - y_train/y_test: numpy int labels，从 0..K-1
    """
    device = device or (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))

    # 1) 取 latent
    Z_train = get_latent_batches(ref_vae, X_train, device=device, batch_size=max(256, batch_size))
    Z_test  = get_latent_batches(ref_vae, X_test,  device=device, batch_size=max(256, batch_size))

    in_dim = Z_train.shape[1]
    n_classes = int(max(y_train.max(), y_test.max()) + 1)

    # 2) 分类器
    clf = _build_classifier(in_dim, n_classes, hidden=hidden,
                            use_discriminator=use_discriminator,
                            discriminator_class=discriminator_class).to(device)
    opt = torch.optim.Adam(clf.parameters(), lr=lr)

    # 类别加权（可选）
    if class_weighting:
        counts = np.bincount(y_train, minlength=n_classes)
        weights = (len(y_train) / (counts + 1e-6)).astype(np.float32)
        w_t = torch.tensor(weights, dtype=torch.float32, device=device)
        criterion = nn.CrossEntropyLoss(weight=w_t)
    else:
        criterion = nn.CrossEntropyLoss()

    # 3) DataLoader（在 latent 空间上训练，速度很快）
    train_loader = DataLoader(
        TensorDataset(torch.from_numpy(Z_train).float(), torch.from_numpy(y_train).long()),
        batch_size=batch_size, shuffle=True
    )
    test_loader = DataLoader(
        TensorDataset(torch.from_numpy(Z_test).float(), torch.from_numpy(y_test).long()),
        batch_size=max(256, batch_size), shuffle=False
    )

    # 4) 训练
    for ep in range(1, epochs+1):
        clf.train(); losses=[]
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            logits = clf(xb)
            loss = criterion(logits, yb)
            loss.backward(); opt.step()
            losses.append(loss.item())
        print(f"[latent-clf ep {ep:02d}] loss={np.mean(losses):.4f}")

    # 5) 评估
    @torch.no_grad()
    def _eval(loader):
        clf.eval(); all_y=[]; all_p=[]
        for xb, yb in loader:
            xb = xb.to(device)
            pred = clf(xb).argmax(dim=1).cpu().numpy()
            all_p.append(pred); all_y.append(yb.numpy())
        y = np.concatenate(all_y); p = np.concatenate(all_p)
        return accuracy_score(y, p), f1_score(y, p, average='macro'), confusion_matrix(y, p)

    acc_tr, f1_tr, _ = _eval(train_loader)
    acc_te, f1_te, cm = _eval(test_loader)
    print(f"Train acc={acc_tr:.3f}  macroF1={f1_tr:.3f}")
    print(f"Test  acc={acc_te:.3f}  macroF1={f1_te:.3f}")

    out = {
        "classifier": clf,
        "metrics": {
            "train_acc": acc_tr, "train_macro_f1": f1_tr,
            "test_acc": acc_te,  "test_macro_f1": f1_te,
            "confusion_matrix": cm
        }
    }
    if return_latents:
        out["Z_train"] = Z_train
        out["Z_test"]  = Z_test
    return out


In [None]:
# 假定你已经有：rna_ref_vae, Xn_train, yn_train, Xn_test, yn_test
res_rna = train_latent_classifier(
    rna_ref_vae, Xn_train, yn_train, Xn_test, yn_test,
    epochs=20, batch_size=128, lr=1e-3, hidden=256,
    class_weighting=True, use_discriminator=False,  # 或 True + 传 discriminator_class
    return_latents=False
)
rna_clf = res_rna["classifier"]


In [None]:
# 假定你已经有：atac_ref_vae, Xa_train, ya_train, Xa_test, ya_test
res_atac = train_latent_classifier(
    atac_ref_vae, Xa_train, ya_train, Xa_test, ya_test,
    epochs=20, batch_size=128, lr=1e-3, hidden=256,
    class_weighting=True, use_discriminator=False
)
atac_clf = res_atac["classifier"]


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from itertools import cycle

# 依赖：你前面已经定义过的模块/函数：
# StandardEncoder, StandardDecoder, VAE, Discriminator, old_vae_loss, discrim_loss
# 若不在当前环境，请先粘贴我们之前整理过的“基础模型代码”。

# ----------------------
# 构建 Raman 侧 AAE 模型
# ----------------------
def build_raman_aae(input_dim, latent_dim=128, hidden_dim=2048, latent_norm='batch'):
    """
    返回:
      - raman_vae: VAE(encoder, decoder)，is_vae=False（与文献一致）
      - discrim : 判别器，输出2个logits（source/target）
    """
    raman_vae = VAE(
        encoder=StandardEncoder(input_dim, latent_dim, hidden_dim=hidden_dim),
        decoder=StandardDecoder(input_dim, latent_dim, hidden_dim=hidden_dim, no_final_relu=True),
        is_vae=False,
        latent_norm=latent_norm,
    )
    discrim = Discriminator(latent_dim=latent_dim, spectral=True, end_dim=2)
    return raman_vae, discrim


# ----------------------
# 训练一步：判别器
# ----------------------
def d_step(discrim, ref_z, ram_z, opt_d, label_smooth=0.1):
    discrim.train()
    opt_d.zero_grad()

    # one-hot + label smoothing
    y_ref = torch.tensor([1.0 - label_smooth, 0.0 + label_smooth], device=ref_z.device)
    y_ram = torch.tensor([0.0 + label_smooth, 1.0 - label_smooth], device=ref_z.device)
    y = torch.cat([y_ref.expand(ref_z.size(0), 2),
                   y_ram.expand(ram_z.size(0), 2)], dim=0)

    logits = discrim(torch.cat([ref_z.detach(), ram_z.detach()], dim=0))
    loss = F.binary_cross_entropy_with_logits(logits, y)
    loss.backward()
    opt_d.step()
    return float(loss.item())


# ----------------------
# 训练一步：生成器（Raman VAE）
# ----------------------
def g_step(raman_vae, discrim, xb_raman, opt_g,
          recon_weight=1.0, adv_weight=0.3,
          aux_clf=None, y_raman=None, aux_weight=50.0):
    raman_vae.train(); discrim.eval()
    opt_g.zero_grad()

    recon, mean, logvar, z = raman_vae(xb_raman)  # is_vae=False -> mean/logvar未用，仅接口保持
    # 自重构损失（MSE）
    recon_loss = F.mse_loss(recon, xb_raman)

    # 对抗损失：让 Raman latent 被判为“source”(=ref)
    target = torch.tensor([1.0, 0.0], device=xb_raman.device).expand(z.size(0), 2)
    adv_loss = F.binary_cross_entropy_with_logits(discrim(z), target)

    loss = recon_weight * recon_loss + adv_weight * adv_loss

    # 可选：细胞类型辅助损失（把 ref 上训练好的分类器接到 Raman latent）
    aux = None
    if (aux_clf is not None) and (y_raman is not None):
        aux_clf.eval()
        logits_ct = aux_clf(z)
        aux = F.cross_entropy(logits_ct, y_raman)
        loss = loss + aux_weight * aux

    loss.backward()
    opt_g.step()
    return {
        "recon": float(recon_loss.item()),
        "adv": float(adv_loss.item()),
        "aux": (float(aux.item()) if aux is not None else None),
    }


# -------------------------------------------------
# 主函数：训练 Raman-AAF，把 Raman 对齐到 ref_vae 的潜空间
# -------------------------------------------------
def train_raman_aae(
    *,
    ref_vae,            # 已训练好的参考侧 AE/VAE（RNA 或 ATAC）
    X_ref,              # 参考模态矩阵 (N_ref, F_ref) —— 只用于提供 ref latent
    X_raman,            # Raman 矩阵 (N_raman, n_bins)
    y_raman=None,       # 可选：Raman 的 cell_type int 标签 (N_raman,)
    batch_size=64,
    epochs=50,
    latent_dim=128,
    hidden_dim=2048,
    lr_g=2e-4,
    lr_d=5e-5,
    recon_weight=1.0,
    adv_weight=0.3,
    aux_weight=50.0,
    label_smooth=0.1,
    use_latent_norm='batch',   # 'batch' or 'none'
    num_workers=0,
    pin_memory=True,
    device=None,
    pretrained_aux_clf=None,   # 可选：在 ref latent 上训好的分类器（冻结）
):
    """
    训练完成后返回：
      raman_vae, discrim, transfer_vae（encoder=raman, decoder=ref），训练日志 logs
    """
    device = device or (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
    ref_vae = ref_vae.to(device).eval()
    for p in ref_vae.parameters(): p.requires_grad = False

    # DataLoader
    if y_raman is None:
        ram_ds = TensorDataset(torch.from_numpy(X_raman).float())
    else:
        ram_ds = TensorDataset(torch.from_numpy(X_raman).float(),
                               torch.from_numpy(y_raman).long())
    ref_ds = TensorDataset(torch.from_numpy(X_ref).float())

    ram_loader = DataLoader(ram_ds, batch_size=batch_size, shuffle=True,
                            num_workers=num_workers, pin_memory=pin_memory)
    ref_loader = DataLoader(ref_ds, batch_size=batch_size, shuffle=True,
                            num_workers=num_workers, pin_memory=pin_memory)

    # 模型 & 优化器
    input_dim = X_raman.shape[1]
    raman_vae, discrim = build_raman_aae(input_dim, latent_dim, hidden_dim, latent_norm=use_latent_norm)
    raman_vae = raman_vae.to(device)
    discrim   = discrim.to(device)

    opt_g = torch.optim.Adam(raman_vae.parameters(), lr=lr_g)
    opt_d = torch.optim.Adam(discrim.parameters(),   lr=lr_d)

    # 可选：辅助分类头（冻结，不训练）
    aux_clf = None
    if pretrained_aux_clf is not None:
        aux_clf = pretrained_aux_clf.to(device)
        for p in aux_clf.parameters(): p.requires_grad = False

    print("Begin Raman AAE training ...")
    logs = []
    for ep in range(1, epochs+1):
        ep_d, ep_r, ep_a, ep_aux, nb = 0.0, 0.0, 0.0, 0.0, 0
        for batch_r, batch_ref in zip(ram_loader, cycle(ref_loader)):
            if y_raman is None:
                (xb_r,) = batch_r
                yb = None
            else:
                xb_r, yb = batch_r
                yb = yb.to(device)

            xb_r  = xb_r.to(device)
            (xb_ref,) = batch_ref
            xb_ref = xb_ref.to(device)

            # 参考 latent（不回传梯度）
            with torch.no_grad():
                z_ref = ref_vae.get_latent(xb_ref)

            # Raman latent
            with torch.no_grad():
                _, _, _, z_ram = raman_vae(xb_r)

            # --- 判别器步 ---
            dloss = d_step(discrim, z_ref, z_ram, opt_d, label_smooth=label_smooth)

            # --- 生成器步（更新 Raman VAE）---
            gout = g_step(raman_vae, discrim, xb_r, opt_g,
                          recon_weight=recon_weight, adv_weight=adv_weight,
                          aux_clf=aux_clf, y_raman=yb, aux_weight=aux_weight)

            # 记录
            ep_d   += dloss
            ep_r   += gout["recon"]
            ep_a   += gout["adv"]
            ep_aux += (gout["aux"] if gout["aux"] is not None else 0.0)
            nb     += 1

        msg = f"[{ep:03d}] D:{ep_d/nb:.4f}  Recon:{ep_r/nb:.5f}  Adv:{ep_a/nb:.5f}"
        if y_raman is not None and aux_clf is not None:
            msg += f"  Aux:{ep_aux/nb:.5f}"
        print(msg)
        logs.append({"epoch": ep, "D": ep_d/nb, "Recon": ep_r/nb, "Adv": ep_a/nb,
                     "Aux": (ep_aux/nb if (y_raman is not None and aux_clf is not None) else None)})

    # 训练完构建 transfer_vae：encoder=raman, decoder=ref
    transfer_vae = VAE(
        encoder=raman_vae.encoder,      # 共享已训练的 Raman 编码器
        decoder=ref_vae.decoder,        # 共享参考侧解码器
        is_vae=False,
        latent_norm=use_latent_norm,
    ).to(device).eval()

    return raman_vae.eval(), discrim.eval(), transfer_vae, logs


In [None]:
# 参考侧：RNA
ref_vae = rna_ref_vae
X_ref   = Xn_train  # 与训练 ref_vae 时使用的同一预处理
# Raman
X_raman = Xr        # 你的 Raman 矩阵
y_raman = yr        # 可选：Raman 的 cell_type int 标签（若要用辅助损失）

raman_vae, D, transfer_raman2rna, logs = train_raman_aae(
    ref_vae=ref_vae,
    X_ref=X_ref,
    X_raman=X_raman,
    y_raman=y_raman,                 # 若不用辅助损失可设 None
    batch_size=64,
    epochs=50,
    latent_dim=128,
    hidden_dim=2048,
    lr_g=2e-4, lr_d=5e-5,
    recon_weight=1.0, adv_weight=0.3,
    aux_weight=50.0,
    label_smooth=0.1,
    use_latent_norm='batch',
    pretrained_aux_clf=None          # 若有在RNA latent训练好的分类器就传进来
)


In [None]:
ref_vae = atac_ref_vae
X_ref   = Xa_train
X_raman = Xr
y_raman = yr

raman_vae_A, D_A, transfer_raman2atac, logs_A = train_raman_aae(
    ref_vae=ref_vae,
    X_ref=X_ref,
    X_raman=X_raman,
    y_raman=None,           # 没用辅助损失就设 None
    epochs=50
)


In [None]:
with torch.no_grad():
    xb = torch.from_numpy(Xr).float().to(next(raman_vae.parameters()).device)
    xr_hat, _, _, _ = raman_vae(xb)
    recon_mse = F.mse_loss(xr_hat, xb).item()
print("Raman recon MSE =", recon_mse)


In [None]:
with torch.no_grad():
    xb = torch.from_numpy(Xr).float().to(next(transfer_raman2rna.parameters()).device)
    x_ref_pred, _, _, _ = transfer_raman2rna(xb)   # 形状=(N_raman, F_ref)
