In [None]:
import sys

from sklearn.metrics import r2_score
import numpy as np

import os
# os.chdir('/home/mohsen/projects/cpa/')
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import cpa
import scanpy as sc

In [None]:
sc.settings.set_figure_params(dpi=100)

data_path = './combo_sciplex_prep_hvg_filtered.h5ad'

In [None]:
try:
    adata = sc.read(data_path)
except:
    import gdown
    gdown.download('https://drive.google.com/uc?export=download&id=1RRV0_qYKGTvD3oCklKfoZQFYqKJy4l6t')
    data_path = 'combo_sciplex_prep_hvg_filtered.h5ad'
    adata = sc.read(data_path)

adata

In [None]:
adata.obs['split_1ct_MEC'].value_counts()

adata.X = adata.layers['counts'].copy()

cpa.CPA.setup_anndata(adata,
                      perturbation_key='condition_ID',
                      dosage_key='log_dose',
                      control_group='CHEMBL504',
                      batch_key=None,
                      is_count_data=True,
                      categorical_covariate_keys=['cell_type'],
                      deg_uns_key='rank_genes_groups_cov',
                      deg_uns_cat_key='cov_drug_dose',
                      max_comb_len=2,
                     )

In [None]:
ae_hparams = {
    "n_latent": 1536,
    "recon_loss": "nb",
    "doser_type": "logsigm",
    "n_hidden_encoder": 512,
    "n_layers_encoder": 3,
    "n_hidden_decoder": 512,
    "n_layers_decoder": 3,
    "use_batch_norm_encoder": True,
    "use_layer_norm_encoder": False,
    "use_batch_norm_decoder": True,
    "use_layer_norm_decoder": False,
    "dropout_rate_encoder": 0.1,
    "dropout_rate_decoder": 0.1,
    "variational": False,
    "seed": 434,
}

trainer_params = {
    "n_epochs_kl_warmup": None,
    "n_epochs_pretrain_ae": 30,
    "n_epochs_adv_warmup": 50,
    "n_epochs_mixup_warmup": 3,
    "mixup_alpha": 0.1,
    "adv_steps": 2,
    "n_hidden_adv": 64,
    "n_layers_adv": 2,
    "use_batch_norm_adv": True,
    "use_layer_norm_adv": False,
    "dropout_rate_adv": 0.3,
    "reg_adv": 20.0,
    "pen_adv": 20.0,
    "lr": 0.0003,
    "wd": 4e-07,
    "adv_lr": 0.0003,
    "adv_wd": 4e-07,
    "adv_loss": "cce",
    "doser_lr": 0.0003,
    "doser_wd": 4e-07,
    "do_clip_grad": False,
    "gradient_clip_value": 1.0,
    "step_size_lr": 45,
}

In [None]:
adata.var_names = adata.var['symbol-0'].values # important, change the ensemble id to gene name.

import pickle
with open("/gpfs/gibbs/pi/zhao/tl688/cpsc_finalproject/genept_data/GenePT/ensem_emb_gpt3.5all.pickle", "rb") as fp:
    GPT_3_5_gene_embeddings = pickle.load(fp)
gene_names= list(adata.var.index)
count_missing = 0
EMBED_DIM = 1536 # embedding dim from GPT-3.5
lookup_embed = np.zeros(shape=(len(gene_names),EMBED_DIM))
for i, gene in enumerate(gene_names):
    if gene in GPT_3_5_gene_embeddings:
        lookup_embed[i,:] = GPT_3_5_gene_embeddings[gene].flatten()
    else:
        count_missing+=1
# lookup_embed = np.random.rand(lookup_embed.shape[0], lookup_embed.shape[1])
# genePT_w_emebed = np.dot(adata.X,lookup_embed)/len(gene_names)
genePT_w_emebed = adata.X @ lookup_embed/len(gene_names)
# genePT_w_emebed = adata.X @ lookup_embed
print(f"Unable to match {count_missing} out of {len(gene_names)} genes in the GenePT-w embedding")

In [None]:
model = cpa.CPA(adata=adata,
                split_key='split_1ct_MEC',
                train_split='train',
                valid_split='valid',
                test_split='ood',
                gene_embeddings = lookup_embed, # add the embeddings
                use_gene_emb = True,
                **ae_hparams,
               )

In [None]:
model.train(max_epochs=20,
            use_gpu=True,
            batch_size=128,
            plan_kwargs=trainer_params,
            early_stopping_patience=10,
            check_val_every_n_epoch=5,
            save_path='./cpa_out_gpt35/',
           )

In [None]:
# to load model
# model = cpa.CPA.load(dir_path='/gpfs/gibbs/pi/zhao/tl688/cpsc_finalproject/genept_data/GenePT/cpa_out_gemb_new/',
#                      adata=adata, use_gpu=True)

In [None]:
cpa.pl.plot_history(model)

In [None]:
## Latent space UMAP visualization
latent_outputs = model.get_latent_representation(adata, batch_size=1024)

sc.settings.verbosity = 3

latent_basal_adata = latent_outputs['latent_basal']
latent_adata = latent_outputs['latent_after']

sc.pp.neighbors(latent_basal_adata)
sc.tl.umap(latent_basal_adata)


sc.pl.umap(latent_basal_adata, color=['condition_ID'], frameon=False, wspace=0.2)

In [None]:

sc.pp.neighbors(latent_adata)
sc.tl.umap(latent_adata)

sc.pl.umap(latent_adata, color=['condition_ID'], frameon=False, wspace=0.2)

In [None]:
#save data
latent_basal_adata.write_h5ad("cpa_geneptnew_example_basal.h5ad")

latent_adata.write_h5ad("cpa_geneptnew_example_perturb.h5ad")

# Make prediction

In [None]:
model.predict(adata, batch_size=1024)
adata.var_names = adata.var['ensembl_id-0'].values

In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics import r2_score
from collections import defaultdict
from tqdm import tqdm

n_top_degs = [10, 20, 50, None] # None means all genes

results = defaultdict(list)
ctrl_adata = adata[adata.obs['condition_ID'] == 'CHEMBL504'].copy()
for cat in tqdm(adata.obs['cov_drug_dose'].unique()):
    if 'CHEMBL504' not in cat:
        cat_adata = adata[adata.obs['cov_drug_dose'] == cat].copy()

        deg_cat = f'{cat}'
        deg_list = adata.uns['rank_genes_groups_cov'][deg_cat]

        x_true = cat_adata.layers['counts'].toarray()
        x_pred = cat_adata.obsm['CPA_pred']
        x_ctrl = ctrl_adata.layers['counts'].toarray()

        x_true = np.log1p(x_true)
        x_pred = np.log1p(x_pred)
        x_ctrl = np.log1p(x_ctrl)

        for n_top_deg in n_top_degs:
            if n_top_deg is not None:
                degs = np.where(np.isin(adata.var_names, deg_list[:n_top_deg]))[0]
            else:
                degs = np.arange(adata.n_vars)
                n_top_deg = 'all'

            x_true_deg = x_true[:, degs]
            x_pred_deg = x_pred[:, degs]
            x_ctrl_deg = x_ctrl[:, degs]

            r2_mean_deg = r2_score(x_true_deg.mean(0), x_pred_deg.mean(0))
            r2_var_deg = r2_score(x_true_deg.var(0), x_pred_deg.var(0))

            r2_mean_lfc_deg = r2_score(x_true_deg.mean(0) - x_ctrl_deg.mean(0), x_pred_deg.mean(0) - x_ctrl_deg.mean(0))
            r2_var_lfc_deg = r2_score(x_true_deg.var(0) - x_ctrl_deg.var(0), x_pred_deg.var(0) - x_ctrl_deg.var(0))

            cov, cond, dose = cat.split('_')

            results['cell_type'].append(cov)
            results['condition'].append(cond)
            results['dose'].append(dose)
            results['n_top_deg'].append(n_top_deg)
            results['r2_mean_deg'].append(r2_mean_deg)
            results['r2_var_deg'].append(r2_var_deg)
            results['r2_mean_lfc_deg'].append(r2_mean_lfc_deg)
            results['r2_var_lfc_deg'].append(r2_var_lfc_deg)

df = pd.DataFrame(results)

In [None]:
df[df['n_top_deg'] == 20]

In [None]:
for cat in adata.obs["cov_drug_dose"].unique():
    if "CHEMBL504" not in cat:
        cat_adata = adata[adata.obs["cov_drug_dose"] == cat].copy()

        cat_adata.X = np.log1p(cat_adata.layers["counts"].A)
        cat_adata.obsm["CPA_pred"] = np.log1p(cat_adata.obsm["CPA_pred"])

        deg_list = adata.uns["rank_genes_groups_cov"][f'{cat}'][:20]

        print(cat, f"{cat_adata.shape}")
        cpa.pl.mean_plot(
            cat_adata,
            pred_obsm_key="CPA_pred",
            path_to_save=None,
            deg_list=deg_list,
            # gene_list=deg_list[:5],
            show=True,
            verbose=True,
        )

# Display drug information

In [None]:
cpa_api = cpa.ComPertAPI(adata, model,
                         de_genes_uns_key='rank_genes_groups_cov',
                         pert_category_key='cov_drug_dose',
                         control_group='CHEMBL504',
                         )

In [None]:
cpa_plots = cpa.pl.CompertVisuals(cpa_api, fileprefix=None)

In [None]:
drug_adata = cpa_api.get_pert_embeddings()
drug_adata.shape

In [None]:
cpa_plots.plot_latent_embeddings(drug_adata.X, kind='perturbations', titlename='Drugs')