In [None]:
"""
Toy connectome-like demo using graphph.h0h1_connectome

Pipeline:
  0) Simulate 3 classes × 3 subjects of connectome-like distance matrices
  1) Save raw distances to toy_connectome_demo/raw_D with index.csv + meta.json
  2) Build H0/H1 features via build_features_for_sharded
  3) Take one subject, compute VR persistence + barcode plot
  4) Run edgewise MLE (run_and_save_all_mle)
  5) Run hierarchical NUTS (run_and_save_all_hierarchical)
"""

from pathlib import Path
import json
import csv

import numpy as np
import matplotlib.pyplot as plt

import graphph.h0h1_connectome as gpc

# ---------------------------------------------------------------------
# 0. Simulate toy "connectome-like" distance shards on disk
# ---------------------------------------------------------------------
ROOT = Path("toy_connectome_demo").resolve()
RAW_ROOT = ROOT / "raw_D"
FEAT_ROOT = ROOT / "feat_h0h1"
RUN_MLE_ROOT = ROOT / "runs_mle"
RUN_HMC_ROOT = ROOT / "runs_hmc"

for p in [RAW_ROOT, FEAT_ROOT, RUN_MLE_ROOT, RUN_HMC_ROOT]:
    p.mkdir(parents=True, exist_ok=True)

rng = np.random.default_rng()

n_roi = 10
n_classes = 3
n_subj_per_class = 3

class_ids = [f"Group{i}" for i in range(1, n_classes + 1)]

# Layout ROIs on a circle in 2D
theta = np.linspace(0, 2 * np.pi, n_roi, endpoint=False)
coords = np.vstack([np.cos(theta), np.sin(theta)]).T  # (n_roi, 2)


def base_connectivity(coords, length_scale=0.7):
    """Simple connectivity strength decaying with Euclidean distance."""
    n = coords.shape[0]
    W = np.zeros((n, n), dtype=float)
    for i in range(n):
        for j in range(i + 1, n):
            d = np.linalg.norm(coords[i] - coords[j])
            w_ij = np.exp(-d / length_scale)  # larger for closer ROIs
            W[i, j] = W[j, i] = w_ij
    return W


# -----------------------------
# Class-level connectivity templates
# -----------------------------
W_base = base_connectivity(coords, length_scale=0.7)

# Define three rough "modules"
clusterA = np.array([0, 1, ])
clusterB = np.array([3, 4, 5])
clusterC = np.array([6, 7, 8, 9])

W_templates = {}

# Group1: stronger within A and B (modular pattern)
W1 = W_base.copy()
W1[np.ix_(clusterA, clusterA)] *= 3.5
W1[np.ix_(clusterB, clusterB)] *= 3.5
W_templates["Group1"] = W1

# Group2: weaker A/A & B/B, but stronger A–B cross-talk
W2 = W_base.copy()
W2[np.ix_(clusterA, clusterA)] *= 0.3
W2[np.ix_(clusterB, clusterB)] *= 0.3
W2[np.ix_(clusterA, clusterB)] *= 3.0
W2[np.ix_(clusterB, clusterA)] *= 3.0
W_templates["Group2"] = W2

# Group3: strong C cluster + stronger long-range edges
W3 = W_base.copy()
W3[np.ix_(clusterC, clusterC)] *= 3.0
# "Far" edges = pairs with large Euclidean distance on the circle
dist_mat = np.linalg.norm(coords[:, None, :] - coords[None, :, :], axis=-1)
far_mask = dist_mat > 1.6
W3[far_mask] *= 2.0
W_templates["Group3"] = W3

# -----------------------------
# Convert templates to distance and add subject-level noise
# -----------------------------
eps = 1e-3
subj_noise_sd = 0.01     # within-class variability

index_rows = []
global_eps_max = 0.0

for cid in class_ids:
    class_dir = RAW_ROOT / cid
    class_dir.mkdir(parents=True, exist_ok=True)

    W_template = W_templates[cid]

    # --- NEW: Laplacian-Eigenmaps distance instead of 1/(W+eps) ---
    # For toy data, we set min_weight=None so we don't threshold edges by 20.
    D_template, le_info = gpc.weights_to_le_distance(
        W_template,
        k=5,              # same as your real pipeline
        min_weight=None,  # IMPORTANT: no fiber-count threshold for toy W
        sym="avg",
        transform="log1p",
        rescale=True,
    )
    # D_template is already symmetric with 0 diagonal by construction.

    for k_subj in range(n_subj_per_class):
        # multiplicative symmetric noise around the class template distance
        noise = rng.normal(loc=0.0, scale=subj_noise_sd, size=(n_roi, n_roi))
        noise = 0.5 * (noise + noise.T)
        np.fill_diagonal(noise, 0.0)

        D = D_template * (1.0 + noise)
        D = np.clip(D, 0.0, None)

        finite = np.isfinite(D)
        subj_eps_max = float(np.max(D[finite]))
        subj_cap = subj_eps_max + 0.5
        global_eps_max = max(global_eps_max, subj_eps_max)

        rel_path = f"{cid}/subj_{k_subj:03d}.npz"
        shard_path = RAW_ROOT / rel_path

        np.savez_compressed(
            shard_path,
            D=D.astype(np.float32),
            eps_max=np.array(subj_eps_max, dtype=np.float32),
            cap=np.array(subj_cap, dtype=np.float32),
        )

        index_rows.append(
            {
                "cid": cid,
                "file": rel_path,
                "subject": f"{cid}_subj_{k_subj:03d}",
            }
        )

# Write index.csv expected by build_features_for_sharded
index_path = RAW_ROOT / "index.csv"
with index_path.open("w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=["cid", "file", "subject"])
    writer.writeheader()
    writer.writerows(index_rows)

# Optional: meta.json with global VR distance info
meta = {
    "distance_info": {
        "vr": {
            "eps_max": float(global_eps_max),
            "cap": float(global_eps_max + 0.5),
        }
    },
    "n_roi": n_roi,
    "class_ids": class_ids,
}
(RAW_ROOT / "meta.json").write_text(json.dumps(meta, indent=2))

print(f"[ok] wrote raw shards + index.csv to {RAW_ROOT}")


# ---------------------------------------------------------------------
# 1. Build H0/H1 features from raw shards
# ---------------------------------------------------------------------
gpc.build_features_for_sharded(
    base_dir=str(RAW_ROOT),
    out_dir=str(FEAT_ROOT),
    divide_by_two=True,   # VR radius-time convention, like your real pipeline
    prefer_filled=False,  # keep +inf semantics if there were any
    prefer_vr=False,      # we only stored D; no D_vr here
    overwrite=True,
)

print(f"[ok] feature files written to {FEAT_ROOT}")


# ---------------------------------------------------------------------
# 2. Pick one subject and visualize VR persistence (barcode)
# ---------------------------------------------------------------------
# Just grab the first Group1 shard we created
example_npz = sorted((RAW_ROOT / "Group1").glob("subj_*.npz"))[0]
with np.load(example_npz) as z:
    D_example = z["D"]

st, persistence, max_edge = gpc.vr_persistence_from_distance(
    D_example,
    max_dim=2,
)

print(f"[info] example VR: max_edge = {max_edge:.4f}")

# One stacked figure with all dimensions (0,1,2) as in your snippet
gpc.plot_all_dims_barcode(persistence, use_latex=False)
plt.suptitle("Toy connectome: VR barcodes for one subject (Group1)")
plt.show()


# ---------------------------------------------------------------------
# 3. MLE run (edgewise, no prior/Jacobian) using run_and_save_all_mle
# ---------------------------------------------------------------------
MLE_m = 5  # same as in your real FitConfig

class_Lambda_mle, C_mle, acc_mle = gpc.run_and_save_all_mle(
    feat_dir=str(FEAT_ROOT),
    out_root=str(RUN_MLE_ROOT),
    m=MLE_m,
    use_h1=True,
    init="zeros",      # or "nuts_mean" if you later want warm-starts
    maxiter=1000,      # keep it smaller for toy demo
    gtol=1e-6,
)

print("\n[ok] finished MLE on toy data")
print("  class_Lambda_mle type:", type(class_Lambda_mle))
print("  C_mle type:", type(C_mle))
print("  acc_mle:", acc_mle)


# ---------------------------------------------------------------------
# 4. Hierarchical NUTS run using run_and_save_all_hierarchical
# ---------------------------------------------------------------------
# You can either:
#   - use zeros / defaults inside the function, or
#   - warm-start with MLE estimates as in your real pipeline.
#
# For the toy example we'll call the simplest version; if your actual
# function requires lambda0_by_class / bar_lambda0_init, uncomment the
# relevant lines and compute them from class_Lambda_mle.

fitcfg = gpc.FitConfig(
    m=MLE_m,
    kappa=6.0,
    kappa0=1.0,
    alpha=0.2,
    num_warmup=500,
    num_samples=2000,   # smaller than your real run
    num_chains=1,
    target_accept=0.8,
    dense_mass=False,
    seed=2025,
)

# OPTIONAL warm starts from MLE (uncomment if your function expects them)
# mle_Lambdas = class_Lambda_mle  # whatever structure your code uses
# bar_lambda0 = np.mean(mle_Lambdas, axis=0)

hmc_out = gpc.run_and_save_all_hierarchical(
    feat_dir=str(FEAT_ROOT),
    out_root=str(RUN_HMC_ROOT),
    fitcfg=fitcfg,
    class_ids=None,        # use all classes found in index.csv
    use_h1=True,
    # If your version *requires* these, uncomment and adapt:
    # lambda0_by_class=mle_Lambdas,
    # bar_lambda0_init=bar_lambda0,
    # phi_diag_indices_bar=[(0, 0), (10, 0)],
    # phi_diag_indices_each=[(0, 0), (10, 0)],
    max_lag=50,
    thin_every_for_diag=1,
    save_per_class_phi=True,
)

print("\n[ok] finished hierarchical HMC on toy data")
print("  hmc_out type:", type(hmc_out))
try:
    print("  hmc_out keys:", list(hmc_out.keys()))
except Exception:
    pass


In [None]:
# ---------------------------------------------------------------------
# 5. Confusion matrix from HMC posterior mean Λ
# ---------------------------------------------------------------------
from pathlib import Path
import numpy as np

# Try to get class → Lambda_bar mapping from the returned object
class_Lambda_hmc = {}

if isinstance(hmc_out, dict):
    # Case 1: run_and_save_all_hierarchical already returns a dict of Λ's
    if "class_Lambda_bar" in hmc_out:
        # e.g. {"Group1": Λ_bar1, "Group2": Λ_bar2, ...}
        class_Lambda_hmc = hmc_out["class_Lambda_bar"]
    elif "class_Lambda" in hmc_out:
        class_Lambda_hmc = hmc_out["class_Lambda"]
    else:
        # Sometimes the function *itself* just returns {cid: Λ_bar}
        # If the values look like arrays, treat hmc_out as the mapping
        if all(isinstance(v, np.ndarray) for v in hmc_out.values()):
            class_Lambda_hmc = hmc_out

# If that didn’t work, fall back to loading per-class NPZs from disk
if not class_Lambda_hmc:
    class_Lambda_hmc = {}
    for cid in class_ids:
        # Adjust these filenames / keys to match what your code actually saves
        npz_bar = RUN_HMC_ROOT / cid / "Lambda_bar.npz"
        npz_hat = RUN_HMC_ROOT / cid / "Lambda_hat.npz"

        if npz_bar.exists():
            with np.load(npz_bar) as z:
                # try common dataset names
                if "Lambda_bar" in z.files:
                    class_Lambda_hmc[cid] = z["Lambda_bar"]
                elif "Lambda_hat" in z.files:
                    class_Lambda_hmc[cid] = z["Lambda_hat"]
                else:
                    raise KeyError(f"No Lambda_* in {npz_bar}")
        elif npz_hat.exists():
            with np.load(npz_hat) as z:
                if "Lambda_hat" in z.files:
                    class_Lambda_hmc[cid] = z["Lambda_hat"]
                elif "Lambda_bar" in z.files:
                    class_Lambda_hmc[cid] = z["Lambda_bar"]
                else:
                    raise KeyError(f"No Lambda_* in {npz_hat}")
        else:
            raise FileNotFoundError(
                f"Could not find posterior Λ file for class {cid} "
                f"in {RUN_HMC_ROOT/cid}"
            )

print("\n[info] HMC posterior mean Λ loaded for classes:")
for cid, Lam in class_Lambda_hmc.items():
    print(f"  {cid}: Λ shape = {Lam.shape}")

# Now compute confusion matrix using *HMC* Λ's instead of MLE ones
C_hmc, cls_hmc, acc_hmc = gpc.save_confusion_from_feats(
    feat_dir=str(FEAT_ROOT),
    class_Lambda=class_Lambda_hmc,
    out_png=RUN_HMC_ROOT / "confusion_hmc.png",
    out_csv=RUN_HMC_ROOT / "confusion_hmc.csv",
)

print("\n[ok] HMC-based confusion")
print("classes order:", cls_hmc)
print("confusion matrix:\n", C_hmc)
print("HMC accuracy:", acc_hmc)


# Latent postproc

In [None]:
from pathlib import Path
from graphph.latent_postproc import (
    plot_latent_coords_panels,
    latent_violin_from_samples,
    fdr_select_rois_from_latentdiff,
    replot_latent_violin_from_npz,
)

RUN_DIR = Path(RUN_HMC_ROOT)

# latent coords panels
plot_latent_coords_panels(
    RUN_DIR,
    groups=["Group1", "Group2", "Group3"],
    align_to="Group2",
    out_dir="figs/latent_coords",
    x_stretch=1.01,
)

# save per-draw diffs for each pair
pairs = [("Group1", "Group2"), ("Group1", "Group3"), ("Group2", "Group3")]
for A, B in pairs:
    latent_violin_from_samples(
        RUN_DIR,
        pair=(A, B),
        rank=5,
        ref_label="Group2",
        node_csv="data/node_labels.csv",
        top_k=None,
        sort=False,
        color="C0",
        label_map={"Group1": "Group1", "Group2": "Group2", "Group3": "Group3"},
        save_draw_diffs_dir="tabs/fdr_latent_diffs",
    )

Path("tabs/fdr_latent_diffs").mkdir(parents=True, exist_ok=True)
Path("figs/roi_violin").mkdir(parents=True, exist_ok=True)

pairs = [
    ("Group1", "Group2"),
    ("Group1", "Group3"),
    ("Group2", "Group3"),
]

for A, B in pairs:
    print(f"\n=== FDR + violin for {A} vs {B} ===")

    # .npz with per-draw latent diffs – already created by latent_violin_from_samples
    npz_path = f"tabs/fdr_latent_diffs/latentdiff_draws_rank5_{A}_vs_{B}.npz"

    # pair-specific FDR CSV
    fdr_csv = f"tabs/fdr_latent_diffs/fdr_selected_rois_rank5_{A}_vs_{B}.csv"

    # 1) run FDR for this pair
    fdr_out = fdr_select_rois_from_latentdiff(
        npz_path,
        beta=0.9,
        out_csv=fdr_csv,
    )

    # 2) manuscript-style violin replot with FDR labels for this pair
    png_path = f"figs/roi_violin/violin_rank5_{A}_vs_{B}_manuscript.png"

    replot_latent_violin_from_npz(
        npz_path=npz_path,
        top_k=None,              # all ROIs
        sort=False,
        fig_width=6.0,
        fig_height=3.4,
        fdr_csv=fdr_csv,         # use the pair-specific FDR file
        annotate_fdr_points=True,
        png_path=png_path,
    )
