# Evalutating Geneformer in zero-shot setting

In [None]:
import logging
import warnings
from pathlib import Path

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

from sc_foundation_evals import geneformer_forward as gf
from sc_foundation_evals import data, cell_embeddings, model_output
from sc_foundation_evals.helpers.custom_logging import log

log.setLevel(logging.INFO)

## Setting up variables

Define some variables, that we will rely on later, starting with paths and run configs.

In [None]:
# parameters for papermill
model_name = ""

In [None]:
base_dir = Path("/workspace")
geneformer_data = base_dir / "data/weights/Geneformer"
dict_dir = geneformer_data / "dicts"

In [None]:
model_dir = Path("/workspace/models/") / model_name

In [None]:
# batch_size depends on available GPU memory
batch_size = 24
# output_dir is the path to which the results should be saved
output_dir = base_dir / "output/geneformer" / model_name
# path to where we will store the embeddings and other evaluation outputs
model_out = output_dir / "model_outputs"
# if you can use multithreading specify num_workers, -1 means use all available
num_workers = -1

Data paths and configs.

I will be using the Pancreas dataset as an example, as described in the scGPT_zer-shot notebook.

In [None]:
dataset_dir = base_dir / "data/datasets"
# specify the path to anndata object
in_dataset_path = dataset_dir / "pancreas_scib.h5ad"
# specify the path for the output of the pre-processing
preprocessed_path = dataset_dir / "geneformer" / in_dataset_path.stem
# create the preprocessed path if it does not exist
preprocessed_path.mkdir(parents=True, exist_ok=True)
# in which column in adata.obs are gene names stored? if they are in index, the index will be copied to a column with this name
gene_col = "gene_symbols"
# batch column found in adata.obs
batch_col = "batch"
# where are labels stored in adata.obs?
label_col = "celltype"  # "str_labels"
# where the raw counts are stored?
layer_key = "counts"  # "X"

## Loading model and data

In [None]:
geneform = gf.Geneformer_instance(
    save_dir=output_dir,
    saved_model_path=model_dir,
    explicit_save_dir=True,
    num_workers=num_workers,
)

In [None]:
geneform.load_pretrained_model()

Load them vocabulary and gene to Ensembl ID matching.

In [None]:
geneform.load_vocab(dict_dir)

In [None]:
input_file_path = preprocessed_path / in_dataset_path.with_suffix(".loom").name
dataset_path = preprocessed_path / in_dataset_path.with_suffix(".dataset").name

if input_file_path.exists():
    log.info(
        f"Loading preprocessed input data from {input_file_path} and skipping preprocessing"
    )
    input_data = data.InputData(adata_dataset_path=input_file_path)
    log.info(f"Loading complete")
else:
    log.info(f"Preprocessing input data from {in_dataset_path}")
    input_data = data.InputData(adata_dataset_path=in_dataset_path)
    input_data.preprocess_data(
        gene_col=gene_col,
        model_type="geneformer",
        save_ext="loom",
        gene_name_id_dict=geneform.gene_name_id,
        preprocessed_path=preprocessed_path,
    )
    log.info(f"Preprocessing complete")


if dataset_path.exists():
    log.info(f"Loading preprocessed dataset from {dataset_path} and skipping tokenization")
    geneform.load_tokenized_dataset(
        dataset_path
    )
    log.info(f"Loading complete")
else:
    log.info(f"Tokenizing input data")
    geneform.tokenize_data(
        adata_path=input_file_path,
        dataset_path=dataset_path,
        cell_type_col=label_col,
    )
    log.info(f"Tokenization complete")

If the data was already tokenized, we can just load it.

## Evaluating model outputs

First, we will perform forward pass on the model and extract embeddings. We're interested with second to last layer, as per the instructions in the codebase of Geneformer [here](https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/emb_extractor.py#L285). Using the argument `layer` we can refer to layers according to python logic (i.e. 0 is the embedding layer, 1 - first layer, 2 is the second layer, etc. and -1 is the last layer, etc.).

*Note:* If you get a CUDA out of memory error, you can try reducing the batch size. As a rule of thumb, try batch sizes of multiples of 8, to avoid potential issues with how approximations are handled in CUDA.

In [None]:
geneform.extract_embeddings(data=input_data, batch_size=batch_size, layer=-2)

In [None]:
eval_pred = model_output.GeneExprPredEval(geneform, output_dir=model_out)

In [None]:
eval_pred.evaluate(n_cells=500, save_rankings=True)

In [None]:
eval_pred.visualize(n_cells=100, cmap="mako_r")

# Evaluate the cell embeddings

First, creating cell embeddings evaluation object.

In [None]:
eval_ce = cell_embeddings.CellEmbeddingsEval(
    geneform,
    data=input_data,
    output_dir=model_out,
    label_key=label_col,
    batch_key=batch_col,
)

Then, evaluating the embeddings. Here, for speed we are subsetting the data to 1000 cells.

In [None]:
# with n_cells you can specify how much to subset the obs for
eval_ce.evaluate(n_cells=1000, embedding_key="geneformer")

In [None]:
# with n_cells you can specify how much to subset the obs for
eval_ce.evaluate(n_cells=1000, embedding_key="geneformer")

In [None]:
eval_ce.visualize(embedding_key="geneformer")