## Geneformer Fine-Tuning for Cell Annotation Application

In [1]:
import torch.nn as nn
import os

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

In [4]:

GPU_NUMBER = [0,1,2]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"



In [4]:
!pwd


/vsphhome/xwx/Geneformer/examples


In [3]:
!pip install -r requirements.txt
!pip install transformers==4.28.0
!pip install --upgrade accelerate -U

[31mERROR: Could not open requirements file: [Errno 2] No such file or directory: 'requirements.txt'[0m[31m
Collecting accelerate
  Downloading accelerate-0.28.0-py3-none-any.whl.metadata (18 kB)
Downloading accelerate-0.28.0-py3-none-any.whl (290 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m290.1/290.1 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hInstalling collected packages: accelerate
  Attempting uninstall: accelerate
    Found existing installation: accelerate 0.27.2
    Uninstalling accelerate-0.27.2:
      Successfully uninstalled accelerate-0.27.2
Successfully installed accelerate-0.28.0


In [5]:
# imports
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments

from geneformer import DataCollatorForCellClassification

## Prepare training and evaluation datasets

In [6]:
# load cell type dataset (includes all tissues)
train_dataset=load_from_disk("/vsphhome/xwx/Geneformer/token_data/MMD_snRNA.dataset/")

In [7]:
print(train_dataset)
import pandas as pd
# input_ids represent rank encodings
pd.DataFrame(train_dataset)
train_dataset['length']

Dataset({
    features: ['input_ids', 'label', 'length'],
    num_rows: 78886
})


[2048,
 1274,
 969,
 1666,
 1029,
 305,
 359,
 2048,
 1008,
 303,
 354,
 282,
 409,
 1478,
 400,
 369,
 1850,
 294,
 2048,
 1100,
 1748,
 2048,
 337,
 314,
 1244,
 991,
 408,
 345,
 298,
 2022,
 2048,
 2048,
 1631,
 331,
 1085,
 419,
 983,
 322,
 1413,
 1206,
 2046,
 2048,
 1291,
 1414,
 1608,
 311,
 1199,
 1109,
 1908,
 288,
 323,
 1075,
 347,
 328,
 2048,
 465,
 406,
 1063,
 2048,
 2048,
 1665,
 357,
 428,
 849,
 971,
 1783,
 2048,
 325,
 288,
 312,
 366,
 2048,
 415,
 1314,
 419,
 1064,
 383,
 2048,
 411,
 1352,
 1094,
 1069,
 940,
 1023,
 1274,
 1457,
 363,
 307,
 318,
 1062,
 1170,
 302,
 1545,
 1437,
 401,
 299,
 437,
 2048,
 905,
 1467,
 263,
 309,
 1132,
 394,
 398,
 347,
 325,
 1223,
 1015,
 411,
 2048,
 291,
 323,
 2048,
 2048,
 1245,
 2048,
 2048,
 396,
 2048,
 365,
 332,
 347,
 411,
 289,
 1022,
 2042,
 295,
 2048,
 282,
 1292,
 325,
 412,
 2003,
 1577,
 316,
 337,
 1297,
 325,
 2048,
 300,
 1889,
 373,
 1632,
 298,
 2048,
 1365,
 1227,
 1015,
 2048,
 1654,
 369,
 312,
 193

In [9]:
dataset_list = []
evalset_list = []
target_dict_list = []

train_dataset_shuffled = train_dataset.shuffle(seed=42)



# create dictionary of cell types : label ids 
target_names = list(Counter(train_dataset_shuffled["label"]).keys())
print(f"target name: {target_names}")

target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
target_dict_list += [target_name_id_dict]
print(target_name_id_dict)

# change labels to numerical ids
def classes_to_ids(example):
    example["label"] = target_name_id_dict[example["label"]]
    return example
labeled_trainset = train_dataset_shuffled.map(classes_to_ids, num_proc=16)

# create 80/20 train/eval splits
labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])

dataset_list += [labeled_train_split]
evalset_list += [labeled_eval_split]

target name: ['Control', 'Suicide']
{'Control': 0, 'Suicide': 1}


Map (num_proc=16):   0%|          | 0/78886 [00:00<?, ? examples/s]

In [11]:
trainset_dict=dataset_list

evalset_dict=evalset_list



## Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance

In [13]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    # calculate accuracy and macro f1 using sklearn's function
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average='macro')
    return {
      'accuracy': acc,
      'macro_f1': macro_f1
    }

### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example hyperparameters are defined below, but please see the "hyperparam_optimiz_for_disease_classifier" script for an example of how to tune hyperparameters for downstream applications.

In [23]:
# set model parameters
# max input size
max_input_size = 2048  # 2048 2 ** 11
 
# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 3
# number gpus
num_gpus = 3 #4
# number cpu cores
num_proc = 16
# batch size for training and eval
geneformer_batch_size = 12 #12
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 10
# optimizer
optimizer = "adamw"

In [15]:
#for organ in organ_list:
organ_trainset = trainset_dict
organ_evalset = evalset_dict
# organ_label_dict = traintargetdict_dict

In [16]:
organ_trainset[0]


Dataset({
    features: ['input_ids', 'label', 'length'],
    num_rows: 63109
})

In [26]:
#for organ in organ_list:
organ_trainset = trainset_dict[0]
organ_evalset = evalset_dict[0]
# organ_label_dict = traintargetdict_dict

# set logging steps
logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)

# reload pretrained model
model = BertForSequenceClassification.from_pretrained("/vsphhome/xwx/Geneformer", 
                                                    num_labels=2,
                                                    output_attentions = False,
                                                    #from_tf=True,
                                                    output_hidden_states = False).to("cuda")
print(model)
# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
output_dir = f"/vsphhome/xwx/Geneformer/models/{datestamp}_geneformer_DepressionClassifier_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"

# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
if os.path.isfile(saved_model_test) == True:
    raise Exception("Model already saved to this directory.")

# make output directory
subprocess.call(f'mkdir {output_dir}', shell=True)

# set training arguments
training_args = {
    "learning_rate": max_lr,
    "do_train": True,
    "do_eval": True,
    "evaluation_strategy": "epoch",
    "save_strategy": "epoch",
    "logging_steps": logging_steps,
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": False,
    "lr_scheduler_type": lr_schedule_fn,
    "warmup_steps": warmup_steps,
    "weight_decay": 0.001,
    "per_device_train_batch_size": geneformer_batch_size,
    "per_device_eval_batch_size": geneformer_batch_size,
    "num_train_epochs": epochs,
    "load_best_model_at_end": True,
    "output_dir": output_dir,
}

training_args_init = TrainingArguments(**training_args)

# create the trainer
trainer = Trainer(
    model=model,
    args=training_args_init,
    data_collator=DataCollatorForCellClassification(),
    train_dataset=organ_trainset,
    eval_dataset=organ_evalset,
    compute_metrics=compute_metrics
)
# train the cell type classifier
trainer.train()
predictions = trainer.predict(organ_evalset)
with open(f"{output_dir}predictions.pickle", "wb") as fp:
    pickle.dump(predictions, fp)
trainer.save_metrics("eval",predictions.metrics)
trainer.save_model(output_dir)

Some weights of the model checkpoint at /vsphhome/xwx/Geneformer were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /vsphhome/xwx/Genefo

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(25426, 256, padding_idx=0)
      (position_embeddings): Embedding(2048, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.02, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.02, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12

Epoch,Training Loss,Validation Loss


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


In [28]:
print(predictions.label_ids)
print(predictions.predictions.shape, predictions.label_ids.shape)

[1 1 0 ... 1 1 0]
(15777, 2) (15777,)


In [29]:
predictions

PredictionOutput(predictions=array([[ 0.77346176, -0.77803445],
       [-1.2236983 ,  0.8859356 ],
       [ 1.440623  , -1.3611727 ],
       ...,
       [-1.7334038 ,  1.273731  ],
       [-2.879622  ,  2.3565166 ],
       [ 1.3992474 , -1.329712  ]], dtype=float32), label_ids=array([1, 1, 0, ..., 1, 1, 0]), metrics={'test_loss': 0.4160965383052826, 'test_accuracy': 0.8094060974836788, 'test_macro_f1': 0.8083044936447897, 'test_runtime': 48.8708, 'test_samples_per_second': 322.831, 'test_steps_per_second': 8.983})