In [100]:
import numpy as np
import scanpy as sc
import yaml
from pathlib import Path
from scripts.EGGFM.eggfm import run_eggfm_dimred
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score


def compute_ari(X, labels, k):
    km = KMeans(n_clusters=len(np.unique(labels)), n_init=10)
    km.fit(X[:, :k])
    return adjusted_rand_score(labels, km.labels_)


def main():
    params = {
    "seed": 7,
    "hvg_n_top_genes": 2000,

    "spec": {
        "n_pcs": 20,
        "dcol_max_cells": 3000,
        # "ari_label_key": "Cell type annotation",  # <-- this must match an obs column name
        "ari_label_key": "paul15_clusters",         # <-- this must match an obs column name
        "ari_n_dims": 10,                           # how many dims to use for ARI per embedding
        "ad_file": "data/paul15/paul15.h5ad",
        # "ad_file": "data/prep/qc.h5ad",
    },

    "qc": {
        "min_cells": 500,
        "min_genes": 200,
        "max_pct_mt": 15,
    },

    "eggfm_model": {
        "hidden_dims": [512, 512, 512, 512],
    },

    "eggfm_train": {
        "batch_size": 2048,
        "num_epochs": 50,
        "lr": 1.0e-4,
        "sigma": 0.1,
        "device": "cuda",
        "latent_space": "hvg",
    },

    "eggfm_diffmap": {
        "geometry_source": "pca",          # "pca" or "hvg"
        "energy_source": "hvg",            # where SCM/Hessian read energies
        "metric_mode": "scm",    # "euclidean", "scm", or "hessian_mixed"
        "n_neighbors": 30,
        "n_comps": 30,
        "device": "cuda",
        "hvp_batch_size": 1024,
        "eps_mode": "median",
        "eps_value": 1.0,
        "eps_trunc": "no",
        "distance_power": 0.0,
        "t": 2.0,
        "norm_type": "l0",

        # SCM hyperparams
        "metric_gamma": 0.2,
        "metric_lambda": 4.0,
        "energy_clip_abs": 3.0,
        "energy_batch_size": 2048,

        # Hessian mixing hyperparams
        "hessian_mix_mode": "none",   # "additive" | "multiplicative" | "none"
        "hessian_mix_alpha": 0.3,
        "hessian_beta": 0.2,
        "hessian_clip_std": 2.0,
        "hessian_use_neg": True,
    },
}


    spec = params["spec"]
    k = spec.get("ari_n_dims", spec.get("n_pcs", 10))

    base = sc.read_h5ad(spec.get("ad_file"))
    labels = base.obs[spec["ari_label_key"]].to_numpy()

    scores_eggfm = []
    scores_eggfm_2 = []
    scores_eggfm_3 = []
    scores_eggfm_4 = []
    scores_eggfm_5 = []
    # scores_eggfm_6 = []
    scores_pca = []
    scores_pca_2 = []
    total = 5
    qc = base.copy()
    for run in range(total):
        print(f"=== Run {run+1}/{total} ===")
        qc, _ = run_eggfm_dimred(qc, params)        

        # PCA → Diffmap
        # sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_pca")
        # sc.tl.diffmap(qc, n_comps=k)
        # X_diff_pca = qc.obsm["X_diffmap"][:, :k]
        # qc.obsm["X_diff_pca"] = X_diff_pca

        # # PCA → Diffmap → Diffmap
        # sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_diff_pca")
        # sc.tl.diffmap(qc, n_comps=k)
        # X_diff_pca_double = qc.obsm["X_diffmap"][:, :k]
        # qc.obsm["X_diff_pca_x2"] = X_diff_pca_double

        # EGGFM
        X_eggfm = qc.obsm["X_eggfm"][:, :k]

        # EGGFM DM
        sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_eggfm")
        sc.tl.diffmap(qc, n_comps=k)
        X_diff_eggfm = qc.obsm["X_diffmap"][:, :k]
        qc.obsm["X_diff_eggfm"] = X_diff_eggfm

        # EGGFM DM DM
        sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_diff_eggfm")
        sc.tl.diffmap(qc, n_comps=k)
        X_diff_eggfm_x2 = qc.obsm["X_diffmap"][:, :k]
        qc.obsm["X_diff_eggfm_x2"] = X_diff_eggfm_x2

        # EGGFM DM DM DM
        sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_diff_eggfm_x2")
        sc.tl.diffmap(qc, n_comps=k)
        X_diff_eggm_x3 = qc.obsm["X_diffmap"][:, :k]
        qc.obsm["X_diff_eggm_x3"] = X_diff_eggm_x3

        # EGGFM DM DM DM DM
        sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_diff_eggm_x3")
        sc.tl.diffmap(qc, n_comps=k)
        X_diff_eggm_x4 = qc.obsm["X_diffmap"][:, :k]
        qc.obsm["X_diff_eggm_x4"] = X_diff_eggm_x4

        # scores_pca.append(compute_ari(X_diff_pca, labels, k))
        # scores_pca_2.append(compute_ari(X_diff_pca_double, labels, k))
        scores_eggfm.append(compute_ari(X_eggfm, labels, k))
        scores_eggfm_2.append(compute_ari(X_diff_eggfm, labels, k))
        scores_eggfm_3.append(compute_ari(X_diff_eggfm_x2, labels, k))
        scores_eggfm_4.append(compute_ari(X_diff_eggm_x3, labels, k))
        scores_eggfm_5.append(compute_ari(X_diff_eggm_x4, labels, k))

    print("\n=== Variance results ===")
    # print(f"PCA→DM:    mean={np.mean(scores_pca):.4f}, std={np.std(scores_pca):.4f}")
    # print(
    #     f"PCA→DM2:   mean={np.mean(scores_pca_2):.4f}, std={np.std(scores_pca_2):.4f}"
    # )
    print(
        f"EGGFM:     mean={np.mean(scores_eggfm):.4f}, std={np.std(scores_eggfm):.4f}"
    )
    print(
        f"EGGFM DM:  mean={np.mean(scores_eggfm_2):.4f}, std={np.std(scores_eggfm_2):.4f}"
    )
    print(
        f"EGGFM DM2: mean={np.mean(scores_eggfm_3):.4f}, std={np.std(scores_eggfm_3):.4f}"
    )
    print(
        f"EGGFM DM3: mean={np.mean(scores_eggfm_4):.4f}, std={np.std(scores_eggfm_4):.4f}"
    )
    print(
        f"EGGFM DM4: mean={np.mean(scores_eggfm_5):.4f}, std={np.std(scores_eggfm_5):.4f}")
    
main()

=== Run 1/5 ===
[Energy DSM] Epoch 1/50  loss=150402.5788
[Energy DSM] Epoch 2/50  loss=150336.5392
[Energy DSM] Epoch 3/50  loss=150405.1575
[Energy DSM] Epoch 4/50  loss=150401.6410
[Energy DSM] Epoch 5/50  loss=150445.5267
[Energy DSM] Epoch 6/50  loss=150250.5495
[Energy DSM] Epoch 7/50  loss=150316.6359
[Energy DSM] Epoch 8/50  loss=150244.8410
[Energy DSM] Epoch 9/50  loss=150330.4908
[Energy DSM] Epoch 10/50  loss=150395.9912
[Energy DSM] Epoch 11/50  loss=150388.7473
[Energy DSM] Epoch 12/50  loss=150430.5934
[Energy DSM] Epoch 13/50  loss=150297.4007
[Energy DSM] Epoch 14/50  loss=150235.3114
[Energy DSM] Epoch 15/50  loss=150119.5253
[Energy DSM] Epoch 16/50  loss=150175.1326
[Energy DSM] Epoch 17/50  loss=150188.6476
[Energy DSM] Epoch 18/50  loss=150069.0403
[Energy DSM] Epoch 19/50  loss=150251.4403
[Energy DSM] Epoch 20/50  loss=150472.1582
[Energy DSM] Epoch 21/50  loss=150214.9158
[Energy DSM] Epoch 22/50  loss=150368.3868
[Energy DSM] Epoch 23/50  loss=150360.4161
[Ene

ValueError: Unknown norm: l0