In [5]:
from helical.models.geneformer import GeneformerConfig, GeneformerFineTuningModel
import anndata as ad

# Load the data
ann_data = ad.read_h5ad("/iridisfs/ddnb/Ahmed/AI_hackathon25/yolksac_human.h5ad")

# Get the column for fine-tuning
cell_types = list(ann_data.obs["LVL1"][:10])
label_set = set(cell_types)

# Create a GeneformerConfig object
geneformer_config = GeneformerConfig(model_name="gf-12L-95M-i4096", batch_size=10)

# Create a GeneformerFineTuningModel object
geneformer_fine_tune = GeneformerFineTuningModel(geneformer_config=geneformer_config, fine_tuning_head="classification", output_size=len(label_set))

# Process the data
dataset = geneformer_fine_tune.process_data(ann_data[:10])

# Add column to the dataset
dataset = dataset.add_column('cell_types', cell_types)

# Create a dictionary to map cell types to ids
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

def classes_to_ids(example):
    example["cell_types"] = class_id_dict[example["cell_types"]]
    return example

# Convert cell types to ids
dataset = dataset.map(classes_to_ids, num_proc=1)

# Fine-tune the model
geneformer_fine_tune.train(train_dataset=dataset, label="cell_types")

# Get logits from the fine-tuned model
outputs = geneformer_fine_tune.get_outputs(dataset)
print(outputs[:10])

# Get embeddings from the fine-tuned model
embeddings = geneformer_fine_tune.get_embeddings(dataset)
print(embeddings[:10])

INFO:helical.models.geneformer.model:Model finished initializing.
INFO:helical.models.geneformer.model:'gf-12L-95M-i4096' model is in 'eval' mode, on device 'cpu' with embedding mode 'cell'.
INFO:helical.models.geneformer.model:Processing data for Geneformer.
  adata.var["index"] = adata.var.index

INFO:pyensembl.sequence_data:Loaded sequence dictionary from /iridisfs/ddnb/Ahmed/AI_hackathon25/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.cdna.all.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /iridisfs/ddnb/Ahmed/AI_hackathon25/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.ncrna.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /iridisfs/ddnb/Ahmed/AI_hackathon25/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.pep.all.fa.gz.pickle
INFO:helical.utils.mapping:Mapped 21359 genes to Ensembl IDs from a total of 37318 genes.
INFO:helical.models.geneformer.geneformer_tokenizer:AnnData object with n_obs × n_vars = 10 

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

INFO:helical.models.geneformer.model:Successfully processed the data for Geneformer.


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

INFO:helical.models.geneformer.fine_tuning_model:Freezing the first 2 encoder layers of the Geneformer model during fine-tuning.
INFO:helical.models.geneformer.fine_tuning_model:Starting Fine-Tuning
Fine-Tuning: epoch 1/1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:24<00:00, 24.18s/it, loss=1.2]
INFO:helical.models.geneformer.fine_tuning_model:Fine-Tuning Complete. Epochs: 1
Generating Outputs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:07<00:00,  7.96s/it]
INFO:helical.models.geneformer.model:Started getting embeddings:


[[-0.05119071  0.14384666  0.11044433]
 [-0.06255999 -0.01253583  0.24082741]
 [-0.08798311 -0.13323031  0.19078472]
 [-0.08693075 -0.27726153  0.22992265]
 [ 0.23822522 -0.46375433 -0.17303208]
 [ 0.21388966 -0.04218999  0.07147676]
 [-0.03459525 -0.14606042  0.22730944]
 [ 0.31956843 -0.10422018  0.11569984]
 [ 0.04160344 -0.49710146  0.44002718]
 [-0.01202716 -0.0746004   0.11259264]]


  0%|          | 0/1 [00:00<?, ?it/s]

INFO:helical.models.geneformer.model:Finished getting embeddings.


[[ 0.02975997  0.11766281  0.19413415 ...  0.8038816  -0.3138102
   0.0783158 ]
 [ 0.01784887  0.07231951  0.10296704 ...  0.45810795 -0.18964255
  -0.01404279]
 [ 0.01978923  0.08934405  0.1727128  ...  0.5229969  -0.3132735
   0.10580335]
 ...
 [-0.05103347  0.17012133  0.13277583 ...  0.8351886  -0.2965814
  -0.09796386]
 [-0.03087549  0.04334023  0.12429308 ...  0.37981436 -0.17387764
   0.02804179]
 [ 0.02256808  0.09870598  0.20592609 ...  0.53165424 -0.35759747
   0.04842021]]
