# SCGEN:  Perturbation Prediction

In [1]:
import sys
#if branch is stable, will install via pypi, else will install from source
branch = "stable"
IN_COLAB = "google.colab" in sys.modules

if IN_COLAB and branch == "stable":
    !pip install --quiet scgen[tutorials]
elif IN_COLAB and branch != "stable":
    !pip install --quiet --upgrade jsonschema
    !pip install --quiet git+https://github.com/theislab/scgen@$branch#egg=scgen[tutorials]

In [2]:
import logging
import scanpy as sc
import scgen
import numpy as np
import torch
from gears import PertData

In [3]:
print("torch:", torch.__version__)
print("scgen:", scgen.__version__)
print("numpy:", np.__version__)

torch: 2.9.0+cu128
scgen: 2.1.1
numpy: 2.3.4


### Loading Data

In [4]:
pert_data = PertData("data/")
pert_data.load(data_name="norman")

Found local copy...
Found local copy...
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['RHOXF2BB+ctrl' 'LYL1+IER5L' 'ctrl+IER5L' 'KIAA1804+ctrl' 'IER5L+ctrl'
 'RHOXF2BB+ZBTB25' 'RHOXF2BB+SET']
Local copy of pyg dataset is detected. Loading...
Done!


In [5]:
pert_adata = pert_data.adata # 전체 데이터

# train/test 나누기
train_idx = np.load("norman_simulation_seed1_train_idx.npy")
test_idx  = np.load("norman_simulation_seed1_test_idx.npy")

pert_adata.obs['split'] = 'unknown'
pert_adata.obs.iloc[train_idx, pert_adata.obs.columns.get_loc('split')] = 'train'
pert_adata.obs.iloc[test_idx, pert_adata.obs.columns.get_loc('split')] = 'test'

In [6]:
stimulated = "SAMD1+ZBTB1"
control = "ctrl"
cell_type = "A549"

### Preprocessing Data

In [7]:
train_adata = pert_adata[pert_adata.obs['split'] == 'train']
val_adata = pert_adata[pert_adata.obs['split'] == 'val']
test_adata = pert_adata[pert_adata.obs['split'] == 'test']

In [8]:
train_new = train_adata.concatenate(val_adata, index_unique=None) # train, val 합치기

In [9]:
idx_list = []
for cond, df in test_adata.obs.groupby('condition'):
    # test set에 있는 perturbation에 대해서 각 perturbation 당 하나씩 뽑기
    chosen_idx = np.random.choice(df.index, size=1, replace=False)[0] 
    idx_list.append(chosen_idx)
subset_adata = test_adata[idx_list].copy()

# train set에 뽑은 데이터 추가
train_new = train_new.concatenate(subset_adata, index_unique=None)

In [10]:
train_new = train_new.copy()

## Creating and Saving the model¶

In [11]:
model = scgen.SCGEN.load("./saved_models/scgen_norman_prediction.pt", train_new)
model.adata

[34mINFO    [0m File .[35m/saved_models/scgen_norman_prediction.pt/[0m[95mmodel.pt[0m already downloaded                                


AnnData object with n_obs × n_vars = 49956 × 5045
    obs: 'condition', 'cell_type', 'dose_val', 'control', 'condition_name', 'split', 'batch', '_scvi_batch', '_scvi_labels'
    var: 'gene_name'
    uns: '_scvi_uuid', '_scvi_manager_uuid'
    layers: 'counts'

## Prediction

After training the model you can pass the adata of the cells you want to perturb. Here we pass unperturbed CD4T cells


Here the 'adata' contains the cells that you want estimate the perturbation based on them. we set "ctrl" to our control labels and "stim" to our stimulated labels. If you apply it in another context just set "ctrl" :"your_control_label" and "stim":"your_stimulated_label". the returned value is a numpy matrix of our predicted cells and the second one is the difference vector between our conditions which might become useful later.

In [12]:
pert_list = test_adata.obs['condition'].unique().tolist()
pred_dict = {}

for gene in pert_list:
    pred, delta = model.predict(
        ctrl_key="ctrl",
        stim_key=gene,
        #celltype_to_predict="A549"
        adata_to_predict=test_adata,
    )
    pred.obs["condition"] = gene
    pred_dict[gene] = pred

[34mINFO    [0m Received view of anndata, making copy.                                                                    
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             
[34mINFO    [0m Received view of anndata, making copy.                                                                    
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             
[34mINFO    [0m Received view of anndata, making copy.                                                                    
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             
[34mINFO    [0m Received view of anndata, making copy.                                                                    
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             


In the previous block, the difference between conditions is by default computed using all cells (obs_key="all"). However, some times you might have a rough idea that which groups (e.g. cell types) are close to your cell type of interest. This might give you more accurate predictions. For example, we can restrict the delta computation only to CD8T and NK cells. We provide dictionary in form of obs_key={"cell_type": ["CD8T", "NK"]} which is telling the model to look at "cell_type" labels in adata (here: train_new) and only compute the delta vector based on "CD8T" and "NK" cells :

pred, delta = scg.predict(adata=train_new, adata_to_predict=unperturbed_cd4t, conditions={"ctrl": "control", "stim": "stimulated"}, cell_type_key="cell_type", condition_key="condition", obs_key={"cell_type": ["CD8T", "NK"]})`

## Evaluation of the predcition¶


#### Extracting both control and real stimulated CD4T cells from our dataset

In [13]:
ctrl_adata = pert_adata[pert_adata.obs['condition'] == control]
stim_adata = test_adata[test_adata.obs['condition'] != control]

In [14]:
ctrl_adata

View of AnnData object with n_obs × n_vars = 7353 × 5045
    obs: 'condition', 'cell_type', 'dose_val', 'control', 'condition_name', 'split'
    var: 'gene_name'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20'
    layers: 'counts'

In [15]:
stim_adata

View of AnnData object with n_obs × n_vars = 28754 × 5045
    obs: 'condition', 'cell_type', 'dose_val', 'control', 'condition_name', 'split', '_scvi_batch', '_scvi_labels'
    var: 'gene_name'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20', '_scvi_uuid', '_scvi_manager_uuid'
    layers: 'counts'

Merging predicted cells with real ones

In [16]:
eval_adata = ctrl_adata.concatenate(stim_adata, pred)

In [17]:
eval_adata.obs

Unnamed: 0_level_0,condition,cell_type,dose_val,control,condition_name,split,_scvi_batch,_scvi_labels,batch
cell_barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
AAACCTGCACGAAGCA-1-0,ctrl,A549,1,1,A549_ctrl_1,train,,,0
AAACCTGGTATAATGG-1-0,ctrl,A549,1,1,A549_ctrl_1,train,,,0
AAACCTGTCCGATATG-1-0,ctrl,A549,1,1,A549_ctrl_1,train,,,0
AAACGGGCAATGGACG-1-0,ctrl,A549,1,1,A549_ctrl_1,train,,,0
AAAGATGAGATGAGAG-1-0,ctrl,A549,1,1,A549_ctrl_1,train,,,0
...,...,...,...,...,...,...,...,...,...
TTTGTCAGTAGCTTGT-8-2,C3orf72+FOXL2,A549,1+1,0,A549_FOXL2+MEIS1_1+1,test,98.0,0.0,2
TTTGTCAGTAGGCATG-8-2,C3orf72+FOXL2,A549,1+1,0,A549_COL2A1+ctrl_1+1,test,54.0,0.0,2
TTTGTCAGTCACTTCC-8-2,C3orf72+FOXL2,A549,1+1,0,A549_ETS2+CEBPE_1+1,test,66.0,0.0,2
TTTGTCATCCACTCCA-8-2,C3orf72+FOXL2,A549,1+1,0,A549_CELF2+ctrl_1+1,test,45.0,0.0,2


## Mean correlation plot¶

You can also visualize your mean gene expression of your predicted cells vs control cells while highlighting your genes of interest (here top 10 differentially expressed genes)

In [20]:
pred_cell = pert_adata[pert_adata.obs["split"] == 'test']

In [None]:
sc.tl.rank_genes_groups(pred_cell, groupby="split", method="wilcoxon")
diff_genes = pred_cell.uns["rank_genes_groups_cov_all"]["names"][pert_list[0]]
print(diff_genes)

In [None]:
r2_value = model.reg_mean_plot(
    eval_adata,
    axis_keys={"x": "pred", "y": stimulated},
    gene_list=diff_genes[:10],
    labels={"x": "predicted", "y": "ground truth"},
    path_to_save="./reg_mean1.pdf",
    show=True,
    legend=False
)

You can also pass a list of differentially epxressed genes to compute correlation based on them

In [None]:
r2_value = model.reg_mean_plot(
    eval_adata,
    axis_keys={"x": "pred", "y": stimulated},
    gene_list=diff_genes[:10],
    top_100_genes= diff_genes,
    labels={"x": "predicted","y": "ground truth"},
    path_to_save="./reg_mean1.pdf",
    show=True,
    legend=False
)

Let's go deeper and compare the distribution of "ISG15", the top DEG between stimulated and control CD4T cells between predcited and real cells