In [81]:
import numpy as np
import scanpy as sc
from scripts.EGGFM.eggfm import run_eggfm_dimred
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score

In [82]:
def subset_anndata(ad: sc.AnnData, n_cells: int, random_state: int = 0) -> sc.AnnData:
    rng = np.random.default_rng(random_state)
    n = ad.n_obs
    n_subset = min(n_cells, n)
    idx = rng.choice(np.arange(n), size=n_subset, replace=False)
    return ad[idx].copy()

def compute_ari_fixed(X, labels, k, random_state: int = 0) -> float:
    Xk = X[:, :k]
    km = KMeans(
        n_clusters=len(np.unique(labels)),
        n_init=10,
        random_state=random_state,
    )
    km.fit(Xk)
    return adjusted_rand_score(labels, km.labels_)

In [85]:
params = {
    "seed": 7,
    "pca_n_top_genes": 2000,

    "spec": {
        "n_pcs": 20,
        "dcol_max_cells": 3000,
        "ari_label_key": "Cell type annotation",  # <-- this must match an obs column name
        # "ari_label_key": "paul15_clusters",         # <-- this must match an obs column name
        "ari_n_dims": 10,                           # how many dims to use for ARI per embedding
        # "ad_file": "data/paul15/paul15.h5ad",
        "ad_file": "data/prep/qc.h5ad",
    },

    "qc": {
        "min_cells": 500,
        "min_genes": 200,
        "max_pct_mt": 15,
    },

    "eggfm_model": {
        "hidden_dims": [512, 512, 512, 512],
    },

    "eggfm_train": {
        "batch_size": 2048,
        "num_epochs": 50,
        "lr": 1.0e-4,
        "sigma": 0.1,
        "device": "cuda",
        "latent_space": "hvg",
        "early_stop_patience": 10,
        "early_stop_min_delta": 0.0,
        "n_cells_sample":2000
    },

    "eggfm_diffmap": {
        "geometry_source": "pca",          # "pca" or "hvg"
        "energy_source": "hvg",            # where SCM/Hessian read energies
        "metric_mode": "scm",    # "euclidean", "scm", or "hessian_mixed"
        "n_neighbors": 30,
        "n_comps": 30,
        "device": "cuda",
        "hvp_batch_size": 1024,
        "eps_mode": "median",
        "eps_value": 1.0,
        "eps_trunc": "no",
        "distance_power": 1.0,
        "t": 1.5,
        "norm_type": "l2",

        # SCM hypered
        "metric_gamma": 0.4,
        "metric_lambda": 8.0,
        "energy_clip_abs": 3.0,
        "energy_batch_size": 2048,

        # Hessian mixing hyperparams
        "hessian_mix_mode": "multiplicative",   # "additive" | "multiplicative" | "none"
        "hessian_mix_alpha": 0.3,
        "hessian_beta": 0.2,
        "hessian_clip_std": 2.0,
        "hessian_use_neg": True,
    },
}

spec = params["spec"]
k = spec.get("ari_n_dims", spec.get("n_pcs", 10))

base_ad = sc.read_h5ad(spec.get("ad_file"))


scores_eggfm = []
scores_eggfm_2 = []
scores_eggfm_3 = []
scores_eggfm_4 = []
scores_eggfm_5 = []
# scores_eggfm_6 = []
scores_pca = []
scores_pca_2 = []
total = 15
for run in range(total):
    run_seed = 0 + run
    ad_prep = subset_anndata(base_ad, params["eggfm_train"]["n_cells_sample"], run_seed)
    labels = ad_prep.obs[spec["ari_label_key"]].to_numpy()
    
    qc = ad_prep.copy()

    print(f"=== Run {run+1}/{total} ===")
    qc, _ = run_eggfm_dimred(qc, params)        

    # PCA → Diffmap
    sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_pca")
    sc.tl.diffmap(qc, n_comps=k)
    X_diff_pca = qc.obsm["X_diffmap"][:, :k]
    qc.obsm["X_diff_pca"] = X_diff_pca

    # # PCA → Diffmap → Diffmap
    # sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_diff_pca")
    # sc.tl.diffmap(qc, n_comps=k)
    # X_diff_pca_double = qc.obsm["X_diffmap"][:, :k]
    # qc.obsm["X_diff_pca_x2"] = X_diff_pca_double

    # EGGFM
    X_eggfm = qc.obsm["X_eggfm"][:, :k]

    # EGGFM DM
    sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_eggfm")
    sc.tl.diffmap(qc, n_comps=k)
    X_diff_eggfm = qc.obsm["X_diffmap"][:, :k]
    qc.obsm["X_diff_eggfm"] = X_diff_eggfm

    # EGGFM DM DM
    sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_diff_eggfm")
    sc.tl.diffmap(qc, n_comps=k)
    X_diff_eggfm_x2 = qc.obsm["X_diffmap"][:, :k]
    qc.obsm["X_diff_eggfm_x2"] = X_diff_eggfm_x2

    # EGGFM DM DM DM
    sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_diff_eggfm_x2")
    sc.tl.diffmap(qc, n_comps=k)
    X_diff_eggm_x3 = qc.obsm["X_diffmap"][:, :k]
    qc.obsm["X_diff_eggm_x3"] = X_diff_eggm_x3

    # EGGFM DM DM DM DM
    sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_diff_eggm_x3")
    sc.tl.diffmap(qc, n_comps=k)
    X_diff_eggm_x4 = qc.obsm["X_diffmap"][:, :k]
    qc.obsm["X_diff_eggm_x4"] = X_diff_eggm_x4

    scores_pca.append(compute_ari_fixed(X_diff_pca, labels, k))
    # scores_pca_2.append(compute_ari(X_diff_pca_double, labels, k))
    scores_eggfm.append(compute_ari_fixed(X_eggfm, labels, k))
    scores_eggfm_2.append(compute_ari_fixed(X_diff_eggfm, labels, k))
    scores_eggfm_3.append(compute_ari_fixed(X_diff_eggfm_x2, labels, k))
    scores_eggfm_4.append(compute_ari_fixed(X_diff_eggm_x3, labels, k))
    scores_eggfm_5.append(compute_ari_fixed(X_diff_eggm_x4, labels, k))

print("\n=== Variance results ===")
print(f"PCA→DM:    mean={np.mean(scores_pca):.4f}, std={np.std(scores_pca):.4f}")
# print(
#     f"PCA→DM2:   mean={np.mean(scores_pca_2):.4f}, std={np.std(scores_pca_2):.4f}"
# )
print(
    f"EGGFM:     mean={np.mean(scores_eggfm):.4f}, std={np.std(scores_eggfm):.4f}"
)
print(
    f"EGGFM DM:  mean={np.mean(scores_eggfm_2):.4f}, std={np.std(scores_eggfm_2):.4f}"
)
print(
    f"EGGFM DM2: mean={np.mean(scores_eggfm_3):.4f}, std={np.std(scores_eggfm_3):.4f}"
)
print(
    f"EGGFM DM3: mean={np.mean(scores_eggfm_4):.4f}, std={np.std(scores_eggfm_4):.4f}"
)
print(
    f"EGGFM DM4: mean={np.mean(scores_eggfm_5):.4f}, std={np.std(scores_eggfm_5):.4f}")

=== Run 1/15 ===
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Early stopping at epoch 11 (best_loss=0.0000)
[AnnDataViewProvider] using PCA for geometry with shape (2000, 20)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-64.2084, raw_max=91.2926, norm_min=-5.8411, norm_max=8.4299, clip=±3.0
[EGGFM SCM] metric G stats: min=0.7983, max=161.0843, mean=24.0497
[DiffusionMap] using eps = 35.13 (power p=1.0)
[DiffusionMap] computing eigenvectors...


  utils.warn_names_duplicates("obs")


[DiffusionMap] finished. Embedding shape: (2000, 30)
=== Run 2/15 ===
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Early stopping at epoch 11 (best_loss=0.0000)
[AnnDataViewProvider] using PCA for geometry with shape (2000, 20)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-77.2271, raw_max=95.6151, norm_min=-6.2907, norm_max=7.8594, clip=±3.0
[EGGFM SCM] metric G stats: min=0.7983, max=161.0843, mean=22.2142
[DiffusionMap] using eps = 34.84 (power p=1.0)
[Diffusion

  utils.warn_names_duplicates("obs")


[DiffusionMap] finished. Embedding shape: (2000, 30)
=== Run 3/15 ===
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Early stopping at epoch 11 (best_loss=0.0000)
[AnnDataViewProvider] using PCA for geometry with shape (2000, 20)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-71.4489, raw_max=71.2718, norm_min=-5.9459, norm_max=5.8239, clip=±3.0
[EGGFM SCM] metric G stats: min=0.7983, max=161.0843, mean=21.2567
[DiffusionMap] using eps = 33.97 (power p=1.0)
[Diffusion

  utils.warn_names_duplicates("obs")


=== Run 4/15 ===
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Early stopping at epoch 11 (best_loss=0.0000)
[AnnDataViewProvider] using PCA for geometry with shape (2000, 20)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-71.7746, raw_max=82.5563, norm_min=-6.5272, norm_max=7.4705, clip=±3.0
[EGGFM SCM] metric G stats: min=0.7983, max=161.0843, mean=22.5624
[DiffusionMap] using eps = 35.51 (power p=1.0)
[DiffusionMap] computing eigenvectors...


  utils.warn_names_duplicates("obs")


[DiffusionMap] finished. Embedding shape: (2000, 30)
=== Run 5/15 ===
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Early stopping at epoch 11 (best_loss=0.0000)
[AnnDataViewProvider] using PCA for geometry with shape (2000, 20)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-84.4260, raw_max=70.5996, norm_min=-7.3653, norm_max=6.1188, clip=±3.0
[EGGFM SCM] metric G stats: min=0.7983, max=161.0843, mean=22.4019
[DiffusionMap] using eps = 34.7 (power p=1.0)
[DiffusionM

  utils.warn_names_duplicates("obs")


=== Run 6/15 ===
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Early stopping at epoch 11 (best_loss=0.0000)
[AnnDataViewProvider] using PCA for geometry with shape (2000, 20)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-85.1413, raw_max=92.3135, norm_min=-7.3160, norm_max=7.8945, clip=±3.0
[EGGFM SCM] metric G stats: min=0.7983, max=161.0843, mean=21.7776
[DiffusionMap] using eps = 34.85 (power p=1.0)
[DiffusionMap] computing eigenvectors...


  utils.warn_names_duplicates("obs")


[DiffusionMap] finished. Embedding shape: (2000, 30)
=== Run 7/15 ===
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Early stopping at epoch 11 (best_loss=0.0000)
[AnnDataViewProvider] using PCA for geometry with shape (2000, 20)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-69.6315, raw_max=73.2435, norm_min=-5.6735, norm_max=6.1299, clip=±3.0
[EGGFM SCM] metric G stats: min=0.7983, max=161.0843, mean=23.7389
[DiffusionMap] using eps = 34.1 (power p=1.0)
[DiffusionM

  utils.warn_names_duplicates("obs")


[DiffusionMap] finished. Embedding shape: (2000, 30)
=== Run 8/15 ===
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Early stopping at epoch 11 (best_loss=0.0000)
[AnnDataViewProvider] using PCA for geometry with shape (2000, 20)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-89.5956, raw_max=83.1656, norm_min=-7.9785, norm_max=7.2908, clip=±3.0
[EGGFM SCM] metric G stats: min=0.7983, max=161.0843, mean=21.1120
[DiffusionMap] using eps = 34.35 (power p=1.0)
[Diffusion

  utils.warn_names_duplicates("obs")


=== Run 9/15 ===
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Early stopping at epoch 11 (best_loss=0.0000)
[AnnDataViewProvider] using PCA for geometry with shape (2000, 20)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-85.7698, raw_max=66.9769, norm_min=-7.6910, norm_max=6.0292, clip=±3.0
[EGGFM SCM] metric G stats: min=0.7983, max=161.0843, mean=23.1244
[DiffusionMap] using eps = 35.09 (power p=1.0)
[DiffusionMap] computing eigenvectors...
[DiffusionMap] finishe

  utils.warn_names_duplicates("obs")


=== Run 10/15 ===
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Early stopping at epoch 11 (best_loss=0.0000)
[AnnDataViewProvider] using PCA for geometry with shape (2000, 20)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-89.4617, raw_max=74.5783, norm_min=-8.2284, norm_max=6.8935, clip=±3.0
[EGGFM SCM] metric G stats: min=0.7983, max=161.0843, mean=23.5191
[DiffusionMap] using eps = 34.32 (power p=1.0)
[DiffusionMap] computing eigenvectors...
[DiffusionMap] finish

  utils.warn_names_duplicates("obs")


=== Run 11/15 ===
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Early stopping at epoch 11 (best_loss=0.0000)
[AnnDataViewProvider] using PCA for geometry with shape (2000, 20)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-90.4777, raw_max=65.0756, norm_min=-7.7913, norm_max=5.6162, clip=±3.0
[EGGFM SCM] metric G stats: min=0.7983, max=161.0843, mean=22.0548
[DiffusionMap] using eps = 34.67 (power p=1.0)
[DiffusionMap] computing eigenvectors...


  utils.warn_names_duplicates("obs")


[DiffusionMap] finished. Embedding shape: (2000, 30)
=== Run 12/15 ===
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Early stopping at epoch 11 (best_loss=0.0000)
[AnnDataViewProvider] using PCA for geometry with shape (2000, 20)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-78.5897, raw_max=64.0922, norm_min=-6.7557, norm_max=5.4593, clip=±3.0
[EGGFM SCM] metric G stats: min=0.7983, max=161.0843, mean=21.6065
[DiffusionMap] using eps = 34.45 (power p=1.0)
[Diffusio

  utils.warn_names_duplicates("obs")


[DiffusionMap] finished. Embedding shape: (2000, 30)
=== Run 13/15 ===
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Early stopping at epoch 11 (best_loss=0.0000)
[AnnDataViewProvider] using PCA for geometry with shape (2000, 20)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-83.3199, raw_max=118.8716, norm_min=-7.2255, norm_max=10.3266, clip=±3.0
[EGGFM SCM] metric G stats: min=0.7983, max=161.0843, mean=22.0875
[DiffusionMap] using eps = 34.57 (power p=1.0)
[Diffus

  utils.warn_names_duplicates("obs")


=== Run 14/15 ===
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Early stopping at epoch 11 (best_loss=0.0000)
[AnnDataViewProvider] using PCA for geometry with shape (2000, 20)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-93.6351, raw_max=74.5666, norm_min=-7.7310, norm_max=6.0637, clip=±3.0
[EGGFM SCM] metric G stats: min=0.7983, max=161.0843, mean=20.9371
[DiffusionMap] using eps = 34.56 (power p=1.0)
[DiffusionMap] computing eigenvectors...
[DiffusionMap] finish

  utils.warn_names_duplicates("obs")


=== Run 15/15 ===
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Early stopping at epoch 11 (best_loss=0.0000)
[AnnDataViewProvider] using PCA for geometry with shape (2000, 20)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-100.3523, raw_max=92.3673, norm_min=-8.7095, norm_max=7.8984, clip=±3.0
[EGGFM SCM] metric G stats: min=0.7983, max=161.0843, mean=20.2584
[DiffusionMap] using eps = 34.09 (power p=1.0)
[DiffusionMap] computing eigenvectors...


  utils.warn_names_duplicates("obs")


[DiffusionMap] finished. Embedding shape: (2000, 30)

=== Variance results ===
PCA→DM:    mean=0.2726, std=0.0273
EGGFM:     mean=0.2588, std=0.0519
EGGFM DM:  mean=0.3160, std=0.0798
EGGFM DM2: mean=0.3596, std=0.1002
EGGFM DM3: mean=0.2997, std=0.0778
EGGFM DM4: mean=0.2793, std=0.0458


In [None]:
from datetime import datetime
import os
import pandas as pd

# ---- build the result row (config + score summaries) ----

all_results = []
ed = params["eggfm_diffmap"]
ari_label_key = params["spec"]["ari_label_key"]

row = {
    # break out the ARI label explicitly
    "ari_label_key": ari_label_key,

    # all diffmap hyperparams become their own columns
    **ed,

    # EGGFM score summaries (rounded for nicer CSV)
    "ari_pca_mean": round(float(np.mean(scores_pca)), 4),
    "ari_eggfm_mean": round(float(np.mean(scores_eggfm)), 4),
    "ari_eggfm_std": round(float(np.std(scores_eggfm)), 4),
    "ari_eggfm_dm_mean": round(float(np.mean(scores_eggfm_2)), 4),
    "ari_eggfm_dm_std": round(float(np.std(scores_eggfm_2)), 4),
    "ari_eggfm_dm2_mean": round(float(np.mean(scores_eggfm_3)), 4),
    "ari_eggfm_dm2_std": round(float(np.std(scores_eggfm_3)), 4),
    "ari_eggfm_dm3_mean": round(float(np.mean(scores_eggfm_4)), 4),
    "ari_eggfm_dm3_std": round(float(np.std(scores_eggfm_4)), 4),
    "ari_eggfm_dm4_mean": round(float(np.mean(scores_eggfm_5)), 4),
    "ari_eggfm_dm4_std": round(float(np.std(scores_eggfm_5)), 4),
}

all_results.append(row)

results_df = pd.DataFrame(all_results)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_path = f"out/eggfm_admr_layered_ablation_subset_{timestamp}.csv"
results_df.to_csv(results_path, index=False)

gcs_path = f"gs://medit-uml-prod-uscentral1-8e7a/{results_path}"
os.system(f"gsutil cp {results_path} {gcs_path}")
print("Uploaded to:", gcs_path)

Copying file://out/eggfm_admr_layered_ablation_subset_20251128_222538.csv [Content-Type=text/csv]...
/ [1 files][  687.0 B/  687.0 B]                                                
Operation completed over 1 objects/687.0 B.                                      


Uploaded to: gs://medit-uml-prod-uscentral1-8e7a/out/eggfm_admr_layered_ablation_subset_20251128_222538.csv


In [None]:
# Final focused hyperparam sweep

pattern_grid = {
    # Euclidean baseline: test both l0 and linf, since linf was a star on Paul15
    "eucl_only": {
        "n_layers": [1, 2, 3],
        "norm_types": ["l0", "linf"],
        "distance_powers": [0.0, 0.25, 0.5],
    },
    # SCM: winner on Weinreb at L3, p ~ 0.25, norm l0
    "scm_alt_euclid": {
        "n_layers": [1, 2, 3],
        "norm_types": ["l0"],
        "distance_powers": [0.0, 0.25, 0.5],
    },
    # Hessian-mixed: keep as a single EGGFM competitor
    "hessMult_alt_euclid": {
        "n_layers": [1, 2, 3],
        "norm_types": ["l0"],
        "distance_powers": [0.25, 0.5],
    },
}

t_euclid_values = [2.0]

config_list = []
for pattern_type, grid in pattern_grid.items():
    for n_layers in grid["n_layers"]:
        for norm in grid["norm_types"]:
            for p in grid["distance_powers"]:
                for t_eucl in t_euclid_values:
                    exp_name = (
                        f"{pattern_type}_L{n_layers}_norm{norm}_p{p}_teucl{t_eucl}"
                    )
                    config_list.append(
                        dict(
                            exp_name=exp_name,
                            pattern_type=pattern_type,
                            n_layers=n_layers,
                            t_euclid=t_eucl,
                            norm_type=norm,
                            distance_power=p,
                        )
                    )

len(config_list)


33