# What is Fine-Tuning?


Fine-tuning is the process of taking a pretrained model and training it further on a specific task using a smaller, labeled dataset. Instead of training a model from scratch (which can take weeks and massive amounts of data), fine-tuning starts from a model that already understands the structure and meaning of language.

In [1]:
from helical.models.geneformer import GeneformerConfig, GeneformerFineTuningModel
import anndata as ad


INFO:datasets:PyTorch version 2.6.0 available.


* Upload single cell data in annadata format
* Define the labels column

In [2]:
# Load the data
ann_data = ad.read_h5ad("../yolksac_human.h5ad")
print(ann_data)
# Get the column for fine-tuning
cell_types = list(ann_data.obs["LVL1"][:10])
label_set = set(cell_types)
label_set

AnnData object with n_obs × n_vars = 31680 × 37318
    obs: 'component', 'stage', 'sex', 'sort.ids', 'fetal.ids', 'orig.dataset', 'sequencing.type', 'lanes', 'LVL1', 'LVL2', 'LVL3', 'LVL3_for_embryo_study'
    var: 'n_cells'
    obsm: 'X_umap'


{'ERYTHROID', 'MYELOID', 'STROMA'}

* Define Geneformer configuration parameters


- fine_tuning_head : Literal["classification", "regression"] | HelicalBaseFineTuningHead

- output_size : Optional[int]
    The output size of the fine-tuning model. This is required if the `fine_tuning_head` is a string specified task. For a classification task this is number of unique classes.

In [3]:
# Create a GeneformerConfig object

geneformer_config = GeneformerConfig(model_name="gf-12L-95M-i4096", batch_size=10, nproc = 32, accelerator = True)

Define the parameters required for Geneformer finetuning
- fine_tuning_head : Literal["classification", "regression"] | HelicalBaseFineTuningHead

- output_size : Optional[int]
    The output size of the fine-tuning model. This is required if the `fine_tuning_head` is a string specified task. For a classification task this is number of unique classes.

In [4]:
# Create a GeneformerFineTuningModel object
geneformer_fine_tune = GeneformerFineTuningModel(geneformer_config=geneformer_config, fine_tuning_head="classification", output_size=len(label_set))

INFO:helical.models.geneformer.model:Model finished initializing.
INFO:helical.models.geneformer.model:'gf-12L-95M-i4096' model is in 'eval' mode, on device 'cpu' with embedding mode 'cell'.


* Process the single cell data (map ensemble_ids and tokenization)

In [5]:
# Process the data
dataset = geneformer_fine_tune.process_data(ann_data[:10])

print(dataset)
# Add column to the dataset
dataset2 = dataset.add_column('cell_types', cell_types)

print(dataset2)

INFO:helical.models.geneformer.model:Processing data for Geneformer.
  adata.var["index"] = adata.var.index

INFO:pyensembl.sequence_data:Loaded sequence dictionary from /scratch/aazd1f17/shared_space/AI_hackathon25/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.cdna.all.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /scratch/aazd1f17/shared_space/AI_hackathon25/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.ncrna.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /scratch/aazd1f17/shared_space/AI_hackathon25/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.pep.all.fa.gz.pickle
INFO:helical.utils.mapping:Mapped 21359 genes to Ensembl IDs from a total of 37318 genes.
INFO:helical.models.geneformer.geneformer_tokenizer:AnnData object with n_obs × n_vars = 10 × 37318
    obs: 'component', 'stage', 'sex', 'sort.ids', 'fetal.ids', 'orig.dataset', 'sequencing.type', 'lanes', 'LVL1', 'LVL2', 'LVL3', 'LVL3_for_embryo_study

Map (num_proc=10):   0%|          | 0/10 [00:00<?, ? examples/s]

INFO:helical.models.geneformer.model:Successfully processed the data for Geneformer.


Dataset({
    features: ['input_ids', 'length'],
    num_rows: 10
})
Dataset({
    features: ['input_ids', 'length', 'cell_types'],
    num_rows: 10
})


* convert text labels (cell type names) into numeric class IDs, which is a common preprocessing step when preparing data for machine learning models.

In [6]:
# Create a dictionary to map cell types to ids
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
print(class_id_dict)

def classes_to_ids(example):
    example["cell_types"] = class_id_dict[example["cell_types"]]
    return example

# Convert cell types to ids
dataset2 = dataset2.map(classes_to_ids, num_proc=1)


# Check the unique values in the "cell_types" column to find the number of classes
num_classes = len(set(dataset2["cell_types"]))  # Or use np.unique() if you're working with NumPy arrays

print(f"Number of classes in the dataset: {num_classes}")


{'MYELOID': 0, 'ERYTHROID': 1, 'STROMA': 2}


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Number of classes in the dataset: 3


In [7]:
# Fine-tune the model
geneformer_fine_tune.train(train_dataset=dataset2, label="cell_types", epochs=2)

INFO:helical.models.geneformer.fine_tuning_model:Freezing the first 2 encoder layers of the Geneformer model during fine-tuning.
INFO:helical.models.geneformer.fine_tuning_model:Starting Fine-Tuning
Fine-Tuning: epoch 1/2: 100%|██████████| 1/1 [00:22<00:00, 22.30s/it, loss=1.11]
Fine-Tuning: epoch 2/2: 100%|██████████| 1/1 [00:19<00:00, 19.01s/it, loss=0.97]
INFO:helical.models.geneformer.fine_tuning_model:Fine-Tuning Complete. Epochs: 2



# Saving only the model weights (state_dict)
torch.save(geneformer_fine_tune.state_dict(), './fine_tuned_geneformer.pth')

# Saving the model configuration
import pickle
with open('./fine_tuned_geneformer_config.pkl', 'wb') as f:
    pickle.dump(geneformer_config, f)

# Or Saving the full model (architecture + weights)
torch.save(geneformer_fine_tune, './full_geneformer_model.pth')

# Loading the full model (architecture + weights)
loaded_model = torch.load('./full_geneformer_model.pth')


In [12]:
# Save the fine-tuned model to a specified directory

import torch
# Save the model's state_dict (weights)
torch.save(geneformer_fine_tune.state_dict(), './fine_tuned_geneformer.pth')

import pickle

# Save the configuration (such as model architecture, hyperparameters)
with open('./fine_tuned_geneformer_config.pkl', 'wb') as f:
    pickle.dump(geneformer_config, f)

* Outputs: These are typically the logits from the final classification layer of the model.

shape = number of cells x number of classes

Embeddings: These are dense vector representations of the input
shape = number of cells x number of dimensions


In [13]:
# Get logits from the fine-tuned model
outputs = geneformer_fine_tune.get_outputs(dataset)
#print(outputs[:100])

# Get embeddings from the fine-tuned model
embeddings = geneformer_fine_tune.get_embeddings(dataset)
#print(embeddings[:100])

Generating Outputs: 100%|█████████████████████████| 1/1 [00:07<00:00,  7.46s/it]
INFO:helical.models.geneformer.model:Started getting embeddings:


  0%|          | 0/1 [00:00<?, ?it/s]

INFO:helical.models.geneformer.model:Finished getting embeddings.


In [14]:
# Print dimensions of model outputs (logits)
print("Outputs shape:", outputs.shape)

# Print dimensions of embeddings
print("Embeddings shape:", embeddings.shape)

Outputs shape: (10, 3)
Embeddings shape: (10, 512)


In [15]:
# Load the model configuration
with open('./fine_tuned_geneformer_config.pkl', 'rb') as f:
    loaded_config = pickle.load(f)

# Recreate the model using the configuration
geneformer_fine_tune = GeneformerFineTuningModel(geneformer_config=loaded_config, fine_tuning_head="classification", output_size=len(label_set))

# Load the saved weights
geneformer_fine_tune.load_state_dict(torch.load('./fine_tuned_geneformer.pth'))

INFO:helical.models.geneformer.model:Model finished initializing.
INFO:helical.models.geneformer.model:'gf-12L-95M-i4096' model is in 'eval' mode, on device 'cpu' with embedding mode 'cell'.


<All keys matched successfully>

In [17]:
# Load the data
ann_data = ad.read_h5ad("../yolksac_human.h5ad")

# Get the column for fine-tuning
cell_types = list(ann_data.obs["LVL1"][:10])
label_set = set(cell_types)
label_set
# Process the data
dataset = geneformer_fine_tune.process_data(ann_data[:10])

dataset = dataset.add_column('cell_types', cell_types)

# Create a dictionary to map cell types to ids
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

def classes_to_ids(example):
    example["cell_types"] = class_id_dict[example["cell_types"]]
    return example

# Convert cell types to ids
dataset = dataset.map(classes_to_ids, num_proc=1)

INFO:helical.models.geneformer.model:Processing data for Geneformer.
  adata.var["index"] = adata.var.index

INFO:pyensembl.sequence_data:Loaded sequence dictionary from /scratch/aazd1f17/shared_space/AI_hackathon25/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.cdna.all.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /scratch/aazd1f17/shared_space/AI_hackathon25/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.ncrna.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /scratch/aazd1f17/shared_space/AI_hackathon25/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.pep.all.fa.gz.pickle
INFO:helical.utils.mapping:Mapped 21359 genes to Ensembl IDs from a total of 37318 genes.
INFO:helical.models.geneformer.geneformer_tokenizer:AnnData object with n_obs × n_vars = 10 × 37318
    obs: 'component', 'stage', 'sex', 'sort.ids', 'fetal.ids', 'orig.dataset', 'sequencing.type', 'lanes', 'LVL1', 'LVL2', 'LVL3', 'LVL3_for_embryo_study

Map (num_proc=10):   0%|          | 0/10 [00:00<?, ? examples/s]

INFO:helical.models.geneformer.model:Successfully processed the data for Geneformer.


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

In [18]:
# Get logits from the fine-tuned model
outputs = geneformer_fine_tune.get_outputs(dataset)
print(outputs[:10])

# Get embeddings from the fine-tuned model
embeddings = geneformer_fine_tune.get_embeddings(dataset)
print(embeddings[:10])

Generating Outputs: 100%|█████████████████████████| 1/1 [00:07<00:00,  7.07s/it]
INFO:helical.models.geneformer.model:Started getting embeddings:


[[ 0.33764753 -0.63448006  0.42256805]
 [-0.0047673  -0.31886572  0.3476621 ]
 [ 0.4034266  -0.47087127  0.47804916]
 [-0.28821072 -0.18367608  0.20982245]
 [-0.03870388 -0.30438682  0.35442135]
 [ 0.6792635  -0.4365916   0.19478877]
 [ 0.6383105  -0.34320307  0.34863678]
 [ 0.48446694 -0.51842535  0.12969084]
 [ 0.10496124 -0.5309649   0.45663705]
 [ 0.6005198  -0.40026966  0.1865585 ]]


  0%|          | 0/1 [00:00<?, ?it/s]

INFO:helical.models.geneformer.model:Finished getting embeddings.


[[ 0.03350314  0.13316047  0.20343904 ...  0.8142347  -0.3028932
   0.08662338]
 [ 0.02586686  0.08685685  0.10706159 ...  0.508063   -0.20464893
  -0.00271893]
 [ 0.0270964   0.09861639  0.18019074 ...  0.5467254  -0.32423416
   0.11923523]
 ...
 [-0.01033951  0.01030621  0.08626219 ...  0.47623956 -0.1535618
  -0.10797301]
 [-0.0342384   0.06156068  0.12234074 ...  0.41991723 -0.19008005
   0.0402521 ]
 [ 0.02662977  0.10814787  0.21575229 ...  0.5291545  -0.3622686
   0.05420014]]


In [19]:
import torch
import numpy as np

# Convert logits to probabilities using softmax
probs = torch.nn.functional.softmax(torch.tensor(outputs), dim=-1)
print(probs)

tensor([[0.4054, 0.1533, 0.4413],
        [0.3172, 0.2317, 0.4512],
        [0.4009, 0.1672, 0.4319],
        [0.2663, 0.2956, 0.4381],
        [0.3079, 0.2360, 0.4561],
        [0.5145, 0.1686, 0.3169],
        [0.4710, 0.1765, 0.3525],
        [0.4835, 0.1774, 0.3391],
        [0.3389, 0.1794, 0.4817],
        [0.4929, 0.1812, 0.3259]])


In [20]:
# Get predicted class (the class with the highest probability)
predicted_classes = probs.argmax(dim=-1).numpy()  # Convert to numpy for easier handling
print(predicted_classes)

[2 2 2 2 2 0 0 0 2 0]


In [21]:
# True labels (cell types from dataset)
true_labels = np.array(dataset["cell_types"])
print(true_labels)

[2 2 2 1 1 0 0 0 2 0]


In [22]:
# Compare predicted labels with true labels to calculate accuracy
accuracy = (predicted_classes == true_labels).mean()
print(f"Accuracy: {accuracy * 100:.2f}%")

Accuracy: 80.00%
