In [2]:
from typing import Dict, Optional
import numpy as np
import scanpy as sc
import scib
from anndata import AnnData
from sklearn.metrics import silhouette_score
from tqdm import tqdm
import pandas as pd
import logging

log = logging.getLogger(__name__)


def eval_clustering_metrics(
    adata: AnnData,
    batch_key: Optional[str] = "str_batch",
    label_key: str = "cell_type",
    embedding_key: str = "X",  # "X" for raw, or embedding key in .obsm
    resolutions: Optional[list] = None,
    use_progress_bar: bool = True,
    verbose: bool = False,
) -> Dict[str, float]:
    """Evaluate biological and batch mixing metrics on an embedding or raw expression."""
    
    results_dict = {}

    # Determine whether to use .X or .obsm[embedding_key]
    if embedding_key == "X":
        use_rep = "X"
        adata.obsm["X"] = adata.X
    elif embedding_key in adata.obsm:
        use_rep = embedding_key
    else:
        raise ValueError(f"embedding_key '{embedding_key}' not found in adata.obsm or is not 'X'")

    # Clear stale neighbors
    if "neighbors" in adata.uns:
        if verbose:
            log.warning(f"Removing stale neighbors computed from other representations.")
        adata.uns.pop("neighbors", None)

    sc.pp.neighbors(adata, use_rep=use_rep)

    # Run Louvain across multiple resolutions
    if resolutions is None:
        resolutions = [2 * i / 20 for i in range(1, 21)]  # Default: 20 steps from 0.1 to 2.0

    best_nmi = -1
    best_res = None
    best_clustering = None

    if verbose:
        log.info(f"Searching for optimal clustering resolution on {use_rep}...")

    for res in tqdm(resolutions, disable=not use_progress_bar, desc="Louvain clustering"):
        sc.tl.louvain(adata, resolution=res, key_added="temp_cluster")
        nmi = scib.metrics.nmi(adata, "temp_cluster", label_key)
        if nmi > best_nmi:
            best_nmi = nmi
            best_res = res
            best_clustering = adata.obs["temp_cluster"].copy()
        del adata.obs["temp_cluster"]

    if verbose:
        log.info(f"Best resolution: {best_res:.2f} with NMI = {best_nmi:.4f}")

    adata.obs["cluster"] = best_clustering

    # Biological conservation metrics
    results_dict["NMI_cluster/label"] = scib.metrics.nmi(adata, "cluster", label_key, "arithmetic")
    results_dict["ARI_cluster/label"] = scib.metrics.ari(adata, "cluster", label_key)
    results_dict["ASW_label"] = scib.metrics.silhouette(adata, label_key, use_rep, "euclidean")

    # Batch effect metrics (if batch_key valid)
    if batch_key is not None and batch_key in adata.obs and adata.obs[batch_key].nunique() > 1:
        results_dict["graph_conn"] = scib.metrics.graph_connectivity(adata, label_key)
        results_dict["ASW_batch"] = scib.metrics.silhouette(adata, batch_key, use_rep, "euclidean")
        results_dict["ASW_label/batch"] = scib.metrics.silhouette_batch(
            adata, batch_key, label_key, embed=use_rep, metric="euclidean", return_all=False
        )
        results_dict["PCR_batch"] = scib.metrics.pcr(
            adata, covariate=batch_key, embed=use_rep, recompute_pca=True, n_comps=50, verbose=False
        )
    else:
        if verbose:
            log.info("Skipping batch metrics — only one batch present or invalid batch_key.")
    
    results_dict["avg_bio"] = np.mean([
        results_dict["NMI_cluster/label"],
        results_dict["ARI_cluster/label"],
        results_dict["ASW_label"]
    ])

    # Filter NaNs
    results_dict = {k: v for k, v in results_dict.items() if not np.isnan(v)}

    return results_dict


In [None]:
import scanpy as sc 
adata = sc.read_h5ad("zero_shot_batch_data/pbmc.h5ad") 

results_dict = eval_clustering_metrics(adata=adata, 
                                        batch_key="batch",
                                        label_key="celltype",
                                        embedding_key="X",  # or "X_scGPT", etc.
                                        verbose=True)

Louvain clustering: 100%|██████████| 20/20 [00:15<00:00,  1.32it/s]
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)


mean silhouette per group:                    silhouette_score
group                              
B cells                    0.986484
CD14+ Monocytes            0.943531
CD4 T cells                0.980745
CD8 T cells                0.951482
Dendritic Cells            0.956119
FCGR3A+ Monocytes          0.986242
Megakaryocytes             0.856766
NK cells                   0.953083
Other                      0.930244




In [12]:
results_dict

{'NMI_cluster/label': 0.7043350648326699,
 'ARI_cluster/label': 0.6456273245075416,
 'ASW_label': 0.5333220548927784,
 'graph_conn': 0.9038879996225364,
 'ASW_batch': 0.4965497492812574,
 'ASW_label/batch': 0.9494108132303586,
 'PCR_batch': 0.0009914006163016576,
 'avg_bio': 0.6277614814109966}

In [5]:
results_dict_ood = eval_clustering_metrics(adata=adata_ood[:15000],
                                        batch_key="batch",
                                        label_key="cell_type",
                                        embedding_key="X", 
                                        verbose=True)

  adata.obsm["X"] = adata.X
Louvain clustering: 100%|██████████| 20/20 [00:11<00:00,  1.68it/s]
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)


mean silhouette per group: nan




In [6]:
results_dict_ood

{'NMI_cluster/label': 0.9334102174490695,
 'ARI_cluster/label': 0.9699361136567832,
 'ASW_label': 0.5538543930108312,
 'graph_conn': 0.9231509101914211,
 'ASW_batch': 0.6438532075334105,
 'PCR_batch': 0.042066597759588056,
 'avg_bio': 0.8190669080388946}