## Imports

In [None]:
from geneformer import EmbExtractor
import os

## Path settings


In [None]:
fine_tuned_model = "/home/domino/geneformer_workflow/Geneformer/fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224"

path_to_input_data = "/home/domino/geneformer_workflow/input/data/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset"

path_to_output_directory = "/home/domino/geneformer_workflow/results/cell_embeddings"
os.makedirs(path_to_output_directory, exist_ok = True)

output_prefix = "human_dcm_hcm_nf"

## Initiate EmbExtractor and extract embedding from input data


In [None]:
embex = EmbExtractor(model_type="CellClassifier",
                     num_classes=3,
                     filter_data={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]},
                     max_ncells=1000,
                     emb_layer=0,
                     emb_label=["disease","cell_type"],
                     labels_to_plot=["disease"],
                     forward_batch_size=50,#200,
                     nproc=16)


# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset
# embs = embex.extract_embs("../fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224",
#                           "path/to/input_data/",
#                           "path/to/output_directory/",
#                           "output_prefix")
embs = embex.extract_embs(model_directory = fine_tuned_model,
                          input_data_file = path_to_input_data,
                          output_directory = path_to_output_directory,
                          output_prefix = output_prefix)

## Plot UMAP of cell embeddings

In [None]:

# note: scanpy umap necessarily saves figs to figures directory
embex.plot_embs(embs = embs, 
                plot_style = "umap",
                output_directory = path_to_output_directory,  
                output_prefix = "emb_plot")

## Plot heatmap of cell embeddings

In [None]:
embex.plot_embs(embs = embs, 
                plot_style = "heatmap",
                output_directory = path_to_output_directory,
                output_prefix = "emb_plot")