# LoRA plus MMD pipeline

## 1. 数据准备

In [None]:
源域数据：/data3/fanpeishan/state/for_state/data/State-Tahoe-Filtered-processed/c37_prep.h5ad
目标域数据：/data3/fanpeishan/state/for_state/data/State-Tahoe-Filtered-processed/c38_prep.h5ad

In [None]:
# 将目标域数据划分为对齐集和测试集
# 对齐集：/data3/fanpeishan/state/for_state/run_results/run23/data_align.h5ad
# 测试集：/data3/fanpeishan/state/for_state/run_results/run23/data_test.h5ad

## 2. 在对齐集上进行LoRA微调

In [None]:
export CUDA_VISIBLE_DEVICES=4
python /data3/fanpeishan/state/for_state/scripts/finetune_v2.py \
  --model_dir /data3/fanpeishan/state/for_state/models/ST-Tahoe \
  --checkpoint /data3/fanpeishan/state/for_state/models/ST-Tahoe/final_from_preprint.ckpt \
  --adata /data3/fanpeishan/state/for_state/run_results/run23/data_align.h5ad \
  --pert_col drugname_drugconc \
  --batch_col plate \
  --output_lora /data3/fanpeishan/state/for_state/run_results/run23/lora_state.pth \
  --save_dir /data3/fanpeishan/state/for_state/run_results/run23/lora_out \
  --epochs 10 \
  --batch_size 128 \
  --lr 2e-4 \
  --lora_rank 16 \
  --lora_alpha 32 \
  --use_delta_loss \
  --lambda_delta 1.0 \
  --lambda_pearson 0.4 \
  --lambda_mse 0.2 \
  --weight_decay 1e-4 \
  --target_modules pert_encoder,basal_encoder,transformer_backbone

## 3.在源域数据和对齐集上训练MMD adatper

In [None]:
export CUDA_VISIBLE_DEVICES=3
python /data3/fanpeishan/state/for_state/MMD_alignment_experiment/scripts/train_mmd_adapter.py \
  --source_data /data3/fanpeishan/state/for_state/data/State-Tahoe-Filtered-processed/c37_prep.h5ad \
  --target_data /data3/fanpeishan/state/for_state/run_results/run23/data_align.h5ad \
  --output_dir /data3/fanpeishan/state/for_state/run_results/run23/adapters/ \
  --pert_col drugname_drugconc \
  --control_name "[('DMSO_TF', 0.0, 'uM')]" \
  --adapter_type shift \
  --epochs 1000 \
  --lr 2e-4 \
  --log_interval 100 \
  --source_sample_size 8000 \
  --target_sample_size 4000 \
  --seed 42 

## 4. 将MMD adapter应用在测试集上

In [None]:
# 使用训练好的适配器,将目标域数据对齐到源域
export CUDA_VISIBLE_DEVICES=3
python /data3/fanpeishan/state/for_state/MMD_alignment_experiment/scripts/apply_mmd_alignment.py \
  --input_data /data3/fanpeishan/state/for_state/run_results/run23/data_test.h5ad \
  --adapter_path /data3/fanpeishan/state/for_state/run_results/run23/adapters/adapter_shift_final_weights.pt \
  --output_data /data3/fanpeishan/state/for_state/run_results/run23/data_test_aligned.h5ad \
  --adapter_type shift \
  --batch_size 10000 \
  --seed 42

## 5. 在对齐后的测试集上，应用LoRA微调后的st进行推理

In [None]:
export CUDA_VISIBLE_DEVICES=4
python /data3/fanpeishan/state/for_state/scripts/infer_lora_v2.py \
  --model_dir /data3/fanpeishan/state/for_state/models/ST-Tahoe \
  --checkpoint /data3/fanpeishan/state/for_state/models/ST-Tahoe/final_from_preprint.ckpt \
  --lora_path /data3/fanpeishan/state/for_state/run_results/run23/lora_out/lora_epoch4.pth \
  --adata /data3/fanpeishan/state/for_state/run_results/run23/data_test_aligned.h5ad \
  --pert_col drugname_drugconc \
  --output /data3/fanpeishan/state/for_state/run_results/run23/data_test_aligned_pred.h5ad \
  --batch_size 1024

## 6. 评估结果

In [None]:
export CUDA_VISIBLE_DEVICES=4
cell-eval run \
    -ap /data3/fanpeishan/state/for_state/run_results/run23/data_test_aligned_pred.h5ad \
    -ar /data3/fanpeishan/state/for_state/run_results/run23/data_test.h5ad \
    -o /data3/fanpeishan/state/for_state/run_results/run23/eval_results \
    --control-pert "[('DMSO_TF', 0.0, 'uM')]" \
    --pert-col 'drugname_drugconc' \
    --profile full \
    --batch-size 1024 \
    --num-threads 12 

In [None]:
import pandas as pd
results=pd.read_csv('/data3/fanpeishan/state/for_state/run_results/run23/eval_results/agg_results.csv')
mean_results=results[results.statistic == 'mean'][['overlap_at_100', 'pearson_delta','mse']]
print(mean_results)