In [None]:
import anndata as ad
import numpy as np

from helical.models.geneformer import Geneformer, GeneformerFineTuningModel, GeneformerConfig

INFO:datasets:PyTorch version 2.8.0 available.
INFO:datasets:Polars version 1.25.2 available.
INFO:datasets:Duckdb version 1.3.2 available.
INFO:datasets:TensorFlow version 2.19.0 available.
INFO:datasets:JAX version 0.5.3 available.


# Geneformer (Pre-trained)

Generating cell embeddings using the pre-trained Geneformer for unperturbed data

In [None]:
# Since gf-12L-95M-i4096 is deprecated I'm using gf-12L-38M-i4096
model_config = GeneformerConfig(model_name="gf-12L-38M-i4096", batch_size=10, device="cuda:0")
geneformer = Geneformer(model_config)

INFO:helical.models.geneformer.model:Model finished initializing.
INFO:helical.models.geneformer.model:'gf-12L-38M-i4096' model is in 'eval' mode, on device 'cuda:0' with embedding mode 'cell'.


In [None]:
# Path to the ANNDATA object
anndata_path = "/content/drive/MyDrive/counts_Astro_filtered_BA4_ALS_ITGB8.h5ad"

# Loads the ANNDATA object and renames ENSID to ensembl_id
anndata = ad.read_h5ad(anndata_path)
anndata.var.rename(columns={"ENSID" : "ensembl_id"}, inplace=True)

In [None]:
# Process data
dataset = geneformer.process_data(anndata, gene_names="ensembl_id")

INFO:helical.models.geneformer.model:Processing data for Geneformer.
INFO:helical.models.geneformer.geneformer_tokenizer:AnnData object with n_obs × n_vars = 2295 × 22832
    obs: 'Sample_ID', 'Donor', 'Region', 'Sex', 'Condition', 'Group', 'C9_pos', 'CellClass', 'CellType', 'SubType', 'full_label', 'DGE_Group', 'Bakken_M1', 'data_merge_id', 'data_sample_id', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'Cellstates_LVL1', 'Cellstates_LVL2', 'Cellstates_LVL3', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'n_genes', 'split'
    var: 'Biotype', 'Chromosome', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'ensembl_id', 'mt', 'n_cells', 'biotype', 'ensembl_id_collapsed'
    uns: 'Gene_and_Scale' has no column attribute 'filter_pass'; tokenizing all cells.
INFO:helical.models.geneformer.geneformer_tokenizer:Creating dataset.
INFO:helical.models.geneformer.model:Su

In [None]:
# Generate embeddings, store in anndata object and save
embeddings, genes = geneformer.get_embeddings(dataset, output_genes=True)
anndata.obsm['X_Geneformer'] = embeddings

# Save the anndata file with embeddings
anndata.write(anndata_path)

INFO:helical.models.geneformer.model:Started getting embeddings:


  0%|          | 0/230 [00:00<?, ?it/s]

INFO:helical.models.geneformer.model:Finished getting embeddings.
... storing 'ensembl_id_collapsed' as categorical


# Geneformer (Fine-tuned)

Finetuning the Geneformer on cell types for classifying cell disease states.

In [65]:
anndata_path = "/content/drive/MyDrive/counts_combined_filtered_BA4_sALS_PN.h5ad"
anndata = ad.read_h5ad(anndata_path)
anndata.var.rename(columns={"ENSID" : "ensembl_id"}, inplace=True)

anndata_finetuning = anndata[anndata.obs['CellType'] == 'Astro']

# Adds a classification ID column to train and test anndata objects
anndata_train = anndata_finetuning[anndata_finetuning.obs['split'] == 'train'].copy()
anndata_test = anndata_finetuning[anndata_finetuning.obs['split'] == 'test'].copy()

cell_types_train = list(np.array(anndata_train.obs['Condition'].tolist()))
cell_types_test = list(np.array(anndata_test.obs['Condition'].tolist()))

label_set = set(cell_types_train) | set(cell_types_test)
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

for i in range(len(cell_types_train)):
    cell_types_train[i] = class_id_dict[cell_types_train[i]]

for i in range(len(cell_types_test)):
    cell_types_test[i] = class_id_dict[cell_types_test[i]]

Loads in the model

In [71]:
geneformer_ft_config = GeneformerConfig(device='cuda:0', batch_size=10, model_name="gf-12L-38M-i4096")
geneformer_ft = GeneformerFineTuningModel(geneformer_config=geneformer_ft_config, fine_tuning_head="classification", output_size=len(class_id_dict))

INFO:helical.models.geneformer.model:Model finished initializing.
INFO:helical.models.geneformer.model:'gf-12L-38M-i4096' model is in 'eval' mode, on device 'cuda:0' with embedding mode 'cell'.


Process and add column with condition ID to the geneformer train and test datasets

In [72]:
geneformer_train_dataset = geneformer_ft.process_data(anndata_train, gene_names='ensembl_id')
geneformer_test_dataset = geneformer_ft.process_data(anndata_test, gene_names='ensembl_id')

geneformer_train_dataset = geneformer_train_dataset.add_column("Condition_ID", cell_types_train)
geneformer_test_dataset = geneformer_test_dataset.add_column("Condition_ID", cell_types_test)

INFO:helical.models.geneformer.model:Processing data for Geneformer.
INFO:helical.models.geneformer.geneformer_tokenizer:AnnData object with n_obs × n_vars = 9022 × 22832
    obs: 'Sample_ID', 'Donor', 'Region', 'Sex', 'Condition', 'Group', 'C9_pos', 'CellClass', 'CellType', 'SubType', 'full_label', 'DGE_Group', 'Bakken_M1', 'data_merge_id', 'data_sample_id', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'Cellstates_LVL1', 'Cellstates_LVL2', 'Cellstates_LVL3', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'n_genes', 'split'
    var: 'Biotype', 'Chromosome', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'ensembl_id', 'mt', 'n_cells', 'biotype', 'ensembl_id_collapsed' has no column attribute 'filter_pass'; tokenizing all cells.
INFO:helical.models.geneformer.geneformer_tokenizer:Creating dataset.
INFO:helical.models.geneformer.model:Successfully processed the d

Fine-tune the model on specified data and save unperturbed embeddings

In [None]:
geneformer_ft.train(train_dataset=geneformer_train_dataset.shuffle(), validation_dataset=geneformer_test_dataset, label="Condition_ID", freeze_layers=0, epochs=1, optimizer_params={"lr": 1e-4}, lr_scheduler_params={"name":"linear", "num_warmup_steps":0, 'num_training_steps':1})

embeds, genes = geneformer_ft.get_embeddings(geneformer_test_dataset, output_genes=True)
anndata_test = anndata_test.copy()
anndata_test.obsm['X_Geneformer_FT'] = embeds

anndata_test.write("/content/drive/MyDrive/counts_Astro_EmbeddingsFT_filtered_BA4_sALS.h5ad")

INFO:helical.models.geneformer.fine_tuning_model:Starting Fine-Tuning
Fine-Tuning: epoch 1/1: 100%|██████████| 903/903 [12:24<00:00,  1.21it/s, loss=0.296]
Fine-Tuning Validation: 100%|██████████| 416/416 [01:57<00:00,  3.53it/s, val_loss=0.585]
INFO:helical.models.geneformer.fine_tuning_model:Fine-Tuning Complete. Epochs: 1


Generate embeddings for anndata file with perturbation effects and save


In [86]:
# Path to the ANNDATA object
anndata_path = "/content/drive/MyDrive/counts_Astro_filtered_BA4_ALS_CTNNA2.h5ad"

# Loads the ANNDATA object and renames ENSID to ensembl_id
anndata = ad.read_h5ad(anndata_path)
anndata.var.rename(columns={"ENSID" : "ensembl_id"}, inplace=True)

In [87]:
# Process data
dataset_FT = geneformer_ft.process_data(anndata, gene_names='ensembl_id')

INFO:helical.models.geneformer.model:Processing data for Geneformer.
INFO:helical.models.geneformer.geneformer_tokenizer:AnnData object with n_obs × n_vars = 2295 × 22832
    obs: 'Sample_ID', 'Donor', 'Region', 'Sex', 'Condition', 'Group', 'C9_pos', 'CellClass', 'CellType', 'SubType', 'full_label', 'DGE_Group', 'Bakken_M1', 'data_merge_id', 'data_sample_id', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'Cellstates_LVL1', 'Cellstates_LVL2', 'Cellstates_LVL3', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'n_genes', 'split'
    var: 'Biotype', 'Chromosome', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'ensembl_id', 'mt', 'n_cells', 'biotype', 'ensembl_id_collapsed'
    uns: 'Genes_and_Scale' has no column attribute 'filter_pass'; tokenizing all cells.
INFO:helical.models.geneformer.geneformer_tokenizer:Creating dataset.
INFO:helical.models.geneformer.model:S

In [88]:
# Generate embeddings, store in new anndata object and save
embeddings, genes = geneformer_ft.get_embeddings(dataset_FT, output_genes=True)
anndata.obsm['X_Geneformer_FT'] = embeddings

anndata.write(anndata_path)

INFO:helical.models.geneformer.model:Started getting embeddings:


  0%|          | 0/230 [00:00<?, ?it/s]

INFO:helical.models.geneformer.model:Finished getting embeddings.
... storing 'ensembl_id_collapsed' as categorical
