# What is Fine-Tuning?


Fine-tuning is the process of taking a pretrained model and training it further on a specific task using a smaller, labeled dataset. Instead of training a model from scratch (which can take weeks and massive amounts of data), fine-tuning starts from a model that already understands the structure and meaning of language.

In [None]:
from helical.models.geneformer import GeneformerConfig, GeneformerFineTuningModel
import anndata as ad


* Upload single cell data in annadata format
* Define the labels column

In [None]:
# Load the data
ann_data = ad.read_h5ad("../yolksac_human.h5ad")
print(ann_data)
# Get the column for fine-tuning
cell_types = list(ann_data.obs["LVL1"][:10])
label_set = set(cell_types)
label_set

* Define Geneformer configuration parameters


- fine_tuning_head : Literal["classification", "regression"] | HelicalBaseFineTuningHead

- output_size : Optional[int]
    The output size of the fine-tuning model. This is required if the `fine_tuning_head` is a string specified task. For a classification task this is number of unique classes.

In [None]:
# Create a GeneformerConfig object

geneformer_config = GeneformerConfig(model_name="gf-12L-95M-i4096", batch_size=10, nproc = 32, accelerator = True)

Define the parameters required for Geneformer finetuning
- fine_tuning_head : Literal["classification", "regression"] | HelicalBaseFineTuningHead

- output_size : Optional[int]
    The output size of the fine-tuning model. This is required if the `fine_tuning_head` is a string specified task. For a classification task this is number of unique classes.

In [None]:
# Create a GeneformerFineTuningModel object
geneformer_fine_tune = GeneformerFineTuningModel(geneformer_config=geneformer_config, fine_tuning_head="classification", output_size=len(label_set))

* Process the single cell data (map ensemble_ids and tokenization)

In [None]:
# Process the data
dataset = geneformer_fine_tune.process_data(ann_data[:10])

print(dataset)
# Add column to the dataset
dataset2 = dataset.add_column('cell_types', cell_types)

print(dataset2)

* convert text labels (cell type names) into numeric class IDs, which is a common preprocessing step when preparing data for machine learning models.

In [None]:
# Create a dictionary to map cell types to ids
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
print(class_id_dict)

def classes_to_ids(example):
    example["cell_types"] = class_id_dict[example["cell_types"]]
    return example

# Convert cell types to ids
dataset2 = dataset2.map(classes_to_ids, num_proc=1)


# Check the unique values in the "cell_types" column to find the number of classes
num_classes = len(set(dataset2["cell_types"]))  # Or use np.unique() if you're working with NumPy arrays

print(f"Number of classes in the dataset: {num_classes}")


In [None]:
# Fine-tune the model
geneformer_fine_tune.train(train_dataset=dataset2, label="cell_types", epochs=2)


# Saving only the model weights (state_dict)
torch.save(geneformer_fine_tune.state_dict(), './fine_tuned_geneformer.pth')

# Saving the model configuration
import pickle
with open('./fine_tuned_geneformer_config.pkl', 'wb') as f:
    pickle.dump(geneformer_config, f)

# Or Saving the full model (architecture + weights)
torch.save(geneformer_fine_tune, './full_geneformer_model.pth')

# Loading the full model (architecture + weights)
loaded_model = torch.load('./full_geneformer_model.pth')


In [None]:
# Save the fine-tuned model to a specified directory

# Save the model's state_dict (weights)
torch.save(geneformer_fine_tune.state_dict(), './fine_tuned_geneformer.pth')

import pickle

# Save the configuration (such as model architecture, hyperparameters)
with open('./fine_tuned_geneformer_config.pkl', 'wb') as f:
    pickle.dump(geneformer_config, f)

* Outputs: These are typically the logits from the final classification layer of the model.

shape = number of cells x number of classes

Embeddings: These are dense vector representations of the input
shape = number of cells x number of dimensions


In [None]:
# Get logits from the fine-tuned model
outputs = geneformer_fine_tune.get_outputs(dataset)
#print(outputs[:100])

# Get embeddings from the fine-tuned model
embeddings = geneformer_fine_tune.get_embeddings(dataset)
#print(embeddings[:100])

In [None]:
# Print dimensions of model outputs (logits)
print("Outputs shape:", outputs.shape)

# Print dimensions of embeddings
print("Embeddings shape:", embeddings.shape)

In [None]:
# Load the model configuration
with open('./fine_tuned_geneformer_config.pkl', 'rb') as f:
    loaded_config = pickle.load(f)

# Recreate the model using the configuration
geneformer_fine_tune = GeneformerFineTuningModel(geneformer_config=loaded_config, fine_tuning_head="classification", output_size=len(label_set))

# Load the saved weights
geneformer_fine_tune.load_state_dict(torch.load('./fine_tuned_geneformer.pth'))

In [None]:
# Load the data
ann_data = ad.read_h5ad("../yolksac_human.h5ad")

# Get the column for fine-tuning
cell_types = list(ann_data.obs["LVL1"][:10])
label_set = set(cell_types)
label_set
# Process the data
dataset = geneformer_fine_tune.process_data(ann_data[:10])

dataset = dataset.add_column('cell_types', cell_types)

# Create a dictionary to map cell types to ids
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

def classes_to_ids(example):
    example["cell_types"] = class_id_dict[example["cell_types"]]
    return example

# Convert cell types to ids
dataset = dataset.map(classes_to_ids, num_proc=1)

In [None]:
# Get logits from the fine-tuned model
outputs = geneformer_fine_tune.get_outputs(dataset)
print(outputs[:10])

# Get embeddings from the fine-tuned model
embeddings = geneformer_fine_tune.get_embeddings(dataset)
print(embeddings[:10])

In [None]:
import torch
import numpy as np

# Convert logits to probabilities using softmax
probs = torch.nn.functional.softmax(torch.tensor(outputs), dim=-1)
print(probs)

In [None]:
# Get predicted class (the class with the highest probability)
predicted_classes = probs.argmax(dim=-1).numpy()  # Convert to numpy for easier handling
print(predicted_classes)

In [None]:
# True labels (cell types from dataset)
true_labels = np.array(dataset["cell_types"])
print(true_labels)

In [None]:
# Compare predicted labels with true labels to calculate accuracy
accuracy = (predicted_classes == true_labels).mean()
print(f"Accuracy: {accuracy * 100:.2f}%")