# Fine-tuning on Pre-trained Tabula for Multi-omics & Multi-batch integration
In this tutorial, we illustrate the finetuning steps for the downstream task multi-omics and multi-batch integration.

Here we takes DC dataset which contains scRNA-seq data from human blood dendritic cells (DCs), it contains two batches, each with four distinct cell types. Please refer to our manuscript for more information regarding the dataset. The processed dataset can be downloaded from the following link: https://drive.google.com/drive/folders/12Wg6fUe2MG8UBpMVi6SZKsFUZo4ai-UR?usp=sharing

In [1]:
import sys
sys.path.append('..')
import os

import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
import torch
import wandb
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers.wandb import WandbLogger
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scipy.sparse import issparse
from tabula import logger
from tabula.task.dataloader import MultiOmicsDataset
from tabula.task.ft_model import FinetuneModel
from tabula.task.preprocessor import Preprocessor, get_pretrained_model, get_ft_training_args
from tabula.task.utils import FinetuneConfig
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from anndata import AnnData
import scanpy as sc



## Pre-define parameters 
- For detailed finetuning parameters, please refer to and modify the yaml file in ```params['config_path']```
- For model weight, please download from this link: https://drive.google.com/drive/folders/19uG3hmvBZr2Zr4mWgIU-8SQ1dSg8GZuJ?usp=sharing

In [2]:
params = {
    'seed': 23,
    'config_path': '../resource/finetune_framework_integration.yaml',
    'save_folder': 'finetune_out/integration_dc_test',
    'model_path': '../weight/brain.pth',
    'device': 'cuda:0',  # 'cuda:0' or 'cpu'
}

data_params = {
    'data_path': '../data/integration/DC_raw.h5ad',
    'vocab_path': '../resource/vocab.json',
    'n_bins': 51,
    'n_hvg': 2000,
    'data_is_raw': True,
    'batch_size': 32,
    'n_workers': 4,
}

if_wandb = True
wandb_params = {
    'key': 'your_wandb_key',
    'project': 'Integration_tutorial_test',
    'entity': 'tabula-downstream',
    'task': 'integration_dc_test'
}


In [3]:
seed_everything(params['seed'])
os.makedirs(params['save_folder'], exist_ok=True)
finetune_config = FinetuneConfig(seed=params['seed'], config_path=params['config_path'])
finetune_config.set_finetune_param('save_folder', params['save_folder'])
logger.info(f'Configuration loaded from {params["config_path"]}, save finetuning result to {params["save_folder"]}')


Global seed set to 23


Tabula - INFO - Configuration loaded from ../resource/finetune_framework_integration.yaml, save finetuning result to finetune_out/integration_dc_test


In [4]:
if if_wandb:
    wandb.login(key=wandb_params['key'])
    wandb.init(project=wandb_params['project'], entity=wandb_params['entity'], name=wandb_params['task'])
    wandb_logger = WandbLogger(project=wandb_params['project'], log_model=False, offline=False)
    logger.info(f'Wandb logging enabled')
else:
    wandb_logger = None
    logger.info(f'Wandb logging disabled')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjianhuilin2001[0m ([33msctab-downstream[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /mnt/first19T/linjh/.netrc


Tabula - INFO - Wandb logging enabled


## Downstream data preprocessing

In [5]:
adata = sc.read(data_params['data_path'])
ori_batch_col = "batch"
adata.obs["celltype"] = adata.obs["celltype"].astype("category")

# make the batch category column
adata.obs["str_batch"] = adata.obs[ori_batch_col].astype(str)
logger.info(f'Number of batches: {len(set(adata.obs["str_batch"]))}')
batch_id_labels = adata.obs["str_batch"].astype("category").cat.codes.values
adata.obs["batch_id"] = batch_id_labels
adata.var["gene_name"] = adata.var.index.tolist()

# retain the common gene set between the data and the pre-trained model for further fine-tuning
vocab = GeneVocab.from_file(data_params['vocab_path'])
adata.var["id_in_vocab"] = [
    1 if gene in vocab else -1 for gene in adata.var["gene_name"]]
gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
logger.info(
    f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
    f"in vocabulary of size {len(vocab)}.")
adata = adata[:, adata.var["id_in_vocab"] >= 0]

# set up the preprocessor, use the args to config the workflow
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=3,  # step 1
    filter_cell_by_counts=False,  # step 2
    normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=data_params['data_is_raw'],  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=data_params['n_hvg'],  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_params['data_is_raw'] else "cell_ranger",
    binning=data_params['n_bins'],  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)
preprocessor(adata, batch_key="str_batch")

input_layer_key = "X_binned"
all_counts = (
    adata.layers[input_layer_key].A
    if issparse(adata.layers[input_layer_key])
    else adata.layers[input_layer_key]
)
genes = adata.var["gene_name"].tolist()

celltype_id_labels = adata.obs['celltype'].cat.codes.values

batch_ids = adata.obs["batch_id"].tolist()
num_batch_types = len(set(batch_ids))
batch_ids = np.array(batch_ids)

dataset = MultiOmicsDataset(
    expression_table=all_counts,
    gene_ids=genes,
    labels=celltype_id_labels,
    batch_id=batch_ids,
    vocab_file=data_params['vocab_path'],
    in_feature=finetune_config.get_model_param('in_feature')
)

# split train and valid
train_idx, valid_idx = train_test_split(np.arange(len(dataset)), test_size=0.1, random_state=params['seed'])
train_dataset = torch.utils.data.Subset(dataset, train_idx)
valid_dataset = torch.utils.data.Subset(dataset, valid_idx)
train_loader = DataLoader(train_dataset, num_workers=data_params['n_workers'], shuffle=True, batch_size=data_params['batch_size'], drop_last=True)
valid_loader = DataLoader(valid_dataset, num_workers=data_params['n_workers'], shuffle=False, batch_size=data_params['batch_size'], drop_last=False)
logger.info(f'Finish building train and valid loader, train size: {len(train_loader.dataset)}, valid size: {len(valid_loader.dataset)}')

# construct a adata_eval for evaluation based on all data
adata_eval = AnnData(
    X=adata.layers["X_binned"],
    obs=adata.obs,
    var=pd.DataFrame(index=adata.var["gene_name"].tolist()),
    layers={"X_binned": adata.layers["X_binned"], }
)
# add one layer for tokenized vocab of genes
adata_eval.layers["X_vocab"] = np.zeros_like(adata_eval.layers["X_binned"])
vocab_lookup_table = dataset.vocab
for i, gene in enumerate(adata_eval.var.index.tolist()):
    if gene in vocab_lookup_table:
        adata_eval.layers["X_vocab"][:, i] = vocab_lookup_table[gene]
adata_eval.var["gene_name"] = adata_eval.var.index.tolist()
adata_eval.obs["batch"] = adata.obs["batch_id"].tolist()
# assign str_batch
le = preprocessing.LabelEncoder()
encoded_batch = le.fit_transform(adata_eval.obs['batch'].values)  # batches: 0, 1, 2
adata_eval.obs["batch_id"] = encoded_batch
adata_eval.obs["str_batch"] = adata_eval.obs["batch_id"].astype('category')

test_dataset = MultiOmicsDataset(expression_table=adata_eval.layers["X_binned"],
                                 gene_ids=adata_eval.var["gene_name"],
                                 labels=adata_eval.obs['celltype'].cat.codes.values,
                                 batch_id=adata_eval.obs['batch_id'].values,
                                 vocab_file=data_params['vocab_path'],
                                 in_feature=finetune_config.get_model_param('in_feature')
                                 )
test_loader = DataLoader(test_dataset, batch_size=data_params['batch_size'], shuffle=False, num_workers=data_params['n_workers'], drop_last=False)
logger.info(f'Finish building test loader, test size: {len(test_loader.dataset)}')

Tabula - INFO - Number of batches: 2
Tabula - INFO - match 13348/17649 genes in vocabulary of size 60697.
Tabula - INFO - Filtering genes by counts ...
Tabula - INFO - Filtering cells by counts ...
Tabula - INFO - Normalizing total counts ...
Tabula - INFO - Log1p transforming ...
Tabula - INFO - Subsetting highly variable genes ...
Tabula - INFO - Binning data ...
Tabula - INFO - Finish building train and valid loader, train size: 512, valid size: 57
Tabula - INFO - Finish building test loader, test size: 569


## Load pre-trained Tabula

In [6]:
if params['device'] != 'cpu' and not torch.cuda.is_available():
    logger.error(f'Cuda is not available, change device to cpu')
    params['device'] = 'cpu'
tabula_pl_model = get_pretrained_model(
    finetune_config=finetune_config,
    enable_batch=True,
    model_path=params['model_path'],
    device=params['device'],
    num_batches=num_batch_types
)

Tabula - INFO - Loading FlashAttention Tabula from path: ../weight/brain.pth
Tabula - ERROR - Error loading model from path: ../weight/brain.pth, switch to load specific weights.
Tabula - INFO - Loading params feature_tokenizer.gene_encoder.embedding.weight with shape torch.Size([60697, 192])
Tabula - INFO - Loading params feature_tokenizer.gene_encoder.enc_norm.weight with shape torch.Size([192])
Tabula - INFO - Loading params feature_tokenizer.gene_encoder.enc_norm.bias with shape torch.Size([192])
Tabula - INFO - Loading params feature_tokenizer.value_encoder.weight with shape torch.Size([1200, 192])
Tabula - INFO - Loading params feature_tokenizer.value_encoder.bias with shape torch.Size([1200, 192])
Tabula - INFO - Loading params feature_tokenizer.value_encoder.enc_norm.weight with shape torch.Size([192])
Tabula - INFO - Loading params feature_tokenizer.value_encoder.enc_norm.bias with shape torch.Size([192])
Tabula - INFO - Loading params bn.weight with shape torch.Size([192])
Ta

## Fine-tune Tabula

In [7]:
finetune_pl_model = FinetuneModel(finetune_type=finetune_config.get_finetune_param('finetune_task'),
                                  model=tabula_pl_model,
                                  seed=params['seed'],
                                  device=params['device'],
                                  record_best_model=True,
                                  finetune_config=finetune_config,
                                  task_params={'test_loader': test_loader, 'eval_adata': adata_eval},
                                  enable_wandb=if_wandb,
                                  )
training_args = get_ft_training_args(finetune_config)
trainer = pl.Trainer(**training_args, logger=wandb_logger, gpus=[0])
trainer.fit(finetune_pl_model, train_loader, valid_loader)
logger.info(f'Finetuning finished.')

Global seed set to 23


Tabula - INFO - Finetune method: heavy. Max epochs: 5000. Patience: 50


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name                | Type               | Params
-----------------------------------------------------------
0 | model               | TabulaTransformer  | 13.7 M
1 | criterion_dab       | CrossEntropyLoss   | 0     
2 | contrastive_loss    | ContrastiveLoss    | 0     
3 | reconstruction_loss | ReconstructionLoss | 0     
-----------------------------------------------------------
13.7 M    Trainable params
0         Non-trainable p

Sanity Checking: 0it [00:00, ?it/s]

Tabula - INFO - Finetuning finished.
