## Geneformer Fine-Tuning for Cell Annotation Application

In [1]:
# virtual-Geneformer FT
import os
GPU_NUMBER = [0]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["WANDB_DISABLED"] = "true"

In [2]:
# imports
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from transformers import BertConfig, BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments

from geneformer import DataCollatorForCellClassification
import sys
import re
import numpy as np

## Prepare training and evaluation datasets

In [4]:
# load cell type dataset (includes all tissues)

dataset_name = "/work/Sim_geneformer/arrow_dataset_sim"

train_dataset=load_from_disk(dataset_name)

print(train_dataset)

Dataset({
    features: ['input_ids', 'cell_types', 'organ_major', 'disease', 'length'],
    num_rows: 240000
})


## Cell Type Classification

In [4]:
# cell type classification

dataset_list = []
evalset_list = []
organ_list = []
target_dict_list = []

for organ in Counter(train_dataset["organ_major"]).keys():
    # collect list of tissues for fine-tuning (immune and bone marrow are included together)
    if organ in ["bone_marrow"]:  
        continue
    elif organ=="immune":
        organ_ids = ["immune","bone_marrow"]
        organ_list += ["immune"]
    else:
        organ_ids = [organ]
        organ_list += [organ]
    
    print(organ)
    
    # filter datasets for given organ
    def if_organ(example):
        return example["organ_major"] in organ_ids
    trainset_organ = train_dataset.filter(if_organ, num_proc=16)
    
    # per scDeepsort published method, drop cell types representing <0.5% of cells
    celltype_counter = Counter(trainset_organ["cell_type"])
    total_cells = sum(celltype_counter.values())
    cells_to_keep = [k for k,v in celltype_counter.items() if v>(0.005*total_cells)]
    def if_not_rare_celltype(example):
        return example["cell_type"] in cells_to_keep
    trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)
      
    # shuffle datasets and rename columns
    trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)
    trainset_organ_shuffled = trainset_organ_shuffled.rename_column("cell_type","label")
    trainset_organ_shuffled = trainset_organ_shuffled.remove_columns("organ_major")
    
    # create dictionary of cell types : label ids
    target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
    target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
    target_dict_list += [target_name_id_dict]
    
    # change labels to numerical ids
    def classes_to_ids(example):
        example["label"] = target_name_id_dict[example["label"]]
        return example
    labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)
    
    # create 80/20 train/eval splits
    labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
    labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
    
    # filter dataset for cell types in corresponding training set
    trained_labels = list(Counter(labeled_train_split["label"]).keys())
    def if_trained_label(example):
        return example["label"] in trained_labels
    labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)

    dataset_list += [labeled_train_split]
    evalset_list += [labeled_eval_split_subset]

brain


Filter (num_proc=16):   0%|          | 0/152470 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/7189 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/7147 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/1429 [00:00<?, ? examples/s]

limb_muscle


Filter (num_proc=16):   0%|          | 0/152470 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/28710 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/28710 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/5742 [00:00<?, ? examples/s]

kidney


Filter (num_proc=16):   0%|          | 0/152470 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/21498 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/21304 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/4261 [00:00<?, ? examples/s]

thymus


Filter (num_proc=16):   0%|          | 0/152470 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/9260 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/9260 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/1852 [00:00<?, ? examples/s]

tongue


Filter (num_proc=16):   0%|          | 0/152470 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/20584 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/20584 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/4117 [00:00<?, ? examples/s]

mammary_gland


Filter (num_proc=16):   0%|          | 0/152470 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/12256 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/12256 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/2451 [00:00<?, ? examples/s]

heart


Filter (num_proc=16):   0%|          | 0/152470 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/9657 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/9657 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/1931 [00:00<?, ? examples/s]

spleen


Filter (num_proc=16):   0%|          | 0/152470 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/35006 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/34789 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/6958 [00:00<?, ? examples/s]

large_intestine


Filter (num_proc=16):   0%|          | 0/152470 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/8310 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/8310 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/1662 [00:00<?, ? examples/s]

## Disease Type Classification

In [5]:
# disease classification

dataset_list = []
evalset_list = []
organ_list = []
target_dict_list = []


for organ in Counter(train_dataset["organ_major"]).keys():
    # collect list of tissues for fine-tuning (immune and bone marrow are included together)
    if organ in ["bone_marrow"]:  
        continue
    elif organ=="immune":
        organ_ids = ["immune","bone_marrow"]
        organ_list += ["immune"]
    else:
        organ_ids = [organ]
        organ_list += [organ]
    
    print(organ)

    
    # filter datasets for given organ
    def if_organ(example):
        return example["organ_major"] in organ_ids
    trainset_organ = train_dataset.filter(if_organ, num_proc=16)
    
    # per scDeepsort published method, drop cell types representing <0.5% of cells
    celltype_counter = Counter(trainset_organ["disease"])
    total_cells = sum(celltype_counter.values())
    cells_to_keep = [k for k,v in celltype_counter.items() if v>(0.005*total_cells)]
    def if_not_rare_celltype(example):
        return example["disease"] in cells_to_keep
    trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)
      
    # shuffle datasets and rename columns
    trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)
    trainset_organ_shuffled = trainset_organ_shuffled.rename_column("disease","label")
    trainset_organ_shuffled = trainset_organ_shuffled.remove_columns("organ_major")
    
    # create dictionary of cell types : label ids
    target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
    target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
    target_dict_list += [target_name_id_dict]
    
    # change labels to numerical ids
    def classes_to_ids(example):
        example["label"] = target_name_id_dict[example["label"]]
        return example
    labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)
    
    # create 80/20 train/eval splits
    labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
    labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
    
    # filter dataset for cell types in corresponding training set
    trained_labels = list(Counter(labeled_train_split["label"]).keys())
    def if_trained_label(example):
        return example["label"] in trained_labels
    labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)

    dataset_list += [labeled_train_split]
    evalset_list += [labeled_eval_split_subset]

adrenal_cortex


Filter (num_proc=16):   0%|          | 0/240000 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/240000 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/240000 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/48000 [00:00<?, ? examples/s]

In [6]:
trainset_dict = dict(zip(organ_list,dataset_list))
traintargetdict_dict = dict(zip(organ_list,target_dict_list))

evalset_dict = dict(zip(organ_list,evalset_list))


print(trainset_dict)
print(traintargetdict_dict)

print(evalset_dict)


{'adrenal_cortex': Dataset({
    features: ['input_ids', 'cell_types', 'label', 'length'],
    num_rows: 192000
})}
{'adrenal_cortex': {'control_male': 0, 'control_female': 1, 'cas+oil': 2, 'ovx+e2': 3, 'cas+e2': 4, 'cas+dht': 5, 'ovx+oil': 6, 'ovx+dht': 7}}
{'adrenal_cortex': Dataset({
    features: ['input_ids', 'cell_types', 'label', 'length'],
    num_rows: 48000
})}


## Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance

In [7]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    
    acc = accuracy_score(labels, preds)
    pre = precision_score(labels, preds, average='macro')
    rec = recall_score(labels, preds, average='macro')
    macro_f1 = f1_score(labels, preds, average='macro')
    return {
      'accuracy': acc,
      'macro_precision': pre,
      'macro_recall': rec,
      'macro_f1': macro_f1
    }

### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example hyperparameters are defined below, but please see the "hyperparam_optimiz_for_disease_classifier" script for an example of how to tune hyperparameters for downstream applications.

In [8]:
# set model parameters
# max input size
max_input_size = 2**12  # 2048

# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 16
# batch size for training and eval
geneformer_batch_size = 12
# learning schedule
lr_schedule_fn = "cosine" #"polynomial", "linear", "cosine"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 10
# optimizer
optimizer = "adamW"


In [None]:

for organ in organ_list:
    print(organ)
    organ_trainset = trainset_dict[organ]
    organ_evalset = evalset_dict[organ]
    organ_label_dict = traintargetdict_dict[organ]
    print(organ_label_dict)
    
    # set logging steps
    logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)

    pretrain_model = "250405_125324_mouse-geneformer_PM-NUse_20M_DV-n1_TMLM_L6_emb256_SL4096_E250_B6_LR0.0001_LScosine_WU10000_DR0.02_ACTsilu_Oadamw_torch_DS8"

    model = BertForSequenceClassification.from_pretrained("/work/mouse-geneformer/models/{}/models/".format(pretrain_model), 
                                                            num_labels=len(organ_label_dict.keys()),
                                                            output_attentions = False,
                                                            output_hidden_states = False).to("cuda")

    # define output directory path
    current_date = datetime.datetime.now()
    datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
    output_dir = f"/work/mouse-geneformer/models/mouse-geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}_ISP-{organ}/"

    # make output directory
    subprocess.call(f'mkdir {output_dir}', shell=True)

    # set training arguments
    training_args = {
        "learning_rate": max_lr,
        "fp16": True, 
        "do_train": True,
        "do_eval": True,
        "evaluation_strategy": "epoch",
        "save_strategy": "epoch",
        "logging_steps": logging_steps,
        "group_by_length": True,
        "length_column_name": "length",
        "disable_tqdm": False,
        "lr_scheduler_type": lr_schedule_fn,
        "warmup_steps": warmup_steps,
        "weight_decay": 0.001,
        "per_device_train_batch_size": geneformer_batch_size,
        "per_device_eval_batch_size": geneformer_batch_size,
        "num_train_epochs": epochs,
        
        "load_best_model_at_end": True,
        "output_dir": output_dir,
        #"max_position_embeddings": 2**11,
    }
    
    training_args_init = TrainingArguments(**training_args)

    # create the trainer
    trainer = Trainer(
        model=model,
        args=training_args_init,
        data_collator=DataCollatorForCellClassification(),
        train_dataset=organ_trainset,
        eval_dataset=organ_evalset,
        compute_metrics=compute_metrics
    )
    # train the cell type classifier
    trainer.train()
    predictions = trainer.predict(organ_evalset)
    with open(f"{output_dir}predictions.pickle", "wb") as fp:
        pickle.dump(predictions, fp)
    trainer.save_metrics("eval",predictions.metrics)
    trainer.save_model(output_dir)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /work/mouse-geneformer/models/250405_125324_mouse-geneformer_PM-NUse_20M_DV-n1_TMLM_L6_emb256_SL4096_E250_B6_LR0.0001_LScosine_WU10000_DR0.02_ACTsilu_Oadamw_torch_DS8/models/ and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


adrenal_cortex
{'control_male': 0, 'control_female': 1, 'cas+oil': 2, 'ovx+e2': 3, 'cas+e2': 4, 'cas+dht': 5, 'ovx+oil': 6, 'ovx+dht': 7}


mkdir: cannot create directory ‘/work/mouse-geneformer/models/mouse-geneformer_CellClassifier_adrenal_cortex_L4096_B12_LR5e-05_LScosine_WU500_E10_OadamW_F0_ISP-adrenal_cortex/’: File exists
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Epoch,Training Loss,Validation Loss,Accuracy,Macro Precision,Macro Recall,Macro F1
1,0.2765,0.335247,0.885458,0.891012,0.885771,0.884674
2,0.1738,0.150022,0.95525,0.955222,0.955472,0.954883
3,0.1127,0.103452,0.971146,0.971362,0.971279,0.971133
4,0.0775,0.097934,0.976292,0.976479,0.976363,0.976316
5,0.0484,0.122619,0.974021,0.974585,0.974097,0.97405
6,0.0281,0.10291,0.981125,0.981304,0.981182,0.981175
7,0.0125,0.099412,0.983583,0.983633,0.983616,0.983617
8,0.0075,0.11022,0.983771,0.983806,0.983834,0.983789
9,0.0018,0.116085,0.984146,0.984255,0.984149,0.984189
10,0.0015,0.114247,0.984646,0.984664,0.984686,0.984672


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
Trainer.tokenizer is now dep