### Zero-shot cell type annotation
Given the gene expression profiles of the cells, as well as textual descriptions of alternative cell types, LangCell can automatically perform cell type annotation.

In [None]:
import os
import sys
sys.path.append('/path/to/OpenBioMed')
os.chdir('/path/to/OpenBioMed')
from open_biomed.core.pipeline import InferencePipeline
from open_biomed.data import Cell, Text
from datasets import load_from_disk
import json
from open_biomed.data import Cell, Text
from sklearn.metrics import classification_report

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Load the model
cfg_path = "./configs/model/langcell.yaml"
pipeline = InferencePipeline(model='langcell', task='cell_annotation', device='cuda:2')

05/23/2025 18:35:11 - INFO - root - The config of this process is:
{
    "model": {
        "name": "langcell",
        "cell_model": "./checkpoints/langcell/cell_bert",
        "cell_proj": "./checkpoints/langcell/cell_proj.bin",
        "text_tokenizer": "./checkpoints/langcell/pubmedbert-base",
        "text_model": "./checkpoints/langcell/text_bert",
        "text_proj": "./checkpoints/langcell/text_proj.bin",
        "ctm_head": "./checkpoints/langcell/ctm_head.bin"
    },
    "task": "cell_annotation",
    "model_ckpt": "",
    "device": "cuda:2",
    "logging_level": "info"
}
Some weights of BertModel were not initialized from the model checkpoint at ./checkpoints/langcell/cell_bert and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# Load the dataset
# Download data: https://drive.google.com/drive/folders/1cuhVG9v0YoAnjW-t_WMpQQguajumCBTp
dataset = load_from_disk('/path/to/pbmc10k.dataset')
type2text = json.load(open('/path/to/type2text.json'))
for type in type2text:
    print(type, '----', type2text[type], '\n')

B cells ---- cell type: b cell. a lymphocyte of b lineage that is capable of b cell mediated immunity.;  

CD8 T cells ---- cell type: cd8-positive, alpha-beta t cell. a t cell expressing an alpha-beta t cell receptor and the cd8 coreceptor.;  

CD14+ Monocytes ---- cell type: cd14-positive monocyte. a monocyte that expresses cd14 and is negative for the lineage markers cd3, cd19, and cd20.;  

Dendritic Cells ---- cell type: dendritic cell. a cell of hematopoietic origin, typically resident in particular tissues, specialized in the uptake, processing, and transport of antigens to lymph nodes for the purpose of stimulating an immune response via t cell activation. these cells are lineage negative (cd3-negative, cd19-negative, cd34-negative, and cd56-negative).;  

NK cells ---- cell type: natural killer cell. a lymphocyte that can spontaneously kill a variety of target cells without prior antigenic activation via germline encoded activation receptors and also regulate immune responses 

In [4]:
# random sample
dataset = dataset.shuffle(seed=42).select(range(2000))

In [None]:
# Organize data into specific formats as model inputs
texts = []
type2label = {}
labels = []
for type in type2text:
    texts.append(Text.from_str(type2text[type]))
    type2label[type] = len(texts) - 1
input = {'cell': [], 'class_texts': [], 'label': []}
for data in dataset:
    input['cell'].append(Cell.from_sequence(data['input_ids']))
    input['class_texts'].append(texts)
    input['label'].append(type2label[data['str_labels']])
    labels.append(type2label[data['str_labels']])

In [None]:
# Predict the cell type of each cell using the model
preds, _ = pipeline.run(batch_size=1, **input)
preds = [p.item() for p in preds]

  batch = {'cell': torch.tensor(batch['input_ids'], dtype=torch.int64),
  'attention_mask': torch.tensor(batch['attention_mask'], dtype=torch.int64),
Inference Steps:   0%|          | 2/2000 [00:03<54:46,  1.64s/it]  

Inference Steps: 100%|██████████| 2000/2000 [18:12<00:00,  1.83it/s]


In [None]:
# Analyze the results
print(classification_report(labels, preds, labels=range(len(type2text)), target_names=type2text.keys()))

                   precision    recall  f1-score   support

          B cells       1.00      0.98      0.99       279
      CD8 T cells       0.59      0.95      0.73       260
  CD14+ Monocytes       0.96      0.99      0.98       387
  Dendritic Cells       1.00      0.82      0.90        67
         NK cells       0.82      0.98      0.90        57
   Megakaryocytes       0.82      0.90      0.86        20
FCGR3A+ Monocytes       1.00      0.82      0.90        66
      CD4 T cells       0.98      0.81      0.89       864

         accuracy                           0.89      2000
        macro avg       0.90      0.91      0.89      2000
     weighted avg       0.93      0.89      0.90      2000

