In [1]:
import yaml
from pathlib import Path

import scanpy as sc

from scripts.EGGFM.train_energy import train_energy_model
from scripts.EGGFM.engine import EGGFMDiffusionEngine
from scripts.EGGFM.data_sources import AnnDataViewProvider
from scripts.EGGFM.utils import subsample_adata

# ---- user knobs for quick tests ----
PARAMS_PATH = "configs/params.yml"  # change if needed
MAX_CELLS = 2000                    # None = use all cells
SUBSAMPLE_SEED = 0                  # controls which cells get picked

# Load params
params = yaml.safe_load(Path(PARAMS_PATH).read_text())
spec = params["spec"]
diff_cfg = params.get("eggfm_diffmap", {})
model_cfg = params.get("eggfm_model", {})
train_cfg = params.get("eggfm_train", {})

diff_cfg


{'geometry_source': 'pca',
 'energy_source': 'hvg',
 'n_neighbors': 30,
 'n_comps': 30,
 'device': 'cuda',
 'hvp_batch_size': 1024,
 'eps_mode': 'median',
 'eps_value': 1.0,
 'eps_trunc': 'no',
 'distance_power': 1.0,
 't': 2.0,
 'norm_type': 'l2',
 'metric_mode': 'hessian_mixed',
 'metric_gamma': 0.2,
 'metric_lambda': 4.0,
 'energy_clip_abs': 3.0,
 'energy_batch_size': 2048,
 'hessian_mix_mode': 'none',
 'hessian_mix_alpha': 0.3,
 'hessian_beta': 0.2,
 'hessian_clip_std': 2.0,
 'hessian_use_neg': True}

In [2]:
print("[notebook] loading paul15...", flush=True)
qc_ad = sc.read_h5ad(spec.get("ad_file"))

# print("[notebook] running prep_for_manifolds...", flush=True)
# qc_ad = prep_for_manifolds(ad)

print("[notebook] subsampling (if requested)...", flush=True)
qc_ad = subsample_adata(
    qc_ad,
    max_cells=MAX_CELLS,
    seed=SUBSAMPLE_SEED,
)

qc_ad


[notebook] loading paul15...
[notebook] subsampling (if requested)...
[subsample_adata] Subsampling 2000 / 2730 cells (seed=0)


AnnData object with n_obs × n_vars = 2000 × 2000
    obs: 'paul15_clusters', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'n_genes'
    var: 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'diffmap_evals', 'eggfm_meta', 'hvg', 'iroot', 'log1p', 'neighbors', 'pca'
    obsm: 'X_diff_eggfm', 'X_diff_pca', 'X_diff_pca_x2', 'X_diffmap', 'X_eggfm', 'X_pca'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'

In [3]:
print("[notebook] training energy model...", flush=True)
energy_model = train_energy_model(qc_ad, model_cfg, train_cfg)
energy_model


[notebook] training energy model...
[Energy DSM] Epoch 1/50  loss=0.0000
[Energy DSM] Epoch 2/50  loss=0.0000
[Energy DSM] Epoch 3/50  loss=0.0000
[Energy DSM] Epoch 4/50  loss=0.0000
[Energy DSM] Epoch 5/50  loss=0.0000
[Energy DSM] Epoch 6/50  loss=0.0000
[Energy DSM] Epoch 7/50  loss=0.0000
[Energy DSM] Epoch 8/50  loss=0.0000
[Energy DSM] Epoch 9/50  loss=0.0000
[Energy DSM] Epoch 10/50  loss=0.0000
[Energy DSM] Epoch 11/50  loss=0.0000
[Energy DSM] Epoch 12/50  loss=0.0000
[Energy DSM] Epoch 13/50  loss=0.0000
[Energy DSM] Epoch 14/50  loss=0.0000
[Energy DSM] Epoch 15/50  loss=0.0000
[Energy DSM] Epoch 16/50  loss=0.0000
[Energy DSM] Epoch 17/50  loss=0.0000
[Energy DSM] Epoch 18/50  loss=0.0000
[Energy DSM] Epoch 19/50  loss=0.0000
[Energy DSM] Epoch 20/50  loss=0.0000
[Energy DSM] Epoch 21/50  loss=0.0000
[Energy DSM] Epoch 22/50  loss=0.0000
[Energy DSM] Epoch 23/50  loss=0.0000
[Energy DSM] Epoch 24/50  loss=0.0000
[Energy DSM] Epoch 25/50  loss=0.0000
[Energy DSM] Epoch 26/5

EnergyMLP(
  (net): Sequential(
    (0): Linear(in_features=2000, out_features=512, bias=True)
    (1): Softplus(beta=1.0, threshold=20.0)
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): Softplus(beta=1.0, threshold=20.0)
    (4): Linear(in_features=512, out_features=512, bias=True)
    (5): Softplus(beta=1.0, threshold=20.0)
    (6): Linear(in_features=512, out_features=512, bias=True)
    (7): Softplus(beta=1.0, threshold=20.0)
    (8): Linear(in_features=512, out_features=2000, bias=True)
  )
)

In [4]:
# Build AnnData view provider (geometry + energy spaces)
view_provider = AnnDataViewProvider(
    geometry_source=diff_cfg.get("geometry_source", "pca"),
    energy_source=diff_cfg.get("energy_source", "hvg"),
)

# Build the EGGFM diffusion engine
engine = EGGFMDiffusionEngine(
    energy_model=energy_model,
    diff_cfg=diff_cfg,
    view_provider=view_provider,
)

metric_modes = ["euclidean", "scm", "hessian_mixed"]

for mode in metric_modes:
    print(f"[notebook] Building embedding for metric_mode='{mode}'", flush=True)
    X_emb = engine.build_embedding(qc_ad, metric_mode=mode)
    key = f"X_eggfm_{mode}"
    qc_ad.obsm[key] = X_emb
    print(f"[notebook] Stored embedding in .obsm['{key}'] with shape {X_emb.shape}", flush=True)

qc_ad


[notebook] Building embedding for metric_mode='euclidean'
[AnnDataViewProvider] using PCA for geometry with shape (2000, 50)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[DiffusionMap] using eps = 16.66 (power p=1.0)
[DiffusionMap] computing eigenvectors...
[DiffusionMap] finished. Embedding shape: (2000, 30)
[notebook] Stored embedding in .obsm['X_eggfm_euclidean'] with shape (2000, 30)
[notebook] Building embedding for metric_mode='scm'
[AnnDataViewProvider] using PCA for geometry with shape (2000, 50)
[EGGFM Engine] building kNN graph (euclidean in geometry space)...
[EGGFM Engine] total edges (directed): 60000
[EGGFM SCM] energy_source hvg
[EGGFM SCM] computing energies E(x) for all cells...
[EGGFM SCM] energy stats: raw_min=-61.6316, raw_max=86.8760, norm_min=-4.3058, norm_max=6.0227, clip=±3.0
[EGGFM SCM] metric G stats: min=0.3991, max=80.5421, mean=10.5434
[DiffusionMap] using eps = 41.03 (power p=1.0)
[Diffusio

AnnData object with n_obs × n_vars = 2000 × 2000
    obs: 'paul15_clusters', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'n_genes'
    var: 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'diffmap_evals', 'eggfm_meta', 'hvg', 'iroot', 'log1p', 'neighbors', 'pca'
    obsm: 'X_diff_eggfm', 'X_diff_pca', 'X_diff_pca_x2', 'X_diffmap', 'X_eggfm', 'X_pca', 'X_eggfm_euclidean', 'X_eggfm_scm', 'X_eggfm_hessian_mixed'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'

In [5]:
from sklearn.metrics import adjusted_rand_score
import numpy as np

label_key = "paul15_clusters"  # this is the usual key for paul15
if label_key not in qc_ad.obs:
    raise KeyError(f"Label column '{label_key}' not found in qc_ad.obs")

labels = qc_ad.obs[label_key].to_numpy()

def kmeans_ari(X, labels, n_clusters=None, random_state=0):
    from sklearn.cluster import KMeans
    if n_clusters is None:
        n_clusters = len(np.unique(labels))
    km = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=10)
    preds = km.fit_predict(X)
    return adjusted_rand_score(labels, preds)

for mode in metric_modes:
    key = f"X_eggfm_{mode}"
    X = qc_ad.obsm[key]
    ari = kmeans_ari(X, labels)
    print(f"metric_mode='{mode}': ARI={ari:.3f}")


metric_mode='euclidean': ARI=0.279
metric_mode='scm': ARI=0.160
metric_mode='hessian_mixed': ARI=0.316


In [6]:
out_path = f"data/paul15/paul15_eggfm_test_{qc_ad.n_obs}cells.h5ad"
Path("data/paul15").mkdir(parents=True, exist_ok=True)
print(f"[notebook] writing result to {out_path}", flush=True)
qc_ad.write_h5ad(out_path)


[notebook] writing result to data/paul15/paul15_eggfm_test_2000cells.h5ad


In [50]:
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score
import numpy as np

def compute_ari_fixed(X, labels, k, random_state=0):
    Xk = X[:, :k]
    km = KMeans(
        n_clusters=len(np.unique(labels)),
        n_init=10,
        random_state=random_state,   # <- fixes k-means randomness
    )
    km.fit(Xk)
    return adjusted_rand_score(labels, km.labels_)


In [65]:
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score
import yaml
from pathlib import Path
import scanpy as sc
import numpy as np

def compute_ari(X, labels, k):
    km = KMeans(
        n_clusters=len(np.unique(labels)),
        n_init=10,
        random_state=0,   # optional: fix for determinism
    )
    km.fit(X[:, :k])
    return adjusted_rand_score(labels, km.labels_)

params = yaml.safe_load(Path("configs/params.yml").read_text())
spec = params["spec"]
k = spec.get("ari_n_dims", spec.get("n_pcs", 10))

# Start from base, then run the same prep + EGGFM you normally do
base = sc.read_h5ad(spec.get("ad_file"))

qc_ad = base.copy()
qc_ad, _ = run_eggfm_dimred(qc_ad, params)   # or however you usually call it

# IMPORTANT: labels and embedding both from qc_ad
labels = qc_ad.obs[spec["ari_label_key"]].to_numpy()
X_eggfm = qc_ad.obsm["X_eggfm"][:, :k]

ari_eggfm = compute_ari(X_eggfm, labels, k)
print("EGGFM base ARI:", ari_eggfm)


[Energy DSM] Epoch 1/50  loss=150245.6147
[Energy DSM] Epoch 2/50  loss=150399.4842
[Energy DSM] Epoch 3/50  loss=150265.1546
[Energy DSM] Epoch 4/50  loss=150380.9172
[Energy DSM] Epoch 5/50  loss=150254.9333
[Energy DSM] Epoch 6/50  loss=150428.7766
[Energy DSM] Epoch 7/50  loss=150221.3392
[Energy DSM] Epoch 8/50  loss=150294.7751
[Energy DSM] Epoch 9/50  loss=150237.9370
[Energy DSM] Epoch 10/50  loss=150286.7927
[Energy DSM] Epoch 11/50  loss=150306.7429
[Energy DSM] Epoch 12/50  loss=150510.6755
[Energy DSM] Epoch 13/50  loss=150279.7128
[Energy DSM] Epoch 14/50  loss=150223.8476
[Energy DSM] Epoch 15/50  loss=150369.9927
[Energy DSM] Epoch 16/50  loss=150257.2894
[Energy DSM] Epoch 17/50  loss=150457.4125
[Energy DSM] Epoch 18/50  loss=150205.5853
[Energy DSM] Epoch 19/50  loss=150416.5275
[Energy DSM] Epoch 20/50  loss=150406.5407
[Energy DSM] Epoch 21/50  loss=150366.2652
[Energy DSM] Epoch 22/50  loss=150275.8447
[Energy DSM] Epoch 23/50  loss=150116.6652
[Energy DSM] Epoch 2

In [None]:
# recipe: Hessian-mixed base, then ONE Euclidean layer with gentle t=1
metric_sequence = ["hessian_mixed", "euclidean"]
t_sequence      = [2.0,            1.0]          # base = t=2, smooth = t=1

qc_ad_admr, layer_embs, log_admr = run_admr_layers(
    ad_prep=qc_ad,
    engine=engine,
    n_layers=len(metric_sequence),
    metric_sequence=metric_sequence,
    t_sequence=t_sequence,
    base_geometry_source="pca",
    store_prefix="X_admr_hess_eucl_t1",
    labels=labels,
    label_key="paul15_clusters",
    n_clusters=None,
    k_overlap=30,
)

X_admr0 = layer_embs[0]  # Hessian-mixed base
X_admr1 = layer_embs[1]  # Euclidean on top

ari_admr0 = compute_ari_fixed(X_admr0, labels, k)
ari_admr1 = compute_ari_fixed(X_admr1, labels, k)

print(f"ADMR layer 0 (Hessian base) ARI: {ari_admr0:.4f}")
print(f"ADMR layer 1 (Hess→Eucl t=1) ARI: {ari_admr1:.4f}")


In [10]:
!BUCKET="gs://medit-uml-prod-uscentral1-8e7a" && \
 gsutil -m rsync -r out/admr_logs "${BUCKET}/out/admr_logs"

CommandException: arg (out/admr_logs) does not name a directory, bucket, or bucket subdir.
If there is an object with the same path, please add a trailing
slash to specify the directory.


In [77]:
import numpy as np
import scanpy as sc
import yaml
from pathlib import Path

from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score

import pandas as pd

from scripts.EGGFM.eggfm import run_eggfm_dimred  # adjust path if needed


In [78]:
def subset_anndata(ad: sc.AnnData, n_cells: int, random_state: int = 0) -> sc.AnnData:
    """Randomly subsample n_cells from ad without replacement."""
    rng = np.random.default_rng(random_state)
    n = ad.n_obs
    n_subset = min(n_cells, n)
    idx = rng.choice(np.arange(n), size=n_subset, replace=False)
    return ad[idx].copy()


def compute_ari_fixed(X, labels, k, random_state: int = 0) -> float:
    """Deterministic ARI for given embedding + labels (first k dims)."""
    Xk = X[:, :k]
    km = KMeans(
        n_clusters=len(np.unique(labels)),
        n_init=10,
        random_state=random_state,
    )
    km.fit(Xk)
    return adjusted_rand_score(labels, km.labels_)


In [79]:
import copy
def run_metric_ablation_subset(
    base_ad: sc.AnnData,
    base_params: dict,
    *,
    exp_name: str,
    metric_mode: str,
    hessian_mix_mode: str | None,
    t: float,
    n_runs: int = 5,
    max_dm_layers: int = 6,
    n_cells_sample: int = 1000,
    base_seed: int = 0,
) -> pd.DataFrame:
    """
    Run a check_variance-style EGGFM+DM ablation on a random subset of cells.
    Returns a DataFrame with one row per (run, layer).
    """
    print(f"\n========== {exp_name} ==========")
    print(f"metric_mode={metric_mode}, hessian_mix_mode={hessian_mix_mode}, t={t}")
    params = copy.deepcopy(base_params)
    diff_cfg = params["eggfm_diffmap"]

    # set metric config
    diff_cfg["metric_mode"] = metric_mode
    diff_cfg["t"] = float(t)
    if metric_mode == "hessian_mixed":
        if hessian_mix_mode is None:
            raise ValueError("hessian_mix_mode must be set for hessian_mixed")
        diff_cfg["hessian_mix_mode"] = hessian_mix_mode

    spec = params["spec"]
    k = spec.get("ari_n_dims", spec.get("n_pcs", 10))
    label_key = spec["ari_label_key"]

    records: list[dict] = []

    for run in range(n_runs):
        run_seed = base_seed + run
        print(f"\n=== {exp_name} | Run {run+1}/{n_runs} (seed={run_seed}) ===", flush=True)

        # 1) subset
        ad_sub = subset_anndata(base_ad, n_cells_sample, random_state=run_seed)

        # 2) EGGFM: train DSM + build X_eggfm
        ad_sub, _ = run_eggfm_dimred(ad_sub, params)

        labels = ad_sub.obs[label_key].to_numpy()

        # layer 0 = EGGFM base
        X_eggfm = ad_sub.obsm["X_eggfm"][:, :k]
        ari0 = compute_ari_fixed(X_eggfm, labels, k, random_state=run_seed)
        print(f"  EGGFM base ARI: {ari0:.4f}")

        records.append(
            dict(
                exp_name=exp_name,
                run=run,
                layer=0,
                layer_label="EGGFM",
                metric_mode=metric_mode,
                hessian_mix_mode=hessian_mix_mode,
                t=t,
                n_cells=ad_sub.n_obs,
                ari=ari0,
            )
        )

        # 3) stack Euclidean DM layers (like your check_variance)
        rep_key = "X_eggfm"
        for layer in range(1, max_dm_layers + 1):
            sc.pp.neighbors(ad_sub, n_neighbors=30, use_rep=rep_key)
            sc.tl.diffmap(ad_sub, n_comps=k)
            X_dm = ad_sub.obsm["X_diffmap"][:, :k].copy()

            ari = compute_ari_fixed(X_dm, labels, k, random_state=run_seed)
            print(f"  EGGFM DM{layer} ARI: {ari:.4f}")

            layer_label = f"EGGFM_DM{layer}"
            records.append(
                dict(
                    exp_name=exp_name,
                    run=run,
                    layer=layer,
                    layer_label=layer_label,
                    metric_mode=metric_mode,
                    hessian_mix_mode=hessian_mix_mode,
                    t=t,
                    n_cells=ad_sub.n_obs,
                    ari=ari,
                )
            )

            rep_key = layer_label  # store & reuse
            ad_sub.obsm[rep_key] = X_dm

    df = pd.DataFrame.from_records(records)

    # quick summary
    print(f"\n--- Summary: {exp_name} ---")
    summary = (
        df.groupby("layer_label")["ari"]
        .agg(["mean", "std"])
        .sort_index()
    )
    display(summary)

    return df


In [None]:
# Load params + base AnnData once
params = {
    "seed": 7,
    "hvg_n_top_genes": 2000,

    "spec": {
        "n_pcs": 20,
        "dcol_max_cells": 3000,
        # "ari_label_key": "Cell type annotation",  # <-- this must match an obs column name
        "ari_label_key": "paul15_clusters",         # <-- this must match an obs column name
        "ari_n_dims": 10,                           # how many dims to use for ARI per embedding
        "ad_file": "data/paul15/paul15.h5ad",
        # "ad_file": "data/prep/qc.h5ad",
    },

    "qc": {
        "min_cells": 500,
        "min_genes": 200,
        "max_pct_mt": 15,
    },

    "eggfm_model": {
        "hidden_dims": [512, 512, 512, 512],
    },

    "eggfm_train": {
        "batch_size": 2048,
        "num_epochs": 50,
        "lr": 1.0e-4,
        "sigma": 0.1,
        "device": "cuda",
        "latent_space": "hvg",
    },

    "eggfm_diffmap": {
        "geometry_source": "pca",          # "pca" or "hvg"
        "energy_source": "hvg",            # where SCM/Hessian read energies
        "metric_mode": "scm",    # "euclidean", "scm", or "hessian_mixed"
        "n_neighbors": 30,
        "n_comps": 30,
        "device": "cuda",
        "hvp_batch_size": 1024,
        "eps_mode": "median",
        "eps_value": 1.0,
        "eps_trunc": "no",
        "distance_power": 1.0,
        "t": 2.0,
        "norm_type": "linf",

        # SCM hyperparams
        "metric_gamma": 0.2,
        "metric_lambda": 4.0,
        "energy_clip_abs": 3.0,
        "energy_batch_size": 2048,

        # Hessian mixing hyperparams
        "hessian_mix_mode": "none",   # "additive" | "multiplicative" | "none"
        "hessian_mix_alpha": 0.3,
        "hessian_beta": 0.2,
        "hessian_clip_std": 2.0,
        "hessian_use_neg": True,
    },
}
spec = params["spec"]
base = sc.read_h5ad(spec["ad_file"])

metric_experiments = [
    dict(
        exp_name="euclidean_t2",
        metric_mode="euclidean",
        hessian_mix_mode=None,
        t=2.0,
    ),
    dict(
        exp_name="hess_none_t2",
        metric_mode="hessian_mixed",
        hessian_mix_mode="none",
        t=2.0,          # your hypothesis: t=2 best for none
    ),
    dict(
        exp_name="hess_mult_t1",
        metric_mode="hessian_mixed",
        hessian_mix_mode="multiplicative",
        t=1.0,          # your hypothesis: t=1 best for multiplicative
    ),
    dict(
        exp_name="scm_t2",
        metric_mode="scm",
        hessian_mix_mode=None,
        t=2.0,
    ),
    dict(
        exp_name="scm_t1",
        metric_mode="scm",
        hessian_mix_mode=None,
        t=1.0,
    ),
]


In [None]:
all_results = []

for cfg in metric_experiments:
    df_cfg = run_metric_ablation_subset(
        base_ad=base,
        base_params=params,
        n_runs=5,            # tweak as you like
        max_dm_layers=6,     # “shit ton of layers”
        n_cells_sample=1000,
        base_seed=7,         # global seed offset
        **cfg,
    )
    all_results.append(df_cfg)

results_df = pd.concat(all_results, ignore_index=True)


In [None]:
summary_all = (
    results_df
    .groupby(["exp_name", "layer_label"])["ari"]
    .agg(["mean", "std"])
    .reset_index()
    .sort_values(["exp_name", "layer_label"])
)
display(summary_all)
