Geneformer for Cell Annotation Application

In [1]:
import os
import sys

sys.path.append("/home/amonell/Geneformer")
GPU_NUMBER = [0]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"

In [2]:
import torch

if torch.cuda.is_available():
    print("CUDA is available")
else:
    print("CUDA is not available")

CUDA is available


In [3]:
# imports
from collections import Counter
from tqdm.notebook import tqdm
import datetime
import pickle
import subprocess
import seaborn as sns

sns.set()
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments

from geneformer import DataCollatorForCellClassification

import numpy as np
import scanpy as sc
import glob
import os
import pandas as pd
import matplotlib.pyplot as plt

2024-09-09 13:43:41.841600: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-09 13:43:41.841638: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-09 13:43:41.842873: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Input the path to a pretrained geneformer model

In [None]:
geneformer_pretrained_path = 'Geneformer/240318_geneformer_CellClassifier_SI2_SI_L2048_B15_LR5e-05_LSlinear_WU500_E3_Oadamw_F0'

Load in a trained GeneFormer model

In [4]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    # calculate accuracy and macro f1 using sklearn's function
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average="macro")
    return {"accuracy": acc, "macro_f1": macro_f1}

In [5]:
# set model parameters
# max input size
max_input_size = 2**11  # 2048

# set training parameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 16
# batch size for training and eval
geneformer_batch_size = 15
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 3
# optimizer
optimizer = "adamw"
# set logging steps
logging_steps = 1

Predicting with geneformer and saving predictions

In [6]:
from transformers import AutoModelForSequenceClassification

# set training arguments
training_args = {
    "learning_rate": max_lr,
    "do_train": True,
    "do_eval": True,
    "evaluation_strategy": "epoch",
    "save_strategy": "epoch",
    "logging_steps": logging_steps,
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": False,
    "lr_scheduler_type": lr_schedule_fn,
    "warmup_steps": warmup_steps,
    "weight_decay": 0.001,
    "per_device_train_batch_size": geneformer_batch_size,
    "per_device_eval_batch_size": geneformer_batch_size,
    "num_train_epochs": epochs,
    "load_best_model_at_end": True,
    "output_dir": geneformer_pretrained_path,
}

training_args_init = TrainingArguments(**training_args)

model = AutoModelForSequenceClassification.from_pretrained(geneformer_pretrained_path)

In [7]:
organ = 'SI'

trainer = Trainer(
    model=model,
    args=training_args_init,
    data_collator=DataCollatorForCellClassification(),
    train_dataset=None,
    eval_dataset=None,
    compute_metrics=compute_metrics,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [23]:
target_names = ['Goblet',
 'Enterocyte_1',
 'Monocyte',
 'Cd8_T-Cell_P14',
 'Cd8_T-Cell_aa+',
 'Cd8_T-Cell_ab+',
 'MAIT',
 'T-Cell gd',
 'Enterocyte_2',
 'Macrophage',
 'ILC',
 'Cd4_T-Cell',
 'B-Cell',
 'Enteroendocrine',
 'T-Cell',
 'cDC1',
 'Early_Enterocyte',
 'Enterocyte_3',
 'Myofibroblast',
 'Eosinophil',
 'DC2',
 'Lymphatic',
 'Tuft',
 'NK-Cell',
 'Resting Fibroblast',
 'Fibroblast',
 'Transit_Amplifying',
 'Fibroblast_Pdgfrb+ ',
 'Vascular Endothelial',
 'Contaminated DCs',
 'ISC',
 'Paneth',
 'Neuron',
 'Fibroblast_Ncam1',
 'Fibroblast_Pdgfra+',
 'Complement_Fibroblast',
 'MegakaryocytePlatelet']

In [14]:


for filename in glob.glob("/mnt/sata1/Analysis_Alex/uninfected/segmentation_SI*"):
    outname = os.path.basename(filename)
    # load train dataset (includes all tissues)
    train_dataset = (
        "/mnt/sata1/Analysis_Alex/Geneformer/loom_"
        + outname
        + "/tokenized/train_"
        + outname
        + ".dataset"
    )

    # load evaluation dataset (includes all tissues)
    c = 0
    d = train_dataset
    # load test
    test_dataset = load_from_disk(d)
    test_dataset = test_dataset.add_column(
        "label", [0 for i in range(test_dataset.num_rows)]
    )


    predictions_test = trainer.predict(test_dataset)

    ad = sc.read(
        os.path.join(
            filename,
            "adatas/06_reference_mapped.h5ad",
        )
    )
    ad.obs["celltype_predicted"] = np.array(target_names)[
        np.argmax(predictions_test.predictions, axis=1)
    ]
    ad.write(
        os.path.join(
            filename,
            "adatas/06_geneformer_celltypes.h5ad",
        )
    )