# What is Fine-Tuning?


Fine-tuning is the process of taking a pretrained model (such as BERT, GPT, RoBERTa, or T5) and training it further on a specific task using a smaller, labeled dataset. Instead of training a model from scratch (which can take weeks and massive amounts of data), fine-tuning starts from a model that already understands the structure and meaning of language.

In [40]:
from helical.models.geneformer import GeneformerConfig, GeneformerFineTuningModel
import anndata as ad


* Upload single cell data in annadata format
* Define the labels column

In [41]:
# Load the data
ann_data = ad.read_h5ad("../yolksac_human.h5ad")
print(ann_data)
# Get the column for fine-tuning
cell_types = list(ann_data.obs["LVL1"][:1000])
label_set = set(cell_types)
label_set

AnnData object with n_obs × n_vars = 31680 × 37318
    obs: 'component', 'stage', 'sex', 'sort.ids', 'fetal.ids', 'orig.dataset', 'sequencing.type', 'lanes', 'LVL1', 'LVL2', 'LVL3', 'LVL3_for_embryo_study'
    var: 'n_cells'
    obsm: 'X_umap'


{'ERYTHROID', 'LYMPHOID', 'MK', 'MYELOID', 'PROGENITOR', 'STROMA'}

* Define Geneformer configuration parameters


- fine_tuning_head : Literal["classification", "regression"] | HelicalBaseFineTuningHead

- output_size : Optional[int]
    The output size of the fine-tuning model. This is required if the `fine_tuning_head` is a string specified task. For a classification task this is number of unique classes.

In [42]:
# Create a GeneformerConfig object

geneformer_config = GeneformerConfig(model_name="gf-12L-95M-i4096", batch_size=10, nproc = 32, accelerator = True)

Define the parameters required for Geneformer finetuning
- fine_tuning_head : Literal["classification", "regression"] | HelicalBaseFineTuningHead

- output_size : Optional[int]
    The output size of the fine-tuning model. This is required if the `fine_tuning_head` is a string specified task. For a classification task this is number of unique classes.

In [43]:
# Create a GeneformerFineTuningModel object
geneformer_fine_tune = GeneformerFineTuningModel(geneformer_config=geneformer_config, fine_tuning_head="classification", output_size=len(label_set))

INFO:helical.models.geneformer.model:Model finished initializing.
INFO:helical.models.geneformer.model:'gf-12L-95M-i4096' model is in 'eval' mode, on device 'cpu' with embedding mode 'cell'.


* Process the single cell data (map ensemble_ids and tokenization)

In [44]:
# Process the data
dataset = geneformer_fine_tune.process_data(ann_data[:1000])

# Add column to the dataset
dataset = dataset.add_column('cell_types', cell_types)

dataset

INFO:helical.models.geneformer.model:Processing data for Geneformer.
  adata.var["index"] = adata.var.index

INFO:pyensembl.sequence_data:Loaded sequence dictionary from /scratch/aazd1f17/shared_space/AI_hackathon25/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.cdna.all.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /scratch/aazd1f17/shared_space/AI_hackathon25/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.ncrna.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /scratch/aazd1f17/shared_space/AI_hackathon25/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.pep.all.fa.gz.pickle
INFO:helical.utils.mapping:Mapped 21359 genes to Ensembl IDs from a total of 37318 genes.
INFO:helical.models.geneformer.geneformer_tokenizer:AnnData object with n_obs × n_vars = 1000 × 37318
    obs: 'component', 'stage', 'sex', 'sort.ids', 'fetal.ids', 'orig.dataset', 'sequencing.type', 'lanes', 'LVL1', 'LVL2', 'LVL3', 'LVL3_for_embryo_stu

Map (num_proc=32):   0%|          | 0/1000 [00:00<?, ? examples/s]

INFO:helical.models.geneformer.model:Successfully processed the data for Geneformer.


Dataset({
    features: ['input_ids', 'length', 'cell_types'],
    num_rows: 1000
})

* convert text labels (cell type names) into numeric class IDs, which is a common preprocessing step when preparing data for machine learning models.

In [45]:
# Create a dictionary to map cell types to ids
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

def classes_to_ids(example):
    example["cell_types"] = class_id_dict[example["cell_types"]]
    return example

# Convert cell types to ids
dataset = dataset.map(classes_to_ids, num_proc=1)

class_id_dict

# Check the unique values in the "cell_types" column to find the number of classes
num_classes = len(set(dataset["cell_types"]))  # Or use np.unique() if you're working with NumPy arrays

print(f"Number of classes in the dataset: {num_classes}")


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Number of classes in the dataset: 6


In [46]:
# Fine-tune the model
geneformer_fine_tune.train(train_dataset=dataset, label="cell_types")

INFO:helical.models.geneformer.fine_tuning_model:Freezing the first 2 encoder layers of the Geneformer model during fine-tuning.
INFO:helical.models.geneformer.fine_tuning_model:Starting Fine-Tuning
Fine-Tuning: epoch 1/1: 100%|█████| 100/100 [22:31<00:00, 13.51s/it, loss=0.309]
INFO:helical.models.geneformer.fine_tuning_model:Fine-Tuning Complete. Epochs: 1


In [47]:
# Save the fine-tuned model to a specified directory
import torch

# Save the model's state_dict (weights)
torch.save(geneformer_fine_tune.state_dict(), './fine_tuned_geneformer.pth')

import pickle

# Save the configuration (such as model architecture, hyperparameters)
with open('./fine_tuned_geneformer_config.pkl', 'wb') as f:
    pickle.dump(geneformer_config, f)

* Outputs: These are typically the logits from the final classification layer of the model.

shape = number of cells x number of classes

Embeddings: These are dense vector representations of the input
shape = number of cells x number of dimensions


In [29]:
# Get logits from the fine-tuned model
outputs = geneformer_fine_tune.get_outputs(dataset)
#print(outputs[:100])

# Get embeddings from the fine-tuned model
embeddings = geneformer_fine_tune.get_embeddings(dataset)
#print(embeddings[:100])

Generating Outputs: 100%|███████████████████████| 10/10 [00:40<00:00,  4.02s/it]
INFO:helical.models.geneformer.model:Started getting embeddings:


  0%|          | 0/10 [00:00<?, ?it/s]

INFO:helical.models.geneformer.model:Finished getting embeddings.


In [30]:
# Print dimensions of model outputs (logits)
print("Outputs shape:", outputs.shape)

# Print dimensions of embeddings
print("Embeddings shape:", embeddings.shape)

Outputs shape: (100, 4)
Embeddings shape: (100, 512)


In [48]:
# Load the model configuration
with open('./fine_tuned_geneformer_config.pkl', 'rb') as f:
    loaded_config = pickle.load(f)

# Recreate the model using the configuration
geneformer_fine_tune = GeneformerFineTuningModel(geneformer_config=loaded_config, fine_tuning_head="classification", output_size=len(label_set))

# Load the saved weights
geneformer_fine_tune.load_state_dict(torch.load('./fine_tuned_geneformer.pth'))

INFO:helical.models.geneformer.model:Model finished initializing.
INFO:helical.models.geneformer.model:'gf-12L-95M-i4096' model is in 'eval' mode, on device 'cpu' with embedding mode 'cell'.


<All keys matched successfully>

In [61]:
# Load the data
ann_data = ad.read_h5ad("../yolksac_human.h5ad")

# Get the column for fine-tuning
cell_types = list(ann_data.obs["LVL1"][:10])
label_set = set(cell_types)
label_set
# Process the data
dataset = geneformer_fine_tune.process_data(ann_data[:10])

dataset = dataset.add_column('cell_types', cell_types)

# Create a dictionary to map cell types to ids
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

def classes_to_ids(example):
    example["cell_types"] = class_id_dict[example["cell_types"]]
    return example

# Convert cell types to ids
dataset = dataset.map(classes_to_ids, num_proc=1)

INFO:helical.models.geneformer.model:Processing data for Geneformer.
  adata.var["index"] = adata.var.index

INFO:pyensembl.sequence_data:Loaded sequence dictionary from /scratch/aazd1f17/shared_space/AI_hackathon25/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.cdna.all.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /scratch/aazd1f17/shared_space/AI_hackathon25/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.ncrna.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /scratch/aazd1f17/shared_space/AI_hackathon25/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.pep.all.fa.gz.pickle
INFO:helical.utils.mapping:Mapped 21359 genes to Ensembl IDs from a total of 37318 genes.
INFO:helical.models.geneformer.geneformer_tokenizer:AnnData object with n_obs × n_vars = 10 × 37318
    obs: 'component', 'stage', 'sex', 'sort.ids', 'fetal.ids', 'orig.dataset', 'sequencing.type', 'lanes', 'LVL1', 'LVL2', 'LVL3', 'LVL3_for_embryo_study

Map (num_proc=10):   0%|          | 0/10 [00:00<?, ? examples/s]

INFO:helical.models.geneformer.model:Successfully processed the data for Geneformer.


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

In [56]:
# Get logits from the fine-tuned model
outputs = geneformer_fine_tune.get_outputs(dataset)
#print(outputs[:10])

# Get embeddings from the fine-tuned model
#embeddings = geneformer_fine_tune.get_embeddings(dataset)
#print(embeddings[:10])

Generating Outputs: 100%|█████████████████████████| 1/1 [00:06<00:00,  6.33s/it]


In [57]:
import torch
import numpy as np

# Convert logits to probabilities using softmax
probs = torch.nn.functional.softmax(torch.tensor(outputs), dim=-1)
#probs

In [58]:
# Get predicted class (the class with the highest probability)
predicted_classes = probs.argmax(dim=-1).numpy()  # Convert to numpy for easier handling
predicted_classes

array([2, 2, 2, 1, 1, 5, 5, 5, 2, 5])

In [62]:
# True labels (cell types from dataset)
true_labels = np.array(dataset["cell_types"])
true_labels

array([1, 1, 1, 0, 0, 2, 2, 2, 1, 2])

In [63]:
# Compare predicted labels with true labels to calculate accuracy
accuracy = (predicted_classes == true_labels).mean()
print(f"Accuracy: {accuracy * 100:.2f}%")

Accuracy: 0.00%
