# Fine-tuning on Pre-trained Tabula for Cell Type Annotation
In this tutorial, we illustrate the finetuning steps for the downstream task cell type annotation.

Here we takes Pancreas dataset as an example. Please refer to our manuscript for more information regarding the dataset. 

In [1]:
import sys
sys.path.append('..')
import os

import numpy as np
from torch.utils.data import DataLoader
import torch
import wandb
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers.wandb import WandbLogger
from tabula import logger
from tabula.finetune.dataloader import CellAnnotationDataset
from tabula.finetune.setup.annotation import CellTypeAnnotation
from tabula.finetune.preprocessor import check_vocab, Preprocessor, get_pretrained_model
from tabula.finetune.utils import FinetuneConfig
from sklearn.model_selection import train_test_split
import scanpy as sc

## Pre-define parameters 
- For detailed finetuning parameters, please refer to and modify the yaml file in ```params['config_path']```
- For model weight, please download from this link: https://drive.google.com/drive/folders/19uG3hmvBZr2Zr4mWgIU-8SQ1dSg8GZuJ?usp=sharing
- For ```data_params['finetune_data_path']``` and ```data_params['test_data_path']```, please download the zip file curated by scGPT from this link: https://drive.google.com/drive/folders/1biD__KaE_fhNry7U3d9XkCMRvtpa3xw5?usp=sharing

In [2]:
params = {
    'seed': 23,
    'config_path': '../resource/finetune_framework_annotation.yaml',
    'save_folder': 'finetune_out/annotation_pancreas_test',
    'model_path': '../weight/pancreas.pth',
    'device': 'cuda:0',  # 'cuda:0' or 'cpu'
}

data_params = {
    'finetune_data_path': '../data/annotation/pancreas/demo_train.h5ad',
    'test_data_path': '../data/annotation/pancreas/demo_test.h5ad',
    'vocab_path': '../resource/vocab.json',
    'n_bins': 51,
    'n_hvg': False,
    'data_is_raw': False,
    'batch_size': 32,
    'n_workers': 4,
}

if_wandb = True
wandb_params = {
    'key': '644b123473f38af040ef215020d8e45acdf48fda',
    'project': 'Annotation_tutorial_test',
    'entity': 'sctab-downstream',
    'task': 'annotation_pancreas_test'
}

In [3]:
seed_everything(params['seed'])
os.makedirs(params['save_folder'], exist_ok=True)
finetune_config = FinetuneConfig(seed=params['seed'], config_path=params['config_path'])

finetune_config.set_finetune_param('enable_wandb', if_wandb)

finetune_config.set_finetune_param('save_folder', params['save_folder'])
logger.info(f'Configuration loaded from {params["config_path"]}, save finetuning result to {params["save_folder"]}')

Global seed set to 23


Tabula - INFO - Configuration loaded from ../resource/finetune_framework_annotation.yaml, save finetuning result to finetune_out/annotation_pancreas_test


In [4]:
if if_wandb:
    wandb.login(key=wandb_params['key'])
    wandb.init(project=wandb_params['project'], entity=wandb_params['entity'], name=wandb_params['task'])
    wandb_logger = WandbLogger(project=wandb_params['project'], log_model=False, offline=False)
    logger.info(f'Wandb logging enabled')
else:
    wandb_logger = None
    logger.info(f'Wandb logging disabled')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjianhuilin2001[0m ([33msctab-downstream[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /mnt/first19T/linjh/.netrc
  from IPython.core.display import HTML, display  # type: ignore


Tabula - INFO - Wandb logging enabled


  rank_zero_warn(


## Downstream data preprocessing

In [5]:
# load evaluation dataset as query set
test_adata = sc.read(data_params['test_data_path'])
# load finetune dataset as reference set
finetune_adata = sc.read(data_params['finetune_data_path'])

logger.info(f'Finetune data cell type: {finetune_adata.obs["Celltype"].unique()} -- {len(finetune_adata.obs["Celltype"].unique())} types')
logger.info(f"Test data cell type: {test_adata.obs['Celltype'].unique()} -- {len(test_adata.obs['Celltype'].unique())} types")



Tabula - INFO - Finetune data cell type: ['acinar', 'delta', 'beta', 'PSC', 'alpha', ..., 'endothelial', 'macrophage', 'schwann', 'mast', 't_cell']
Length: 13
Categories (13, object): ['PP', 'PSC', 'acinar', 'alpha', ..., 'macrophage', 'mast', 'schwann', 't_cell'] -- 13 types
Tabula - INFO - Test data cell type: ['beta', 'PSC', 'ductal', 'alpha', 'acinar', ..., 'PP', 'MHC class II', 'endothelial', 'epsilon', 'mast']
Length: 11
Categories (11, object): ['MHC class II', 'PP', 'PSC', 'acinar', ..., 'ductal', 'endothelial', 'epsilon', 'mast'] -- 11 types


In [6]:
finetune_adata.obs["celltype"] = finetune_adata.obs["Celltype"].astype("category")
finetune_adata.obs["batch_id"] = finetune_adata.obs['str_batch'] = "0"
finetune_adata.var["gene_name"] = finetune_adata.var["Gene Symbol"].tolist()
test_adata.obs["celltype"] = test_adata.obs["Celltype"].astype("category")
test_adata.obs["batch_id"] = test_adata.obs['str_batch'] = "1"
test_adata.var["gene_name"] = test_adata.var["Gene Symbol"].tolist()

adata_test_temp = test_adata.copy()
adata = finetune_adata.concatenate(test_adata, batch_key="str_batch")

# make the batch category column
batch_id_labels = adata.obs["str_batch"].astype("category").cat.codes.values
adata.obs["batch_id"] = batch_id_labels
celltype_id_labels = adata.obs["celltype"].astype("category").cat.codes.values
celltypes = adata.obs["celltype"].unique()
num_types = len(np.unique(celltype_id_labels))
finetune_config.set_model_param('supervised_out_feature', num_types)

adata.obs["celltype_id"] = celltype_id_labels
adata.var["gene_name"] = adata.var.index.tolist()
# get cell type id to cell type name mapping
id2type = dict(enumerate(adata.obs["celltype"].astype("category").cat.categories))

adata, removed_gene_labels = check_vocab(adata, data_params['vocab_path'])
logger.info(f"Removed gene labels: {removed_gene_labels}")

  adata = finetune_adata.concatenate(test_adata, batch_key="str_batch")


Tabula - INFO - Removed gene labels: ['RGS5']


In [7]:
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=False,  # step 1
    filter_cell_by_counts=False,  # step 2
    normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=data_params['data_is_raw'],  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=data_params['n_hvg'],  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_params['data_is_raw'] else "cell_ranger",
    binning=data_params['n_bins'],  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)
# split train and test data
adata_test = adata[adata.obs["str_batch"] == "1"]
adata_finetune = adata[adata.obs["str_batch"] == "0"]
preprocessor(adata_test, batch_key=None)
preprocessor(adata_finetune, batch_key=None)
adata_test_raw = adata_test.copy()

Tabula - INFO - Filtering cells by counts ...
Tabula - INFO - Normalizing total counts ...
Tabula - INFO - Binning data ...


  adata.obs['n_counts'] = number


Tabula - INFO - Filtering cells by counts ...


  adata.obs['n_counts'] = number


Tabula - INFO - Normalizing total counts ...
Tabula - INFO - Binning data ...


In [8]:
# get batch labels
batch_id_labels = adata_finetune.obs["str_batch"].astype("category").cat.codes.values
adata_finetune.obs["batch_id"] = batch_id_labels
# get cell type labels mapped by id2type
celltype_str_list = adata_finetune.obs["celltype"].tolist()
# use id2type to convert cell type string to cell type id
celltype_id_labels = [list(id2type.keys())[list(id2type.values()).index(celltype_str)] for celltype_str in celltype_str_list]

# get expression table
binned_expression_table = adata_finetune.layers["X_binned"]

# get gene ids
adata_finetune.var["gene_name"] = adata_finetune.var["gene_name"].tolist()
gene_ids = adata_finetune.var["gene_name"].tolist()

dataset = CellAnnotationDataset(expression_table=binned_expression_table,
                                masked_expression_table=None,
                                gene_ids=gene_ids,
                                labels=celltype_id_labels,
                                batch_strings=batch_id_labels,
                                x_umap=None,
                                in_feature=finetune_config.get_model_param('in_feature'),
                                vocab_file=data_params['vocab_path'])
train_indices, valid_indices = train_test_split(range(len(dataset)), test_size=0.1, shuffle=True)
logger.info(f"Train data size: {len(train_indices)}")
logger.info(f"Valid data size: {len(valid_indices)}")

# split train and valid dataset
train_set = torch.utils.data.Subset(dataset, train_indices)
valid_set = torch.utils.data.Subset(dataset, valid_indices)
train_loader = DataLoader(dataset=train_set, batch_size=data_params['batch_size'], 
                          shuffle=True, num_workers=data_params['n_workers'], drop_last=True)
valid_loader = DataLoader(dataset=valid_set, batch_size=data_params['batch_size'], 
                          shuffle=False,num_workers=data_params['n_workers'], drop_last=True)
logger.info(f"Train data loader size: {len(train_loader)}")
logger.info(f"Valid data loader size: {len(valid_loader)}")
logger.info(f'Number of cell type in training data: {len(np.unique(celltype_id_labels))}')

Tabula - INFO - Train data size: 9540
Tabula - INFO - Valid data size: 1060
Tabula - INFO - Train data loader size: 298
Tabula - INFO - Valid data loader size: 33
Tabula - INFO - Number of cell type in training data: 13


In [9]:
# set eval dataset
adata_test_raw, removed_gene_labels = check_vocab(adata_test_raw, data_params['vocab_path'])
logger.info(f"Test data removed gene number: {len(removed_gene_labels)} from vocab")
    
label_list = adata_test_raw.obs["celltype"].tolist()
label_list = [list(id2type.keys())[list(id2type.values()).index(celltype_str)] for celltype_str in label_list]
adata_test_raw.obs["label"] = label_list
    
eval_dataset = CellAnnotationDataset(
    expression_table=adata_test_raw.layers["X_binned"],
    masked_expression_table=None,
    gene_ids=adata_test_raw.var["gene_name"].tolist(),
    labels=adata_test_raw.obs["label"],
    batch_strings=adata_test_raw.obs["batch_id"].astype("category").cat.codes.values,
    x_umap=adata_test_temp.obsm["X_umap"],
    in_feature=finetune_config.get_model_param('in_feature'),
    vocab_file=data_params['vocab_path'])
test_loader = DataLoader(
    dataset=eval_dataset, 
    batch_size=data_params['batch_size'], 
    shuffle=False,
    num_workers=data_params['n_workers'], 
    drop_last=False)
logger.info(f"Test data size: {len(eval_dataset)}")
logger.info(f"Test data loader size: {len(test_loader)}")

Tabula - INFO - Test data removed gene number: 0 from vocab


  adata_test_raw.obs["label"] = label_list


Tabula - INFO - Test data size: 4218
Tabula - INFO - Test data loader size: 132


## Load pre-trained Tabula

In [None]:
if params['device'] != 'cpu' and not torch.cuda.is_available():
    logger.error(f'Cuda is not available, change device to cpu')
    params['device'] = 'cpu'
tabula_pl_model = get_pretrained_model(
    finetune_config=finetune_config,
    model_path=params['model_path'],
    device=params['device']
)

In [None]:
annotation_trainer = CellTypeAnnotation(
    config=finetune_config,
    tabula_model=tabula_pl_model,
    wandb_logger=wandb_logger,
    device=params['device'],
    batch_size=data_params['batch_size'],
    id2celltype=id2type,
    dataloaders={'train_loader': train_loader, 
                 'val_loader': valid_loader,
                 'test_loader': test_loader}
    )

annotation_trainer.finetune()