# Tutorial notebook 4: Cell Type Prediction

In this tutorial, we will demonstrate how to use a pretrained Cell2Sentence (C2S) model to perform cell type prediction on single-cell RNA sequencing datasets. Cell type prediction is a crucial step in single-cell analysis, allowing researchers to identify and classify different cell populations within a dataset. By leveraging the power of C2S models, we can make accurate predictions based on the information encoded in cell sentences.

In this tutorial, you will:
1. Load an immune tissue single-cell dataset from Domínguez Conde et al. (preprocessed in tutorial notebook 0, two sample donors)
    - Citation: Domínguez Conde, C., et al. "Cross-tissue immune cell analysis reveals tissue-specific features in humans." Science 376.6594 (2022): eabl5197.
2. Load a pretrained C2S model that is capable of making cell type predictions.
3. Use the model to predict cell types based on the cell sentences derived from the dataset.

We will begin by importing the necessary libraries. These include Python's built-in libraries, third-party libraries for handling numerical computations, progress tracking, and specific libraries for single-cell RNA sequencing data and C2S operations.

In [1]:
# Python built-in libraries
import os
import pickle
import random
from collections import Counter

# Third-party libraries
import numpy as np
from tqdm import tqdm

# Single-cell libraries
import anndata
import scanpy as sc

# Cell2Sentence imports
import cell2sentence as cs
from cell2sentence.tasks import predict_cell_types_of_data

In [2]:
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)

# Load Data

Next, we will load the preprocessed dataset from the tutorial 0. This dataset has already been filtered and normalized, so it it ready for transformation into cell sentences.

<font color='red'>Please make sure you have completed the preprocessing steps in Tutorial 0 before running the following code, if you are using your own dataset.</font>. Ensure that the file path is correctly set in <font color='gold'>DATA_PATH</font> to where your preprocessed data was saved from tutorial 0.

In [3]:
DATA_PATH = "/data/yy/test/D099_processed.h5ad"

In [4]:
import pandas as pd

adata = anndata.read_h5ad(DATA_PATH )

gene_df = pd.read_csv("/home/scbjtfy/RVQ-Alpha/data_utils/gene_name_list_with_index.csv")
gene_df["Index"] = gene_df["Index"].astype(int)

# 按 Index 排序
gene_df = gene_df.sort_values("Index")

# 重命名列
gene_df = gene_df.rename(columns={
    "Gene_Name": "gene_name",
    "Gene_ID": "ensembl_id"
})

# 设置 var
adata.var = gene_df.set_index("Index")[["gene_name", "ensembl_id"]]

# 把 var_names 替换成 gene_name
adata.var_names = adata.var["gene_name"].values

# 确保 var_names 的唯一性
adata.var_names_make_unique()

adata




AnnData object with n_obs × n_vars = 15806 × 36601
    obs: 'donor_id', 'batch', 'author_cell_type', 'disease', 'tissue'
    var: 'gene_name', 'ensembl_id'

In [5]:
adata.obs = adata.obs[["cell_type", "tissue", "batch_condition", "organism", "sex"]]

In [5]:
adata.obs

Unnamed: 0,donor_id,batch,author_cell_type,disease,tissue
N2_O1_AAACGAAAGTCGCCAC-1,R1,O1_R1,Differentiating.Basal,normal,lung
N2_O1_AAACGAAGTTAGCTAC-1,R1,O1_R1,Differentiating.Basal,normal,lung
N2_O1_AAACGAATCCTCTCTT-1,R1,O1_R1,Differentiating.Basal,normal,lung
N2_O1_AAACGCTCACTGCACG-1,R1,O1_R1,Basal,normal,lung
N2_O1_AAAGAACGTAGCTCGC-1,R1,O1_R1,Suprabasal,normal,lung
...,...,...,...,...,...
N3_ALIEX_TTTGATCTCAATCGGT-1,R2,ALIEX_R2,Secretory,normal,lung
N3_ALIEX_TTTGGAGCAAGCAGGT-1,R2,ALIEX_R2,Ciliated,normal,lung
N3_ALIEX_TTTGGTTGTATACCCA-1,R2,ALIEX_R2,Secretory,normal,lung
N3_ALIEX_TTTGGTTTCTGAATGC-1,R2,ALIEX_R2,Secretory,normal,lung


In [6]:
adata.var.head()

Unnamed: 0,gene_name,ensembl_id
MIR1302-2HG,MIR1302-2HG,ENSG00000243485
FAM138A,FAM138A,ENSG00000237613
OR4F5,OR4F5,ENSG00000186092
AL627309.1,AL627309.1,ENSG00000238009
AL627309.3,AL627309.3,ENSG00000239945


In [None]:
sc.pl.umap(
    adata,
    color="author_cell_type",
    size=8,
    title="Human Immune Tissue UMAP",
)

In [None]:
adata.X.max()

We are expecting log10 base 10 transformed data, with a maximum value somewhere around 3 or 4. Make sure to start with processed and normalized data when doing the cell sentence conversion!

# Cell2Sentence Conversion

In this section, we will transform our AnnData object containing our single-cell dataset into a Cell2Sentence (C2S) dataset by calling the functions of the CSData class in the C2S code base. Full documentation for the functions of the CSData class can be found in the documentation page of C2S.

In [7]:
adata_obs_cols_to_keep = ["author_cell_type", "disease", "tissue"]

In [8]:
# Create CSData object
arrow_ds, vocabulary = cs.CSData.adata_to_arrow(
    adata=adata, 
    random_state=SEED, 
    sentence_delimiter=' ',
    label_col_names=adata_obs_cols_to_keep
)

WARN: more variables (36601) than observations (15806)... did you mean to transpose the object (e.g. adata.T)?
WARN: more variables (36601) than observations (15806), did you mean to transpose the object (e.g. adata.T)?
100%|██████████| 15806/15806 [01:05<00:00, 241.59it/s]


In [9]:
arrow_ds

Dataset({
    features: ['cell_name', 'cell_sentence', 'author_cell_type', 'disease', 'tissue'],
    num_rows: 15806
})

In [10]:
sample_idx = 1199
arrow_ds[sample_idx]

{'cell_name': 'N2_O1_GAGTTGTTCTCAAAGC-1',
 'cell_sentence': 'MALAT1 MT-CO2 MT-CO3 MT-CO1 MT-ATP6 KRT19 MT-CYB S100A6 S100A2 RPS2 S100A11 GAPDH RPS18 RPL13 EEF1A1 SERPINB1 KRT6A MT-ND4 RPL10 RPL7A ANXA2 RPS8 RPLP1 TMSB10 RPS14 SERPINB3 TPT1 RPL8 RPL17 RPL11 RPS6 RPS23 RPL13A KRT8 RPL12 RPL29 RPS12 RPS3A RPL15 RPS27A RPL26 S100A10 RPL41 RPS19 RPL3 RPL18 RPS3 RPS7 KRT18 S100A14 RPL5 RACK1 MIF RPL19 RPL6 RPS4X RPL28 GSTP1 RPL32 RPL18A CSTB MT-ND1 RPS24 MT-ND3 RPS15 RPS16 RPL9 S100A16 RPL23A RPLP0 ACTG1 RPSA RPS15A RPL30 MT-ND2 RPL10A RPL14 RPL27A RPS9 RPS28 RPL35A KRT17 RPL34 RPL37A TXN RPS13 TPI1 RPL24 RPL35 RPL7 LDHA RPS27 ENO1 FTH1 RPS5 SAT1 RPL21 NACA RPL39 RPL36 RPLP2 NEAT1 RPS11 FXYD3 RPL37 ACTB EEF1G EIF1 KRT13 RPS25 RPS29 PPIA HINT1 TMSB4X RPS17 EEF1B2 CALM1 FAU PTMA H3F3B H3F3A PERP COX4I1 UBA52 RPL22 NPM1 DHRS9 KRT15 RPS21 NME2 PPDPF EIF4A1 PKM PHLDA2 HSP90AA1 ATP5MC3 BTF3 RPL38 FGFBP1 GAS5 TKT SNHG29 KRT7 RPL27 RPS20 RPL36A RPL23 MT-ND5 RPL4 IGFBP2 SLC25A6 CAST MIR31HG SERF2 ADI

In [11]:
arrow_ds['author_cell_type']

Column(['Differentiating.Basal', 'Differentiating.Basal', 'Differentiating.Basal', 'Basal', 'Suprabasal'])

This time, we will leave off creating our CSData object until after we load our C2S model. This is because along with the model checkpoint, we saved the indices of train, val, and test set cells, which will allow us to select out test set cells for inference.

# Load C2S Model

Now, we will load a C2S model with which we will do cell type annotation. For this tutorial, this model will be the last checkpoint of the training session from <font color="red">tutorial notebook 3</font>, where we finetuned our cell type prediction model to do cell type prediction specifically on our immune tissue dataset. We will load the last checkpoint saved from training, and specify the same save_dir as we used before during training.
- <font color="red">Note:</font> If you are using your own data for this tutorial, make sure to switch out to the model checkpoint which you saved in tutorial notebook 3.
- If you want to annotate cell types without finetuning your own C2S model, then tutorial notebook 6 demonstrates how to load the C2S-Pythia-410M cell type prediction foundation model and use it to predict cell types without any finetuning.

We can define our CSModel object with our pretrained cell type prediction model as follows, specifying the same save_dir as we used in tutorial 3:

In [15]:
# Define CSModel object
cell_type_prediction_model_path = "/data/Mamba/Data/hf_cache/hub/models--vandijklab--C2S-Pythia-410m-cell-type-prediction/snapshots/5a4dc3b949b5868ca63752b37bc22e3b0216e435"
save_dir = "./"
save_name = "cell_type_pred_pythia_410M_inference"
csmodel = cs.CSModel(
    model_name_or_path=cell_type_prediction_model_path,
    save_dir=save_dir,
    save_name=save_name
)

Using device: cuda
[2025-09-29 15:53:57,236] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/scbjtfy/anaconda3/envs/Axolotl/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/home/scbjtfy/anaconda3/envs/Axolotl/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'
/home/scbjtfy/anaconda3/envs/Axolotl/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'
/home/scbjtfy/anaconda3/envs/Axolotl/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'
/home/scbjtfy/anaconda3/envs/Axolotl/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::chrono::_V2::steady_clock::now()@GLIBCXX_3.4.19'
/home/scbjtfy/anaconda3/envs/Axolotl/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)@GLIBCXX_3.4'
/home/

[2025-09-29 15:53:57,981] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False


We will also load the data split indices saved alongside the C2S model checkpoint, so that we know which cells were part of the training and validation set. We will do inference on unseen test set cells, which are 10% of the original data.

In [16]:
base_path = "/".join(cell_type_prediction_model_path.split("/")[:-1])
print(cell_type_prediction_model_path)
print(base_path)

/data/Mamba/Data/hf_cache/hub/models--vandijklab--C2S-Pythia-410m-cell-type-prediction/snapshots/5a4dc3b949b5868ca63752b37bc22e3b0216e435
/data/Mamba/Data/hf_cache/hub/models--vandijklab--C2S-Pythia-410m-cell-type-prediction/snapshots


Select out test set cells from full arrow dataset

In [17]:
arrow_ds

Dataset({
    features: ['cell_name', 'cell_sentence', 'author_cell_type', 'disease', 'tissue'],
    num_rows: 15806
})

In [18]:
test_ds = arrow_ds.select(range(1000))
test_ds

Dataset({
    features: ['cell_name', 'cell_sentence', 'author_cell_type', 'disease', 'tissue'],
    num_rows: 1000
})

Now, we will create our CSData object using only the test set cells:

In [19]:
c2s_save_dir = "./"  # C2S dataset will be saved into this directory
c2s_save_name = "dominguez_immune_tissue_tutorial4"  # This will be the name of our C2S dataset on disk

In [20]:
csdata = cs.CSData.csdata_from_arrow(
    arrow_dataset=test_ds, 
    vocabulary=vocabulary,
    save_dir=c2s_save_dir,
    save_name=c2s_save_name,
    dataset_backend="arrow"
)

Saving the dataset (0/1 shards):   0%|          | 0/1000 [00:00<?, ? examples/s]

In [None]:
print(csdata)

# Predict cell types

Now that we have loaded our finetuned cell type prediction model and have our test set, we will do cell type prediction inference using our C2S model. We can use the function predict_cell_types_of_data() from the tasks.py, which will take a CSModel() object and apply it to do cell type prediction on a CSData() object.

In [None]:
predicted_cell_types = predict_cell_types_of_data(
    csdata=csdata,
    csmodel=csmodel,
    n_genes=200
)

In [None]:
len(predicted_cell_types)

In [None]:
predicted_cell_types[:3]

In [None]:
test_ds

In [27]:
total_correct = 0.0
for model_pred, gt_label in zip(predicted_cell_types, test_ds["cell_type"]):
    # C2S might predict a period at the end of the cell type, which we remove
    if model_pred[-1] == ".":
        model_pred = model_pred[:-1]
    
    if model_pred == gt_label:
        total_correct += 1

accuracy = total_correct / len(predicted_cell_types)

In [None]:
print("Accuracy:", accuracy)

In [None]:
for idx in range(0, 100, 10):
    print("Model pred: {}, GT label: {}".format(predicted_cell_types[idx], test_ds[idx]["cell_type"]))

We can see that our model achieves high accuracy, correctly predicting the cell type of unseen cells from the immune tissue data 83.4% of the time! The model learned to predict cell type annotations in natural language effectively from a short finetuning period on the new data.