# import

In [2]:
############################### import

import sys
import json
from anndata import AnnData
import cpa
import scanpy as sc
import importlib
import numpy as np
from tqdm import tqdm
import pickle
import os
import shutil
import torch
import pandas as pd
from scipy.spatial.distance import cdist

sys.path.append("/data1/lichen/code/single_cell_perturbation/scPerturb/Byte_Pert_Data/")
import v1
from v1.utils import *
from v1.dataloader import *

import argparse

# importlib.reload(v1)
# importlib.reload(v1.utils)
# importlib.reload(v1.dataloader)

# function

In [44]:
############################### function

def cpa_single_pert(pert, model_mode='whole'):
    
    cell_type = pert_data.filter_perturbation_list[0].split(' | ')[1]
    # - get adata_pert and adata_ctrl
    adata_pert = adata_split[adata_split.obs['perturbation_group']==pert+' | '+cell_type].copy()
    adata_ctrl = adata_split[list(adata_pert.obs['control_barcode'])].copy()

    adata_pert = adata_pert[:, common_var]
    adata_ctrl = adata_ctrl[:, common_var]


    # - get adata_rna_common
    adata_rna_common = adata_rna[:, common_var]

    # - generate adata_train to input to scGen model
    np_list, obs_list, pert_list, celltype_list = [], [], [], []
    pert_list_2 = []
    adata_list = [adata_pert, adata_rna_common, adata_ctrl, adata_rna_common]
    for j, adata_ in enumerate(adata_list):
        if j in [0, 1]:
            pert_list.extend(['stimulated']*len(adata_))
        else:
            pert_list.extend(['ctrl']*len(adata_))

        if j in [0, 2]:
            celltype_list.extend([cell_type]*len(adata_))
        else:
            celltype_list.extend([dataset]*len(adata_))
        obs_list.extend([obs+f'_{j}' for obs in adata_.obs_names])
        
        if not isinstance(adata_.X, np.ndarray):
            np_list.append(adata_.X.toarray())
        else:
            np_list.append(adata_.X)

    adata_train = AnnData(X = np.vstack(np_list))
    adata_train.obs_names = obs_list
    adata_train.var_names = adata_pert.var_names

    adata_train.obs['condition'] = pert_list
    # adata_train.obs['condition_2'] = pert_list_2
    adata_train.obs['cell_type'] = celltype_list

    # - transform the adata_train.X to count
    adata_train.obs['cov_cond'] = adata_train.obs['cell_type'] + '_' + adata_train.obs['condition']
    adata_train.X = np.exp(adata_train.X)-1
    
    if model_prefix == 'CPA_v1':
        # - add norm
        sc.pp.normalize_per_cell(adata_train, key_n_counts='n_counts_all')

    # - initial model
    cpa.CPA.setup_anndata(adata_train, 
                        perturbation_key='condition',
                        control_group='ctrl',
                        #   dosage_key='dose',
                        categorical_covariate_keys=['cell_type'],
                        is_count_data=True,
                        #   deg_uns_key='rank_genes_groups_cov',
                        deg_uns_cat_key='cov_cond',
                        max_comb_len=1,
                        )

    # - set the train and validation for cpa
    # -- get total obs_names of the pert
    adata_train_new = adata_train[~((adata_train.obs["cell_type"] == dataset) &
                        (adata_train.obs["condition"] == "stimulated"))].copy()
    # obs_df_split = adata_train_new.obs
    obs_df_sub_idx = np.array(adata_train_new.obs.index)

    np.random.seed(2024)
    np.random.shuffle(obs_df_sub_idx)

    # -- data split
    split_point_1 = int(len(obs_df_sub_idx) * 0.9)
    split_point_2 = int(len(obs_df_sub_idx) * (0.9+0.1))
    train = obs_df_sub_idx[:split_point_1]
    valid = obs_df_sub_idx[split_point_1:split_point_2]


    adata_train.obs['split_key'] = 'ood'

    # -- set the test row
    adata_train.obs.loc[train,'split_key'] = 'train'
    adata_train.obs.loc[valid,'split_key'] = 'valid'

    # - initial the model and training   
    model = cpa.CPA(adata=adata_train, 
                    split_key='split_key',
                    train_split='train',
                    valid_split='valid',
                    test_split='ood',
                    **model_params,
                )

    # if cell_line_bulk == 'PC3' and adata_train.shape[0] == 726:
    #     batch_size = 512
    # else:
    #     batch_size = 500
    batch_size = 500
    model.train(max_epochs=2000,
                use_gpu=True, 
                batch_size=batch_size,
                plan_kwargs=trainer_params,
                early_stopping_patience=5,
                check_val_every_n_epoch=5,
                # save_path='../../datasets/',
                progress_bar_refresh_rate = 0
            )

    # - predict result
    model.predict(adata_train, batch_size=2048)
    
    if model_prefix == 'CPA_v3':
        # - get the pred data
        cat = dataset + '_' + 'stimulated'
        cat_adata = adata_train[adata_train.obs['cov_cond'] == cat].copy()
        x_pred_sti = cat_adata.obsm['CPA_pred']
        x_pred_sti = np.log1p(x_pred_sti)
        
        # - get the pred data
        cat = dataset + '_' + 'ctrl'
        cat_adata = adata_train[adata_train.obs['cov_cond'] == cat].copy()
        x_pred_ctrl = cat_adata.obsm['CPA_pred']
        x_pred_ctrl = np.log1p(x_pred_ctrl)
        
        x_pred = x_pred_sti - x_pred_ctrl
        
        if model_mode == 'subset':
            # - get pert_gene_rank_dict
            adata_ctrl = adata_rna_common.copy()
            adata_pert = adata_ctrl.copy()
            adata_pert.X = x_pred
            
        elif model_mode == 'whole':
            # - get pert_gene_rank_dict
            adata_ctrl = adata_rna.copy()
            adata_pert = adata_ctrl.copy()
            adata_pert.X = x_pred[:, common_idx]

        elif model_mode == 'zero_pad':
            adata_ctrl = adata_rna.copy()
            adata_pert = adata_ctrl.copy()
            adata_pert.X = np.zeros(adata_pert.X.shape)
            for i, gene in enumerate(adata_pert.var_names):
                if gene in common_var:
                    adata_pert.X[:, i] = x_pred[:, list(common_var).index(gene)]

        
        else:
            raise ValueError()
        
        
        return adata_pert
        
    else:

        # - get the pred data
        cat = cell_line_bulk + '_' + 'stimulated'
        cat_adata = adata_train[adata_train.obs['cov_cond'] == cat].copy()
        x_pred = cat_adata.obsm['CPA_pred']
        x_pred = np.log1p(x_pred)
    
    if model_prefix == 'CPA_v2': # normalize output
        x_pred = x_pred / x_pred.mean(1).reshape(-1,1) * adata_rna_common.X.mean(1).reshape(-1, 1)

    if model_mode == 'subset':
        # - get pert_gene_rank_dict
        adata_ctrl = adata_rna_common.copy()
        adata_pert = adata_ctrl.copy()
        adata_pert.X = x_pred
        
    elif model_mode == 'whole':
        # - get pert_gene_rank_dict
        adata_ctrl = adata_rna.copy()
        adata_pert = adata_ctrl.copy()
        adata_pert.X = x_pred[:, common_idx]
    else:
        raise ValueError()
        

    return adata_pert

# init

In [4]:
############################### init
# - model initial
model_params = {
    "n_latent": 64,
    "recon_loss": "nb",
    "doser_type": "linear",
    "n_hidden_encoder": 128,
    "n_layers_encoder": 2,
    "n_hidden_decoder": 512,
    "n_layers_decoder": 2,
    "use_batch_norm_encoder": True,
    "use_layer_norm_encoder": False,
    "use_batch_norm_decoder": False,
    "use_layer_norm_decoder": True,
    "dropout_rate_encoder": 0.0,
    "dropout_rate_decoder": 0.1,
    "variational": False,
    "seed": 6977,
}

trainer_params = {
    "n_epochs_kl_warmup": None,
    "n_epochs_pretrain_ae": 30,
    "n_epochs_adv_warmup": 50,
    "n_epochs_mixup_warmup": 0,
    "mixup_alpha": 0.0,
    "adv_steps": None,
    "n_hidden_adv": 64,
    "n_layers_adv": 3,
    "use_batch_norm_adv": True,
    "use_layer_norm_adv": False,
    "dropout_rate_adv": 0.3,
    "reg_adv": 20.0,
    "pen_adv": 5.0,
    "lr": 0.0003,
    "wd": 4e-07,
    "adv_lr": 0.0003,
    "adv_wd": 4e-07,
    "adv_loss": "cce",
    "doser_lr": 0.0003,
    "doser_wd": 4e-07,
    "do_clip_grad": True,
    "gradient_clip_value": 1.0,
    "step_size_lr": 10,
}

# - get cell line name
common_cell_line = \
{   'A549': 'A549',
    'HEPG2': 'HepG2',
    'HT29': 'HT29',
    'MCF7': 'MCF7',
    # 'SKBR3': 'SK-BR-3',
    'SW480': 'SW480',
    'PC3': 'PC3',
    'A375': 'A375',
} # L1000 cell line : single-cell cell line

# - read adata_L1000, this is processed data
adata_L1000 = sc.read('/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/GSE92742/adata_gene_pert.h5ad')
adata_L1000

model_prefix = 'CPA_v3'

In [5]:
dataset_pert_dict = {
    'CAR_T': ['PDCD1'], # direct_transfer只有前两个数据有
    'blood': ['GATA1', 'SPI1'],
    # 'OSKM': ['SOX2',
    #      'POU5F1',
    #      'KLF4',
    #      'MYC'],
    # 'ADM': ['PTF1A']
}
dataset_celltype_dict = {
    'CAR_T': 'Tex', 
    'blood': 'LMPP',
    'OSKM': 'Fibroblast-like',
    'ADM': 'Acinar'
}

dataset_dire_dict = {
    'CAR_T': 'down', 
    'blood': 'down',
    'OSKM': 'up',
    'ADM': 'down'
}


datasets = list(dataset_pert_dict.keys())

In [6]:
gene_dataset_dict = {
    'SPI1': 'AdamsonWeissman2016_GSM2406675_10X001',
    'GATA1': 'ReplogleWeissman2022_K562_essential',
    'PDCD1': 'ShifrutMarson2018',
}

In [7]:
# - init dataloader para
data_dir = '/nfs/public/lichen/data/single_cell/perturb_data/scPerturb/raw/scPerturb_rna/statistic_20240520'
pert_cell_filter = 100 # this is used to filter perts, cell number less than this will be filtered
seed = 2024 # this is the random seed
split_type = 1 # 1 for unseen perts; 0 for unseen celltypes
split_ratio = [0.7, 0.2, 0.1] # train:test:val; val is used to choose data, test is for final validation
var_num = 5000 # selecting hvg number
num_de_genes = 20 # number of de genes
bs_train = 32 # batch size of trainloader
bs_test = 32 # batch size of testloader
lr = 1e-4

# main

In [49]:
parser = argparse.ArgumentParser(description="CPA on real case")
# parser.add_argument('--cell_line_bulk', type=str, default=None)
# parser.add_argument('--model_mode', type=str, default='whole') # pretrain, init
args = parser.parse_args([])

adata_mode = 'non_minus'
# minus: save adata as the minus delta
# non_minus: save adata, add the minus delta on the original gene exp

args.cell_line_bulk = None
args.model_mode = 'zero_pad'
# whole: find similar genes and pad; 
# subset: discard non overlap; 
# zero_pad: for non overlap, pad zero values

model_prefix = 'CPA_v3'

# CPA_v3: minus version

torch.cuda.set_device(1)

# - get paras
cell_line_bulk = args.cell_line_bulk
# cell_line_single = common_cell_line[cell_line_bulk]

model_mode = args.model_mode

for dataset in datasets[:]:
    # dataset = 'OSKM'
    # dataset = 'blood'
    # dataset = 'CAR_T'
    # dataset = 'ADM'
    print('='*20, dataset, '='*20)
    save_dir = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/real_case/data'
    adata_rna = sc.read(os.path.join(save_dir, dataset, 'adata_ctrl_v2.h5ad'))

    if not isinstance(adata_rna.X, np.ndarray):
        adata_rna.X = adata_rna.X.toarray()

    for pert in tqdm(dataset_pert_dict[dataset]):

        if pert not in gene_dataset_dict:
            print(f'{pert} can not be found in scPerturb pert_data')
            continue

        print('='*20, pert, '='*20)

        prefix = gene_dataset_dict[pert]

        # - read dataset
        tmp_dir = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/scPerturb'
        save_prefix = f'GEARS_v2-prefix_{prefix}-pert_cell_filter_{pert_cell_filter}-\
seed_{seed}-split_type_{split_type}-var_num_{var_num}-num_de_genes_{num_de_genes}-bs_train_{bs_train}-\
bs_test_{bs_test}'
        save_dir = os.path.join(tmp_dir, prefix, save_prefix)
        # - load pert_data
        pert_data = pickle.load(open(os.path.join(save_dir,'pert_data.pkl'), 'rb'))

        adata_split = pert_data.adata_split

        # - get common var
        common_var = np.intersect1d(adata_rna.var_names, adata_split.var_names)
        # common_var_2 = np.intersect1d(common_var, adata_L1000.var_names)

        print('common var of direct change and single-cell data is: ', len(common_var))
        # print('common var to L1000 data is: ', len(common_var_2))

        # 最近基因计算
        matrix = adata_rna.X.T
        index_list = np.array([list(adata_rna.var_names).index(i) for i in common_var])

        distance_matrix = cdist(matrix, matrix, metric='cosine')
        np.fill_diagonal(distance_matrix, np.inf)
        mask = np.ones(distance_matrix.shape, dtype=bool)
        mask[:, index_list] = False
        distance_matrix[mask] = np.inf
        nearest_indices = np.argmin(distance_matrix, axis=1)
        nearest_indices_list = nearest_indices.tolist()

        common_idx = [list(common_var).index(gene) if i in common_var else list(common_var).index(adata_rna.var_names[nearest_indices_list[i]]) for i, gene in enumerate(adata_rna.var_names)]

        torch.cuda.set_device(1)
        # - run CPA
        adata_pert = cpa_single_pert(pert, model_mode)
            
            
        tmp_dir = f'/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark_202410/real_case/result/{dataset}'
        save_prefix = f'CPA/{pert}' # use result of K562 to do the direct transfer
        os.makedirs(os.path.join(tmp_dir, save_prefix), exist_ok=True)

        if adata_mode == 'minus':
            adata_pert.write(os.path.join(tmp_dir, save_prefix, 'adata_pert_minus.h5ad'))
        elif adata_mode == 'non_minus':
            adata_pert.X = adata_rna.X + adata_pert.X
            adata_pert.write(os.path.join(tmp_dir, save_prefix, 'adata_pert.h5ad'))

    #     break
    # break



  0%|          | 0/1 [00:00<?, ?it/s]

common var of direct change and single-cell data is:  750


100%|██████████| 4148/4148 [00:00<00:00, 91608.78it/s]
100%|██████████| 4148/4148 [00:00<00:00, 507288.69it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 646.77it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.5749346758297766
disnt_after = 2.0
val_r2_mean = 0.6618771553039551
val_r2_var = -0.46952927112579346



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.4977304192713927
disnt_after = 2.0
val_r2_mean = 0.6649097601572672
val_r2_var = -0.4202553828557332



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.4346504331177627
disnt_after = 2.0
val_r2_mean = 0.6800881425539652
val_r2_var = -0.3028413454691569



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.2482252003695335
disnt_after = 2.0
val_r2_mean = 0.7104568282763163
val_r2_var = -0.1482564608256022

disnt_basal = 1.5883297868054402
disnt_after = 2.0
val_r2_mean = 0.771066149075826
val_r2_var = 0.048276940981547035



Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.327965921977321
disnt_after = 2.0
val_r2_mean = 0.7902522683143616
val_r2_var = 0.16724008321762085



Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.2182651937153377
disnt_after = 2.0
val_r2_mean = 0.8297659357388815
val_r2_var = 0.2861541112263997



Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.1674344678929511
disnt_after = 2.0
val_r2_mean = 0.8415791988372803
val_r2_var = 0.32634154955546063



Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.1744286511176978
disnt_after = 2.0
val_r2_mean = 0.8594867785771688
val_r2_var = 0.3807335098584493



Epoch 00094: cpa_metric reached. Module best state updated.



disnt_basal = 1.1052602861236802
disnt_after = 2.0
val_r2_mean = 0.8656089901924133
val_r2_var = 0.41238588094711304

disnt_basal = 1.0718043248120916
disnt_after = 2.0
val_r2_mean = 0.8649522662162781
val_r2_var = 0.4377317229906718



Epoch 00114: cpa_metric reached. Module best state updated.



disnt_basal = 1.0469504074284761
disnt_after = 2.0
val_r2_mean = 0.8785508275032043
val_r2_var = 0.45979710419972736



Epoch 00124: cpa_metric reached. Module best state updated.

Epoch 00129: cpa_metric reached. Module best state updated.



disnt_basal = 1.0417310078035242
disnt_after = 2.0
val_r2_mean = 0.8797484238942465
val_r2_var = 0.47189627091089886



Epoch 00139: cpa_metric reached. Module best state updated.



disnt_basal = 1.0419406304307635
disnt_after = 2.0
val_r2_mean = 0.8843962152798971
val_r2_var = 0.48448437452316284



Epoch 00144: cpa_metric reached. Module best state updated.

Epoch 00149: cpa_metric reached. Module best state updated.



disnt_basal = 1.0352091696656793
disnt_after = 2.0
val_r2_mean = 0.881819466749827
val_r2_var = 0.4930091102917989

disnt_basal = 1.0378484792147784
disnt_after = 2.0
val_r2_mean = 0.8811887502670288
val_r2_var = 0.49894193808237713



Epoch 00169: cpa_metric reached. Module best state updated.



disnt_basal = 1.0319951824246372
disnt_after = 2.0
val_r2_mean = 0.885300079981486
val_r2_var = 0.5251136223475138



Epoch 00174: cpa_metric reached. Module best state updated.



disnt_basal = 1.0369625292243212
disnt_after = 2.0
val_r2_mean = 0.888872762521108
val_r2_var = 0.5277415911356608

disnt_basal = 1.035637651983126
disnt_after = 2.0
val_r2_mean = 0.8887220422426859
val_r2_var = 0.5329270561536154



Epoch 00194: cpa_metric reached. Module best state updated.

Epoch 00199: cpa_metric reached. Module best state updated.



disnt_basal = 1.0304959402515255
disnt_after = 2.0
val_r2_mean = 0.8912355502446493
val_r2_var = 0.5381438533465067



Epoch 00204: cpa_metric reached. Module best state updated.



disnt_basal = 1.030754611706222
disnt_after = 2.0
val_r2_mean = 0.8937287529309591
val_r2_var = 0.5465837915738424



Epoch 00214: cpa_metric reached. Module best state updated.



disnt_basal = 1.0313411439437008
disnt_after = 2.0
val_r2_mean = 0.8932034174601237
val_r2_var = 0.560086190700531

disnt_basal = 1.031263211786393
disnt_after = 2.0
val_r2_mean = 0.8941017190615336
val_r2_var = 0.5645580887794495



Epoch 00234: cpa_metric reached. Module best state updated.



disnt_basal = 1.0286424590977965
disnt_after = 2.0
val_r2_mean = 0.8899584213892618
val_r2_var = 0.5775222380956014



Epoch 00244: cpa_metric reached. Module best state updated.



disnt_basal = 1.0286213298579707
disnt_after = 2.0
val_r2_mean = 0.8981737494468689
val_r2_var = 0.5782150030136108



Epoch 00259: cpa_metric reached. Module best state updated.



disnt_basal = 1.0261316505943157
disnt_after = 2.0
val_r2_mean = 0.8979706565539042
val_r2_var = 0.5880561073621114



Epoch 00269: cpa_metric reached. Module best state updated.



disnt_basal = 1.0275848015988154
disnt_after = 2.0
val_r2_mean = 0.9023659030596415
val_r2_var = 0.5882585644721985



Epoch 00274: cpa_metric reached. Module best state updated.

Epoch 00279: cpa_metric reached. Module best state updated.



disnt_basal = 1.022066428053114
disnt_after = 2.0
val_r2_mean = 0.9027782678604126
val_r2_var = 0.5861150423685709



Epoch 00289: cpa_metric reached. Module best state updated.



disnt_basal = 1.0207695396846548
disnt_after = 2.0
val_r2_mean = 0.9020123680432638
val_r2_var = 0.5978950659434

disnt_basal = 1.0219635118229569
disnt_after = 2.0
val_r2_mean = 0.9010432958602905
val_r2_var = 0.6019090612729391



Epoch 00304: cpa_metric reached. Module best state updated.



disnt_basal = 1.0206779360818339
disnt_after = 2.0
val_r2_mean = 0.9021538694699606
val_r2_var = 0.6117414832115173



Epoch 00319: cpa_metric reached. Module best state updated.



disnt_basal = 1.0168564584025979
disnt_after = 2.0
val_r2_mean = 0.9029973745346069
val_r2_var = 0.6175216833750407



Epoch 00324: cpa_metric reached. Module best state updated.

Epoch 00329: cpa_metric reached. Module best state updated.



disnt_basal = 1.0133534736165961
disnt_after = 2.0
val_r2_mean = 0.907687246799469
val_r2_var = 0.621389369169871



Epoch 00339: cpa_metric reached. Module best state updated.



disnt_basal = 1.0144801582350063
disnt_after = 2.0
val_r2_mean = 0.9084329803784689
val_r2_var = 0.6277651786804199

disnt_basal = 1.0169129193777113
disnt_after = 2.0
val_r2_mean = 0.9075831373532613
val_r2_var = 0.6199437777201334



Epoch 00354: cpa_metric reached. Module best state updated.



disnt_basal = 1.0152847220019137
disnt_after = 2.0
val_r2_mean = 0.9046849211057028
val_r2_var = 0.6364588936169943



Epoch 00364: cpa_metric reached. Module best state updated.

Epoch 00369: cpa_metric reached. Module best state updated.



disnt_basal = 1.0128239525594727
disnt_after = 2.0
val_r2_mean = 0.9065301815668741
val_r2_var = 0.6360171238581339



Epoch 00379: cpa_metric reached. Module best state updated.



disnt_basal = 1.0122781831943595
disnt_after = 2.0
val_r2_mean = 0.9099411765734354
val_r2_var = 0.6397326985994974



Epoch 00389: cpa_metric reached. Module best state updated.



disnt_basal = 1.0096176267566097
disnt_after = 2.0
val_r2_mean = 0.9073919256528219
val_r2_var = 0.645405093828837

disnt_basal = 1.0139571066405573
disnt_after = 2.0
val_r2_mean = 0.9006768862406412
val_r2_var = 0.6505118211110433



Epoch 00409: cpa_metric reached. Module best state updated.



disnt_basal = 1.00549554491887
disnt_after = 2.0
val_r2_mean = 0.9113807876904806
val_r2_var = 0.6510932445526123

disnt_basal = 1.0088301102500448
disnt_after = 2.0
val_r2_mean = 0.9128730098406473
val_r2_var = 0.6594099005063375

disnt_basal = 1.0082134709175747
disnt_after = 2.0
val_r2_mean = 0.912967840830485
val_r2_var = 0.6562577287356058



Epoch 00434: cpa_metric reached. Module best state updated.



disnt_basal = 1.0100928653796721
disnt_after = 2.0
val_r2_mean = 0.9124264319737753
val_r2_var = 0.6529348691304525



Epoch 00444: cpa_metric reached. Module best state updated.



disnt_basal = 1.005425684883802
disnt_after = 2.0
val_r2_mean = 0.914214034875234
val_r2_var = 0.6644651691118876

disnt_basal = 1.0081357069464616
disnt_after = 2.0
val_r2_mean = 0.9117849469184875
val_r2_var = 0.6715217034022013

disnt_basal = 1.0126526853509954
disnt_after = 2.0
val_r2_mean = 0.916390856107076
val_r2_var = 0.6691354115804037


100%|██████████| 3/3 [00:00<00:00, 58.83it/s]
100%|██████████| 1/1 [01:56<00:00, 116.20s/it]




  0%|          | 0/2 [00:00<?, ?it/s]

common var of direct change and single-cell data is:  494


100%|██████████| 14362/14362 [00:00<00:00, 41898.52it/s]
100%|██████████| 14362/14362 [00:00<00:00, 419789.92it/s]


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 170.42it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.0694425517908046
disnt_after = 1.274846490390486
val_r2_mean = 0.7312578658262889
val_r2_var = -0.6477146744728088



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.0799994002203113
disnt_after = 1.2645305525719057
val_r2_mean = 0.7810930709044139
val_r2_var = -0.37496981024742126



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.0613946097480351
disnt_after = 1.2605579582042317
val_r2_mean = 0.8030420343081157
val_r2_var = -0.15025739868481955



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.0159836900233963
disnt_after = 1.2585126441347103
val_r2_mean = 0.7845711608727772
val_r2_var = -0.0709395706653595



Epoch 00044: cpa_metric reached. Module best state updated.

Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.0126758231913813
disnt_after = 1.2573876263531436
val_r2_mean = 0.8404724498589834
val_r2_var = -0.059096972147623696



Epoch 00054: cpa_metric reached. Module best state updated.



disnt_basal = 1.0037486341619517
disnt_after = 1.252275821439687
val_r2_mean = 0.8351262708504995
val_r2_var = -0.04307033618291219

disnt_basal = 1.0047229380455178
disnt_after = 1.2564642564139685
val_r2_mean = 0.8023407955964407
val_r2_var = -0.011872212092081705



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.0027057415880583
disnt_after = 1.2589122342986447
val_r2_mean = 0.8141951262950897
val_r2_var = 0.0490286648273468



Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 0.9990744367399101
disnt_after = 1.2610386005641225
val_r2_mean = 0.8418309390544891
val_r2_var = 0.07316219806671143

disnt_basal = 0.9948269678483843
disnt_after = 1.2568683548146025
val_r2_mean = 0.8346656461556752
val_r2_var = 0.029072205225626625



Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.017829335016835
disnt_after = 1.2731007996632997
val_r2_mean = 0.8522086888551712
val_r2_var = 0.10852158069610596

disnt_basal = 1.0020058464392616
disnt_after = 1.257022054313076
val_r2_mean = 0.8231609960397085
val_r2_var = 0.09130229552586873

disnt_basal = 1.018511340278663
disnt_after = 1.2612565355902834
val_r2_mean = 0.8156916797161102
val_r2_var = 0.07555624842643738


100%|██████████| 8/8 [00:00<00:00, 82.49it/s]
 50%|█████     | 1/2 [04:25<04:25, 265.87s/it]

common var of direct change and single-cell data is:  441


100%|██████████| 15538/15538 [00:00<00:00, 106016.67it/s]
100%|██████████| 15538/15538 [00:00<00:00, 900650.85it/s]


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 375.77it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6067034702612197
disnt_after = 2.0
val_r2_mean = 0.924451728661855
val_r2_var = -0.42099299033482873



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.591637574511886
disnt_after = 1.9583333333333333
val_r2_mean = 0.9512549738089243
val_r2_var = -0.048510223627090454



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.581436523504678
disnt_after = 2.0
val_r2_mean = 0.962794045607249
val_r2_var = 0.15578149755795795



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.4790529980642808
disnt_after = 1.975
val_r2_mean = 0.9616301556428273
val_r2_var = 0.20164329806963602



Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.3833205435868674
disnt_after = 1.975
val_r2_mean = 0.965000331401825
val_r2_var = 0.2536863287289937

disnt_basal = 1.4762289867445764
disnt_after = 1.9333333333333333
val_r2_mean = 0.9681967496871948
val_r2_var = 0.2769266366958618

disnt_basal = 1.5825009416631644
disnt_after = 1.9500000000000002
val_r2_mean = 0.9607940316200256
val_r2_var = 0.2851107021172842


100%|██████████| 8/8 [00:00<00:00, 108.20it/s]
100%|██████████| 2/2 [05:34<00:00, 167.04s/it]




100%|██████████| 4/4 [00:00<00:00, 32017.59it/s]


SOX2 can not be found in scPerturb pert_data
POU5F1 can not be found in scPerturb pert_data
KLF4 can not be found in scPerturb pert_data
MYC can not be found in scPerturb pert_data


100%|██████████| 1/1 [00:00<00:00, 9845.78it/s]

PTF1A can not be found in scPerturb pert_data





: 

# debug

In [30]:
cell_type = pert_data.filter_perturbation_list[0].split(' | ')[1]
# - get adata_pert and adata_ctrl
adata_pert = adata_split[adata_split.obs['perturbation_group']==pert+' | '+cell_type].copy()
adata_ctrl = adata_split[list(adata_pert.obs['control_barcode'])].copy()

adata_pert = adata_pert[:, common_var]
adata_ctrl = adata_ctrl[:, common_var]


# - get adata_rna_common
adata_rna_common = adata_rna[:, common_var]

# - generate adata_train to input to scGen model
np_list, obs_list, pert_list, celltype_list = [], [], [], []
pert_list_2 = []
adata_list = [adata_pert, adata_rna_common, adata_ctrl, adata_rna_common]
for j, adata_ in enumerate(adata_list):
    if j in [0, 1]:
        pert_list.extend(['stimulated']*len(adata_))
    else:
        pert_list.extend(['ctrl']*len(adata_))

    if j in [0, 2]:
        celltype_list.extend([cell_type]*len(adata_))
    else:
        celltype_list.extend([dataset]*len(adata_))
    obs_list.extend([obs+f'_{j}' for obs in adata_.obs_names])
    
    if not isinstance(adata_.X, np.ndarray):
        np_list.append(adata_.X.toarray())
    else:
        np_list.append(adata_.X)

adata_train = AnnData(X = np.vstack(np_list))
adata_train.obs_names = obs_list
adata_train.var_names = adata_pert.var_names

adata_train.obs['condition'] = pert_list
# adata_train.obs['condition_2'] = pert_list_2
adata_train.obs['cell_type'] = celltype_list

# - transform the adata_train.X to count
adata_train.obs['cov_cond'] = adata_train.obs['cell_type'] + '_' + adata_train.obs['condition']
adata_train.X = np.exp(adata_train.X)-1

if model_prefix == 'CPA_v1':
    # - add norm
    sc.pp.normalize_per_cell(adata_train, key_n_counts='n_counts_all')

# - initial model
cpa.CPA.setup_anndata(adata_train, 
                    perturbation_key='condition',
                    control_group='ctrl',
                    #   dosage_key='dose',
                    categorical_covariate_keys=['cell_type'],
                    is_count_data=True,
                    #   deg_uns_key='rank_genes_groups_cov',
                    deg_uns_cat_key='cov_cond',
                    max_comb_len=1,
                    )

# - set the train and validation for cpa
# -- get total obs_names of the pert
adata_train_new = adata_train[~((adata_train.obs["cell_type"] == dataset) &
                    (adata_train.obs["condition"] == "stimulated"))].copy()
# obs_df_split = adata_train_new.obs
obs_df_sub_idx = np.array(adata_train_new.obs.index)

np.random.seed(2024)
np.random.shuffle(obs_df_sub_idx)

# -- data split
split_point_1 = int(len(obs_df_sub_idx) * 0.9)
split_point_2 = int(len(obs_df_sub_idx) * (0.9+0.1))
train = obs_df_sub_idx[:split_point_1]
valid = obs_df_sub_idx[split_point_1:split_point_2]


adata_train.obs['split_key'] = 'ood'

# -- set the test row
adata_train.obs.loc[train,'split_key'] = 'train'
adata_train.obs.loc[valid,'split_key'] = 'valid'

# - initial the model and training   
model = cpa.CPA(adata=adata_train, 
                split_key='split_key',
                train_split='train',
                valid_split='valid',
                test_split='ood',
                **model_params,
            )

# if cell_line_bulk == 'PC3' and adata_train.shape[0] == 726:
#     batch_size = 512
# else:
#     batch_size = 500
batch_size = 500
model.train(max_epochs=2000,
            use_gpu=True, 
            batch_size=batch_size,
            plan_kwargs=trainer_params,
            early_stopping_patience=5,
            check_val_every_n_epoch=5,
            # save_path='../../datasets/',
            progress_bar_refresh_rate = 0
        )

# - predict result
model.predict(adata_train, batch_size=2048)



100%|██████████| 14362/14362 [00:00<00:00, 58012.37it/s]
100%|██████████| 14362/14362 [00:00<00:00, 506683.55it/s]


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 335.57it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.0694425517908046
disnt_after = 1.274846490390486
val_r2_mean = 0.7312578658262889
val_r2_var = -0.6477146744728088



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.0799994002203113
disnt_after = 1.2645305525719057
val_r2_mean = 0.7810930709044139
val_r2_var = -0.37496981024742126



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.0613946097480351
disnt_after = 1.2605579582042317
val_r2_mean = 0.8030420343081157
val_r2_var = -0.15025739868481955



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.0159836900233963
disnt_after = 1.2585126441347103
val_r2_mean = 0.7845711608727772
val_r2_var = -0.0709395706653595



Epoch 00044: cpa_metric reached. Module best state updated.

Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.0126758231913813
disnt_after = 1.2573876263531436
val_r2_mean = 0.8404724498589834
val_r2_var = -0.059096972147623696



Epoch 00054: cpa_metric reached. Module best state updated.



disnt_basal = 1.0037486341619517
disnt_after = 1.252275821439687
val_r2_mean = 0.8351262708504995
val_r2_var = -0.04307033618291219

disnt_basal = 1.0047229380455178
disnt_after = 1.2564642564139685
val_r2_mean = 0.8023407955964407
val_r2_var = -0.011872212092081705



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.0027057415880583
disnt_after = 1.2589122342986447
val_r2_mean = 0.8141951262950897
val_r2_var = 0.0490286648273468



Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 0.9990744367399101
disnt_after = 1.2610386005641225
val_r2_mean = 0.8418309390544891
val_r2_var = 0.07316219806671143

disnt_basal = 0.9948269678483843
disnt_after = 1.2568683548146025
val_r2_mean = 0.8346656461556752
val_r2_var = 0.029072205225626625



Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.017829335016835
disnt_after = 1.2731007996632997
val_r2_mean = 0.8522086888551712
val_r2_var = 0.10852158069610596

disnt_basal = 1.0020058464392616
disnt_after = 1.257022054313076
val_r2_mean = 0.8231609960397085
val_r2_var = 0.09130229552586873

disnt_basal = 1.018511340278663
disnt_after = 1.2612565355902834
val_r2_mean = 0.8156916797161102
val_r2_var = 0.07555624842643738


100%|██████████| 8/8 [00:00<00:00, 11.14it/s]


In [31]:
adata_train

AnnData object with n_obs × n_vars = 14362 × 494
    obs: 'condition', 'cell_type', 'cov_cond', 'CPA_dose_val', 'CPA_cat', 'CPA_ctrl', '_scvi_condition', '_scvi_cell_type', '_scvi_CPA_cat', 'split_key'
    uns: '_scvi_uuid', '_scvi_manager_uuid'
    obsm: 'perts', 'perts_doses', 'CPA_pred'

In [32]:
# - get the pred data
cat = dataset + '_' + 'stimulated'
cat_adata = adata_train[adata_train.obs['cov_cond'] == cat].copy()
x_pred_sti = cat_adata.obsm['CPA_pred']
x_pred_sti = np.log1p(x_pred_sti)

# - get the pred data
cat = dataset + '_' + 'ctrl'
cat_adata = adata_train[adata_train.obs['cov_cond'] == cat].copy()
x_pred_ctrl = cat_adata.obsm['CPA_pred']
x_pred_ctrl = np.log1p(x_pred_ctrl)

x_pred = x_pred_sti - x_pred_ctrl

In [33]:
x_pred.shape

(7073, 494)

In [36]:
if model_mode == 'subset':
    # - get pert_gene_rank_dict
    adata_ctrl = adata_rna_common.copy()
    adata_pert = adata_ctrl.copy()
    adata_pert.X = x_pred
    
elif model_mode == 'whole':
    # - get pert_gene_rank_dict
    adata_ctrl = adata_rna.copy()
    adata_pert = adata_ctrl.copy()
    adata_pert.X = x_pred[:, common_idx]

elif model_mode == 'zero_pad':
    adata_ctrl = adata_rna.copy()
    adata_pert = adata_ctrl.copy()
    adata_pert.X = np.zeros(adata_pert.X.shape)
    for i, gene in enumerate(adata_pert.var_names):
        if gene in common_var:
            adata_pert.X[:, i] = x_pred[:, list(common_var).index(gene)]

else:
    raise ValueError()

In [38]:
adata_pert.X

array([[-0.22718215,  0.        , -0.01641595, ...,  0.05895928,
        -0.05172372,  0.09957126],
       [-0.26976117,  0.        ,  0.01202253, ...,  0.10973807,
         0.67712152,  0.40142328],
       [-0.23892549,  0.        , -0.02298968, ...,  0.01848671,
         0.70593727, -0.46685663],
       ...,
       [ 0.01390265,  0.        ,  0.01464741, ..., -0.07788485,
        -0.98207021,  0.16043766],
       [-0.45531112,  0.        , -0.01492572, ...,  0.27272016,
        -0.34711146,  0.00320973],
       [-0.37197357,  0.        ,  0.01045347, ..., -0.09800789,
         0.61474955,  0.02261361]])

In [41]:
adata_rna, len(common_var)

(AnnData object with n_obs × n_vars = 7073 × 1206
     obs: 'cid', 'seq_tech', 'donor_ID', 'donor_gender', 'donor_age', 'donor_status', 'original_name', 'organ', 'region', 'subregion', 'sample_status', 'treatment', 'ethnicity', 'cell_type', 'cell_id', 'study_id', 'age_bin', 'celltype', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt'
     var: 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
     uns: 'hvg',
 494)

In [27]:
pert_data.adata_split.var_names

Index(['A1BG', 'AAAS', 'AAGAB', 'AAK1', 'AAMDC', 'AARS', 'AARS2', 'AASDH',
       'AASDHPPT', 'AATF',
       ...
       'ZSCAN32', 'ZSCAN9', 'ZSWIM6', 'ZSWIM7', 'ZSWIM8', 'ZUP1', 'ZW10',
       'ZWINT', 'ZYX', 'ZZEF1'],
      dtype='object', name='gene_name', length=5642)

In [16]:
cell_type = pert_data.filter_perturbation_list[0].split(' | ')[1]
# - get adata_pert and adata_ctrl
adata_pert = adata_split[adata_split.obs['perturbation_group']==pert+' | '+cell_type].copy()
adata_ctrl = adata_split[list(adata_pert.obs['control_barcode'])].copy()

adata_pert = adata_pert[:, common_var]
adata_ctrl = adata_ctrl[:, common_var]

In [14]:
adata_pert

View of AnnData object with n_obs × n_vars = 108 × 5642
    obs: 'batch', 'gene', 'gene_id', 'transcript', 'gene_transcript', 'guide_id', 'percent_mito', 'UMI_count', 'z_gemgroup_UMI', 'core_scale_factor', 'core_adjusted_UMI_count', 'disease', 'cancer', 'cell_line', 'sex', 'age', 'perturbation', 'organism', 'perturbation_type', 'tissue_type', 'ncounts', 'ngenes', 'nperts', 'percent_ribo', 'perturbation_new', 'perturbation_type_new', 'nperts_new', 'celltype', 'celltype_new', 'sgRNA_new', 'perturbation_group', 'data_split', 'retain', 'n_genes', 'control_barcode', 'sgRNA_ID', 'pert_sgRNA'
    var: 'chr', 'start', 'end', 'class', 'strand', 'length', 'in_matrix', 'mean', 'std', 'cv', 'fano', 'ensembl_id', 'ncounts', 'ncells'
    uns: 'rank_genes_groups'

In [15]:
len(common_var)

494

In [43]:
# tmp_dir = f'/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/real_case/data/{dataset}'
tmp_dir = f'/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark_202410/real_case/result/{dataset}'
save_prefix = f'CPA/{pert}' # use result of K562 to do the direct transfer
os.makedirs(os.path.join(tmp_dir, save_prefix), exist_ok=True)
# adata_pert.write(os.path.join(tmp_dir, save_prefix, 'adata_pert.h5ad'))
adata_pert.write(os.path.join(tmp_dir, save_prefix, 'adata_pert_minus.h5ad'))