# MMD

对齐组流程
```
data   :c37      -> splits   -> source_domain + target_test -> train adapter -> apply alignment -> infer -> eval

process:Original -> prepare  -> split                       -> train_mmd     -> align           -> infer -> cell-eval

.X     :raw      -> prepared -> aligned                     -> inferred      -> evaluated
```

数据变化
- 对齐组：使用对齐后的数据进行推理
- 基线组：使用未对齐的数据进行推理

## 1.数据准备

In [None]:
import anndata as ad
import pickle
import pandas as pd
import scipy.sparse
import numpy as np

def prepare_test_data(data_in, data_out, model_dir):
    """
    准备推理数据：筛选高变基因并格式化数据
    
    修改说明：处理后的数据同时保存在 adata.X 和 adata.obsm['X_hvg'] 中，
    确保后续步骤无论是否指定 --embed_key 参数都能正常工作。
    
    Args:
        data_in: 输入数据路径 (.h5ad)
        data_out: 输出数据路径 (.h5ad)
        model_dir: 模型目录（包含 var_dims.pkl）
    """
    # 1. 读取测试集数据
    adata_holdout = ad.read_h5ad(data_in)
    print(f"=== 加载原始数据 ===")
    print(f"原始数据形状: {adata_holdout.shape}")

    # 2. 加载预训练模型的高变基因列表（使用参数化路径）
    var_dims_path = f'{model_dir}/var_dims.pkl'
    hvg_names = pickle.load(open(var_dims_path, 'rb'))['gene_names']
    print(f"模型高变基因数量: {len(hvg_names)}")

    # 3. 筛选测试集中存在的高变基因（确保与训练集基因集一致）
    valid_genes = [g for g in hvg_names if g in adata_holdout.var_names]
    print(f"测试集中匹配的基因数: {len(valid_genes)}/{len(hvg_names)}")
    
    if len(valid_genes) < len(hvg_names):
        missing_genes = set(hvg_names) - set(valid_genes)
        print(f"警告: 缺失 {len(missing_genes)} 个高变基因")
    
    adata_subset = adata_holdout[:, valid_genes].copy()

    # 4. 清理var DataFrame，只保留基因名索引
    adata_subset.var = pd.DataFrame(index=adata_subset.var_names)

    # 5. 将表达矩阵转换为稠密数组
    if scipy.sparse.issparse(adata_subset.X):
        X_dense = adata_subset.X.toarray()
    else:
        X_dense = np.array(adata_subset.X).copy()

    # 6. 【关键修改】同时保存到 adata.X 和 adata.obsm['X_hvg']
    # 这样无论后续脚本使用 adata.X 还是 adata.obsm['X_hvg'] 都能获取正确数据
    adata_subset.X = X_dense  # 保存到 X
    adata_subset.obsm['X_hvg'] = X_dense.copy()  # 同时保存到 obsm['X_hvg']

    # 7. 清理其他 obsm 键（只保留 X_hvg）
    keys_to_remove = [k for k in adata_subset.obsm.keys() if k != 'X_hvg']
    for k in keys_to_remove:
        del adata_subset.obsm[k]

    # 8. 输出数据格式检查信息
    print("=== 数据格式检查 ===")
    print(f"Shape of X: {adata_subset.X.shape}")
    print(f"X 类型: {type(adata_subset.X)}")
    # print(f"X 非零值: {np.count_nonzero(adata_subset.X)}")
    print(f"X 最大值: {adata_subset.X.max():.4f}")
    print(f"obsm keys: {list(adata_subset.obsm.keys())}")
    print(f"Shape of obsm['X_hvg']: {adata_subset.obsm['X_hvg'].shape}")
    # print(f"X 和 X_hvg 是否一致: {np.allclose(adata_subset.X, adata_subset.obsm['X_hvg'])}")

    # 9. 保存预处理后的数据
    adata_subset.write_h5ad(data_out)
    print(f"\n保存到: {data_out}")
    print("保存完成!")


if __name__ == '__main__':
    prepare_test_data(
        data_in='/data3/fanpeishan/state/for_state/data/State-Tahoe-Filtered/50_100/c37.h5ad',
        data_out='/data3/fanpeishan/state/for_state/run_results/run15/c37_prep.h5ad',
        model_dir='/data3/fanpeishan/state/for_state/models/ST-Tahoe'  # 参数化模型目录
    )


=== 加载原始数据 ===
原始数据形状: (1835947, 62710)
模型高变基因数量: 2000
测试集中匹配的基因数: 2000/2000
=== 数据格式检查 ===
Shape of X: (1835947, 2000)
X 类型: <class 'numpy.ndarray'>
X 非零值: 158717635
X 最大值: 354.2545
obsm keys: ['X_hvg']
Shape of obsm['X_hvg']: (1835947, 2000)
X 和 X_hvg 是否一致: True

保存到: /data3/fanpeishan/state/for_state/run_results/run15/c37_prep.h5ad
保存完成!


## 2.数据拆分

In [None]:
python /data3/fanpeishan/state/for_state/MMD_alignment_experiment/scripts/data_split.py \
  --input_data /data3/fanpeishan/state/for_state/run_results/run15/c37_prep.h5ad \
  --output_dir /data3/fanpeishan/state/for_state/run_results/run20/data \
  --source_ratio 0.7 \
  --align_ratio 0.5 \
  --seed 42

In [None]:
[2025-12-21 21:43:06] INFO - 已设置随机种子: 42
[2025-12-21 21:43:06] INFO - ============================================================
[2025-12-21 21:43:06] INFO - 开始加载数据...
[2025-12-21 21:43:06] INFO - 数据路径: /data3/fanpeishan/state/for_state/run_results/run15/c37_prep.h5ad
[2025-12-21 21:45:09] INFO - ============================================================
[2025-12-21 21:45:09] INFO - 数据基本信息:
[2025-12-21 21:45:09] INFO -   细胞数: 1,835,947
[2025-12-21 21:45:09] INFO -   基因数: 2,000
[2025-12-21 21:45:09] INFO -   数据形状: (1835947, 2000)
[2025-12-21 21:45:09] INFO - 
obs列名: ['sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate']
[2025-12-21 21:45:09] INFO - 
批次信息 (plate):
[2025-12-21 21:45:09] INFO -   批次数量: 14
[2025-12-21 21:45:09] INFO -   批次列表: ['plate1', 'plate10', 'plate11', 'plate12', 'plate13', 'plate14', 'plate2', 'plate3', 'plate4', 'plate5', 'plate6', 'plate7', 'plate8', 'plate9']
[2025-12-21 21:45:09] INFO - 
各批次细胞数:
[2025-12-21 21:45:09] INFO -   plate1: 100,194 细胞
[2025-12-21 21:45:09] INFO -   plate2: 158,870 细胞
[2025-12-21 21:45:09] INFO -   plate3: 96,140 细胞
[2025-12-21 21:45:09] INFO -   plate4: 152,674 细胞
[2025-12-21 21:45:09] INFO -   plate5: 144,567 细胞
[2025-12-21 21:45:09] INFO -   plate6: 154,792 细胞
[2025-12-21 21:45:09] INFO -   plate7: 97,348 细胞
[2025-12-21 21:45:09] INFO -   plate8: 137,483 细胞
[2025-12-21 21:45:09] INFO -   plate9: 91,365 细胞
[2025-12-21 21:45:09] INFO -   plate10: 131,046 细胞
[2025-12-21 21:45:09] INFO -   plate11: 126,186 细胞
[2025-12-21 21:45:09] INFO -   plate12: 171,368 细胞
[2025-12-21 21:45:09] INFO -   plate13: 137,815 细胞
[2025-12-21 21:45:09] INFO -   plate14: 136,099 细胞
[2025-12-21 21:45:09] INFO - 
扰动条件信息 (drugname_drugconc):
[2025-12-21 21:45:09] INFO -   扰动条件数量: 1137
[2025-12-21 21:45:09] INFO -   控制组(DMSO)细胞数: 48,159 (2.62%)
[2025-12-21 21:45:09] INFO -   处理组细胞数: 1,787,788 (97.38%)
[2025-12-21 21:45:09] INFO - 
前10个扰动条件:
[2025-12-21 21:45:09] INFO -   [('DMSO_TF', 0.0, 'uM')]: 45,150 细胞
[2025-12-21 21:45:09] INFO -   [('Adagrasib', 0.05, 'uM')]: 23,449 细胞
[2025-12-21 21:45:09] INFO -   [('Afatinib', 0.5, 'uM')]: 6,042 细胞
[2025-12-21 21:45:09] INFO -   [('Almonertinib (mesylate)', 0.5, 'uM')]: 5,277 细胞
[2025-12-21 21:45:09] INFO -   [('Clonidine (hydrochloride)', 5.0, 'uM')]: 4,556 细胞
[2025-12-21 21:45:09] INFO -   [('Naproxen', 0.5, 'uM')]: 4,551 细胞
[2025-12-21 21:45:09] INFO -   [('Berberine (chloride hydrate)', 5.0, 'uM')]: 4,550 细胞
[2025-12-21 21:45:09] INFO -   [('Berbamine (dihydrochloride)', 0.5, 'uM')]: 4,480 细胞
[2025-12-21 21:45:09] INFO -   [('Belumosudil', 0.5, 'uM')]: 4,465 细胞
[2025-12-21 21:45:09] INFO -   [('Gemfibrozil', 5.0, 'uM')]: 4,410 细胞
[2025-12-21 21:45:09] INFO - ============================================================
[2025-12-21 21:45:09] INFO - 
============================================================
[2025-12-21 21:45:09] INFO - 开始按批次划分数据...
[2025-12-21 21:45:09] INFO - 批次总数: 14
[2025-12-21 21:45:09] INFO - 源域批次数: 9 (70%)
[2025-12-21 21:45:09] INFO - 目标域批次数: 5 (30%)
[2025-12-21 21:45:09] INFO - 
源域批次: ['plate1', 'plate11', 'plate13', 'plate14', 'plate2', 'plate5', 'plate6', 'plate7', 'plate8']
[2025-12-21 21:45:09] INFO - 目标域批次: ['plate10', 'plate12', 'plate3', 'plate4', 'plate9']
[2025-12-21 21:45:37] INFO - 
源域数据形状: (1193354, 2000) (1,193,354 细胞)
[2025-12-21 21:45:37] INFO - 目标域数据形状: (642593, 2000) (642,593 细胞)
[2025-12-21 21:45:37] INFO - 
源域扰动分布:
[2025-12-21 21:45:37] INFO -   控制组(DMSO): 32,617 (2.73%)
[2025-12-21 21:45:37] INFO -   处理组: 1,160,737 (97.27%)
[2025-12-21 21:45:37] INFO - 
目标域扰动分布:
[2025-12-21 21:45:37] INFO -   控制组(DMSO): 15,542 (2.42%)
[2025-12-21 21:45:37] INFO -   处理组: 627,051 (97.58%)
[2025-12-21 21:45:37] INFO - ============================================================
[2025-12-21 21:45:37] INFO - 
============================================================
[2025-12-21 21:45:37] INFO - 划分目标域为对齐集和测试集...
[2025-12-21 21:45:37] INFO - 目标域总细胞数: 642,593
[2025-12-21 21:45:37] INFO - 对齐集细胞数: 321,296 (50%)
[2025-12-21 21:45:37] INFO - 测试集细胞数: 321,297 (50%)
[2025-12-21 21:45:55] INFO - 
对齐集数据形状: (321296, 2000)
[2025-12-21 21:45:55] INFO - 测试集数据形状: (321297, 2000)
[2025-12-21 21:45:55] INFO - 
对齐集扰动分布:
[2025-12-21 21:45:55] INFO -   控制组(DMSO): 7,775 (2.42%)
[2025-12-21 21:45:55] INFO -   处理组: 313,521
[2025-12-21 21:45:55] INFO - 
测试集扰动分布:
[2025-12-21 21:45:55] INFO -   控制组(DMSO): 7,767 (2.42%)
[2025-12-21 21:45:55] INFO -   处理组: 313,530
[2025-12-21 21:45:55] INFO - ============================================================
[2025-12-21 21:45:55] INFO - 
============================================================
[2025-12-21 21:45:55] INFO - 保存划分后的数据...
[2025-12-21 21:45:55] INFO - 保存源域数据到: /data3/fanpeishan/state/for_state/run_results/run20/data/source_domain.h5ad
[2025-12-21 21:46:05] INFO - 保存对齐集数据到: /data3/fanpeishan/state/for_state/run_results/run20/data/target_align.h5ad
[2025-12-21 21:46:08] INFO - 保存测试集数据到: /data3/fanpeishan/state/for_state/run_results/run20/data/target_test.h5ad
[2025-12-21 21:46:11] INFO - 保存划分信息到: /data3/fanpeishan/state/for_state/run_results/run20/data/split_info.json
[2025-12-21 21:46:11] INFO - ============================================================
[2025-12-21 21:46:11] INFO - 数据划分完成!
[2025-12-21 21:46:11] INFO - 
生成的文件:
[2025-12-21 21:46:11] INFO -   - /data3/fanpeishan/state/for_state/run_results/run20/data/source_domain.h5ad
[2025-12-21 21:46:11] INFO -   - /data3/fanpeishan/state/for_state/run_results/run20/data/target_align.h5ad
[2025-12-21 21:46:11] INFO -   - /data3/fanpeishan/state/for_state/run_results/run20/data/target_test.h5ad
[2025-12-21 21:46:11] INFO -   - /data3/fanpeishan/state/for_state/run_results/run20/data/split_info.json
[2025-12-21 21:46:11] INFO - ============================================================

## 3.训练 MMD 适配器

In [None]:
# 使用 source_domain 和 target_align 数据训练 MMD 适配器
export CUDA_VISIBLE_DEVICES=3
python /data3/fanpeishan/state/for_state/MMD_alignment_experiment/scripts/train_mmd_adapter.py \
  --source_data /data3/fanpeishan/state/for_state/run_results/run20/data/source_domain.h5ad \
  --target_data /data3/fanpeishan/state/for_state/run_results/run20/data/target_align.h5ad \
  --output_dir /data3/fanpeishan/state/for_state/run_results/run20/adapters/ \
  --pert_col drugname_drugconc \
  --control_name "[('DMSO_TF', 0.0, 'uM')]" \
  --adapter_type shift \
  --epochs 1000 \
  --lr 2e-4 \
  --log_interval 100 \
  --source_sample_size 8000 \
  --target_sample_size 4000 \
  --seed 42 

In [None]:
============================================================
开始 MMD 适配器训练（精简整合版）
============================================================
加载数据: /data3/fanpeishan/state/for_state/run_results/run20/data/source_domain.h5ad
原始数据形状: (1193354, 2000)
控制组细胞数: 30136 / 1193354
随机采样到 8000 个细胞
特征形状: (8000, 2000)
加载数据: /data3/fanpeishan/state/for_state/run_results/run20/data/target_align.h5ad
原始数据形状: (321296, 2000)
控制组细胞数: 7503 / 321296
随机采样到 4000 个细胞
特征形状: (4000, 2000)
--------------------------------------------------
使用设备: cuda
训练配置:
  适配器类型: shift
  特征维度: 2000
  源域样本数: 8000
  目标域样本数: 4000
  训练轮数: 1000
  学习率: 0.0002
适配器参数数量: 2000
使用中位数准则估计MMD带宽...
带宽集合: ['7.6752', '15.3504', '30.7009', '61.4018', '122.8036']
初始MMD²: 0.000674
--------------------------------------------------
Epoch [   1/1000] MMD²: 0.000674 (改善: +0.00%) Best: 0.000674
Epoch [ 100/1000] MMD²: 0.000581 (改善: +13.82%) Best: 0.000581
Epoch [ 200/1000] MMD²: 0.000520 (改善: +22.83%) Best: 0.000520
Epoch [ 300/1000] MMD²: 0.000473 (改善: +29.77%) Best: 0.000473
Epoch [ 400/1000] MMD²: 0.000435 (改善: +35.47%) Best: 0.000435
Epoch [ 500/1000] MMD²: 0.000402 (改善: +40.30%) Best: 0.000402
Epoch [ 600/1000] MMD²: 0.000374 (改善: +44.50%) Best: 0.000374
Epoch [ 700/1000] MMD²: 0.000349 (改善: +48.21%) Best: 0.000349
Epoch [ 800/1000] MMD²: 0.000327 (改善: +51.50%) Best: 0.000327
Epoch [ 900/1000] MMD²: 0.000307 (改善: +54.48%) Best: 0.000307
Epoch [1000/1000] MMD²: 0.000289 (改善: +57.18%) Best: 0.000289
--------------------------------------------------
训练完成！初始MMD²: 0.000674 → 最终: 0.000289
MMD降低: 57.18%
适配器已保存至: /data3/fanpeishan/state/for_state/run_results/run20/adapters/adapter_shift_20251223_165802_final.pt
训练成功完成！

## 4.应用 MMD 对齐

In [None]:
# 使用训练好的适配器,将目标域数据对齐到源域
export CUDA_VISIBLE_DEVICES=3
python /data3/fanpeishan/state/for_state/MMD_alignment_experiment/scripts/apply_mmd_alignment.py \
  --input_data /data3/fanpeishan/state/for_state/run_results/run20/data/target_test.h5ad \
  --adapter_path /data3/fanpeishan/state/for_state/run_results/run20/adapters/adapter_shift_final_weights.pt \
  --output_data /data3/fanpeishan/state/for_state/run_results/run20/data/target_test_aligned.h5ad \
  --adapter_type shift \
  --batch_size 10000 \
  --seed 42

In [None]:
使用设备: cuda
============================================================
开始 MMD 对齐应用（精简整合版）
============================================================
输入数据: /data3/fanpeishan/state/for_state/run_results/run20/data/target_test.h5ad
适配器: /data3/fanpeishan/state/for_state/run_results/run20/adapters/adapter_shift_final_weights.pt (类型: shift)
输出数据: /data3/fanpeishan/state/for_state/run_results/run20/data/target_test_aligned.h5ad
------------------------------------------------------------
加载输入数据...
数据形状: (321297, 2000)
特征维度: 2000
特征范围: [0.000, 184.255]
------------------------------------------------------------
加载适配器...
适配器加载成功
------------------------------------------------------------
应用对齐转换...
批量对齐: 321297 个样本, 33 个批次 (batch_size=10000)
  处理批次 10/33
  处理批次 20/33
  处理批次 30/33
  处理批次 33/33
对齐转换完成
对齐后特征范围: [-0.237, 184.017]
------------------------------------------------------------
保存对齐结果...
警告: 原始 .X 将被覆盖！建议提前备份原始数据
数据已保存: /data3/fanpeishan/state/for_state/run_results/run20/data/target_test_aligned.h5ad
============================================================
对齐成功完成！
============================================================

## 5.对齐组推理

In [None]:
# 使用 MMD 对齐后的数据进行零样本推理
export CUDA_VISIBLE_DEVICES=3
state tx infer \
  --model-dir /data3/fanpeishan/state/for_state/models/ST-Tahoe \
  --checkpoint /data3/fanpeishan/state/for_state/models/ST-Tahoe/final_from_preprint.ckpt \
  --adata /data3/fanpeishan/state/for_state/run_results/run20/data/target_test_aligned.h5ad \
  --output /data3/fanpeishan/state/for_state/run_results/run20/data/target_test_aligned_infer.h5ad \
  --pert-col drugname_drugconc \
  --control-pert "[('DMSO_TF', 0.0, 'uM')]" \
  --quiet \
  --batch-col plate

In [None]:
=== Inference complete ===
Input cells:         321297
Controls simulated:  7511
Treated simulated:   313786
Wrote predictions to adata.X
Saved:               /data3/fanpeishan/state/for_state/run_results/run20/data/target_test_aligned_infer.h5ad

## 6.对齐组评估

In [None]:
# 使用 cell-eval 评估对齐组的推理结果
export CUDA_VISIBLE_DEVICES=3
cell-eval run \
    -ap /data3/fanpeishan/state/for_state/run_results/run20/data/target_test_aligned_infer.h5ad \
    -ar /data3/fanpeishan/state/for_state/run_results/run20/data/target_test.h5ad \
    -o /data3/fanpeishan/state/for_state/run_results/run20/run_results/ \
    --control-pert "[('DMSO_TF', 0.0, 'uM')]" \
    --pert-col 'drugname_drugconc' \
    --profile minimal \
    --batch-size 1024 \
    --num-threads 12 \
    --skip-metrics de_nsig_counts_real,de_nsig_counts_pred

In [1]:
import pandas as pd
results=pd.read_csv('/data3/fanpeishan/state/for_state/run_results/run20/run_results/agg_results.csv')
mean_results=results[results.statistic == 'mean'][['overlap_at_N','pearson_delta','mse']]
print(mean_results)

   overlap_at_N  pearson_delta       mse
2      0.129052       0.367432  0.126995
