# In-silico Pairwise Gene Regulatory Network Inference with Pre-trained Tabula

In this tutorial, we illustrate the zero-shot steps for the downstream task in-silico pairwise gene regulatory network inference.

Here we take Cardiogenesis system as an example. Please refer to our preprint for more information regarding the dataset and the task.

In [1]:
import sys
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')
import warnings
warnings.filterwarnings('ignore')

import os
import anndata
import torch
from tabula.finetune.tokenizer import GeneVocab
from pytorch_lightning import seed_everything
from tabula.finetune.preprocessor import get_pretrained_model, Preprocessor, check_vocab, FinetuneConfig
from tabula.finetune.model.insilico_grn import pairwise_inference
from tabula.finetune.dataloader import GRNDataset
from tabula import logger
warnings.filterwarnings("ignore")

## Pre-define parameters 
- For detailed finetuning parameters, please refer to and modify the yaml file in `params['config_path']`
- For model weight, please download from this link: https://drive.google.com/drive/folders/19uG3hmvBZr2Zr4mWgIU-8SQ1dSg8GZuJ?usp=sharing
- For `data_params['data_dir']`, please download the curated h5ad file for Cardiogenesis system from this link: https://drive.google.com/drive/folders/1G-y6PYaF1nTocjXYGdLz7uzQHD_SLc2v?usp=sharing

In [2]:
params = {
    'seed': 0,
    'config_path': '../resource/finetune_framework_perturbation.yaml',
    'save_folder': 'finetune_out/pairwise_grn/cardiogenesis/one2one/FHF',
    'model_path': '../weight/heart.pth',
    'device': 'cuda:0',  
}

data_params = {
    'data_dir': '../data/GRN/cardiogenesis.h5ad',
    'vocab_path': '../resource/vocab.json',
    'batch_size': 128,
    'n_workers': 4,
    'n_bins': 51,
    'n_hvg': 2400,
    'data_is_raw': False,
}

pair_list = [
    ('BMP2', 'NKX2-5', 'activate'),
    ('NKX2-5', 'GATA4', 'activate'),
    ('NKX2-5', 'TBX5', 'activate'),
    ('GATA4', 'NKX2-5', 'activate'),
    ('TBX5', 'NKX2-5', 'activate'),
    ('TBX5', 'GATA4', 'activate'),
    ('BMP2', 'NKX2-5', 'activate'),
    ('TBX5', 'TBX5', 'activate'),
]
all_testing_genes = [k[0] for k in pair_list] + [k[1] for k in pair_list]
all_testing_genes = list(set(all_testing_genes))

In [3]:
seed_everything(params['seed'])
os.makedirs(params['save_folder'], exist_ok=True)
finetune_config = FinetuneConfig(seed=params['seed'], config_path=params['config_path'])
logger.info(f'Configuration loaded from {params["config_path"]}, save finetuning result to {params["save_folder"]}')

Global seed set to 0


Tabula - INFO - Configuration loaded from ../resource/finetune_framework_perturbation.yaml, save finetuning result to finetune_out/pairwise_grn/cardiogenesis/one2one/FHF


## Downstream data preprocessing

In [4]:
adata = anndata.read_h5ad(data_params['data_dir'])
adata, removed_gene_labels = check_vocab(adata, data_params['vocab_path'])
vocab = GeneVocab.from_file(data_params['vocab_path'])
logger.info(f"Removed {len(removed_gene_labels)} genes from the dataset.")

preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=5,  # step 1
    filter_cell_by_counts=5,  # step 2
    normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=data_params['data_is_raw'],  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=data_params['n_hvg'],  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_params['data_is_raw'] else "cell_ranger",
    binning=data_params['n_bins'],  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)

preprocessor(adata, batch_key=None)
adata = adata[adata.obs['Progenitor Domain'] == 'FHF', :]

gene_name = adata.var["gene_name"].values
finetune_config.set_model_param('in_feature', len(gene_name))
finetune_config.set_model_param('reconstruction_out_feature', len(gene_name))
logger.info(f'adata shape: {adata.shape}')

exist_gene = [i for i in all_testing_genes if i in gene_name]
logger.info(f"Exist gene: {exist_gene} after filtering with HVG")

original_expression_table = adata.X.toarray()
expression_table = adata.layers["X_binned"]
gene_ids = adata.var["gene_name"].values.tolist()
gene_vocab_ids = [vocab[gene] for gene in gene_ids]
Grn_dataset = GRNDataset(expression_table=expression_table)
grn_dataloader = torch.utils.data.DataLoader(
    Grn_dataset, batch_size=data_params['batch_size'], 
    num_workers=data_params['n_workers'], shuffle=False, drop_last=False)

Tabula - INFO - Removed 9518 genes from the dataset.
Tabula - INFO - Filtering genes by counts ...
Tabula - INFO - Filtering cells by counts ...
Tabula - INFO - Normalizing total counts ...
Tabula - INFO - Subsetting highly variable genes ...
Tabula - INFO - Binning data ...
Tabula - INFO - adata shape: (233, 2400)
Tabula - INFO - Exist gene: ['GATA4', 'TBX5', 'NKX2-5', 'BMP2'] after filtering with HVG


## Load pre-trained Tabula

In [None]:
tabula_pl_model = get_pretrained_model(
    finetune_config=finetune_config,
    model_path=params['model_path'],
    device=params['device']
)
tabula_pl_model.eval()

## Pairwise Inference

In [None]:
pairwise_inference(
    pair_list=pair_list,
    evolve_epoch=10,
    model=tabula_pl_model,
    vocab=vocab,
    gene_ids=gene_vocab_ids,
    Grn_dataloader=grn_dataloader,
    device=params['device'],
    save_path=params['save_folder'],
)