This hands-on exercise demonstrates how to train and use the GEARS model (Roohani Y, et al. Nature Biotechnology, 2024) with a gene perturbation dataset. You can plug the LCM into the Perturbation task by changing the gene embedding from GEARS initial embedding to LCMs' gene representation. 

Due to computational resources, we won't be calling the LCMs here to get cell representations. If you are interested, you can try it on your own after class.

We will first specifically show the content of the perturbation data, then construct the GEARS model and train it, and finally use the trained model for perturbation prediction and evaluate the effectiveness of the prediction.

# Import necessary library

In [None]:
import os
import torch
import pickle

import sys
sys.path.append('/kaggle/input/gears-tutorial/pytorch/default/1/GEARS')
from gears import GEARS
from gears.inference import evaluate
from gears.model import GEARS_Model
from gears.utils import get_similarity_network, GeneSimNetwork
from copy import deepcopy
                  
from sklearn.metrics import r2_score
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import mean_squared_error as mse
from sklearn.metrics import mean_absolute_error as mae

In [None]:
sys.path.append("/kaggle/input/pert_dataset/pytorch/default/1/Pert_Data/")
import v1
from v1.utils import *
from v1.dataloader import *

# Load perturbation dataset

**Load and process data**

In [None]:
# - load pert_data and preprocess
pert_data = Byte_Pert_Data(data_dir='/kaggle/input/example-data-pert/',prefix='NormanWeissman2019_filtered',) # NormanWeissman2019_filtered or XuCao2023
pert_data.read_files()
pert_data.filter_perturbation()   # filter perturbation with less cell number
pert_data.get_and_process_adata(var_num=1000)    # process the data and obtain the higly variable genes
pert_data.data_split(split_type=1)   # split the data into train and test set
pert_data.set_control_barcode()   # set the control cell's barcode for each perturbed cell
pert_data.filter_sgRNA()  # for each pert, get the sgRNA num
pert_data.get_de_genes()  # calculate the DE genes for each perturbed cell; used for evaluation

In [None]:
fix_seed(2024)

# - get go genes; special set for GEARS; these are used for constructing the go graph of GEARS
pert_data.get_gene2go()   # get gene2go dict: {'gene1': [go1, go2, ...], 'gene2': [go1, go3, ...], ...}
pert_data.set_pert_genes()  # get the list of genes that can be perturbed to be included in perturbation graph

# - transform dataset into gears required format
pert_data.get_Data_gears(num_de_genes = pert_data.num_de_genes,
                        dataset_name = ['train', 'test', 'val'],
                        add_control = False)
# - add necessary elements for gears
pert_data.modify_gears()

# - get dataloader
trainloader, testloader, valloader = pert_data.get_dataloader(mode='all')

**Dataset details**

This dataset has 96852 cells and each of them has 1167 genes. It has 170 different perturbations. We have pre-processed the data and calculated the differential genes. These differential genes can be considered as significantly changed genes after perturbation and will be used to assess the effect of prediction.

In [None]:
pert_data.adata

In [None]:
pert_data.adata[pert_data.adata.obs['data_split']=='test'].obs['perturbation_new'].unique()

In [None]:
pert_data.adata[pert_data.adata.obs['data_split']=='train'].obs['perturbation_new'].unique()

In [None]:
# In the perturbation dataset, each cell has one (or more than one) perturbed gene
pert_data.adata.obs['perturbation_new'][:5]

In [None]:
# And also a control group cell, representing the cell before perturbation
pert_data.adata.obs['control_barcode'][:5]

In [None]:
# We obtain the differentially expressed genes (DEGs) for each perturbation; 
# Here we chose the top 20 non-zero expression DEGs for each perturbation from the adata.uns['rank_genes_groups_cov_all']
print( dict(list(pert_data.adata.uns['top_non_zero_de_20'].items())[:5]) )

# Initialize the GEARS model

Set GEARS model parameters

In [None]:
# - init gears model
pert_data.adata_split.X = pert_data.adata_split.X.toarray()
gears_model = GEARS(pert_data, device = 'cuda:0', 
                        weight_bias_track = False, 
                        proj_name = 'pertnet', 
                        exp_name = 'pertnet')

# - set model configuration
gears_model.config = {'hidden_size': 64,
                'num_go_gnn_layers' : 1, 
                'num_gene_gnn_layers' : 1,
                'decoder_hidden_size' : 16,
                'num_similar_genes_go_graph' : 20,
                'num_similar_genes_co_express_graph' : 20,
                'coexpress_threshold': 0.4,
                'uncertainty' : False, 
                'uncertainty_reg' : 1,
                'direction_lambda' : 1e-1,
                'G_go': None,
                'G_go_weight': None,
                'G_coexpress': None,
                'G_coexpress_weight': None,
                'device': gears_model.device,
                'num_genes': gears_model.num_genes,
                'num_perts': gears_model.num_perts,
                'no_perturb': False
                }

Construct co-expression graph and go graph. These two graphs are used to build the graph neural network.

In [None]:
# - Set the gene co expression network (green graph)
if gears_model.config['G_coexpress'] is None:
    ## calculating co expression similarity graph
    edge_list = get_similarity_network(network_type='co-express',
                                        adata=gears_model.adata,
                                        threshold=gears_model.config['coexpress_threshold'],
                                        k=gears_model.config['num_similar_genes_co_express_graph'],
                                        data_path=gears_model.data_path,
                                        data_name=gears_model.dataset_name,
                                        split=gears_model.split, seed=gears_model.seed,
                                        train_gene_set_size=gears_model.train_gene_set_size,
                                        set2conditions=gears_model.set2conditions)

    sim_network = GeneSimNetwork(edge_list, gears_model.gene_list, node_map = gears_model.node_map)
    gears_model.config['G_coexpress'] = sim_network.edge_index
    gears_model.config['G_coexpress_weight'] = sim_network.edge_weight

# - Set the gene ontology network (red graph)
if gears_model.config['G_go'] is None:
    ## calculating gene ontology similarity graph
    edge_list = get_similarity_network(network_type='go',
                                        adata=gears_model.adata,
                                        threshold=gears_model.config['coexpress_threshold'],
                                        k=gears_model.config['num_similar_genes_co_express_graph'],
                                        pert_list=gears_model.pert_list,
                                        data_path=gears_model.data_path,
                                        data_name=gears_model.dataset_name,
                                        split=gears_model.split, seed=gears_model.seed,
                                        train_gene_set_size=gears_model.train_gene_set_size,
                                        set2conditions=gears_model.set2conditions,
                                        default_pert_graph=gears_model.default_pert_graph)

    sim_network = GeneSimNetwork(edge_list, gears_model.pert_list, node_map = gears_model.node_map_pert)
    gears_model.config['G_go'] = sim_network.edge_index
    gears_model.config['G_go_weight'] = sim_network.edge_weight

# - finally obtain the model
gears_model.model = GEARS_Model(gears_model.config).to(gears_model.device)
gears_model.best_model = deepcopy(gears_model.model)

# Training

Here we use a simplified pseudocode to demonstrate the core part of GEARS model. For the details of each line of code, you can refer to the source code.

**Get Base Gene Embeddings**  
In the following loop we use co-expression graph to obtain the final gene embeddings. (Green part in the figure above)
<div style="background:#f9f9f9; border:1px solid #ccc; border-radius:6px; padding:12px; color:#333; font-family:monospace;">
<pre><code>
base_emb = self.gene_emb(self.num_genes)
pos_emb = self.emb_pos(self.num_genes)
for idx, gnn_layer in enumerate(self.gnn_layers_exp):
    pos_emb = gnn_layer(pos_emb, self.G_coexpress, self.G_coexpress_weight)
base_emb = base_emb + pos_emb
</code></pre>
</div>
Note that the gene embedding here can be directly replaced with LCM gene embedding (Such as scFoundation's embedding). these LCM embedding are pre-trained on huge amount of data and have richer and more comprehensive cellular information, which can enhance the performance in the Perturbation task.

**Get Perturbation Embeddings**
<div style="background:#f9f9f9; border:1px solid #ccc; border-radius:6px; padding:12px; color:#333; font-family:monospace;">
<pre><code>
pert_global_emb = self.pert_emb(self.num_perts)
</code></pre>
</div>     

In the following loop we use GO graph to obtain the final perturbation embeddings. (Red part in the figure above)
<div style="background:#f9f9f9; border:1px solid #ccc; border-radius:6px; padding:12px; color:#333; font-family:monospace;">
<pre><code>
for idx, gnn_layer in enumerate(self.gnn_layers_go):
    pert_global_emb = gnn_layer(pert_global_emb, self.G_sim, self.G_sim_weight)
</code></pre>
</div>

**Add Global Perturbation Embedding to Each Gene in Each Cell in the Batch**  
(Composition Operator in the figure above)  
Select the perturbation embedding of the coresponding gene
<div style="background:#f9f9f9; border:1px solid #ccc; border-radius:6px; padding:12px; color:#333; font-family:monospace;">
<pre><code>
pert_track = {}
for i, j in enumerate(pert_index[0]):
    pert_track[j.item()] = pert_global_emb[pert_index[1][i]]
</code></pre>
</div> 
Add the selected perturbation embedding to the gene embedding
<div style="background:#f9f9f9; border:1px solid #ccc; border-radius:6px; padding:12px; color:#333; font-family:monospace;">
<pre><code>
emb_total = self.pert_fuse(torch.stack(list(pert_track.values())))
for idx, j in enumerate(pert_track.keys()):
    base_emb[j] = base_emb[j] + emb_total[idx]
</code></pre>
</div> 

**Finally Go Through A Fully Connected Network to Obtain the Predicted Expression**  
(Blue part in the figure above)  
<div style="background:#f9f9f9; border:1px solid #ccc; border-radius:6px; padding:12px; color:#333; font-family:monospace;">
<pre><code>
predict_expression = self.fcn(base_emb)
</code></pre>
</div> 

Here we do not specifically delve into the training process of the model. The detail training process can be directly referred to the __gears_model.train__.

In [None]:
gears_model.train(epochs = 1, lr = 1e-4)

# Evaluate the performance  
Here we use the test set to evaluate the model's performance. All the perturbed genes in the test set are not seen in the training set.

In [None]:
# - Model testing
test_output = evaluate(testloader, gears_model.best_model,
                    gears_model.config['uncertainty'], gears_model.device)

We can then calculate the evaluation metric

In [None]:
pert_metric = {}
# here we take one perturbation as an example
# this function will first prepare the data for metric calculation
de_idx_map, ctrl, p_idx = prepare_for_metric(pert_data.adata, test_output, pert = np.unique(test_output['pert_cat'])[0],
                                              most_variable_genes=None, p_thre_1=0.01, p_thre_2=0.1)
# de_idx_map: the gene index of top20, top50 and top100 DEGs
# ctrl: the control group cell for this perturbed gene
# p_idx: the perturbed cell index for this perturbed gene

We use the following metrics for evaluation

**pearson correlation**: measures the linear correlation between predicted and true values  
**mean squared error**: measures the average of the squares of the errors  
**mean absolute error**: measures the average of the absolute errors  
**change ratio**: measures the relative change between predicted and true values  
**spearman correlation**: measures the monotonic relationship between predicted and true values  

In [None]:
metric2fct = {
    'pearson': pearsonr,
    'mse': mse,
    'mae': mae,
    'change_ratio': get_change_ratio,
    'spearman': spearmanr,  
}
name = 'DE_'
for prefix in list(de_idx_map.keys()): # ['top20', 'top50', 'top100']
    de_idx = de_idx_map[prefix]        # the gene index of ['top20', 'top50', 'top100'] DEGs
    for m, fct in metric2fct.items():
        if m == 'pearson' or m == 'spearman':
            val = fct(test_output['pred'][p_idx].mean(0)[de_idx] - ctrl[de_idx], test_output['truth'][p_idx].mean(0)[de_idx]-ctrl[de_idx])[0]
            if np.isnan(val):
                val = 0
            pert_metric[name + m + f'_delta_{prefix}'] = val

            val = fct(test_output['pred'][p_idx].mean(0)[de_idx], test_output['truth'][p_idx].mean(0)[de_idx])[0]
            if np.isnan(val):
                val = 0
            pert_metric[name + m + f'_{prefix}'] = val
        elif m == 'change_ratio':
            val = fct(test_output['pred'][p_idx].mean(0)[de_idx], test_output['truth'][p_idx].mean(0)[de_idx])
            pert_metric[name + m + f'_{prefix}'] = val
        else:
            val = fct(test_output['pred'][p_idx].mean(0)[de_idx] - ctrl[de_idx], test_output['truth'][p_idx].mean(0)[de_idx]-ctrl[de_idx])
            pert_metric[name + m + f'_{prefix}'] = val

In [None]:
pert_metric