# Zero-shot integration tutorial with scGPT

## Introduction

This tutorial covers the zero-shot integration with continual pre-trained scGPT. This particular workflow works for scRNA-seq datasets without fine-tuning (or any extensive training) of scGPT.

Continual pre-trained scGPT (scGPT_CP) is a model that inherits the pre-trained scGPT whole-human model checkpoint, and is further supervised by extra cell type labels (using the [Tabula Sapiens](https://tabula-sapiens-portal.ds.czbiohub.org/) dataset) during the continual pre-training stage. We observed that the scGPT_CP model can achieve comparable or better zero-shot performance on cell embedding related tasks compared to the original checkpoint, especially on datasets with observable technical batch effects.

This tutorial will show how to use the latent space of scGPT to integrate scRNA-seq datasets. We use the `scGPT_CP` model to provide embeddings out of the box. You may download it from [here](https://drive.google.com/drive/folders/1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB).

We will use the [scIB](https://www.nature.com/articles/s41592-021-01336-8) pancreas dataset as an example. This dataset is publicly accessible via [here](https://figshare.com/ndownloader/files/24539828). You may place the dataset under `data` directory at the outer level.


The zero-shot integration workflow is as follows:

 1. [Load and pre-process the dataset](#prepare-the-datasets)
    
 2. [Generate scGPT embeddings for each cell](#generate-the-cell-embeddings)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bowang-lab/scGPT/blob/main/tutorials/zero-shot/Tutorial_ZeroShot_Integration_Continual_Pretraining.ipynb)

In [None]:
# Specifically for Google Colab, install dependencies and download data

import os
import sys

if "google.colab" in sys.modules:
    print("Running on Google Colab")
    print("Installing dependencies...")
    !pip install -U scgpt "torch<=2.2.2" "numpy<2" "umap-learn<0.5.7"
    # the optional dependency of flash-attion is skipped on colab
    !pip install wandb louvain

    # NOTE: MAY NEED TO RESTART RUNTIME AFTER THE INSTALLATION

    print("Downloading data and model ckpt...")
    !pip install -q -U gdown
    import gdown

    data_dir = "../../data"
    if not os.path.exists(data_dir):
        os.mkdir(data_dir)
    if not os.path.exists(os.path.join(data_dir, "human_pancreas_norm_complexBatch.h5ad")):
        !wget --content-disposition https://figshare.com/ndownloader/files/24539828 -O $data_dir/human_pancreas_norm_complexBatch.h5ad

    print("Downloading model ckpt...")
    model_dir = "../../save/scGPT_CP"
    if not os.path.exists(model_dir):
        !mkdir -p $model_dir
        gdown.download_folder(
            "https://drive.google.com/drive/folders/1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB?usp=sharing",
            output=model_dir,
        )
    model_dir = "../../save/scGPT_human"
    if not os.path.exists(model_dir):
        !mkdir -p $model_dir
        gdown.download_folder(
            "https://drive.google.com/drive/folders/1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y",
            output=model_dir,
        )

Running on Google Colab
Installing dependencies...
Collecting scgpt
  Downloading scgpt-0.2.4-py3-none-any.whl.metadata (10.0 kB)
Collecting torch<=2.2.2
  Downloading torch-2.2.2-cp312-cp312-manylinux1_x86_64.whl.metadata (25 kB)
Collecting numpy<2
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting umap-learn<0.5.7
  Downloading umap_learn-0.5.6-py3-none-any.whl.metadata (21 kB)
Collecting cell-gears<0.0.3 (from scgpt)
  Downloading cell-gears-0.0.2.tar.gz (25 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting datasets<3.0.0,>=2.3.0 (from scgpt)
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting leidenalg>=0.8.10 (from scgpt)
  Downloading leidenalg-0.10.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting orbax<0.1.8 (from s

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Import scGPT and dependencies

In [4]:
from pathlib import Path
import warnings

import scanpy as sc
import scib
import numpy as np
import sys

sys.path.insert(0, "../")

import scgpt as scg
import matplotlib.pyplot as plt

plt.style.context('default')
warnings.simplefilter("ignore", ResourceWarning)

model_dir = Path("../../save/scGPT_CP")

  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)


## Set up evaluation function

We set up the evaluation function, we mainly compare the integration performance on avgBIO and avgBATCH. Refer to our manuscript for more details.

In [5]:
"""
Calculate the metrics for integration results
"""
def scib_eval(adata, batch_key, cell_type_key, embed_key):
    results = scib.metrics.metrics(
        adata,
        adata_int=adata,
        batch_key=batch_key,
        label_key=cell_type_key,
        embed=embed_key,
        isolated_labels_asw_=False,
        silhouette_=True,
        hvg_score_=False,
        graph_conn_=True,
        pcr_=True,
        isolated_labels_f1_=False,
        trajectory_=False,
        nmi_=True,  # use the clustering, bias to the best matching
        ari_=True,  # use the clustering, bias to the best matching
        cell_cycle_=False,
        kBET_=False,  # kBET return nan sometimes, need to examine
        ilisi_=False,
        clisi_=False,
    )
    result_dict = results[0].to_dict()

    # compute avgBIO metrics
    result_dict["avg_bio"] = np.mean(
        [
            result_dict["NMI_cluster/label"],
            result_dict["ARI_cluster/label"],
            result_dict["ASW_label"],
        ]
    )

    # compute avgBATCH metrics
    result_dict["avg_batch"] = np.mean(
        [
            result_dict["graph_conn"],
            result_dict["ASW_label/batch"],
        ]
    )

    result_dict = {k: v for k, v in result_dict.items() if not np.isnan(v)}

    return result_dict

## Prepare the datasets

Load the Pancreas dataset (download it from [here](https://figshare.com/ndownloader/files/24539828)), and we set the columns storing gene name columns, batch key and cell type key (optional, this is for evaluation).

In [8]:
import pandas as pd
from scipy.sparse import csr_matrix
expr_df = pd.read_csv('/content/drive/MyDrive/expression.csv', index_col=0)
print(expr_df.shape)
adata = sc.AnnData(
            X=csr_matrix(expr_df.values),
            var=pd.DataFrame(index=expr_df.columns.astype(str))
        )


(3639, 13859)


  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)


## Generate the cell embeddings

Now we will generate the cell embeddings for the dataset using `embed_data` function. `embed_data` calculates the cell embedding for each cell with the given scGPT model. The extracted embedding is stored in the `X_scGPT` field of `obsm` in AnnData.

In [None]:
embed_adata = scg.tasks.embed_data(
    adata,
    model_dir,
    gene_col='index',
    batch_size=64,
)
import numpy as np

embedding_matrix = embed_adata.obsm["X_scGPT"]
np.save('/content/drive/MyDrive/embedding.npy', embedding_matrix)

scGPT - INFO - match 13116/13859 genes in vocabulary of size 60697.


  return datetime.utcnow().replace(tzinfo=utc)
  self.pid = os.fork()
Embedding cells:   4%|▎         | 2/57 [06:38<3:02:02, 198.60s/it]