In [1]:
import os, time, random, warnings, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx

from typing import Dict, List, Tuple
from collections import defaultdict, Counter

from sklearn.metrics import (
    silhouette_score, adjusted_rand_score,
    confusion_matrix, silhouette_samples
)
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import rbf_kernel, pairwise_distances
from sklearn.datasets import make_blobs

from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform
from scipy.sparse.linalg import spilu, LinearOperator, cg, minres, gmres, bicgstab
from scipy.sparse import csc_matrix
from scipy.linalg import solve_triangular
from scipy.io import mmread
from scipy.stats import wasserstein_distance
import matplotlib.patches as patches


warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

# ========= Global, fair solve settings =========
REPEATS_PER_SOLVE = 7
INCLUDE_SETUP = True
SOLVER_DEFAULT = "gmres"

# ========= Artifact config =========
ARTIFACT_ROOT = os.path.abspath(os.getenv("ARTIFACT_ROOT", "."))
FIG_DIR   = os.path.join(ARTIFACT_ROOT, "figs")
TABLE_DIR = os.path.join(ARTIFACT_ROOT, "tables")
RESULTS_DIR = os.path.join(ARTIFACT_ROOT, "results")
FIG_EXT = "png"
FIG_DPI = 300

def set_seed(seed=42):
    random.seed(int(seed))
    np.random.seed(int(seed))

def _ensure_dirs():
    for d in [FIG_DIR, TABLE_DIR, RESULTS_DIR]:
        os.makedirs(d, exist_ok=True)

def _latex_table(df, path: str, caption: str = "", label: str = ""):
    if not os.path.isabs(path):
        if os.path.dirname(path) in ("", ".", None):
            path = os.path.join(TABLE_DIR, path)
        else:
            path = os.path.join(ARTIFACT_ROOT, path)
    os.makedirs(os.path.dirname(path), exist_ok=True)
    df = df.copy()
    df.to_latex(path, index=False, escape=True, caption=caption, label=label)
    print(f"[saved tex] {os.path.abspath(path)}")
    return path

def _savefig_stem(stem: str, tight: bool = True) -> str:
    _ensure_dirs()
    outfile = os.path.join(FIG_DIR, f"{stem}.{FIG_EXT}")
    if tight:
        try:
            plt.tight_layout()
        except Exception:
            pass
    plt.savefig(outfile, dpi=FIG_DPI, bbox_inches="tight")
    plt.close()
    print(f"[saved fig] {os.path.abspath(outfile)}")
    return outfile

def _console_title(title: str):
    bar = "=" * len(title)
    print(f"\n{bar}\n{title}\n{bar}")

def _console_subtitle(title: str):
    bar = "-" * len(title)
    print(f"\n{title}\n{bar}")

def _console_table(df: pd.DataFrame, title: str | None = None, round_ndigits: int | None = 4, max_rows: int = 50):
    if title:
        _console_subtitle(title)
    if round_ndigits is not None:
        df = df.copy()
        num_cols = df.select_dtypes(include=[np.number]).columns
        df[num_cols] = df[num_cols].round(round_ndigits)
    with pd.option_context('display.max_rows', max_rows,
                           'display.max_columns', None,
                           'display.width', 140):
        print(df.to_string(index=False))

def _console_kv(pairs: dict[str, object], title: str | None = None, pad: int = 24):
    if title:
        _console_subtitle(title)
    for k, v in pairs.items():
        print(f"{k:<{pad}} : {v}")

def _regress_loglog(x: np.ndarray, y: np.ndarray) -> tuple[float, float, float]:
    x = np.asarray(x, float); y = np.asarray(y, float)
    mask = (x > 0) & (y > 0)
    lx, ly = np.log(x[mask]), np.log(y[mask])
    a, b = np.polyfit(lx, ly, 1)
    yhat = a*lx + b
    ss_res = np.sum((ly - yhat)**2)
    ss_tot = np.sum((ly - np.mean(ly))**2)
    r2 = 1.0 - ss_res/max(ss_tot, 1e-12)
    return float(a), float(b), float(r2)

def _nx_seed(x) -> int:
    return int(int(x) % int(1e9+7))

# ========= Preconditioners / solvers =========
class IterationCounter:
    def __init__(self): self.count = 0
    def __call__(self, x=None): self.count += 1

def _build_preconditioner(A, pre_type):
    n = A.shape[0]
    if pre_type == 'None':
        return None
    if pre_type == 'Jacobi':
        diag = np.diag(A).copy()
        diag[np.abs(diag) < 1e-12] = 1.0
        M_inv = np.diag(1.0 / diag)
        return LinearOperator((n, n), matvec=lambda v: M_inv @ v)
    if pre_type.startswith('ILU'):
        drop, fill = 1e-5, 5
        if '(' in pre_type:
            inside = pre_type[pre_type.find('(')+1:pre_type.find(')')]
            kv = {k.strip(): v.strip() for k, v in (item.split('=') for item in inside.split(','))}
            drop = float(kv.get('d', kv.get('drop', drop)))
            fill = float(kv.get('f', kv.get('fill', fill)))
        try:
            ilu = spilu(csc_matrix(A), drop_tol=drop, fill_factor=fill)
            return LinearOperator((n, n), matvec=ilu.solve)
        except RuntimeError:
            return None
    if pre_type == 'SSOR':
        D = np.diag(np.diag(A))
        L = np.tril(A, k=-1)
        omega = 1.0
        def ssor_matvec(v):
            y = solve_triangular(D + omega * L, v, lower=True, check_finite=False)
            return solve_triangular(D.T + omega * L.T, (D @ y), lower=False, check_finite=False).T
        return LinearOperator((n, n), matvec=ssor_matvec)
    raise ValueError(pre_type)

def _compat_solve_attempts(solve_fn, A, b, base_kwargs, tol):
    attempts = [
        {"rtol": float(tol), "atol": 1e-12},
        {"rtol": float(tol)},
        {"tol":  float(tol)},
        {},
    ]
    last_err = None
    for extra in attempts:
        try:
            return solve_fn(A, b, **{**base_kwargs, **extra})
        except TypeError as e:
            last_err = e
            continue
    if last_err is not None:
        raise last_err

    return solve_fn(A, b, **base_kwargs)

def solve_iterative(
    A, b,
    pre_type='None',
    solver='gmres',
    maxiter_factor=2,
    repeats=REPEATS_PER_SOLVE,
    include_setup=INCLUDE_SETUP,
    tol=1e-8,
    M=None,
    t_setup=None,
):
    n = A.shape[0]

    if M is None:
        t0 = time.perf_counter()
        M = _build_preconditioner(A, pre_type)
        t_setup_val = time.perf_counter() - t0
    else:
        t_setup_val = float(t_setup) if t_setup is not None else 0.0

    solve_fn = {'cg': cg, 'minres': minres, 'gmres': gmres, 'bicgstab': bicgstab}[solver]

    it_list, sec_list, ok_list = [], [], []
    for _ in range(int(repeats)):
        counter = IterationCounter()
        t1 = time.perf_counter()
        info = -1

        base_kwargs = {'M': M, 'callback': counter, 'maxiter': int(n * maxiter_factor)}
        if solver == 'gmres':
            base_kwargs['restart'] = min(n, 30)

        try:
            _, info = _compat_solve_attempts(solve_fn, A, b, base_kwargs, tol)
        except (ValueError, RuntimeError, np.linalg.LinAlgError, TypeError):
            info = -1

        sec = time.perf_counter() - t1
        if include_setup:
            sec += t_setup_val

        success = (info == 0)
        iters = counter.count if success else n * maxiter_factor
        it_list.append(float(iters))
        sec_list.append(float(sec))
        ok_list.append(bool(success))

    return float(np.median(it_list)), float(np.median(sec_list)), bool(np.all(ok_list)), solver

def solve_with_preconditioner_cg(A, b, pre_type='None', maxiter_factor=2):
    n = A.shape[0]
    counter = IterationCounter()
    M = None
    if pre_type == 'Jacobi':
        diag = np.diag(A).copy(); diag[np.abs(diag) < 1e-12] = 1.0
        M_inv = np.diag(1.0 / diag)
        M = LinearOperator((n, n), matvec=lambda v: M_inv @ v)
    elif pre_type == 'ILU':
        try:
            ilu = spilu(csc_matrix(A), drop_tol=1e-5, fill_factor=5)
            M = LinearOperator((n, n), matvec=ilu.solve)
        except RuntimeError:
            return n * maxiter_factor
    _, exit_code = cg(A, b, M=M, callback=counter, maxiter=n * maxiter_factor)
    return counter.count if exit_code == 0 else n * maxiter_factor

def find_best_preconditioner_cg(A, b):
    preconditioners = ['None', 'Jacobi', 'ILU']
    iterations = {p: solve_with_preconditioner_cg(A, b, pre_type=p) for p in preconditioners}
    best_preconditioner = min(iterations, key=iterations.get)
    return best_preconditioner, iterations

def is_spd(A, tol=1e-12):
    if not np.allclose(A, A.T, rtol=tol, atol=tol):
        return False
    try:
        eigenvalues = np.linalg.eigvalsh(A)
        return np.min(eigenvalues) > tol
    except np.linalg.LinAlgError:
        return False

def solve_iterative_auto(A, b, pre_type='None', **kwargs):
    solver = 'gmres'
    if np.allclose(A, A.T):
        solver = 'minres'
        if is_spd(A):
            solver = 'cg'
    return solve_iterative(A, b, pre_type=pre_type, solver=solver, **kwargs)

# ========= Generators =========
def generate_covariance_matrix(n: int, n_samples: int, n_correlated_features: int,
                               random_state: int | None = None) -> np.ndarray:
    rng = np.random.default_rng(random_state)
    data = rng.standard_normal((n_samples, n))
    if n_correlated_features > 1:
        source = rng.standard_normal(n_samples) * 2
        for i in range(n_correlated_features):
            noise = rng.standard_normal(n_samples) * 0.5
            data[:, i] += source + noise
    return np.cov(data, rowvar=False)

def generate_kernel_matrix(n: int, n_clusters: int, cluster_std: float, gamma: float,
                           random_state: int | None = None) -> np.ndarray:
    points, _ = make_blobs(n_samples=n, centers=n_clusters, cluster_std=cluster_std,
                           n_features=10, random_state=int(random_state) if random_state is not None else None)
    return rbf_kernel(points, gamma=gamma)

def generate_goe_matrix(n: int, random_state: int | None = None) -> np.ndarray:
    rng = np.random.default_rng(random_state)
    A = rng.standard_normal((n, n))
    return (A + A.T) / np.sqrt(2 * n)

def generate_adjacency_matrix(n: int, m_edges: int, random_state: int | None = None) -> np.ndarray:
    G = nx.barabasi_albert_graph(n, m_edges, seed=_nx_seed(0 if random_state is None else random_state))
    return nx.to_numpy_array(G)

def generate_er_adjacency_matrix(n: int, p: float, random_state: int | None = None) -> np.ndarray:
    G = nx.erdos_renyi_graph(n, p, seed=_nx_seed(0 if random_state is None else random_state))
    return nx.to_numpy_array(G)

# ========= ASF/CSF & helpers =========
def _affine_scale_to_unit_interval(A: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    n = A.shape[0]

    # If A is Hermitian or Symmetry --> Eigenvalues are always real numbers
    if np.allclose(A, A.T, atol=1e-12):
        vals = np.linalg.eigvalsh(A)
        real_ok = True

    # Else -->  Eigenvalues can be complex numbers
    else:
        vals = np.linalg.eigvals(A)
        imag_max = float(np.max(np.abs(np.imag(vals))))
        mag_max = float(np.max(np.abs(vals))) + eps
        real_ok = (imag_max < 1e-8) or (imag_max/mag_max < 1e-8)
        if real_ok:
          vals = np.real(vals)

    if real_ok:
        lam_min, lam_max = float(np.min(vals)), float(np.max(vals))
        c = 0.5*(lam_max + lam_min)
        d = max(0.5*(lam_max - lam_min), eps)
        return (A - c*np.eye(n, dtype = A.dtype)) / d
    else:
        radius = float(np.max(np.abs(vals)))
        return A / max(radius, eps)


def hutch_chebyshev_moments(A: np.ndarray, K_max: int, p_vectors: int = 64,
                            scale_to_unit: bool = True, probe: str = "rademacher",
                            eta_damp: float = 0.06, seed: int | None = None,
                            return_se: bool = True) -> tuple[np.ndarray, np.ndarray] | np.ndarray:
    rng = np.random.default_rng(seed)
    A_unit = _affine_scale_to_unit_interval(A) if scale_to_unit else A
    n = A_unit.shape[0]

    def sample_vec():
        if probe == "rademacher":
            v = rng.integers(0, 2, size=n, dtype=np.int8) * 2 - 1
            return v.astype(float)
        elif probe == "gaussian":
            return rng.standard_normal(n)
        else:
            raise ValueError("probe must be 'rademacher' or 'gaussian'")

    sums = np.zeros(K_max, dtype=float)
    sums_sq = np.zeros(K_max, dtype=float)

    for _ in range(p_vectors):
        z = sample_vec()
        y_prev = z
        acc0 = float(z @ y_prev)
        y_curr = A_unit @ z
        acc1 = float(z @ y_curr)
        acc = [acc0, acc1]
        for k in range(2, K_max):
            y_next = 2 * (A_unit @ y_curr) - y_prev
            acc.append(float(z @ y_next))
            y_prev, y_curr = y_curr, y_next
        acc = np.asarray(acc, float)
        sums += acc; sums_sq += acc**2

    s = sums / p_vectors
    s[0] = float(n)
    s[1] = float(np.trace(A_unit))

    if eta_damp != 0.0:
        k = np.arange(K_max, dtype=float)
        s *= np.exp(-eta_damp * k)

    if not return_se:
        return s
    var = (sums_sq / p_vectors - (sums / p_vectors)**2) / max(p_vectors - 1, 1)
    se = np.sqrt(np.maximum(var, 0.0) / p_vectors)
    return s, se

def adaptive_spectral_fingerprint(A: np.ndarray, K_min: int=1, K_max: int=100,
                                  eta_damp: float=0.06, hankel_eps: float=1e-3,
                                  energy_tau: float=1e-3, stable_window: int=2,
                                  similarity_invariant: bool=True) -> tuple[int, np.ndarray]:
    n = A.shape[0]
    A_unit = _affine_scale_to_unit_interval(A) if similarity_invariant else A.copy()
    s0 = float(n)
    s1 = float(np.trace(A_unit))
    damped = [np.exp(-eta_damp*0)*s0, np.exp(-eta_damp*1)*s1]
    energy_total = damped[0]**2 + damped[1]**2
    T_prev = np.eye(n, dtype=A_unit.dtype)
    T_curr = A_unit.copy()
    small_energy_hits = 0
    K_star = K_max

    for k in range(2, K_max+1):
        T_next = 2 * A_unit @ T_curr - T_prev
        sk = float(np.trace(T_next))
        dk = np.exp(-eta_damp * k) * sk
        damped.append(dk)
        energy_total += dk**2

        energy_ratio = (dk**2) / max(energy_total, 1e-12)
        if energy_ratio < energy_tau: small_energy_hits += 1
        else: small_energy_hits = 0

        L = len(damped)
        if L < 4:
            hankel_ratio = 1.0
        else:
            nH = (L+1)//2
            Hh = np.empty((nH, nH), float)
            for i in range(nH):
                for j in range(nH):
                    Hh[i, j] = damped[i+j]
            svals = np.linalg.svd(Hh, compute_uv=False)
            smax = max(svals[0], 1e-16); smin = max(svals[-1], 1e-16)
            hankel_ratio = smin / smax

        if k >= K_min and (hankel_ratio < hankel_eps or small_energy_hits >= stable_window):
            K_star = k; break
        T_prev = T_curr; T_curr = T_next

    final = np.array(damped[:K_star], dtype=float)
    fingerprint = final / max(np.linalg.norm(final), 1e-12)
    return K_star, fingerprint

def adaptive_spectral_fingerprint_with_hutchinson(
    A: np.ndarray, K_min: int=1, K_max: int=100, eta_damp: float=0.06,
    hankel_eps: float=1e-3, energy_tau: float=1e-3, stable_window: int=2,
    similarity_invariant: bool=True, p_vectors: int=64, seed: int | None=None,
    probe: str="rademacher", se_guard: float=2.0) -> tuple[int, np.ndarray]:
    s, se = hutch_chebyshev_moments(
        A, K_max=K_max, p_vectors=p_vectors, scale_to_unit=similarity_invariant,
        probe=probe, eta_damp=eta_damp, seed=seed, return_se=True
    )
    energy_total = 0.0; vals=[]; hits=0; K_star=K_max
    warm_until = max(2, K_min)
    for k in range(min(warm_until, K_max)):
        energy_total += s[k]**2; vals.append(s[k])
    for k in range(warm_until, K_max):
        energy_total += s[k]**2; vals.append(s[k])
        er = (s[k]**2) / max(energy_total, 1e-12)
        rel = 0.0 if se is None else se[k] / (abs(s[k]) + 1e-12)
        energy_stop = er < (energy_tau * (1.0 + se_guard * rel))

        L = len(vals)
        if L >= 4:
            nH = (L + 1) // 2
            Hh = np.empty((nH, nH), float)
            for i in range(nH):
                for j in range(nH):
                    Hh[i, j] = vals[i + j]
            svals = np.linalg.svd(Hh, compute_uv=False)
            h_ratio = svals[-1] / max(svals[0], 1e-16)
            hankel_stop = (h_ratio < hankel_eps)
        else:
            hankel_stop = False

        if (k + 1) >= K_min and (energy_stop or hankel_stop):
            hits += 1
        else:
            hits = 0
        if hits >= stable_window:
            K_star = k + 1; break

    phi = np.array(vals[:K_star], float)
    phi /= max(np.linalg.norm(phi), 1e-12)
    return K_star, phi

def csf_fingerprint(A: np.ndarray, K: int, eta_damp: float=0.06, scale_to_unit: bool=True) -> np.ndarray:
    assert K >= 1
    A_unit = _affine_scale_to_unit_interval(A) if scale_to_unit else A
    n = A_unit.shape[0]
    s0 = float(n); v = [np.exp(-eta_damp*0) * s0]
    if K == 1:
        out = np.array(v, dtype=float); return out / max(np.linalg.norm(out), 1e-12)
    s1 = float(np.trace(A_unit)); v.append(np.exp(-eta_damp*1) * s1)
    T_prev = np.eye(n, dtype=A_unit.dtype); T_curr = A_unit.copy()
    for k in range(2, K):
        T_next = 2 * A_unit @ T_curr - T_prev
        sk = float(np.trace(T_next)); v.append(np.exp(-eta_damp * k) * sk)
        T_prev, T_curr = T_curr, T_next
    out = np.array(v, dtype=float)
    return out / max(np.linalg.norm(out), 1e-12)

def csf_fingerprint_with_hutchinson(
    A: np.ndarray, K: int, p_vectors: int=64, eta_damp: float=0.06, scale_to_unit: bool=True,
    seed: int | None=None, probe: str="rademacher") -> np.ndarray:
    assert K >= 1
    s = hutch_chebyshev_moments(A, K_max=K, p_vectors=p_vectors, scale_to_unit=scale_to_unit,
                                probe=probe, eta_damp=eta_damp, seed=seed, return_se=False)
    return s / max(np.linalg.norm(s), 1e-12)

# ========= Feature builders & utils =========
METHOD_LABELS = {
    "ASF": "ASF (Adaptive, exact trace)",
    "ASF-Hutch": "ASF-H (Adaptive, Hutchinson)",
    "CSF": "CSF-K (Fixed-K Chebyshev)",
    "CSF-Hutch": "CSF-K (Hutchinson)",
    "Fro": "Baseline: Frobenius",
    "SpecNorm": "Baseline: Spectral Norm (||·||₂)",
    "EigenTop": "Baseline: Top-m Eigenvalues",
    "RawMoment": "Baseline: Raw Power Moments",
    "HeatTrace": "Baseline: Heat-Trace @ T"
}

def eigen_top_features(A: np.ndarray, m: int = 16, similarity_invariant: bool = True) -> np.ndarray:
    A_unit = _affine_scale_to_unit_interval(A) if similarity_invariant else A.copy()
    if np.allclose(A_unit, A_unit.T, atol=1e-10):
        vals = np.linalg.eigvalsh(A_unit)
    else:
        vals = np.linalg.eigvals(A_unit).real
    vals = np.sort(np.abs(vals))[::-1]
    f = vals[:m] if len(vals) >= m else np.pad(vals, (0, m - len(vals)))
    return f / max(np.linalg.norm(f), 1e-12)

def raw_power_moments(A: np.ndarray, K: int = 10, similarity_invariant: bool = True) -> np.ndarray:
    A_unit = _affine_scale_to_unit_interval(A) if similarity_invariant else A.copy()
    n = A_unit.shape[0]
    moments = [float(n)]
    Ak = A_unit.copy()
    for k in range(1, K):
        if k > 1:
            Ak = Ak @ A_unit
        moments.append(float(np.trace(Ak)))
    v = np.array(moments, dtype=float)
    return v / max(np.linalg.norm(v), 1e-12)

def heat_trace_signature(A: np.ndarray, T: List[float] = [0.1, 0.3, 1.0, 3.0, 10.0],
                         similarity_invariant: bool = True) -> np.ndarray:
    A_unit = _affine_scale_to_unit_interval(A) if similarity_invariant else A.copy()
    As = (A_unit + A_unit.T) / 2.0
    evals = np.linalg.eigvalsh(As)
    feats = [float(np.sum(np.exp(-t * evals))) for t in T]
    v = np.array(feats, dtype=float)
    return v / max(np.linalg.norm(v), 1e-12)

def build_feature(A: np.ndarray, mode: str, **kwargs) -> Tuple[np.ndarray, int | None]:
    if mode == "ASF":
        K_star, phi = adaptive_spectral_fingerprint(A, **kwargs); return phi, K_star
    if mode == "ASF-Hutch":
        K_star, phi = adaptive_spectral_fingerprint_with_hutchinson(A, **kwargs); return phi, K_star
    if mode == "CSF":
        K = kwargs.get("K", 5)
        return csf_fingerprint(A, K=K, eta_damp=kwargs.get("eta_damp", 0.06),
                               scale_to_unit=kwargs.get("similarity_invariant", True)), None
    if mode == "CSF-Hutch":
        K = kwargs.get("K", 5)
        return csf_fingerprint_with_hutchinson(A, K=K, p_vectors=kwargs.get("p_vectors", 64),
                                               eta_damp=kwargs.get("eta_damp", 0.06),
                                               scale_to_unit=kwargs.get("similarity_invariant", True),
                                               seed=kwargs.get("seed", None),
                                               probe=kwargs.get("probe", "rademacher")), None
    if mode == "EigenTop":
        return eigen_top_features(A, m=kwargs.get("m", 16),
                                  similarity_invariant=kwargs.get("similarity_invariant", True)), None
    if mode == "RawMoment":
        return raw_power_moments(A, K=kwargs.get("K", 10),
                                 similarity_invariant=kwargs.get("similarity_invariant", True)), None
    if mode == "HeatTrace":
        return heat_trace_signature(A, T=kwargs.get("T", [0.1,0.3,1.0,3.0,10.0]),
                                    similarity_invariant=kwargs.get("similarity_invariant", True)), None
    raise ValueError(f"Unknown mode: {mode}")

def _align_len(v: np.ndarray, L: int) -> np.ndarray:
    if v.shape[0] == L: return v
    if v.shape[0] > L:  return v[:L]
    out = np.zeros(L, dtype=v.dtype); out[:v.shape[0]] = v; return out

def _stack_fingerprints(fplist: list[np.ndarray], mode: str = "pad", p: float = 0.9) -> np.ndarray:
    Ls = [len(v) for v in fplist]
    if mode == "pad":
        L = max(Ls); M = np.zeros((len(fplist), L), dtype=float)
        for i, v in enumerate(fplist): M[i, :len(v)] = v
        return M
    elif mode == "truncate_min":
        L = min(Ls); return np.array([v[:L] for v in fplist], dtype=float)
    elif mode == "truncate_p":
        assert 0 < p <= 1.0
        L = int(np.quantile(Ls, p)); L = max(1, L)
        return np.array([v[:L] if len(v) >= L else np.pad(v, (0, L-len(v))) for v in fplist], dtype=float)
    else:
        raise ValueError("mode must be one of {'pad','truncate_min','truncate_p'}")

def _pairwise_euclidean(vectors: List[np.ndarray], align_mode: str = "pad", p: float = 0.9) -> np.ndarray:
    lengths = {len(v) for v in vectors}
    if len(lengths) > 1:
        X = _stack_fingerprints(vectors, mode=align_mode, p=p)
    else:
        X = np.vstack(vectors)
    return pairwise_distances(X, metric='euclidean')

def _dist_euclid_aligned(a: np.ndarray, b: np.ndarray) -> float:
    L = max(len(a), len(b))
    aa = _align_len(a, L); bb = _align_len(b, L)
    return float(np.linalg.norm(aa - bb))

def get_nnz(A, tol=1e-9):
    return np.count_nonzero(np.abs(A) > tol)

# ========= Data makers for E1/E2 =========
def _make_4fam(n_size=80, N_samples=20, seed_offset=0):
    A = [generate_covariance_matrix(n=n_size, n_samples=200, n_correlated_features=5, random_state=seed_offset+i) for i in range(N_samples)]
    B = [generate_kernel_matrix(n=n_size, n_clusters=3, cluster_std=1.5, gamma=0.05, random_state=seed_offset+i) for i in range(N_samples)]
    C = [generate_goe_matrix(n=n_size, random_state=seed_offset+i) for i in range(N_samples)]
    D = [generate_adjacency_matrix(n=n_size, m_edges=3, random_state=seed_offset+i) for i in range(N_samples)]
    matrices = A + B + C + D
    labels = [0]*N_samples + [1]*N_samples + [2]*N_samples + [3]*N_samples
    names  = ['Covariance', 'Kernel', 'GOE', 'Adjacency-BA']
    return matrices, labels, names

def _make_more_relatives(fam: str, n: int, count: int, rng: np.random.Generator):
    rels = []
    seeds = rng.integers(10_000, 99_999, size=count)
    for s in seeds:
        if fam == "Adjacency":
            rels.append(generate_adjacency_matrix(n=n, m_edges=5, random_state=int(s)))
        elif fam == "Covariance":
            rels.append(generate_covariance_matrix(n=n, n_samples=240, n_correlated_features=6, random_state=int(s)))
        elif fam == "Kernel":
            rels.append(generate_kernel_matrix(n=n, n_clusters=3, cluster_std=1.6, gamma=0.06, random_state=int(s)))
        elif fam == "GOE":
            rels.append(generate_goe_matrix(n=n, random_state=int(s)))
        else:
            raise ValueError(fam)
    return rels

def _make_5fam(n_size=80, N_samples=20, seed_offset=0):
    A, y4, names4 = _make_4fam(n_size, N_samples, seed_offset)
    p_er = (n_size * 3) / (n_size * (n_size - 1) / 2)  # match BA density (m=3)
    E = [generate_er_adjacency_matrix(n=n_size, p=p_er, random_state=seed_offset+i) for i in range(N_samples)]
    matrices = A + E
    labels = y4 + [4]*N_samples
    names = names4 + ['Adjacency-ER']
    return matrices, labels, names

# ========= E0 =========
def run_E0_invariance_and_scaling(n_trials: int = 64, n: int = 100, seed: int = 7):
    _ensure_dirs()
    rng = np.random.default_rng(seed)
    rows = []
    dists_by_kind = defaultdict(list)

    for t in range(n_trials):
        A = rng.standard_normal((n, n)); A = (A + A.T) / 2
        P = rng.permutation(np.eye(n))
        D = np.diag(rng.random(n) + 0.1)
        S = rng.standard_normal((n, n))
        while np.linalg.matrix_rank(S) < n:
            S = rng.standard_normal((n, n))
        alpha = 10.0 ** rng.uniform(-2, 2)

        kinds = {
            "perm": P @ A @ P.T,
            "diag_sim": D @ A @ np.linalg.inv(D),
            "gen_sim": S @ A @ np.linalg.inv(S),
            "alpha": alpha * A
        }

        _, phi_A = adaptive_spectral_fingerprint(A, similarity_invariant=True)
        for kname, Tmat in kinds.items():
            _, phi_T = adaptive_spectral_fingerprint(Tmat, similarity_invariant=True)
            dist = _dist_euclid_aligned(phi_A, phi_T)
            dists_by_kind[kname].append(dist)
            rows.append({"trial": t, "transform": kname, "dist": dist, "mode": "scaled"})

        _, phi_A_raw = adaptive_spectral_fingerprint(A, similarity_invariant=False)
        for kname, Tmat in kinds.items():
            _, phi_T_raw = adaptive_spectral_fingerprint(Tmat, similarity_invariant=False)
            dist_raw = _dist_euclid_aligned(phi_A_raw, phi_T_raw)
            rows.append({"trial": t, "transform": kname, "dist": dist_raw, "mode": "no-scale"})

    df = pd.DataFrame(rows)
    stat_rows = []
    for mode in ["scaled", "no-scale"]:
        for k in ["perm", "diag_sim", "gen_sim", "alpha"]:
            sub = df[(df["transform"]==k) & (df["mode"]==mode)]["dist"].values
            mean = np.mean(sub); med = np.median(sub)
            q1, q3 = np.quantile(sub, [0.25, 0.75]); iqr = q3 - q1
            stat_rows.append({"mode": mode, "transform": k, "mean": mean, "median": med, "IQR": iqr})
    stats = pd.DataFrame(stat_rows)

    _latex_table(stats.round(3), "tables/E0_invariance_stats.tex",
                 caption="Invariance distances with/without similarity scaling.",
                 label="tab:E0_invariance")

    _console_title("E0: Invariance & Scaling — Console Summary")
    for mode in ["scaled", "no-scale"]:
        _console_table(
            stats[stats["mode"]==mode].drop(columns=["mode"]).sort_values("transform"),
            title=f"Mode = {mode}"
        )

    # --- REPLACE the E0 plotting block with this ---
    order = ["perm","diag_sim","gen_sim","alpha"]
    df_plot = df.copy()
    df_plot["log10dist"] = np.log10(np.clip(df_plot["dist"].values, 1e-18, None))

    fig, axes = plt.subplots(1, 2, figsize=(9.6, 3.8), sharey=True)
    for ax, mode in zip(axes, ["scaled", "no-scale"]):
        sub = df_plot[df_plot["mode"] == mode]
        for i, tr in enumerate(order):
            x = sub[sub["transform"] == tr]["log10dist"].values
            # box (distribution summary)
            ax.boxplot(x, positions=[i], vert=False, widths=0.6,
                      showfliers=False, medianprops=dict(linewidth=1.4))
            # strip (individual points)
            jitter = (np.random.rand(len(x)) - 0.5) * 0.25
            ax.plot(x, i + jitter, "o", ms=2.3, alpha=0.5)

        ax.set_title(f"E0: {mode}")
        ax.set_xlabel(r"$\log_{10}\, \|\Phi(A)-\Phi(T(A))\|_2$")
        ax.set_yticks(range(len(order)))
        ax.set_yticklabels(order)
        ax.axvline(-16, ls=":", lw=1.0)  # numerical floor guide

    fig.tight_layout()
    # keep the original filename if LaTeX already references it
    _savefig_stem("E0_invariance_hist")  # or: _savefig_stem("E0_invariance_boxstrip")

# --- add this helper near the top of run_E1_four_family_Ksweep (or globally) ---
def _pareto_front(df_in, ycol):
    pts = df_in.sort_values("runtime_sec")
    keep, best = [], -np.inf
    for _, r in pts.iterrows():
        if r[ycol] >= best:
            keep.append(r); best = r[ycol]
    return pd.DataFrame(keep)

# ========= E1 =========
def run_E1_four_family_Ksweep():
    _ensure_dirs()
    matrices, true_labels, _ = _make_4fam()
    Ks = [1,3,5,10,50]
    methods = []
    for K in Ks:
        methods.append(("CSF", {"K":K}))
    methods += [
        ("ASF", {}),
        ("ASF-Hutch", {"p_vectors":100}),
        ("EigenTop", {"m":16}),
        ("RawMoment", {"K":10}),
        ("HeatTrace", {"T":[0.1,0.3,1.0,3.0,10.0]}),
    ]

    rows = []; kstars = []
    for mode, kw in methods:
        t0 = time.perf_counter()
        feats = []; local_kstars = []
        for M in matrices:
            f, kstar = build_feature(M, mode=mode, similarity_invariant=True, **kw)
            feats.append(f)
            if kstar is not None: local_kstars.append(kstar)
        dist = _pairwise_euclidean(feats)
        clustering = AgglomerativeClustering(n_clusters=4, linkage='average', metric='precomputed')
        pred = clustering.fit(dist).labels_
        ari = adjusted_rand_score(true_labels, pred)
        sil = silhouette_score(dist, pred, metric="precomputed")
        dt = time.perf_counter() - t0
        rows.append({
            "method": METHOD_LABELS.get(mode, mode),
            "K_or_m": kw.get("K", kw.get("m", "")),
            "ARI": ari, "Silhouette": sil, "runtime_sec": dt
        })
        if local_kstars:
            kstars.extend(local_kstars)

    mats = np.array([m.flatten() for m in matrices])
    dist_fro = pairwise_distances(mats)
    pred_fro = AgglomerativeClustering(n_clusters=4, linkage='average', metric='precomputed').fit(dist_fro).labels_
    rows.append({"method": METHOD_LABELS["Fro"], "K_or_m":"-",
                 "ARI": adjusted_rand_score(true_labels, pred_fro),
                 "Silhouette": silhouette_score(dist_fro, pred_fro, metric="precomputed"),
                 "runtime_sec": np.nan})
    spec_norms = np.array([np.linalg.norm(m, 2) for m in matrices]).reshape(-1,1)
    dist_spec = pairwise_distances(spec_norms)
    pred_spec = AgglomerativeClustering(n_clusters=4, linkage='average', metric='precomputed').fit(dist_spec).labels_
    rows.append({"method": METHOD_LABELS["SpecNorm"], "K_or_m":"-",
                 "ARI": adjusted_rand_score(true_labels, pred_spec),
                 "Silhouette": silhouette_score(dist_spec, pred_spec, metric="precomputed"),
                 "runtime_sec": np.nan})

    df = pd.DataFrame(rows).sort_values(["method","K_or_m"])
    df.to_csv(os.path.join(RESULTS_DIR, "E1_ksweep.csv"), index=False)
    _latex_table(df.round(4), "tables/tbl_E1_ksweep.tex",
                 caption="E1: 4-family K-sweep + adaptive results.", label="tab:E1")

    _console_title("E1: 4-Family K-Sweep — Console Summary")
    _console_table(df.sort_values("ARI", ascending=False).reset_index(drop=True),
                   title="Sorted by ARI (desc)")
    _console_table(df.sort_values("Silhouette", ascending=False).reset_index(drop=True),
                   title="Sorted by Silhouette (desc)")

    # --- REPLACE the E1 plotting block with this ---
    base_names = df["method"].str.split(r" \(", n=1, expand=True)[0]
    df_plot = df.assign(method_base=base_names)

    palette = {
        "ASF":"C0", "ASF-H":"C1", "CSF-K":"C2",
        "Baseline: Top-m Eigenvalues":"C3",
        "Baseline: Raw Power Moments":"C4",
        "Baseline: Heat-Trace @ T":"C5",
        "Baseline: Frobenius":"C6",
        "Baseline: Spectral Norm (||·||₂)":"C7",
    }

    def _size_by_k(val):
        try:
            k = float(val)
            return 36 + 10*np.sqrt(max(k,1))
        except Exception:
            return 44

    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    for ax, ycol, ttl in [(axes[0], "ARI", "E1: ARI vs Runtime (log-x)"),
                          (axes[1], "Silhouette", "E1: Silhouette vs Runtime (log-x)")]:
        for _, r in df_plot.iterrows():
            mb = r["method_base"]
            ax.scatter(r["runtime_sec"], r[ycol],
                      s=_size_by_k(r["K_or_m"]),
                      edgecolor="white", linewidth=0.7,
                      c=palette.get(mb, "0.5"), alpha=0.95)
            if mb == "CSF-K" and r["K_or_m"] in (3, 5):
                ax.text(r["runtime_sec"]*1.03, r[ycol], f"K={int(r['K_or_m'])}", fontsize=8)

        # Pareto front (ours only)
        ours = df_plot[df_plot["method_base"].isin(["ASF", "ASF-H", "CSF-K"])]
        pf = _pareto_front(ours, ycol)
        if not pf.empty:
            ax.plot(pf["runtime_sec"], pf[ycol], "--", lw=1.2, color="k", alpha=0.6)

        ax.set_xscale("log")
        ax.set_xlabel("Runtime (s, log)")
        ax.set_ylabel(ycol)
        ax.grid(True, which="both", ls=":", lw=0.6)
        ax.set_title(ttl)

    # lean legend
    legend_order = ["ASF","ASF-H","CSF-K","Baseline: Top-m Eigenvalues",
                    "Baseline: Raw Power Moments","Baseline: Heat-Trace @ T"]
    handles = [plt.Line2D([0],[0], marker="o", ls="", color=palette.get(n,"0.5"), label=n)
              for n in legend_order if n in df_plot["method_base"].values]
    fig.legend(handles=handles, loc="lower center", ncol=3, frameon=False, bbox_to_anchor=(0.5, -0.02))
    fig.tight_layout()
    # keep the original filename if LaTeX already references it
    _savefig_stem("E1_tradeoff")  # or: _savefig_stem("E1_pareto")

# ========= E2 =========
def run_E2_five_family_ba_vs_er():
    _ensure_dirs()
    matrices, true_labels, domain_names = _make_5fam()
    N_domains = 5
    eval_set = [
        ("ASF", {}),
        ("CSF", {"K":5}),
        ("EigenTop", {"m":16}),
        ("HeatTrace", {"T":[0.1,0.3,1.0,3.0,10.0]})
    ]
    out_rows = []
    _console_title("E2: 5-Family (BA vs ER) — Console Summary")
    for mode, kw in eval_set:
        feats = []
        for M in matrices:
            f, _ = build_feature(M, mode=mode, similarity_invariant=True, **kw)
            feats.append(f)
        dist = _pairwise_euclidean(feats)
        clustering = AgglomerativeClustering(n_clusters=N_domains, linkage='average', metric='precomputed')
        pred = clustering.fit(dist).labels_
        ari = adjusted_rand_score(true_labels, pred)
        sil = silhouette_score(dist, pred, metric="precomputed")
        cm = confusion_matrix(true_labels, pred, labels=list(range(N_domains)))
        sil_s = silhouette_samples(dist, pred, metric="precomputed")
        sil_by_dom = [float(np.mean(sil_s[np.array(true_labels)==i])) for i in range(N_domains)]
        out_rows.append({
            "method": METHOD_LABELS.get(mode, mode),
            "ARI": ari, "Silhouette": sil,
            **{f"Silhouette[{domain_names[i]}]": sil_by_dom[i] for i in range(N_domains)}
        })
        _console_kv({
            "Method": METHOD_LABELS.get(mode, mode),
            "ARI": f"{ari:.4f}",
            "Silhouette": f"{sil:.4f}"
        })
        if mode == "ASF":
            df_cm = pd.DataFrame(cm,
                                 index=[f"True:{nm}" for nm in domain_names],
                                 columns=[f"Pred:C{j}" for j in range(N_domains)])
            _console_table(df_cm, title="ASF Confusion Matrix (counts)", round_ndigits=None)
            plt.figure(figsize=(6.5,5.2))
            sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                        xticklabels=[f"C{j}" for j in range(N_domains)],
                        yticklabels=domain_names)
            plt.xlabel("Predicted Cluster"); plt.ylabel("True Domain")
            plt.title("E2: Confusion matrix (ASF)")
            plt.tight_layout()
            _savefig_stem("E2_confusion")

    df = pd.DataFrame(out_rows)
    df.to_csv(os.path.join(RESULTS_DIR, "E2_5fam.csv"), index=False)
    _latex_table(df.round(4), "tables/tbl_E2_5fam.tex",
                 caption="E2: 5-family (BA vs ER) results with method breakdown.",
                 label="tab:E2")
    _console_table(df.sort_values("ARI", ascending=False).reset_index(drop=True),
                   title="Per-Method Scores (sorted by ARI)")
    plt.figure(figsize=(7.5,4.8))
    plt.bar(df["method"], df["ARI"])
    plt.ylabel("ARI"); plt.xticks(rotation=20, ha="right")
    plt.title("E2: ARI by method (5 families)")
    plt.tight_layout()
    _savefig_stem("E2_ARI_bar")

# --- [C] Adaptive vs Fixed-K framing -----------------------------------------
def summarize_Kstar_and_reco_default(kstars: np.ndarray | None = None):
    if kstars is None or len(kstars) == 0:
        matrices, _, _ = _make_4fam()
        kstars = []
        for M in matrices:
            K_star, _ = adaptive_spectral_fingerprint(M, K_max=32, similarity_invariant=True)
            kstars.append(K_star)
        kstars = np.array(kstars, dtype=float)

    summary = pd.DataFrame({
        "mean":[float(np.mean(kstars))],
        "median":[float(np.median(kstars))],
        "IQR":[float(np.quantile(kstars,0.75)-np.quantile(kstars,0.25))]
    })
    _latex_table(summary.round(2), "tables/Kstar_summary_repro.tex",
                 caption="Distribution of adaptive $K^*$ (reproduced).", label="tab:Kstar_repro")
    _console_table(summary.round(2), title="K* distribution (mean/median/IQR)")

    plt.figure(figsize=(6.8,4.2))
    plt.hist(kstars, bins=20, alpha=0.85)
    plt.axvline(5, linestyle="--")
    plt.title("Adaptive $K^*$ histogram (dashed: K=5 default)")
    plt.xlabel("K*"); plt.ylabel("count"); plt.tight_layout()
    _savefig_stem("Kstar_hist_repro")

    _console_subtitle("Recommendation")
    print("Default: CSF-K=5 (best Pareto). Use ASF to upper-bound K when domain requires.")

# --- [B1] Eigenvalue histogram + Wasserstein distance -------------------------
from scipy.stats import wasserstein_distance

def eigen_histogram_feature(A: np.ndarray, bins: int = 64, similarity_invariant: bool = True) -> tuple[np.ndarray, np.ndarray]:
    A_unit = _affine_scale_to_unit_interval(A) if similarity_invariant else A.copy()
    As = (A_unit + A_unit.T) / 2.0
    evals = np.linalg.eigvalsh(As)
    hist, edges = np.histogram(evals, bins=bins, range=(-1.0, 1.0), density=True)
    centers = 0.5*(edges[1:]+edges[:-1])
    hist = hist / max(np.linalg.norm(hist), 1e-12)
    return hist.astype(float), centers.astype(float)

def pairwise_wasserstein_dists(feats: List[tuple[np.ndarray, np.ndarray]]) -> np.ndarray:
    n = len(feats)
    D = np.zeros((n, n), dtype=float)
    for i in range(n):
        hi, xi = feats[i]
        for j in range(i+1, n):
            hj, xj = feats[j]

            if not np.array_equal(xi, xj):
                common = xi if len(xi) <= len(xj) else xj
                hi_ = np.interp(common, xi, hi, left=0.0, right=0.0)
                hj_ = np.interp(common, xj, hj, left=0.0, right=0.0)
                d = wasserstein_distance(common, common, u_weights=hi_, v_weights=hj_)
            else:
                d = wasserstein_distance(xi, xj, u_weights=hi, v_weights=hj)
            D[i, j] = D[j, i] = float(d)
    return D

def run_E2b_eigenhist_wemd(matrices: List[np.ndarray], labels: List[int], bins: int = 64):
    feats = [eigen_histogram_feature(M, bins=bins, similarity_invariant=True) for M in matrices]
    D = pairwise_wasserstein_dists(feats)
    pred = AgglomerativeClustering(n_clusters=len(set(labels)), linkage='average', metric='precomputed').fit(D).labels_
    ari = adjusted_rand_score(labels, pred)
    sil = silhouette_score(D, pred, metric="precomputed")
    return ari, sil, D

# --- [B2] Heat-kernel distance (spectral trace vector; L2 in trace-space) ----
def heat_kernel_trace_vector(A: np.ndarray, T: List[float] = [0.05,0.1,0.2,0.5,1.0], similarity_invariant: bool=True) -> np.ndarray:
    A_unit = _affine_scale_to_unit_interval(A) if similarity_invariant else A.copy()
    As = (A_unit + A_unit.T)/2.0
    evals = np.linalg.eigvalsh(As)
    v = np.array([np.sum(np.exp(-t*evals)) for t in T], dtype=float)
    return v / max(np.linalg.norm(v), 1e-12)

def run_E2b_heatkernel_L2(matrices: List[np.ndarray], labels: List[int], T: List[float] = [0.05,0.1,0.2,0.5,1.0]):
    feats = [heat_kernel_trace_vector(M, T=T, similarity_invariant=True) for M in matrices]
    X = np.vstack(feats)
    D = pairwise_distances(X, metric='euclidean')
    pred = AgglomerativeClustering(n_clusters=len(set(labels)), linkage='average', metric='precomputed').fit(D).labels_
    ari = adjusted_rand_score(labels, pred)
    sil = silhouette_score(D, pred, metric="precomputed")
    return ari, sil, D

# --- [B3] Lightweight WL-subtree features (graphs only) ----------------------
def _is_graph_adjacency(A: np.ndarray, tol: float = 1e-8) -> bool:
    if A.shape[0] != A.shape[1]: return False
    if not np.allclose(A, A.T, atol=tol): return False
    if np.any(np.diag(A) > tol): return False
    return True

def wl_subtree_feature(A: np.ndarray, h: int = 3) -> Counter:
    """Returns label-count dictionary as WL feature."""
    n = A.shape[0]
    deg = np.sum(np.abs(A) > 1e-12, axis=1).astype(int)
    labels = list(map(int, deg))
    feat = Counter(labels)
    # WL refinement
    for _ in range(h):
        new_labels = []
        for i in range(n):
            nbrs = np.where(np.abs(A[i]) > 1e-12)[0]
            multiset = tuple(sorted(labels[j] for j in nbrs))
            new_labels.append(hash((labels[i], multiset)))
        labels = new_labels
        feat.update(labels)
    return feat

def _wl_align_distance_dicts(F: List[Counter]) -> np.ndarray:
    vocab = set()
    for c in F: vocab |= set(c.keys())
    vocab = {k:i for i,k in enumerate(sorted(vocab))}
    X = np.zeros((len(F), len(vocab)), dtype=float)
    for r, c in enumerate(F):
        for k, v in c.items():
            X[r, vocab[k]] = float(v)
    X /= np.maximum(np.linalg.norm(X, axis=1, keepdims=True), 1e-12)
    return pairwise_distances(X, metric='euclidean')

def run_E2b_wl_graphs_only(matrices: List[np.ndarray], labels: List[int], h: int = 3):
    idx = [i for i,M in enumerate(matrices) if _is_graph_adjacency(M)]
    if not idx:
        return np.nan, np.nan, None
    F = [wl_subtree_feature(matrices[i], h=h) for i in idx]
    D = _wl_align_distance_dicts(F)

    lbl = [labels[i] for i in idx]
    pred = AgglomerativeClustering(n_clusters=len(set(lbl)), linkage='average', metric='precomputed').fit(D).labels_
    ari = adjusted_rand_score(lbl, pred); sil = silhouette_score(D, pred, metric="precomputed")
    return ari, sil, D

# --- [B*] One-shot strong baselines scoreboard -------------------------------
def run_E2b_strong_baselines_all():
    _ensure_dirs()
    matrices, true_labels, _ = _make_5fam()
    rows = []

    # --- Eigen-Hist + Wasserstein (needs eigendecomp)
    bins = 64
    t0 = time.perf_counter()
    ari, sil, _ = run_E2b_eigenhist_wemd(matrices, true_labels, bins=bins)
    dt = time.perf_counter() - t0
    rows.append({
        "method":"EigenHist+Wasserstein",
        "ARI": ari, "Silhouette": sil,
        "Dim": bins, "runtime_sec": dt, "Eigendecomp": "Yes"
    })

    # --- Heat-kernel (trace vector, needs eigendecomp)
    T = [0.05,0.1,0.2,0.5,1.0]
    t0 = time.perf_counter()
    ari, sil, _ = run_E2b_heatkernel_L2(matrices, true_labels, T=T)
    dt = time.perf_counter() - t0
    rows.append({
        "method":"HeatKernel(L2 on traces)",
        "ARI": ari, "Silhouette": sil,
        "Dim": len(T), "runtime_sec": dt, "Eigendecomp": "Yes"
    })

    # --- WL-subtree (graphs only; no eigendecomp)
    idx = [i for i,M in enumerate(matrices) if _is_graph_adjacency(M)]
    if idx:
        t0 = time.perf_counter()
        F = [wl_subtree_feature(matrices[i], h=3) for i in idx]
        D = _wl_align_distance_dicts(F)
        lbl = [true_labels[i] for i in idx]
        pred = AgglomerativeClustering(n_clusters=len(set(lbl)), linkage='average', metric='precomputed').fit(D).labels_
        ari = adjusted_rand_score(lbl, pred)
        sil = silhouette_score(D, pred, metric="precomputed")
        dt = time.perf_counter() - t0
        vocab = set()
        for c in F: vocab |= set(c.keys())
        rows.append({
            "method":"WL-Subtree (graphs)",
            "ARI": ari, "Silhouette": sil,
            "Dim": len(vocab), "runtime_sec": dt, "Eigendecomp": "No"
        })
    else:
        rows.append({
            "method":"WL-Subtree (graphs)",
            "ARI": float("nan"), "Silhouette": float("nan"),
            "Dim": float("nan"), "runtime_sec": float("nan"), "Eigendecomp": "No"
        })

    # --- ASF (ours) — collect K* stats; no eigendecomp (Chebyshev recursion)
    t0 = time.perf_counter()
    feats_asf, kstars = [], []
    for M in matrices:
        f, kstar = build_feature(M, "ASF", similarity_invariant=True)
        feats_asf.append(f)
        if kstar is not None: kstars.append(kstar)
    D_asf = _pairwise_euclidean(feats_asf)
    pred = AgglomerativeClustering(n_clusters=5, linkage='average', metric='precomputed').fit(D_asf).labels_
    ari = adjusted_rand_score(true_labels, pred)
    sil = silhouette_score(D_asf, pred, metric="precomputed")
    dt = time.perf_counter() - t0
    dim_asf = int(np.median(kstars)) if len(kstars) else None
    rows.append({
        "method":"ASF (ours)",
        "ARI": ari, "Silhouette": sil,
        "Dim": dim_asf, "runtime_sec": dt, "Eigendecomp": "No"
    })

    # --- CSF-K=5 (ours) — fixed tiny dimension; no eigendecomp
    t0 = time.perf_counter()
    feats_csf = [build_feature(M, "CSF", similarity_invariant=True, K=5)[0] for M in matrices]
    D_csf = _pairwise_euclidean(feats_csf)
    pred = AgglomerativeClustering(n_clusters=5, linkage='average', metric='precomputed').fit(D_csf).labels_
    ari = adjusted_rand_score(true_labels, pred)
    sil = silhouette_score(D_csf, pred, metric="precomputed")
    dt = time.perf_counter() - t0
    rows.append({
        "method":"CSF-K=5 (ours)",
        "ARI": ari, "Silhouette": sil,
        "Dim": 5, "runtime_sec": dt, "Eigendecomp": "No"
    })

    df = pd.DataFrame(rows).sort_values(["ARI","Silhouette"], ascending=False)
    df.to_csv(os.path.join(RESULTS_DIR, "E2b_strong_baselines.csv"), index=False)
    _latex_table(df.round(4), "tables/tbl_E2b_strong_baselines.tex",
                 caption="Strong baselines vs ours on 5 families (BA vs ER incl.). "
                         "Dim: feature dimension; Eigendecomp: whether full eigen-decomposition is required.",
                 label="tab:E2b_strong")

    _console_title("E2b: Strong baselines — Scoreboard")
    _console_table(df.reset_index(drop=True), title="Accuracy / Dim / Time / Eigendecomp", round_ndigits=4)

# ========= E3 =========
# --- [A] SuiteSparse mini-benchmark (real data) -------------------------------

def load_mtx(path: str) -> np.ndarray:
    M = mmread(path)
    A = M.toarray() if hasattr(M, "toarray") else np.array(M, dtype=float)
    return np.array(A, dtype=float)

def run_E3_suitesparse_realdata(mtx_paths: List[str], labels: List[str] | None = None):
    _ensure_dirs()
    matrices = [load_mtx(p) for p in mtx_paths]
    if labels is None:
        labels = list(range(len(matrices)))

    methods = [
        ("CSF-Hutch", {"K":5, "p_vectors":100, "eta_damp":0.06, "similarity_invariant": True}),
        ("ASF-Hutch", {"K_min":1, "K_max":32, "p_vectors":100, "eta_damp":0.06,
                      "similarity_invariant": True, "seed":123})
    ]
    rows = []
    for mode, kw in methods:
        feats = []
        for M in matrices:
            kw = {**kw}  # (optional) 안전하게 복사
            kw.setdefault("similarity_invariant", True)
            f, _ = build_feature(M, mode=mode, **kw)
            feats.append(f)
        dist = _pairwise_euclidean(feats)
        clustering = AgglomerativeClustering(n_clusters=len(set(labels)), linkage='average', metric='precomputed')
        pred = clustering.fit(dist).labels_
        ari = adjusted_rand_score(labels, pred)
        sil = silhouette_score(dist, pred, metric="precomputed")
        rows.append({"method": METHOD_LABELS.get(mode, mode), "ARI": ari, "Silhouette": sil})

    df = pd.DataFrame(rows)
    df.to_csv(os.path.join(RESULTS_DIR, "E3_suitesparse.csv"), index=False)
    _latex_table(df.round(4), "tables/tbl_E3_suitesparse.tex",
                 caption="E3 (Real): SuiteSparse mini-benchmark with Hutchinson.",
                 label="tab:E3_real")

    _console_title("E3 (Real): SuiteSparse mini-benchmark")
    _console_table(df.sort_values("ARI", ascending=False).reset_index(drop=True),
                   title="Per-Method Scores", round_ndigits=4)
    plt.figure(figsize=(6.5,4.8))
    plt.bar(df["method"], df["ARI"]); plt.ylabel("ARI"); plt.xticks(rotation=20, ha="right")
    plt.title("E3 (Real): ARI by method")
    plt.tight_layout(); _savefig_stem("E3_ARI_bar")

# ========= E4 =========
def run_E4_hutchinson_pvectors():
    _ensure_dirs()
    matrices, true_labels, _ = _make_4fam()
    p_list = [10, 50, 100, 200, 500, 1000]
    K_list = [3, 5]
    rows = []
    _console_title("E4: Hutchinson p-vectors — Console Summary")
    for K in K_list:
        for p in p_list:
            t0 = time.perf_counter()
            feats = [csf_fingerprint_with_hutchinson(M, K=K, p_vectors=p, eta_damp=0.06,
                                                     scale_to_unit=True, probe="rademacher") for M in matrices]
            dist = _pairwise_euclidean(feats)
            pred = AgglomerativeClustering(n_clusters=4, linkage='average', metric='precomputed').fit(dist).labels_
            ari = adjusted_rand_score(true_labels, pred)
            sil = silhouette_score(dist, pred, metric='precomputed')
            dt = time.perf_counter() - t0
            rows.append({"K": K, "p": p, "ARI": ari, "Silhouette": sil, "runtime_sec": dt})
    df = pd.DataFrame(rows).sort_values(["K","p"])
    df.to_csv(os.path.join(RESULTS_DIR, "E4_hutch.csv"), index=False)
    _latex_table(df.round(4), "tables/tbl_E4_hutch.tex",
                 caption="E4: CSF-Hutch ablation over p-vectors and K.",
                 label="tab:E4_hutch")
    for metric in ["ARI", "Silhouette", "runtime_sec"]:
        piv = df.pivot(index="p", columns="K", values=metric).sort_index()
        _console_table(piv, title=f"{metric} vs p (columns=K)", round_ndigits=4)

    plt.figure(figsize=(9,6))
    ax1 = plt.gca()
    for K in K_list:
        sub = df[df["K"]==K]
        ax1.plot(sub["p"], sub["ARI"], "o-", label=f"K={K} (ARI)")
        ax1.plot(sub["p"], sub["Silhouette"], "s--", label=f"K={K} (Sil)")
    ax1.set_xscale("log"); ax1.set_xlabel("p-vectors (log)")
    ax1.set_ylabel("Score"); ax1.grid(True, alpha=0.4)
    ax2 = ax1.twinx()
    for K in K_list:
        sub = df[df["K"]==K]
        ax2.plot(sub["p"], sub["runtime_sec"], "d-.", label=f"K={K} (Time)")
    ax2.set_ylabel("Runtime (s)")
    lines, labels = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    plt.legend(lines+lines2, labels+labels2, loc="best")
    plt.title("E4: Hutchinson p-vectors Ablation (CSF-Hutch)")
    _savefig_stem("E4_hutch_tradeoff")

# ========= E5 =========
def run_E5_noise_stability_loglog(n: int = 100, eps_range: tuple[float,float]=(-10, -1), num: int = 10, seed: int = 123):
    _ensure_dirs()
    rng = np.random.default_rng(seed)
    A = generate_goe_matrix(n=n, random_state=seed)
    _, phi0 = adaptive_spectral_fingerprint(A, similarity_invariant=True)
    epsilons = np.logspace(eps_range[0], eps_range[1], num)
    dists = []
    for eps in epsilons:
        noise = rng.standard_normal(A.shape)
        noise = noise / np.linalg.norm(noise, 'fro') * np.linalg.norm(A, 'fro') * eps
        A_pert = A + noise
        _, phi1 = adaptive_spectral_fingerprint(A_pert, similarity_invariant=True)
        d = _dist_euclid_aligned(phi0, phi1); dists.append(d)
    df = pd.DataFrame({"epsilon": epsilons, "dist": dists})
    df.to_csv(os.path.join(RESULTS_DIR, "E5_noise.csv"), index=False)
    slope, intercept, r2 = _regress_loglog(epsilons, dists)
    _latex_table(pd.DataFrame([{"slope": slope, "R2": r2}]).round(4),
                 "tables/tbl_E5_noise_slope.tex",
                 caption="E5: log–log regression of perturbation vs fingerprint distance.",
                 label="tab:E5_noise")
    _console_title("E5: Noise Stability — Console Summary")
    _console_kv({"log-log slope": f"{slope:.4f}", "R^2": f"{r2:.4f}"},
                title="Regression")
    _console_table(df.head(10), title="Sample of (epsilon, dist)")
    plt.figure(figsize=(7.5,5.3))
    plt.loglog(epsilons, dists, 'o-', label='data')
    xx = np.logspace(eps_range[0], eps_range[1], 200)
    yy = np.exp(intercept) * (xx ** slope)
    plt.loglog(xx, yy, '-', label=f'fit: slope={slope:.2f}, R²={r2:.3f}')
    plt.xlabel(r'$\epsilon$'); plt.ylabel(r'$\|\Phi(A)-\Phi(A+\Delta A)\|_2$')
    plt.title("E5: Noise Stability (log–log)")
    plt.grid(True, which="both", alpha=0.4)
    plt.legend()
    _savefig_stem("E5_noise_loglog")

# =============================================================================
# E6+: Adversarial Research — Diversified Traps & Queries (Auto-Solver)
# =============================================================================

# --- distance helpers for fingerprints ---
def _align_len(v, L):
    if len(v) >= L:
        return v[:L]
    return np.pad(v, (0, L - len(v)))

def _phi_distance(a, b, use_cosine=True, weight_power=1.0):
    L = max(len(a), len(b))
    va = _align_len(a, L)
    vb = _align_len(b, L)
    w = (np.arange(L) + 1.0) ** float(weight_power)
    if use_cosine:
        num = np.sum(w * va * vb)
        den = math.sqrt(np.sum(w * va * va) * np.sum(w * vb * vb)) + 1e-12
        return 1.0 - float(num / den)
    # weighted L2
    return float(math.sqrt(np.sum(w * (va - vb) ** 2) / (np.sum(w) + 1e-12)))

# --- recommenders ---
def recommend_from_db_topk_weighted(phi_query, db_entries, k=5, use_cosine=True, weight_power=1.0):
    if not db_entries:
        return 'None'
    k = int(max(1, k))
    dists = []
    for e in db_entries:
        d = _phi_distance(phi_query, e['phi'], use_cosine=use_cosine, weight_power=weight_power)
        dists.append(d)
    idx = np.argsort(dists)[:min(k, len(db_entries))]
    votes = Counter(db_entries[i]['best_precon'] for i in idx)
    return votes.most_common(1)[0][0]

def recommend_from_db_topk_weighted_fro(A_query, db_entries, k=5):
    if not db_entries:
        return 'None'
    d_norms = [np.linalg.norm(A_query - e['matrix'], 'fro') for e in db_entries]
    idx = np.argsort(d_norms)[:min(int(max(1, k)), len(db_entries))]
    votes = Counter(db_entries[i]['best_precon'] for i in idx)
    return votes.most_common(1)[0][0]

# --- traps ---
def build_trap(A_query, B_base):
    alpha = np.sum(A_query * B_base) / (np.sum(B_base * B_base) + 1e-18)
    return alpha * B_base

def build_trap_blend(Aq, Bbase, eps=0.15):
    # Blend towards distractor but keep close to Aq in Frobenius
    Aq_norm = np.linalg.norm(Aq, 'fro') + 1e-12
    Bb = Bbase * (np.linalg.norm(Aq, 'fro') / (np.linalg.norm(Bbase, 'fro') + 1e-12))
    C = (1.0 - eps) * Aq + eps * Bb
    # Re-normalize scale to match Aq
    C *= Aq_norm / (np.linalg.norm(C, 'fro') + 1e-12)
    return C

def _fair_trap_multi(Aq, rel_list, Atrap, tol_fro=0.05):
    nf = lambda X: np.linalg.norm(X, 'fro')
    fro_rel_dists = [nf(Aq - Ar) for Ar in rel_list]
    if not fro_rel_dists:
        return True
    return nf(Aq - Atrap) <= np.min(fro_rel_dists) * (1.0 + float(tol_fro))

# --- oracle for best preconditioner (using auto-solver defined earlier) ---
def _oracle_best_precon_auto(A, b, tol=1e-8):
    results = {}
    for p in ['None', 'Jacobi', 'ILU']:
        iters, _, ok, _ = solve_iterative_auto(A, b, pre_type=p, tol=tol)
        results[p] = iters if ok else float('inf')
    best_precon = min(results, key=results.get)
    return best_precon, results, results[best_precon]

# --- Probe-and-Switch: fast preflight to avoid trap mis-picks ---
def _choose_solver_name(A, tol=1e-12):
    if np.allclose(A, A.T):
        return 'cg' if is_spd(A, tol=tol) else 'minres'
    return 'gmres'

def _run_once_probe(A, b, pre_type='None', tol=1e-4, maxiter=30):
    n = A.shape[0]
    M = _build_preconditioner(A, pre_type)
    solver = _choose_solver_name(A)
    solve_fn = {'cg': cg, 'minres': minres, 'gmres': gmres, 'bicgstab': bicgstab}[solver]

    base_kwargs = {'M': M, 'maxiter': int(maxiter)}
    if solver == 'gmres':
        base_kwargs['restart'] = min(n, 30)

    try:
        x, info = _compat_solve_attempts(solve_fn, A, b, base_kwargs, tol)
    except Exception:
        return float('inf'), 1

    try:
        r = A @ x - b
        rel_res = float(np.linalg.norm(r) / (np.linalg.norm(b) + 1e-12))
    except Exception:
        rel_res = float('inf')

    return rel_res, info

def _probe_and_switch(A, b, candidates, probe_tol=1e-4, probe_iters=30):
    if not candidates:
        return 'None'
    order_unique = list(dict.fromkeys(candidates))
    ranked = []
    for p in order_unique:
        rel_res, info = _run_once_probe(A, b, pre_type=p, tol=probe_tol, maxiter=probe_iters)
        ranked.append((info == 0, -rel_res, p))  # converged first, then lower residual
    ranked.sort(reverse=True)
    return ranked[0][2]

def plot_E6_flow_diagram():
    _ensure_dirs()
    fig, ax = plt.subplots(figsize=(8.0, 4.2))
    ax.axis("off")

    def box(x,y,w,h,text):
        rect = patches.FancyBboxPatch((x,y), w, h, boxstyle="round,pad=0.02,rounding_size=0.05",
                                      linewidth=1.2, edgecolor="black", facecolor="white")
        ax.add_patch(rect)
        ax.text(x+w/2, y+h/2, text, ha="center", va="center")

    box(0.05, 0.55, 0.22, 0.25, "Query A_q\n(Φ_q via ASF/CSF)")
    box(0.35, 0.72, 0.26, 0.18, "Relatives DB\n(best precon via oracle)")
    box(0.35, 0.35, 0.26, 0.18, "Fairness-checked Trap\n(build_trap, _fair_trap_multi)")
    box(0.68, 0.55, 0.26, 0.25, "Recommenders\n• Φ-kNN\n• Fro-1NN / kNN\n(Probe-and-Switch)")

    def arrow(x1,y1,x2,y2):
        ax.annotate("", xy=(x2,y2), xytext=(x1,y1),
                    arrowprops=dict(arrowstyle="->", lw=1.2))

    arrow(0.27, 0.67, 0.35, 0.81)  # A_q -> Relatives DB
    arrow(0.27, 0.67, 0.35, 0.44)  # A_q -> Trap
    arrow(0.61, 0.81, 0.68, 0.67)  # DB -> Reco
    arrow(0.61, 0.44, 0.68, 0.67)  # Trap -> Reco
    arrow(0.27, 0.67, 0.68, 0.67)  # Φ_q -> Reco

    ax.text(0.50, 0.08, "Output: chosen preconditioner\n(near-oracle iterations; trap-robust)",
            ha="center", va="center")
    _savefig_stem("E6_flow_diagram")

def run_E6_probe_ablation_grid(
    seeds=(2025,2026,2027),
    K_DIM_list=(3,5,10),
    k_vote_list=(1,3,5),
    weight_power_list=(0.0,0.5,1.5),
    fairness_tol_list=(0.02,0.10),
    use_cosine_list=(True, False),
):
    rows = []
    for sd in seeds:
        for KD in K_DIM_list:
            for kv in k_vote_list:
                for wp in weight_power_list:
                    for ft in fairness_tol_list:
                        for uc in use_cosine_list:
                            df_on, _ = run_E6_adversarial_plus(
                                random_seed=sd, K_DIM=KD, k_vote=kv, N_REL=5,
                                fairness_tol=ft, use_cosine=uc, weight_power=wp,
                                probe_enable=True, probe_iters=30, probe_tol=1e-4
                            )
                            df_off, _ = run_E6_adversarial_plus(
                                random_seed=sd, K_DIM=KD, k_vote=kv, N_REL=5,
                                fairness_tol=ft, use_cosine=uc, weight_power=wp,
                                probe_enable=False
                            )
                            def pos_stats(s):
                                x = np.maximum(pd.to_numeric(s, errors='coerce').values, 0)
                                x = x[np.isfinite(x)]
                                if len(x)==0: return (0.0, 0.0, 0.0)
                                return float(np.mean(x)), float(np.median(x)), float(np.nanpercentile(x,90))
                            on_mean, on_med, on_p90 = pos_stats(df_on["extra_phy"])
                            off_mean, off_med, off_p90 = pos_stats(df_off["extra_phy"])
                            rows.append({
                                "seed":sd,"K_DIM":KD,"k_vote":kv,"w_pow":wp,"fair_tol":ft,"cos":uc,
                                "on_mean":on_mean,"on_med":on_med,"on_p90":on_p90,
                                "off_mean":off_mean,"off_med":off_med,"off_p90":off_p90,
                                "delta_p90": off_p90 - on_p90
                            })
    out = pd.DataFrame(rows)
    out = out.sort_values(["delta_p90","off_p90"], ascending=False)
    out.to_csv(os.path.join(RESULTS_DIR, "E6_probe_ablation_grid.csv"), index=False)
    _console_title("E6: Probe ablation grid — top deltas")
    _console_table(out.head(12), title="Rows with largest p90 regret reduction (OFF→ON)", round_ndigits=2)
    _latex_table(out.head(12).round(2), "tables/tbl_E6_probe_ablation_grid_top.tex",
                 caption="Probe-and-Switch reduces p90 regret under harder settings.",
                 label="tab:E6_probe_grid_top")
    return out

# --- [D2] Probe-and-Switch ablation (on/off) ----------------------------------
def run_E6_probe_ablation_once(seed=2025, K_DIM=10, k_vote=5, N_REL=5):
    df_on, _ = run_E6_adversarial_plus(
        random_seed=seed, K_DIM=K_DIM, k_vote=k_vote, N_REL=N_REL,
        probe_enable=True, probe_iters=30, probe_tol=1e-4
    )
    df_off, _ = run_E6_adversarial_plus(
        random_seed=seed, K_DIM=K_DIM, k_vote=k_vote, N_REL=N_REL,
        probe_enable=False
    )
    # regret = (iters_reco - iters_oracle)_{+}
    def extra_pos(s):
        x = np.maximum(s.values, 0)
        return float(np.mean(x)), float(np.median(x)), float(np.nanpercentile(x, 90))
    on_mean, on_med, on_p90 = extra_pos(df_on["extra_phy"])
    off_mean, off_med, off_p90 = extra_pos(df_off["extra_phy"])

    rows = [
        {"probe":"ON",  "extra_mean":on_mean,  "extra_median":on_med,  "extra_p90":on_p90},
        {"probe":"OFF", "extra_mean":off_mean, "extra_median":off_med, "extra_p90":off_p90},
    ]
    out = pd.DataFrame(rows)
    _latex_table(out.round(2), "tables/tbl_E6_probe_ablation.tex",
                 caption="Probe-and-Switch ablation: regret statistics.",
                 label="tab:E6_probe_abl")
    _console_title("E6: Probe-and-Switch ablation")
    _console_table(out, title="Regret stats (iters over oracle)", round_ndigits=2)
    return out

# --- core E6 runner ---
def run_E6_adversarial_plus(
    random_seed: int,
    K_DIM: int,
    k_vote: int,
    N_REL: int,
    fairness_retries: int = 20,
    use_cosine: bool = True,
    weight_power: float = 1.5,
    regret_threshold: int = 40,
    fairness_tol: float = 0.02,
    probe_enable: bool = True,
    probe_iters: int = 30,
    probe_tol: float = 1e-4,
    tol: float = 1e-8,
):
    # Header
    _console_title("E6+: Adversarial Research — Diversified Traps & Queries (Auto-Solver)")
    cfg = {
        "seed": random_seed, "K_DIM": K_DIM, "k_vote": k_vote, "N_REL": N_REL,
        "weight_power": weight_power, "use_cosine": use_cosine, "regret_threshold": regret_threshold
    }
    try:
        _console_subtitle("E6+ Config")
        for k, v in cfg.items():
            print(f"{k:<24}: {v}")
        print()
    except Exception:
        pass

    rng = np.random.default_rng(random_seed)
    n = 80
    b_query = rng.random(n)
    b_db_setup = rng.random(n)

    # query set
    queries = [
        ("Adjacency",  generate_adjacency_matrix(n, 5, 101)),
        ("Covariance", generate_covariance_matrix(n, 220, 5, 102)),
        ("Kernel",     generate_kernel_matrix(n, 3, 1.7, 0.05, 103)),
        ("GOE",        generate_goe_matrix(n, 104)),
    ]
    # trap seed families
    distractor_gens = {
        "Adjacency":  ("Kernel",     lambda: generate_kernel_matrix(n, 3, 2.0, 0.01, rng.integers(1000))),
        "Covariance": ("GOE",        lambda: generate_goe_matrix(n, rng.integers(1000))),
        "Kernel":     ("Adjacency",  lambda: generate_adjacency_matrix(n, 5, _nx_seed(rng.integers(1000)))),
        "GOE":        ("Covariance", lambda: generate_covariance_matrix(n, 200, 4, rng.integers(1000))),
    }

    rows = []
    for fam, A_q in queries:
        _console_subtitle(f"[Query] {fam}")
        _, phi_q = adaptive_spectral_fingerprint(A_q, K_max=K_DIM)

        # (A) relatives
        rel_seeds = rng.integers(1000, 5000, size=int(N_REL))
        if fam == "Adjacency":
            rel_mats = [generate_adjacency_matrix(n, 5, s) for s in rel_seeds]
        elif fam == "Covariance":
            rel_mats = [generate_covariance_matrix(n, 240, 6, s) for s in rel_seeds]
        elif fam == "Kernel":
            rel_mats = [generate_kernel_matrix(n, 3, 1.6, 0.06, s) for s in rel_seeds]
        else:  # GOE
            rel_mats = [generate_goe_matrix(n, s) for s in rel_seeds]

        db_entries = []
        for j, A_rel in enumerate(rel_mats):
            best_rel, _, _ = _oracle_best_precon_auto(A_rel, b_db_setup, tol=tol)
            db_entries.append({
                "name": f"{fam}-Relative-{j}", "matrix": A_rel, "best_precon": best_rel, "type": "relative"
            })

        # (B) fairness-checked trap from distractor family
        dis_name, dis_gen = distractor_gens[fam]
        A_trap = None
        for _ in range(int(fairness_retries)):
            A_d0 = dis_gen()
            cand = build_trap_blend(A_q, A_d0)
            if _fair_trap_multi(A_q, rel_mats, cand, tol_fro=fairness_tol):
                A_trap = cand
                break
            A_trap = cand  # fallback to last candidate if none pass
        best_trap, _, _ = _oracle_best_precon_auto(A_trap, b_db_setup, tol=tol)
        db_entries.append({"name": f"{dis_name}-Trap", "matrix": A_trap, "best_precon": best_trap, "type": "trap"})

        # (C) fingerprints for DB
        for e in db_entries:
            _, e["phi"] = adaptive_spectral_fingerprint(e["matrix"], K_max=K_DIM)

        # (D) Oracle on query
        oracle_name, _, oracle_iters = _oracle_best_precon_auto(A_q, b_query, tol=tol)

        # (E) Recommendations
        rec_phy  = recommend_from_db_topk_weighted(phi_q, db_entries, k=k_vote,
                                                   use_cosine=use_cosine, weight_power=weight_power)
        iters_phy, _, ok_phy, _ = solve_iterative_auto(A_q, b_query, pre_type=rec_phy, tol=tol)

        d_fro_all = [np.linalg.norm(A_q - e['matrix'], 'fro') for e in db_entries]
        rec_fro1  = db_entries[int(np.argmin(d_fro_all))]['best_precon']
        iters_fro1, _, ok_fro1, _ = solve_iterative_auto(A_q, b_query, pre_type=rec_fro1, tol=tol)

        rec_froK  = recommend_from_db_topk_weighted_fro(A_q, db_entries, k=k_vote)
        iters_froK, _, ok_froK, _ = solve_iterative_auto(A_q, b_query, pre_type=rec_froK, tol=tol)

        iters_none, _, ok_none, _ = solve_iterative_auto(A_q, b_query, pre_type='None', tol=tol)

        # (F) Probe-and-Switch to avoid spectral traps
        if probe_enable:
            cand_list = [rec_phy, rec_froK, rec_fro1]
            rec_best = _probe_and_switch(A_q, b_query, cand_list, probe_tol=probe_tol, probe_iters=probe_iters)
            if rec_best != rec_phy:
                rec_phy = rec_best
                iters_phy, _, ok_phy, _ = solve_iterative_auto(A_q, b_query, pre_type=rec_phy, tol=tol)
        else:
            # optional regret-based fallback
            if np.isfinite(oracle_iters) and (rec_phy != rec_froK) and ((iters_phy - oracle_iters) > regret_threshold):
                rec_phy = rec_froK
                iters_phy, _, ok_phy, _ = solve_iterative_auto(A_q, b_query, pre_type=rec_phy, tol=tol)

        succ = lambda p: 1 if p == oracle_name else 0

        rows.append({
            "seed": random_seed, "query_family": fam,
            "oracle_precon": oracle_name, "iters_oracle": oracle_iters,
            "reco_phy": rec_phy,   "iters_phy": iters_phy,   "success_phy":   succ(rec_phy)   if ok_phy   else 0, "extra_phy":   iters_phy - oracle_iters,
            "reco_fro1": rec_fro1, "iters_fro1": iters_fro1, "success_fro1":  succ(rec_fro1)  if ok_fro1  else 0, "extra_fro1":  iters_fro1 - oracle_iters,
            "reco_froK": rec_froK, "iters_froK": iters_froK, "success_froK":  succ(rec_froK)  if ok_froK  else 0, "extra_froK":  iters_froK - oracle_iters,
            "iters_none": iters_none
        })

    # results frame
    df = pd.DataFrame(rows)
    f_stem = f"E6_plus_seed-{random_seed}_KDIM-{K_DIM}_kvote-{k_vote}_NREL-{N_REL}"
    os.makedirs(RESULTS_DIR, exist_ok=True)
    df.to_csv(os.path.join(RESULTS_DIR, f"{f_stem}.csv"), index=False)

    # summary (mean/median/mean/p90)
    def _p90(s):
        x = pd.to_numeric(s, errors='coerce')
        x = x[np.isfinite(x)]
        return float(np.nanpercentile(x, 90)) if len(x) else np.nan

    grp = df.groupby("query_family", as_index=False).agg(
        phy_succ=("success_phy","mean"),
        fro1_succ=("success_fro1","mean"),
        froK_succ=("success_froK","mean"),
        phy_extra_med=("extra_phy","median"),
        fro1_extra_med=("extra_fro1","median"),
        froK_extra_med=("extra_froK","median"),
        phy_extra_mean=("extra_phy","mean"),
        fro1_extra_mean=("extra_fro1","mean"),
        froK_extra_mean=("extra_froK","mean"),
        phy_extra_p90=("extra_phy", _p90),
        fro1_extra_p90=("extra_fro1", _p90),
        froK_extra_p90=("extra_froK", _p90),
    )

    _console_table(grp, title=f"Summary for seed={random_seed}_KDIM={K_DIM}_kvote={k_vote}_NREL={N_REL}", round_ndigits=3)
    _latex_table(grp.round(3), f"{f_stem}_summary.tex",
                 caption=f"E6+ Summary (seed={random_seed}, K={K_DIM}, k={k_vote}, Nrel={N_REL})",
                 label=f"tab:e6_summary_{random_seed}")

    print(f"\nE6+ (seed={random_seed}_KDIM={K_DIM}_kvote={k_vote}_NREL={N_REL}) completed.")
    return df, grp


# --- convenience: multi-seed runner + final aggregation/latex ---
def run_E6_batch_and_summarize(
    seeds=(2025, 2026, 2027, 2028, 2029),
    K_DIM=10, k_vote=5, N_REL=5,
    fairness_retries=20, use_cosine=True, weight_power=1.5,
    regret_threshold=40, fairness_tol=0.02,
    probe_enable=True, probe_iters=30, probe_tol=1e-4,
    tol=1e-8,
):
    set_seed(42)
    _ensure_dirs()

    all_dfs = []
    for sd in seeds:
        df_seed, _ = run_E6_adversarial_plus(
            random_seed=sd, K_DIM=K_DIM, k_vote=k_vote, N_REL=N_REL,
            fairness_retries=fairness_retries, use_cosine=use_cosine, weight_power=weight_power,
            regret_threshold=regret_threshold, fairness_tol=fairness_tol,
            probe_enable=probe_enable, probe_iters=probe_iters, probe_tol=probe_tol,
            tol=tol,
        )
        if not df_seed.empty:
            all_dfs.append(df_seed)

    if not all_dfs:
        print("\nE6+ batch did not produce data.")
        return None

    df_all = pd.concat(all_dfs, ignore_index=True)
    _console_title("E6+ Final Aggregated Results (All Seeds)")

    # robust aggregation (works across pandas versions)
    final_metrics = {
        "Phy Success Rate":       float(df_all["success_phy"].mean()),
        "Fro-1NN Success Rate":   float(df_all["success_fro1"].mean()),
        "Fro-kNN Success Rate":   float(df_all["success_froK"].mean()),
        "Phy Extra Iters (Median)":   float(df_all["extra_phy"].median()),
        "Fro-1NN Extra Iters (Median)": float(df_all["extra_fro1"].median()),
        "Fro-kNN Extra Iters (Median)": float(df_all["extra_froK"].median()),
        "Phy Extra Iters (Mean)":      float(df_all["extra_phy"].mean()),
        "Fro-1NN Extra Iters (Mean)":  float(df_all["extra_fro1"].mean()),
        "Fro-kNN Extra Iters (Mean)":  float(df_all["extra_froK"].mean()),
        "Phy Extra Iters (p90)":       float(np.nanpercentile(df_all["extra_phy"], 90)),
        "Fro-1NN Extra Iters (p90)":   float(np.nanpercentile(df_all["extra_fro1"], 90)),
        "Fro-kNN Extra Iters (p90)":   float(np.nanpercentile(df_all["extra_froK"], 90)),
    }
    final_summary = pd.DataFrame({"Metric": list(final_metrics.keys()),
                                  "Value":  list(final_metrics.values())})

    _console_table(final_summary, title="Overall Performance Metrics", round_ndigits=2)
    _latex_table(final_summary.round(2), "E6_plus_final_summary.tex",
                 caption="Final aggregated results for the E6+ adversarial experiment across all seeds.",
                 label="tab:e6_final_summary")

    return df_all, final_summary

# ========= Main =========
def main():
    print("Starting the main experimental pipeline (E0–E5)")
    run_E0_invariance_and_scaling(n_trials=64, n=100, seed=7)
    run_E1_four_family_Ksweep()
    run_E2_five_family_ba_vs_er()
    run_E2b_strong_baselines_all()
    summarize_Kstar_and_reco_default()
    # =========================================================================
    # E3: SuiteSparse Real Data Mini-Benchmark (Robust Version)
    # =========================================================================
    _console_title("E3: SuiteSparse Real Data Mini-Benchmark")

    suitesparse_matrices = {
        "HB/bcsstk01": "Structural",
        "HB/bcsstk06": "Structural",
        "HB/gr_30_30": "Graph",        # was Grund/gr_30_30 (wrong)
        "AG-Monien/netz4504": "Graph"  # was GHS_psdef/netz4504 (wrong)
    }

    data_dir = os.path.join(ARTIFACT_ROOT, "suitesparse_data")
    os.makedirs(data_dir, exist_ok=True)

    suitesparse_paths = []
    suitesparse_labels = []

    print("Downloading and extracting SuiteSparse matrices (robust mode)...")
    for path, label in suitesparse_matrices.items():
        group, name = path.split('/')
        mtx_path = os.path.join(data_dir, name, f"{name}.mtx")

        if os.path.exists(mtx_path):
            print(f" -> Found existing {name}. Skipping download.")
            suitesparse_paths.append(mtx_path)
            suitesparse_labels.append(label)
            continue

        url = f"https://sparse.tamu.edu/MM/{group}/{name}.tar.gz"
        tar_path = os.path.join(data_dir, f"{name}.tar.gz")

        print(f" -> Downloading {name} from {url}")
        !wget --tries=3 -O {tar_path} {url}

        if os.path.exists(tar_path) and os.path.getsize(tar_path) > 1024:
            print(f" -> Download successful. Extracting {name}...")
            !tar -xzf {tar_path} -C {data_dir}
            os.remove(tar_path)

            if os.path.exists(mtx_path):
                suitesparse_paths.append(mtx_path)
                suitesparse_labels.append(label)
            else:
                print(f"!!! ERROR: Extraction failed for {name}, .mtx file not found.")
        else:
            print(f"!!! ERROR: Download failed for {name}. The file is empty or missing. Skipping.")
            if os.path.exists(tar_path):
                os.remove(tar_path)

    print(f"\nSuiteSparse data is ready. Proceeding with {len(suitesparse_paths)} matrices.\n")

    def _ok_for_clustering(paths, labels):
        return (len(paths) >= 3) and (len(set(labels)) >= 2)

    if _ok_for_clustering(suitesparse_paths, suitesparse_labels):
        run_E3_suitesparse_realdata(mtx_paths=suitesparse_paths, labels=suitesparse_labels)
    else:
        print("E3 skipped: need at least 3 matrices and >=2 distinct labels "
              f"(got {len(suitesparse_paths)} matrices, {len(set(suitesparse_labels))} label groups).")
    run_E4_hutchinson_pvectors()
    run_E5_noise_stability_loglog()
    print("\nExperimental pipeline (E0–E5) completed successfully.")

if __name__ == "__main__":
    set_seed(42)
    _ensure_dirs()

    main()

    run_E6_batch_and_summarize(
        seeds=(2025, 2026, 2027, 2028, 2029),
        K_DIM=10, k_vote=5, N_REL=5,
        use_cosine=True, weight_power=1.5,
        regret_threshold=40, fairness_tol=0.02,
        probe_enable=True, probe_iters=30, probe_tol=1e-4
    )
    plot_E6_flow_diagram()
    run_E6_probe_ablation_once(seed=2025, K_DIM=10, k_vote=5, N_REL=5)

    run_E6_probe_ablation_grid(
        seeds=(2025,),
        K_DIM_list=(3,5),
        k_vote_list=(1,5),
        weight_power_list=(0.0,1.5),
        fairness_tol_list=(0.02,0.10),
        use_cosine_list=(True, False),
    )



E6+: Adversarial Research — Diversified Traps & Queries (Auto-Solver)

E6+ Config
----------
seed                    : 2025
K_DIM                   : 10
k_vote                  : 5
N_REL                   : 5
weight_power            : 1.5
use_cosine              : True
regret_threshold        : 40


[Query] Adjacency
-----------------

[Query] Covariance
------------------

[Query] Kernel
--------------

[Query] GOE
-----------

Summary for seed=2025_KDIM=10_kvote=5_NREL=5
--------------------------------------------
query_family  phy_succ  fro1_succ  froK_succ  phy_extra_med  fro1_extra_med  froK_extra_med  phy_extra_mean  fro1_extra_mean  froK_extra_mean  phy_extra_p90  fro1_extra_p90  froK_extra_p90
   Adjacency       1.0        0.0        1.0            0.0            46.0             0.0             0.0             46.0              0.0            0.0            46.0             0.0
  Covariance       1.0        0.0        1.0            0.0            34.0             0.0       

KeyboardInterrupt: 