In [2]:
import yaml
from pathlib import Path

import scanpy as sc

from scripts.EGGFM.train_energy import train_energy_model
from scripts.EGGFM.engine import EGGFMDiffusionEngine
from scripts.EGGFM.data_sources import AnnDataViewProvider
from scripts.EGGFM.utils import subsample_adata

# ---- user knobs for quick tests ----
PARAMS_PATH = "configs/params.yml"  # change if needed
MAX_CELLS = 2000                    # None = use all cells
SUBSAMPLE_SEED = 0                  # controls which cells get picked

# Load params
params = yaml.safe_load(Path(PARAMS_PATH).read_text())
spec = params["spec"]
diff_cfg = params.get("eggfm_diffmap", {})
model_cfg = params.get("eggfm_model", {})
train_cfg = params.get("eggfm_train", {})

diff_cfg


{'geometry_source': 'pca',
 'energy_source': 'hvg',
 'n_neighbors': 30,
 'n_comps': 30,
 'device': 'cuda',
 'hvp_batch_size': 1024,
 'eps_mode': 'median',
 'eps_value': 1.0,
 'eps_trunc': 'no',
 'distance_power': 1.0,
 't': 2.0,
 'norm_type': 'l2',
 'metric_mode': 'hessian_mixed',
 'metric_gamma': 0.2,
 'metric_lambda': 4.0,
 'energy_clip_abs': 3.0,
 'energy_batch_size': 2048,
 'hessian_mix_mode': 'none',
 'hessian_mix_alpha': 0.3,
 'hessian_beta': 0.2,
 'hessian_clip_std': 2.0,
 'hessian_use_neg': True}

In [3]:
print("[notebook] loading paul15...", flush=True)
qc_ad = sc.read_h5ad(spec.get("ad_file"))

# print("[notebook] running prep_for_manifolds...", flush=True)
# qc_ad = prep_for_manifolds(ad)

print("[notebook] subsampling (if requested)...", flush=True)
qc_ad = subsample_adata(
    qc_ad,
    max_cells=MAX_CELLS,
    seed=SUBSAMPLE_SEED,
)

qc_ad


[notebook] loading paul15...
[notebook] subsampling (if requested)...
[subsample_adata] Subsampling 2000 / 2730 cells (seed=0)


AnnData object with n_obs × n_vars = 2000 × 2000
    obs: 'paul15_clusters', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'n_genes'
    var: 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'diffmap_evals', 'eggfm_meta', 'hvg', 'iroot', 'log1p', 'neighbors', 'pca'
    obsm: 'X_diff_eggfm', 'X_diff_pca', 'X_diff_pca_x2', 'X_diffmap', 'X_eggfm', 'X_pca'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'

In [4]:
print("[notebook] training energy model...", flush=True)
energy_model = train_energy_model(qc_ad, model_cfg, train_cfg)
energy_model


[notebook] training energy model...
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Epoch 12/50  loss=0.0000
[Energy DSM] Epoch 13/50  loss=0.0000
[Energy DSM] Epoch 14/50  loss=0.0000
[Energy DSM] Epoch 15/50  loss=0.0000
[Energy DSM] Epoch 16/50  loss=0.0000
[Energy DSM] Epoch 17/50  loss=0.0000
[Energy DSM] Epoch 18/50  loss=0.0000
[Energy DSM] Epoch 19/50  loss=0.0000
[Energy DSM] Epoch 20/50  loss=0.0000
[Energy DSM] Epoch 21/50  loss=0.0000
[Energy DSM] Epoch 22/50  loss=0.0000
[Energy DSM] Epoch 23/50  loss=0.0000
[Energy DSM] Epoch 24/50  loss=0.0000
[Energy DSM] Epoch 25/50  loss=0.0000
[Energy DSM] Epoch 26/5

EnergyMLP(
  (net): Sequential(
    (0): Linear(in_features=2000, out_features=512, bias=True)
    (1): Softplus(beta=1.0, threshold=20.0)
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): Softplus(beta=1.0, threshold=20.0)
    (4): Linear(in_features=512, out_features=512, bias=True)
    (5): Softplus(beta=1.0, threshold=20.0)
    (6): Linear(in_features=512, out_features=512, bias=True)
    (7): Softplus(beta=1.0, threshold=20.0)
    (8): Linear(in_features=512, out_features=2000, bias=True)
  )
)

In [5]:
# Build AnnData view provider (geometry + energy spaces)
view_provider = AnnDataViewProvider(
    geometry_source=diff_cfg.get("geometry_source", "pca"),
    energy_source=diff_cfg.get("energy_source", "hvg"),
)

# Build the EGGFM diffusion engine
engine = EGGFMDiffusionEngine(
    energy_model=energy_model,
    diff_cfg=diff_cfg,
    view_provider=view_provider,
)

metric_modes = ["euclidean", "scm", "hessian_mixed"]

for mode in metric_modes:
    print(f"[notebook] Building embedding for metric_mode='{mode}'", flush=True)
    X_emb = engine.build_embedding(qc_ad, metric_mode=mode)
    key = f"X_eggfm_{mode}"
    qc_ad.obsm[key] = X_emb
    print(f"[notebook] Stored embedding in .obsm['{key}'] with shape {X_emb.shape}", flush=True)

qc_ad


[notebook] Building embedding for metric_mode='euclidean'
[AnnDataViewProvider] using PCA for geometry with shape (2000, 50)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[DiffusionMap] using eps = 16.66 (power p=1.0)
[DiffusionMap] computing eigenvectors...
[DiffusionMap] finished. Embedding shape: (2000, 30)
[notebook] Stored embedding in .obsm['X_eggfm_euclidean'] with shape (2000, 30)
[notebook] Building embedding for metric_mode='scm'
[AnnDataViewProvider] using PCA for geometry with shape (2000, 50)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] energy_source hvg
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-62.9081, raw_max=78.7025, norm_min=-5.2009, norm_max=4.8406, clip=±3.0
[EGGFM SCM] metric G stats: min=0.3991, max=80.5421, mean=9.6260
[DiffusionMap] using eps = 37.77 (power p=1.0)
[Diffusion

AnnData object with n_obs × n_vars = 2000 × 2000
    obs: 'paul15_clusters', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'n_genes'
    var: 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'diffmap_evals', 'eggfm_meta', 'hvg', 'iroot', 'log1p', 'neighbors', 'pca'
    obsm: 'X_diff_eggfm', 'X_diff_pca', 'X_diff_pca_x2', 'X_diffmap', 'X_eggfm', 'X_pca', 'X_eggfm_euclidean', 'X_eggfm_scm', 'X_eggfm_hessian_mixed'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'

In [6]:
from sklearn.metrics import adjusted_rand_score
import numpy as np

label_key = "paul15_clusters"  # this is the usual key for paul15
if label_key not in qc_ad.obs:
    raise KeyError(f"Label column '{label_key}' not found in qc_ad.obs")

labels = qc_ad.obs[label_key].to_numpy()

def kmeans_ari(X, labels, n_clusters=None, random_state=0):
    from sklearn.cluster import KMeans
    if n_clusters is None:
        n_clusters = len(np.unique(labels))
    km = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=10)
    preds = km.fit_predict(X)
    return adjusted_rand_score(labels, preds)

for mode in metric_modes:
    key = f"X_eggfm_{mode}"
    X = qc_ad.obsm[key]
    ari = kmeans_ari(X, labels)
    print(f"metric_mode='{mode}': ARI={ari:.3f}")


metric_mode='euclidean': ARI=0.279
metric_mode='scm': ARI=0.149
metric_mode='hessian_mixed': ARI=0.316


In [7]:
out_path = f"data/paul15/paul15_eggfm_test_{qc_ad.n_obs}cells.h5ad"
Path("data/paul15").mkdir(parents=True, exist_ok=True)
print(f"[notebook] writing result to {out_path}", flush=True)
qc_ad.write_h5ad(out_path)


[notebook] writing result to data/paul15/paul15_eggfm_test_2000cells.h5ad


In [8]:
from scripts.EGGFM.admr import run_admr_layers
# Example for paul15:
label_key = "paul15_clusters"  # this is the usual key for paul15
labels = qc_ad.obs[label_key].to_numpy()

metric_sequence = ["scm", "euclidean", "scm"]   # SCM → Euclidean → SCM
t_sequence = [1.0, 2.0, 2.0]                   # or [None, 2.0, 2.0] to use default t on layer 0

qc_ad, layer_embeddings, metrics_log = run_admr_layers(
    ad_prep=qc_ad,
    engine=engine,
    n_layers=3,
    metric_sequence=metric_sequence,
    t_sequence=t_sequence,
    base_geometry_source="pca",    # or "hvg"
    store_prefix="X_admr_layer",
    labels=labels,
    label_key=label_key,
    n_clusters=None,               # default = num unique labels
    k_overlap=30,                  # neighbor overlap k
    ari_random_state=0,
)

metrics_log


[ADMR] Using PCA as base geometry with shape (2000, 50)
[ADMR] Layer 0: metric_mode='scm', t=1.0, geometry shape=(2000, 50)
[EGGFM Engine] using override geometry with shape (2000, 50)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] energy_source hvg
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-62.9081, raw_max=78.7025, norm_min=-5.2009, norm_max=4.8406, clip=±3.0
[EGGFM SCM] metric G stats: min=0.3991, max=80.5421, mean=9.6260
[DiffusionMap] using eps = 37.77 (power p=1.0)
[DiffusionMap] computing eigenvectors...
[DiffusionMap] finished. Embedding shape: (2000, 30)


NameError: name 'kmeans_ari' is not defined

In [None]:
!BUCKET="gs://medit-uml-prod-uscentral1-8e7a" && \
 gsutil -m rsync -r out/admr_logs "${BUCKET}/out/admr_logs"