# Geneformer

In [None]:
import os
import logging
import warnings
import sys

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

from sc_foundation_evals import geneformer_forward as gf
from sc_foundation_evals import data, cell_embeddings, model_output
from sc_foundation_evals.helpers.custom_logging import log
log.setLevel(logging.INFO)

In [None]:
geneformer_data = "model path"
# path to the pre-trained model, can work with the huggingface model hub
# i.e. ctheodoris/Geneformer
model_dir = os.path.join(geneformer_data)
# path to dictionaries in geneformer repo
dict_dir = "Pretrain_data/"

# batch_size depends on available GPU memory
batch_size = 24
# output_dir is the path to which the results should be saved
output_dir = "zero_shot_results/"
# path to where we will store the embeddings and other evaluation outputs
model_out = os.path.join(output_dir, "model_outputs")
# if you can use multithreading specify num_workers, -1 means use all available
num_workers = -1

In [None]:
# specify the path to anndata object
in_dataset_path = "Zero_shot_batch_data/pbmc.h5ad"
# dataset_name is inferred from in_dataset_path
dataset_name = os.path.basename(in_dataset_path).split(".")[0]
# specify the path for the output of the pre-processing
preprocessed_path = f"zero_shot_preprocess/{dataset_name}/"
# create the preprocessed path if it does not exist
os.makedirs(preprocessed_path, exist_ok=True)
# in which column in adata.obs are gene names stored? if they are in index, the index will be copied to a column with this name
gene_col = "gene_symbols"
# batch column found in adata.obs
batch_col = "batch"
# where are labels stored in adata.obs? 
label_col = "celltype" #"str_labels"
# where the raw counts are stored?
layer_key = "counts" #"X" 

In [None]:
geneform = gf.Geneformer_instance(save_dir = output_dir, 
                                  saved_model_path = model_dir,
                                  explicit_save_dir = True,
                                  num_workers = num_workers)

geneform.load_pretrained_model()
geneform.load_vocab(dict_dir)
# input_data = data.InputData(adata_dataset_path = in_dataset_path)

## Create dataset

In [5]:
# input_data.preprocess_data(gene_col = gene_col,
#                            model_type = "geneformer",
#                            save_ext = "loom",
#                            gene_name_id_dict = geneform.gene_name_id,
#                            preprocessed_path = preprocessed_path)

# geneform.tokenize_data(adata_path = os.path.join(preprocessed_path, 
#                                                  f"{dataset_name}.loom"),
#                        dataset_path = preprocessed_path,
#                        cell_type_col = label_col)

## Load dataset

In [6]:
geneform.load_tokenized_dataset(os.path.join(preprocessed_path, f"{dataset_name}.dataset"))
input_data = data.InputData(adata_dataset_path = os.path.join(preprocessed_path, f"{dataset_name}.loom"))

[32mINFO    [0m | 2025-07-17 11:57:03 | [32mLoading data from /ibex/user/chenj0i/Geneformer/zero_shot_preprocess/pbmc/pbmc.loom[0m


## Embeddings extraction

In [7]:
geneform.extract_embeddings(data = input_data,
                            batch_size = batch_size, 
                            layer = -2
                            # layer = -1
                            # layer = 0
                            )

Geneformer (extracting embeddings):   0%|          | 0/500 [00:00<?, ?it/s]

In [None]:
input_data.adata

AnnData object with n_obs × n_vars = 11990 × 3226
    obs: 'adata_order', 'batch', 'celltype', 'labels', 'n_counts', 'n_genes', 'n_genes_by_counts', 'obs_names', 'str_labels', 'total_counts'
    var: 'ensembl_id', 'gene_symbols', 'has_ensembl_match', 'mean_counts', 'n_cells', 'n_cells_by_counts', 'n_counts', 'n_counts-0', 'n_counts-1', 'pct_dropout_by_counts', 'total_counts', 'var_names'
    obsm: 'geneformer'

In [None]:
from typing import Dict, Optional
import numpy as np
import scanpy as sc
import scib
from anndata import AnnData
from sklearn.metrics import silhouette_score
from tqdm import tqdm
import pandas as pd
import logging

log = logging.getLogger(__name__)


def eval_clustering_metrics(
    adata: AnnData,
    batch_key: Optional[str] = "str_batch",
    label_key: str = "cell_type",
    embedding_key: str = "X",  # "X" for raw, or embedding key in .obsm
    resolutions: Optional[list] = None,
    use_progress_bar: bool = True,
    verbose: bool = False,
    subsample_frac: Optional[float] = 0.25,
) -> Dict[str, float]:
    """Evaluate biological and batch mixing metrics on an embedding or raw expression."""
    
    results_dict = {}

    if subsample_frac is not None and 0 < subsample_frac < 1:
        adata = adata.copy()
        sc.pp.subsample(adata, fraction=subsample_frac, copy=False)
        if verbose:
            log.info(f"Subsampled adata to {subsample_frac * 100:.1f}% of original cells.")

    # Determine whether to use .X or .obsm[embedding_key]
    if embedding_key == "X":
        use_rep = "X"
        adata.obsm["X"] = adata.X
    elif embedding_key in adata.obsm:
        use_rep = embedding_key
    else:
        raise ValueError(f"embedding_key '{embedding_key}' not found in adata.obsm or is not 'X'")

    # Clear stale neighbors
    if "neighbors" in adata.uns:
        if verbose:
            log.warning(f"Removing stale neighbors computed from other representations.")
        adata.uns.pop("neighbors", None)

    sc.pp.neighbors(adata, use_rep=use_rep)

    # Run Louvain across multiple resolutions
    if resolutions is None:
        resolutions = [2 * i / 20 for i in range(1, 21)]  # Default: 20 steps from 0.1 to 2.0
        # resolutions = [4 * i / 40 for i in range(1, 41)]  # Default: 20 steps from 0.1 to 2.0

    best_nmi = -1
    best_res = None
    best_clustering = None

    if verbose:
        log.info(f"Searching for optimal clustering resolution on {use_rep}...")

    for res in tqdm(resolutions, disable=not use_progress_bar, desc="Louvain clustering"):
        sc.tl.louvain(adata, resolution=res, key_added="temp_cluster")
        nmi = scib.metrics.nmi(adata, "temp_cluster", label_key)
        if nmi > best_nmi:
            best_nmi = nmi
            best_res = res
            best_clustering = adata.obs["temp_cluster"].copy()
        del adata.obs["temp_cluster"]

    if verbose:
        log.info(f"Best resolution: {best_res:.2f} with NMI = {best_nmi:.4f}")

    adata.obs["cluster"] = best_clustering

    # Biological conservation metrics
    results_dict["NMI_cluster/label"] = scib.metrics.nmi(adata, "cluster", label_key, "arithmetic")
    results_dict["ARI_cluster/label"] = scib.metrics.ari(adata, "cluster", label_key)
    results_dict["ASW_label"] = scib.metrics.silhouette(adata, label_key, use_rep, "euclidean")

    # Batch effect metrics (if batch_key valid)
    if batch_key is not None and batch_key in adata.obs and adata.obs[batch_key].nunique() > 1:
        adata.obs[label_key] = adata.obs[label_key].astype("category")
        results_dict["graph_conn"] = scib.metrics.graph_connectivity(adata, label_key)
        results_dict["ASW_batch"] = scib.metrics.silhouette(adata, batch_key, use_rep, "euclidean")
        results_dict["ASW_label/batch"] = scib.metrics.silhouette_batch(
            adata, batch_key, label_key, embed=use_rep, metric="euclidean", return_all=False
        )
        results_dict["PCR_batch"] = scib.metrics.pcr(
            adata, covariate=batch_key, embed=use_rep, recompute_pca=True, n_comps=50, verbose=False
        )
        results_dict["Average_Batch_Score"] = (
            results_dict["ASW_batch"] + results_dict["PCR_batch"]
        ) / 2
    else:
        if verbose:
            log.info("Skipping batch metrics — only one batch present or invalid batch_key.")
    
    results_dict["avg_bio"] = np.mean([
        results_dict["NMI_cluster/label"],
        results_dict["ARI_cluster/label"],
        results_dict["ASW_label"]
    ])

    # Filter NaNs
    results_dict = {k: v for k, v in results_dict.items() if not np.isnan(v)}

    return results_dict


# Embeddings metrics

In [None]:
results_dict = eval_clustering_metrics(adata=input_data.adata, 
                                        batch_key="batch",
                                        label_key="celltype",
                                        embedding_key="geneformer",  # or "X_scGPT", etc.
                                        verbose=True)
results_dict

Louvain clustering: 100%|██████████| 20/20 [00:02<00:00,  7.68it/s]


mean silhouette per group:                    silhouette_score
group                              
B cells                    0.990590
CD14+ Monocytes            0.979706
CD4 T cells                0.987594
CD8 T cells                0.991305
Dendritic Cells            0.958009
FCGR3A+ Monocytes          0.990665
Megakaryocytes             0.857295
NK cells                   0.977292
Other                      0.933587


{'NMI_cluster/label': 0.6061048617613637,
 'ARI_cluster/label': 0.503784927975462,
 'ASW_label': 0.510432125069201,
 'graph_conn': 0.8852579724762832,
 'ASW_batch': 0.5012279110960662,
 'ASW_label/batch': 0.9628935503212096,
 'PCR_batch': 0.0007131078007747846,
 'Average_Batch_Score': 0.25097050944842053,
 'avg_bio': 0.5401073049353422}

In [None]:
from scGraph import scGraph

scg = scGraph(adata=input_data.adata, batch_key="batch", label_key="celltype", 
                trim_rate=0.05, thres_batch=1, thres_celltype=1)
scg.preprocess()
scg.compute()
results = scg.evaluate()
print(results)

  0%|          | 0/2 [00:00<?, ?it/s]

  Rank-Geneformer
0        0.805556


# OOD Dataset raw metrics

In [None]:
# import scanpy as sc 

# cdata = sc.read_h5ad("zero_shot_data/ood_celltype_data1_expand.h5ad")
# adata = cdata.copy()
# sc.pp.subsample(adata, fraction=0.05, copy=False)

In [19]:
# use_rep = "X"
# adata.obsm["X"] = adata.X
# adata.uns.pop("neighbors", None)

# sc.pp.neighbors(adata, use_rep=use_rep)
# resolutions = [2 * i / 20 for i in range(1, 21)]  # Default: 20 steps from 0.1 to 2.0
# best_nmi = -1
# best_res = None
# best_clustering = None

In [20]:
# label_key = "celltype"
# results_dict = {}
# for res in tqdm(resolutions, disable=not True, desc="Louvain clustering"):
#     sc.tl.louvain(adata, resolution=res, key_added="temp_cluster")
#     nmi = scib.metrics.nmi(adata, "temp_cluster", label_key)
#     if nmi > best_nmi:
#         best_nmi = nmi
#         best_res = res
#         best_clustering = adata.obs["temp_cluster"].copy()
#     del adata.obs["temp_cluster"]

# adata.obs["cluster"] = best_clustering
# # Biological conservation metrics
# results_dict["NMI_cluster/label"] = scib.metrics.nmi(adata, "cluster", label_key, "arithmetic")
# results_dict["ARI_cluster/label"] = scib.metrics.ari(adata, "cluster", label_key)
# results_dict["ASW_label"] = scib.metrics.silhouette(adata, label_key, use_rep, "euclidean")

# # Batch effect metrics (if batch_key valid)
# batch_key = "batch"
# if batch_key is not None and batch_key in adata.obs and adata.obs[batch_key].nunique() > 1:
#     adata.obs[label_key] = adata.obs[label_key].astype("category")
#     results_dict["graph_conn"] = scib.metrics.graph_connectivity(adata, label_key)
#     results_dict["ASW_batch"] = (1 - scib.metrics.silhouette(adata, batch_key, use_rep, "euclidean"))
#     results_dict["ASW_label/batch"] = scib.metrics.silhouette_batch(
#         adata, batch_key, label_key, embed=use_rep, metric="euclidean", return_all=False
#     )
#     results_dict["PCR_batch"] = scib.metrics.pcr(
#         adata, covariate=batch_key, embed=use_rep, recompute_pca=True, n_comps=50, verbose=False
#     )
#     results_dict["Average_Batch_Score"] = (
#         results_dict["ASW_batch"] + results_dict["PCR_batch"]
#     ) / 2
# else:
#     if verbose:
#         log.info("Skipping batch metrics — only one batch present or invalid batch_key.")

# results_dict["avg_bio"] = np.mean([
#     results_dict["NMI_cluster/label"],
#     results_dict["ARI_cluster/label"],
#     results_dict["ASW_label"]
# ])

# # Filter NaNs
# results_dict = {k: v for k, v in results_dict.items() if not np.isnan(v)}

# results_dict

Louvain clustering: 100%|██████████| 20/20 [00:22<00:00,  1.14s/it]


mean silhouette per group:             silhouette_score
group                       
CL:0000077          0.951371
CL:0000091          0.905183
CL:0000099          0.856871
CL:0000164          0.913159
CL:0000189          0.934462
CL:0000312          0.933951
CL:0000453          0.966310
CL:0000575          0.779139
CL:0000750          0.991985
CL:0000767          0.977141
CL:0000771          0.893556
CL:0000776          0.932994
CL:0000810          0.913306
CL:0000817          0.931130
CL:0000837          0.967683
CL:0000843          0.948814
CL:0000861          0.841148
CL:0000915          0.945803
CL:0000957          0.970545
CL:0001029          0.950351
CL:0001057          0.946863
CL:0001074          0.936960
CL:0002028          0.935891
CL:0002045          0.950375
CL:0002064          0.926107
CL:0002075          0.759782
CL:0002201          0.973459
CL:0002393          0.966944
CL:0002518          0.911847
CL:0005012          0.961174
CL:0009009          0.957441
CL:0009010      

{'NMI_cluster/label': 0.7833172618112929,
 'ARI_cluster/label': 0.5728303202672791,
 'ASW_label': 0.4911566338564166,
 'graph_conn': 0.7769019941103583,
 'ASW_batch': 0.5006964505924973,
 'ASW_label/batch': 0.9306380360099057,
 'PCR_batch': 0.757978241899424,
 'Average_Batch_Score': 0.6293373462459606,
 'avg_bio': 0.6157680719783295}

# Raw data metrics

In [None]:
results_dict_raw = eval_clustering_metrics(adata=input_data.adata, 
                                        batch_key="batch",
                                        label_key="celltype",
                                        embedding_key="X",  # or "X_scGPT", etc.
                                        verbose=True)
results_dict_raw

Louvain clustering: 100%|██████████| 20/20 [00:02<00:00,  6.97it/s]


mean silhouette per group:                    silhouette_score
group                              
B cells                    0.971033
CD14+ Monocytes            0.942456
CD4 T cells                0.988742
CD8 T cells                0.987412
Dendritic Cells            0.938792
FCGR3A+ Monocytes          0.950513
Megakaryocytes             0.752894
NK cells                   0.890206
Other                      0.914109


{'NMI_cluster/label': 0.6505152890434263,
 'ARI_cluster/label': 0.5759899223104351,
 'ASW_label': 0.5245759263634682,
 'graph_conn': 0.8891452955038966,
 'ASW_batch': 0.4964794989209622,
 'ASW_label/batch': 0.9262396008669715,
 'PCR_batch': 0.0007824623021499673,
 'Average_Batch_Score': 0.24863098061155608,
 'avg_bio': 0.5836937125724432}

In [None]:
from scGraph import scGraph

scg = scGraph(adata=input_data.adata, batch_key="batch", label_key="celltype", 
                trim_rate=0.05, thres_batch=1, thres_celltype=1, embedding_key="X")
scg.preprocess()
scg.compute()
results = scg.evaluate()
print(results)

# HVG & scVI

## HVG

In [None]:
import os
import logging

import numpy as np
import pandas as pd
import scanpy as sc
from scipy import sparse
import scvi

import sys
sys.path.append("zero_shot_batch_effect")
from sc_foundation_evals import utils
from sc_foundation_evals.helpers.custom_logging import log

log.setLevel(logging.INFO)

import warnings
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings("ignore")

In [None]:
# specify the path to anndata object
adata_path = in_dataset_path
# dataset_name is inferred from in_dataset_path
dataset_name = os.path.basename(adata_path).split(".")[0]

# batch column found in adata.obs
batch_col = "batch"
# where are labels stored in adata.obs? 
label_col = "celltype"
# where the raw counts are stored?
layer_key = "counts"

adata = sc.read(adata_path)
adata

AnnData object with n_obs × n_vars = 11990 × 3346
    obs: 'n_counts', 'batch', 'labels', 'str_labels', 'celltype'
    var: 'gene_symbols', 'n_counts-0', 'n_counts-1', 'n_counts'
    uns: 'cell_types'
    obsm: 'design', 'normalized_qc', 'qc_pc', 'raw_qc'

In [None]:
if layer_key == "X":
    adata.layers["counts"] = adata.X
elif layer_key != "counts":
    adata.layers["counts"] = adata.layers[layer_key]

In [None]:
sc.pp.filter_cells(adata, min_genes=10)
sc.pp.filter_genes(adata, min_cells=10)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

In [None]:
sc.pp.highly_variable_genes(adata, flavor='seurat', subset=False, n_top_genes=2000)

# hvg_mask = adata.var["highly_variable"].values

adata.obsm["X_genes"] = adata.X[:, adata.var.highly_variable.values]

# check if adata.obsm["X_genes"] is sparse and if so, convert to dense
if sparse.issparse(adata.obsm["X_genes"]):
    adata.obsm["X_genes"] = np.asarray(adata.obsm["X_genes"].todense())

In [None]:
results_dict_hvg = eval_clustering_metrics(adata=adata, 
                                        batch_key=batch_col,
                                        label_key=label_col,
                                        embedding_key="X_genes",  # or "X_scGPT", etc.
                                        verbose=True)
results_dict_hvg

[32mINFO    [0m | 2025-06-22 14:32:11 | [32mSubsampled adata to 25.0% of original cells.[0m
[32mINFO    [0m | 2025-06-22 14:32:12 | [32mSearching for optimal clustering resolution on X_genes...[0m
Louvain clustering: 100%|██████████| 20/20 [00:02<00:00,  8.92it/s]
[32mINFO    [0m | 2025-06-22 14:32:14 | [32mBest resolution: 0.70 with NMI = 0.6944[0m


mean silhouette per group:                    silhouette_score
group                              
B cells                    0.990475
CD14+ Monocytes            0.994091
CD4 T cells                0.994429
CD8 T cells                0.996067
Dendritic Cells            0.990181
FCGR3A+ Monocytes          0.997131
Megakaryocytes             0.973109
NK cells                   0.997118
Other                      0.982645


{'NMI_cluster/label': 0.6944194464119003,
 'ARI_cluster/label': 0.6730602977338459,
 'ASW_label': 0.513224795460701,
 'graph_conn': 0.8757625892165339,
 'ASW_batch': 0.4997675784834428,
 'ASW_label/batch': 0.9905828886755944,
 'PCR_batch': 0.0008402505807411988,
 'Average_Batch_Score': 0.250303914532092,
 'avg_bio': 0.626901513202149}

In [None]:
from scGraph import scGraph

scg = scGraph(adata=input_data.adata, batch_key="batch", label_key="celltype", 
                trim_rate=0.05, thres_batch=1, thres_celltype=1, embedding_key="X_genes")
scg.preprocess()
scg.compute()
results = scg.evaluate()
print(results)

## scVI

In [None]:
if "counts" not in adata.layers.keys():
    adata.layers["counts"] = adata.X.copy()

In [None]:
adata

AnnData object with n_obs × n_vars = 11990 × 3345
    obs: 'n_counts', 'batch', 'labels', 'str_labels', 'celltype', 'n_genes'
    var: 'gene_symbols', 'n_counts-0', 'n_counts-1', 'n_counts', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'cell_types', 'log1p', 'hvg'
    obsm: 'design', 'normalized_qc', 'qc_pc', 'raw_qc', 'X_genes'
    layers: 'counts'

In [None]:
scvi.model.SCVI.setup_anndata(adata, layer="counts", batch_key=batch_col)
model = scvi.model.SCVI(adata, n_layers=2, n_latent=30, gene_likelihood="nb")
model.train()
adata.obsm["X_scVI"] = model.get_latent_representation()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Training:   0%|          | 0/400 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=400` reached.


In [None]:
adata.obsm["X_scVI"] = model.get_latent_representation()

In [None]:
results_dict_scvi = eval_clustering_metrics(adata=adata, 
                                        batch_key=batch_col,
                                        label_key=label_col,
                                        embedding_key="X_scVI",  # or "X_scGPT", etc.
                                        verbose=True)
results_dict_scvi

[32mINFO    [0m | 2025-06-22 14:36:48 | [32mSubsampled adata to 25.0% of original cells.[0m
[32mINFO    [0m | 2025-06-22 14:36:48 | [32mSearching for optimal clustering resolution on X_scVI...[0m
Louvain clustering: 100%|██████████| 20/20 [00:02<00:00,  7.97it/s]
[32mINFO    [0m | 2025-06-22 14:36:51 | [32mBest resolution: 1.20 with NMI = 0.7544[0m


mean silhouette per group:                    silhouette_score
group                              
B cells                    0.991501
CD14+ Monocytes            0.976939
CD4 T cells                0.987053
CD8 T cells                0.980696
Dendritic Cells            0.931121
FCGR3A+ Monocytes          0.974440
Megakaryocytes             0.910766
NK cells                   0.971491
Other                      0.899360


{'NMI_cluster/label': 0.7543923134993394,
 'ARI_cluster/label': 0.6471385261878778,
 'ASW_label': 0.482499361038208,
 'graph_conn': 0.9461266173017836,
 'ASW_batch': 0.5024425515439361,
 'ASW_label/batch': 0.9581518028443176,
 'PCR_batch': 0.00044665558752302455,
 'Average_Batch_Score': 0.25144460356572956,
 'avg_bio': 0.628010066908475}

In [None]:
from scGraph import scGraph

scg = scGraph(adata=input_data.adata, batch_key="batch", label_key="celltype", 
                trim_rate=0.05, thres_batch=1, thres_celltype=1, embedding_key="X_scVI")
scg.preprocess()
scg.compute()
results = scg.evaluate()
print(results)