In [1]:
import os

from tqdm import tqdm
import requests
import numpy as np
import matplotlib.pyplot as plt
import sys
import seaborn as sns; sns.set()
import pandas as pd
import json
import gdown
import scanpy as sc
import anndata as ad
from scipy import sparse

from collections import Counter
import datetime
import pickle
import gc
import subprocess
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments

from geneformer import DataCollatorForCellClassification

import pyarrow as pa
import pyarrow.dataset as ds
import pandas as pd
from datasets import Dataset

  @jit
  @jit
  @jit


We try to Train the model for 1 epoch in this notebook and we save the model seperately for later uses.

In [2]:
# Loading the hugging face dataset we saved previously
hg_dataset = load_from_disk(f'reference_25000')

In [3]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    # calculate accuracy and macro f1 using sklearn's function
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average='macro')
    return {
      'accuracy': acc,
      'macro_f1': macro_f1
    }
# set model parameters
# max input size
max_input_size = 2 ** 11  # 2048

# set training parameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 16
# batch size for training and eval
geneformer_batch_size = 6
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 1
# optimizer
optimizer = "adamw"

In [4]:
training_args = {
    "learning_rate": max_lr,
    "do_train": True,
    "do_eval": False,
#     "evaluation_strategy": "epoch",
    "save_strategy": "epoch",
#     "logging_steps": "logging_steps",
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": False,
#     "lr_scheduler_type": lr_schedule_fn,
#     "warmup_steps": warmup_steps,
    "weight_decay": 0.001,
    "per_device_train_batch_size": geneformer_batch_size,
#     "per_device_eval_batch_size": geneformer_batch_size,
    "num_train_epochs": epochs,
#     "load_best_model_at_end": True,
    "output_dir": "./output",
}
training_args_init = TrainingArguments(**training_args)

In [5]:
# reload pretrained model
model = BertForSequenceClassification.from_pretrained("./model", 
                                                  num_labels=50,
                                                  output_attentions = False,
                                                  output_hidden_states = False).to("cuda")

Some weights of the model checkpoint at ./model were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./model and are newly initialized: ['

In [6]:
trainer = Trainer(
    model=model,
    args=training_args_init,
    data_collator=DataCollatorForCellClassification(),
    train_dataset=hg_dataset,
#     eval_dataset=input_df,
    compute_metrics=compute_metrics
)

In [7]:
trainer.train()

  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Step,Training Loss
500,3.1164
1000,2.6524
1500,2.2741
2000,1.9498
2500,1.8432
3000,1.7031
3500,1.6285
4000,1.5696


TrainOutput(global_step=4167, training_loss=2.0732132664738385, metrics={'train_runtime': 2086.9822, 'train_samples_per_second': 11.979, 'train_steps_per_second': 1.997, 'total_flos': 970588333200000.0, 'train_loss': 2.0732132664738385, 'epoch': 1.0})

In [9]:
trainer.save_model("./trained_model")