# Fine-tuning on Pre-trained Tabula for Genetic Perturbation Prediction 
In this tutorial, we illustrate the finetuning steps for the downstream task genetic perturbation prediction task.

Here we takes the Norman dataset, which contains both single- and double- gene perturbation, as an example. Please refer to our preprint for more information regarding the dataset. 

In [1]:
import sys
sys.path.append('..')
import os

import numpy as np
import torch
import wandb
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers.wandb import WandbLogger
from tabula.finetune.tokenizer import GeneVocab
from tabula import logger
from tabula.finetune.setup.perturbation import GenePerturbationPrediction
from tabula.finetune.preprocessor import get_pretrained_model
from tabula.finetune.utils import FinetuneConfig
from gears import PertData



## Pre-define parameters 
- For detailed finetuning parameters, please refer to and modify the yaml file in ```params['config_path']```
- For model weight, please download from this link: https://drive.google.com/drive/folders/19uG3hmvBZr2Zr4mWgIU-8SQ1dSg8GZuJ?usp=sharing

In [2]:
params = {
    'seed': 23,
    'config_path': '../resource/finetune_framework_perturbation.yaml',
    'save_folder': 'finetune_out/perturbation_norman_test',
    'model_path': '../weight/blood.pth',
    'device': 'cuda:0',  # 'cuda:0' or 'cpu'
}

data_params = {
    'data_name': 'norman',
    'split': 'simulation',
    'vocab_path': '../resource/vocab.json',
    'batch_size': 64,
    'n_workers': 4,
    'perts_to_plot': ['FOSB+IKZF3']
}

if_wandb = True
wandb_params = {
    'key': 'your_wandb_key',
    'project': 'Perturbation_tutorial_test',
    'entity': 'tabula-downstream',
    'task': 'perturbation_norman_tutorial_test'
}


In [3]:
seed_everything(params['seed'])
os.makedirs(params['save_folder'], exist_ok=True)
finetune_config = FinetuneConfig(seed=params['seed'], config_path=params['config_path'])
finetune_config.set_finetune_param('save_folder', params['save_folder'])
finetune_config.set_finetune_param('if_wandb', if_wandb)
logger.info(f'Configuration loaded from {params["config_path"]}, save finetuning result to {params["save_folder"]}')

Global seed set to 23


Tabula - INFO - Configuration loaded from ../resource/finetune_framework_perturbation.yaml, save finetuning result to finetune_out/perturbation_norman_test


In [4]:
if if_wandb:
    wandb.login(key=wandb_params['key'])
    wandb.init(project=wandb_params['project'], entity=wandb_params['entity'], name=wandb_params['task'])
    wandb_logger = WandbLogger(project=wandb_params['project'], log_model=False, offline=False)
    logger.info(f'Wandb logging enabled')
else:
    wandb_logger = None
    logger.info(f'Wandb logging disabled')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mjianhuilin2001[0m ([33msctab-downstream[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /mnt/first19T/linjh/.netrc
  from IPython.core.display import HTML, display  # type: ignore


Tabula - INFO - Wandb logging enabled


  rank_zero_warn(


## Downstream data preprocessing
Please refer to GEARS for data loading and data construction: https://github.com/snap-stanford/GEARS

In [5]:
pert_data = PertData("./data_Gears_norman")
pert_data.load(data_name=data_params['data_name'])
pert_data.prepare_split(split=data_params['split'], seed=params['seed'])
pert_data.get_dataloader(batch_size=data_params['batch_size'], test_batch_size=data_params['batch_size'])

Downloading...
100%|██████████| 169M/169M [07:12<00:00, 390kiB/s]     
Extracting zip file...
Done!
  return self.callback(read_func, elem.name, elem, iospec=iospec)
  return self.callback(read_func, elem.name, elem, iospec=iospec)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return self.callback(read_func, elem.name, elem, iospec=iospec)
  return self.callback(read_func, elem.name, elem, iospec=iospec)
  return self.callback(read_func, elem.name, elem, iospec=iospec)
  return self.callback(read_func, elem.name, elem, iospec=iospec)
  return self.callback(read_func, elem.name, elem, iospec=iospec)
  return self.callback(read_func, elem.name, elem, iospec=iospec)
  return self.callback(read_func, elem.name, elem, iospec=iosp

In [6]:
vocab = GeneVocab.from_file(data_params['vocab_path'])
genes = pert_data.adata.var["gene_name"].tolist()
for gene in genes:
    if gene not in vocab:
        vocab.expand_token(gene)
gene_ids = np.array(
    [vocab[gene] if gene in vocab else vocab["<pad>"] for gene in genes], dtype=int
)

finetune_config.set_model_param('embedding_in_feature', len(vocab))
finetune_config.set_model_param('in_feature', len(gene_ids))
finetune_config.set_model_param('reconstruction_out_feature', len(gene_ids))
logger.info(f'Embeedding feature size after expanding vocab: {len(vocab)}')
logger.info(f'Model input feature length is: {len(gene_ids)}')

Tabula - INFO - Embeedding feature size after expanding vocab: 61195
Tabula - INFO - Model input feature length is: 5045


## Load pre-trained Tabula

In [7]:
if params['device'] != 'cpu' and not torch.cuda.is_available():
    logger.error(f'Cuda is not available, change device to cpu')
    params['device'] = 'cpu'
tabula_pl_model = get_pretrained_model(
    finetune_config=finetune_config,
    model_path=params['model_path'],
    device=params['device'],
)

Tabula - INFO - Loading FlashAttention Tabula from path: /mnt/first19T/linjh/program/Tabula/weight/blood.pth




Tabula - ERROR - Error loading model from path: /mnt/first19T/linjh/program/Tabula/weight/blood.pth, switch to load specific weights.
Tabula - INFO - Loading params feature_tokenizer.gene_encoder.enc_norm.weight with shape torch.Size([192])
Tabula - INFO - Loading params feature_tokenizer.gene_encoder.enc_norm.bias with shape torch.Size([192])
Tabula - INFO - Loading params feature_tokenizer.value_encoder.enc_norm.weight with shape torch.Size([192])
Tabula - INFO - Loading params feature_tokenizer.value_encoder.enc_norm.bias with shape torch.Size([192])
Tabula - INFO - Loading params bn.weight with shape torch.Size([192])
Tabula - INFO - Loading params bn.bias with shape torch.Size([192])
Tabula - INFO - Loading params bn.running_mean with shape torch.Size([192])
Tabula - INFO - Loading params bn.running_var with shape torch.Size([192])
Tabula - INFO - Loading params bn.num_batches_tracked with shape torch.Size([])
Tabula - INFO - Loading params cls.weight with shape torch.Size([1, 192

## Fine-tune Tabula

In [8]:
gene_perturb_trainer = GenePerturbationPrediction(
    config=finetune_config,
    pert_data=pert_data,
    tabula_model=tabula_pl_model,
    wandb_logger=wandb_logger,
    device=params['device'],
    batch_size=data_params['batch_size'],
    gene_ids=gene_ids,
    perts_to_plot=data_params['perts_to_plot']
    )

gene_perturb_trainer.finetune()

Global seed set to 23


Tabula - INFO - Finetune method: heavy. Max epochs: 1000. Patience: 5 


  rank_zero_deprecation(
Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name  | Type              | Params
--------------------------------------------
0 | model | TabulaTransformer | 15.7 M
--------------------------------------------
15.7 M    Trainable params
0         Non-trainable params
15.7 M    Total params
31.304    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Tabula - INFO - Start evaluating perturbation...


100%|██████████| 463/463 [03:22<00:00,  2.28it/s]
  val = fct(results['pred_de'][p_idx].mean(0), results['truth_de'][p_idx].mean(0))[0]
  val = fct(results['pred_de'][p_idx].mean(0), results['truth_de'][p_idx].mean(0))[0]
  val = fct(results['pred_de'][p_idx].mean(0), results['truth_de'][p_idx].mean(0))[0]
  val = fct(results['pred_de'][p_idx].mean(0), results['truth_de'][p_idx].mean(0))[0]
  val = fct(results['pred_de'][p_idx].mean(0), results['truth_de'][p_idx].mean(0))[0]
  val = fct(results['pred_de'][p_idx].mean(0), results['truth_de'][p_idx].mean(0))[0]
  val = fct(results['pred_de'][p_idx].mean(0), results['truth_de'][p_idx].mean(0))[0]
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_

Tabula - INFO - test_combo_seen0_pearson_delta: 0.01571675017903333
Tabula - INFO - test_combo_seen0_pearson_delta_de: -0.15264978595028722
Tabula - INFO - test_combo_seen0_pearson_delta_top20_de_non_dropout: -0.07525355124927158
Tabula - INFO - test_combo_seen0_pearson_top20_de_non_dropout: 0.06324112945352459
Tabula - INFO - test_combo_seen1_pearson_delta: 0.01761110570932917
Tabula - INFO - test_combo_seen1_pearson_delta_de: -0.10516396410153779
Tabula - INFO - test_combo_seen1_pearson_delta_top20_de_non_dropout: -0.034400903120964575
Tabula - INFO - test_combo_seen1_pearson_top20_de_non_dropout: 0.03684886811445299
Tabula - INFO - test_combo_seen2_pearson_delta: 0.013015020044460751
Tabula - INFO - test_combo_seen2_pearson_delta_de: -0.16437978782944856
Tabula - INFO - test_combo_seen2_pearson_delta_top20_de_non_dropout: -0.10698829267148088
Tabula - INFO - test_combo_seen2_pearson_top20_de_non_dropout: 0.002541366274692802
Tabula - INFO - test_unseen_single_pearson_delta: 0.018008

  if LooseVersion(mpl.__version__) >= "3.0":
  other = LooseVersion(other)
100%|██████████| 10/10 [00:01<00:00,  5.13it/s]
  if LooseVersion(mpl.__version__) < "3.0":
  other = LooseVersion(other)
  return fn(*args, **kwargs)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Tabula - INFO - Start evaluating perturbation...



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Tabula - INFO - test_combo_seen0_pearson_delta: 0.39217945401619286
Tabula - INFO - test_combo_seen0_pearson_delta_de: 0.6740261891984564
Tabula - INFO - test_combo_seen0_pearson_delta_top20_de_non_dropout: 0.6881664593033884
Tabula - INFO - test_combo_seen0_pearson_top20_de_non_dropout: 0.8416452800091868
Tabula - INFO - test_combo_seen1_pearson_delta: 0.4085470143272695
Tabula - INFO - test_combo_seen1_pearson_delta_de: 0.5824758299666011
Tabula - INFO - test_combo_seen1_pearson_delta_top20_de_non_dropout: 0.6435007356242387
Tabula - INFO - test_combo_seen1_pearson_top20_de_non_dropout: 0.8464699723682579
Tabula - INFO - test_combo_seen2_pearson_delta: 0.38034491800714026
Tabula - INFO - test_combo_seen2_pearson_delta_de: 0.4912358540141448
Tabula - INFO - test_combo_seen2_pearson_delta_top20_de_non_dropout: 0.5588197376843353
Tabula - INFO - test_combo_seen2_pearson_top20_de_non_dropout: 0.8254950566174883
Tabula - INFO - test_unseen_single_pearson_delta: 0.3349614118439068
Tabula -

  if LooseVersion(mpl.__version__) >= "3.0":
  other = LooseVersion(other)

[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 10/10 [00:01<00:00,  5.22it/s]
  if LooseVersion(mpl.__version__) < "3.0":
  other = LooseVersion(other)
  return fn(*args, **kwargs)
