## Geneformer Fine-Tuning for Cell Annotation Application

In [1]:
import os
import sys
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk
import datasets
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments

from geneformer import DataCollatorForCellClassification

import pandas as pd

2023-10-10 16:05:18.825588: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# set model parameters
# max input size
max_input_size = 2 ** 11  # 2048

# set training parameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 4
# number cpu cores
num_proc = 16
# batch size for training and eval. Note that during train cycle space will look free but eval will fill it
geneformer_batch_size = 6
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 10
# optimizer
optimizer = "adamw"

In [3]:
# Where to save model artifact
model_save_dir = "/home/domino/geneformer_workflow/results/perturbation/finetuned_model"

# Where to load pretrained model from. This is location on Huggingface but could be custom
pretrained_model_path = "/home/domino/geneformer_workflow/Geneformer"

tokenized_input_dir = "/home/domino/geneformer_workflow/results/perturbation/tokenized_files"
tokenized_input_prefix = "adata_SS2"


# Name of column storing cell class labels
label_colname = "celltype"

processed_dataset_output_dir = "/home/domino/geneformer_workflow/results/perturbation/processed_dataset"
os.makedirs(processed_dataset_output_dir, exist_ok=True)

## Prepare training and evaluation datasets

In [4]:
# load train dataset (includes all tissues)
full_dataset=load_from_disk(os.path.join(tokenized_input_dir, tokenized_input_prefix + ".dataset"))

In [5]:
full_dataset

Dataset({
    features: ['input_ids', 'n_counts', 'celltype', 'length'],
    num_rows: 1333
})

In [6]:
Counter(full_dataset[label_colname])

Counter({'USM': 501, 'SM': 458, 'N': 374})

In [7]:
# create dictionary of cell types : label ids
target_names = list(Counter(full_dataset[label_colname]).keys())
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))

In [8]:
target_name_id_dict

{'N': 0, 'USM': 1, 'SM': 2}

In [9]:
# change labels to numerical ids
def classes_to_ids(example):
    example["label"] = target_name_id_dict[example[label_colname]]
    return example
full_dataset = full_dataset.map(classes_to_ids, num_proc=num_proc)

In [10]:
full_dataset

Dataset({
    features: ['input_ids', 'n_counts', 'celltype', 'length', 'label'],
    num_rows: 1333
})

In [11]:
full_dataset = full_dataset.remove_columns(label_colname)

In [12]:
full_dataset.save_to_disk(os.path.join(processed_dataset_output_dir, tokenized_input_prefix + ".dataset"))

Saving the dataset (0/1 shards):   0%|          | 0/1333 [00:00<?, ? examples/s]

In [13]:
split_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset['train']
eval_dataset = split_dataset['test']

In [14]:
print(len(full_dataset))
print(len(train_dataset))
print(len(eval_dataset))

1333
1199
134


In [15]:
Counter(train_dataset['label'])

Counter({1: 455, 2: 415, 0: 329})

In [16]:
Counter(eval_dataset['label'])

Counter({1: 46, 0: 45, 2: 43})

## Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance

In [17]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average='macro')
    return {
      'accuracy': acc,
      'macro_f1': macro_f1
    }

In [18]:
# set logging steps
logging_steps = round(len(train_dataset)/geneformer_batch_size/10)
logging_steps

20

In [19]:
model = BertForSequenceClassification.from_pretrained(pretrained_model_path, 
                                                  num_labels=len(target_name_id_dict),
                                                  output_attentions = False,
                                                  output_hidden_states = False).to("cuda")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /home/domino/geneformer_workflow/Geneformer and are newly initialized: ['classifier.bias', 'bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [20]:
# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}_{current_date.strftime('%X').replace(':','')}"
output_dir = f"{model_save_dir}/{datestamp}_geneformer_CellClassifier_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
output_dir

'/home/domino/geneformer_workflow/results/perturbation/finetuned_model/231010_160531_geneformer_CellClassifier_L2048_B6_LR5e-05_LSlinear_WU500_E10_Oadamw_F0/'

In [21]:
# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
if os.path.isfile(saved_model_test) == True:
    raise Exception("Model already saved to this directory.")

os.makedirs(output_dir)

In [22]:
training_args = {
    "learning_rate": max_lr,
    "do_train": True,
    "do_eval": True,
    "evaluation_strategy": "epoch",
    "save_strategy": "epoch",
    "logging_steps": logging_steps,
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": False,
    "lr_scheduler_type": lr_schedule_fn,
    "warmup_steps": warmup_steps,
    "weight_decay": 0.001,
    "per_device_train_batch_size": geneformer_batch_size,
    "per_device_eval_batch_size": geneformer_batch_size,
    "num_train_epochs": epochs,
    "load_best_model_at_end": True,
    "output_dir": output_dir,
    "eval_accumulation_steps": 1,  # Otherwise runs out of memory during eval
    "fp16": False,
}

training_args_init = TrainingArguments(**training_args)

In [23]:
trainer = Trainer(
    model=model,
    args=training_args_init,
    data_collator=DataCollatorForCellClassification(),
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()

  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


In [None]:
trainer.save_model(output_dir)