# Node-Level SRI vs Graph-Level SRI: Testing Walk Resolution Limit Theory

This notebook demonstrates an experiment comparing **node-level** vs **graph-level** Spectral Resolution Index (SRI) for predicting RWSE-vs-LapPE distinguishability gaps.

**Key Idea**: Rather than using a single graph-level spectral gap to predict encoding quality, we compute per-node SRI from eigenvector localization. For ZINC graphs (with complete eigenvector coverage), mean node-level SRI achieves Spearman rho ~0.41 vs graph-level SRI rho ~0.21, nearly doubling correlation strength.

**Pipeline Phases**:
1. **Node-level SRI** computation from local spectral measures with threshold sensitivity
2. **Pairwise distinguishability** analysis (RWSE vs sign-free LapPE)
3. **Correlation comparison** with bootstrap CIs and partial correlations controlling for graph size
4. **SRWE benefit prediction** with Tikhonov regularization and Mann-Whitney tests
5. **Visualization** of results

In [1]:
import subprocess, sys
def _pip(*a): subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', *a])

# psutil, loguru — NOT on Colab, always install
_pip('psutil==7.0.0')
_pip('loguru==0.7.3')

# numpy, scipy, matplotlib — pre-installed on Colab, install locally only
if 'google.colab' not in sys.modules:
    _pip('numpy==2.0.2', 'scipy==1.16.3', 'matplotlib==3.10.0')


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


## Imports

In [2]:
import json
import math
import os
import sys
import time
import warnings
from typing import Any, Optional

import numpy as np
from scipy import stats
from scipy.optimize import nnls
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

## Data Loading

In [3]:
GITHUB_DATA_URL = "https://raw.githubusercontent.com/AMGrobelnik/ai-invention-ace67e-the-walk-resolution-limit-a-super-resolu/main/experiment_iter6_node_level_sri/demo/mini_demo_data.json"

def load_data():
    try:
        import urllib.request
        with urllib.request.urlopen(GITHUB_DATA_URL) as response:
            return json.loads(response.read().decode())
    except Exception: pass
    if os.path.exists("mini_demo_data.json"):
        with open("mini_demo_data.json") as f: return json.load(f)
    raise FileNotFoundError("Could not load mini_demo_data.json")

In [4]:
data = load_data()

# Organize data by dataset name
datasets = {}
for ds in data.get("datasets", []):
    ds_name = ds["dataset"]
    if ds_name not in datasets:
        datasets[ds_name] = []
    datasets[ds_name].extend(ds.get("examples", []))

for ds_name, examples in datasets.items():
    print(f"{ds_name}: {len(examples)} examples")

ZINC-subset: 50 examples


## Configuration

All tunable parameters are defined here. Start with minimum values for testing.

In [5]:
# ── Tunable Parameters ──
# Original values (for full production run with 16,100 graphs):
# K_WALK = 20
# N_BOOTSTRAP = 5000
# EPSILON = 1e-4
# THRESHOLD_FACTORS = [0.5, 1.0, 2.0]
# SAMPLE_PER_DATASET = 500
# MAX_NODE_PAIRS = 500

# Demo values (full original params fit within time budget on 50-graph subset):
K_WALK = 20               # Walk length for SRI computation
N_BOOTSTRAP = 5000        # Number of bootstrap resamples
EPSILON = 1e-4            # Distinguishability threshold
THRESHOLD_FACTORS = [0.5, 1.0, 2.0]  # Threshold factors for node SRI
SAMPLE_PER_DATASET = 50   # Use all 50 demo graphs (original: 500 from 16K)
MAX_NODE_PAIRS = 500      # Max node pairs per graph
SEED = 42

## Helper Functions

Core parsing and SRI computation functions from the original method.

In [6]:
def parse_input(example: dict) -> dict:
    """Parse the JSON-encoded input field of an example."""
    return json.loads(example["input"])


def compute_graph_level_sri(spectral: dict, k: int = K_WALK) -> float:
    """Compute graph-level SRI = delta_min * K."""
    delta_min = spectral.get("delta_min", 0.0)
    return delta_min * k


def compute_node_level_sri(
    spectral: dict,
    num_nodes: int,
    threshold_factor: float = 1.0,
    k: int = K_WALK,
) -> dict:
    """
    Compute node-level SRI from local_spectral data.

    For each node u:
      1. Get the eigenvalues & squared eigenvector components from local_spectral
      2. Filter by threshold: v_i(u)^2 > threshold_factor / n
      3. Compute min gap among relevant eigenvalues
      4. node_sri = min_gap * K
    """
    eigenvalues = np.array(spectral.get("eigenvalues", []))
    local_spectral = spectral.get("local_spectral", [])
    n = num_nodes
    n_nodes_with_spectral = len(local_spectral)

    if n_nodes_with_spectral == 0 or len(eigenvalues) == 0:
        return {
            "node_sris": [], "effective_ranks": [], "local_sparsities": [],
            "mean_node_sri": float("inf"), "min_node_sri": float("inf"),
            "median_node_sri": float("inf"), "std_node_sri": 0.0,
            "p10_node_sri": float("inf"), "mean_effective_rank": 0.0,
            "mean_local_sparsity": 0.0,
        }

    threshold = threshold_factor / max(n, 1)
    node_sris = []
    effective_ranks = []
    local_sparsities = []

    for u in range(n_nodes_with_spectral):
        components = local_spectral[u]
        if not components:
            node_sris.append(float("inf"))
            effective_ranks.append(0)
            local_sparsities.append(0.0)
            continue

        node_eigs = []
        node_weights = []
        for comp in components:
            if len(comp) >= 2:
                node_eigs.append(comp[0])
                node_weights.append(comp[1])

        if not node_eigs:
            node_sris.append(float("inf"))
            effective_ranks.append(0)
            local_sparsities.append(0.0)
            continue

        node_eigs = np.array(node_eigs)
        node_weights = np.array(node_weights)

        relevant_mask = node_weights > threshold
        relevant_eigs = node_eigs[relevant_mask]
        eff_rank = int(np.sum(relevant_mask))
        effective_ranks.append(eff_rank)
        local_sparsities.append(eff_rank / max(n, 1))

        if len(relevant_eigs) <= 1:
            node_sris.append(float("inf"))
            continue

        sorted_eigs = np.sort(relevant_eigs)
        gaps = np.diff(sorted_eigs)
        nonzero_gaps = gaps[np.abs(gaps) > 1e-12]

        if len(nonzero_gaps) == 0:
            node_sris.append(0.0)
        else:
            local_delta_min = float(np.min(np.abs(nonzero_gaps)))
            node_sris.append(local_delta_min * k)

    node_sris_arr = np.array(node_sris, dtype=float)
    finite_sris = node_sris_arr[np.isfinite(node_sris_arr)]

    if len(finite_sris) == 0:
        return {
            "node_sris": node_sris, "effective_ranks": effective_ranks,
            "local_sparsities": local_sparsities,
            "mean_node_sri": float("inf"), "min_node_sri": float("inf"),
            "median_node_sri": float("inf"), "std_node_sri": 0.0,
            "p10_node_sri": float("inf"),
            "mean_effective_rank": float(np.mean(effective_ranks)) if effective_ranks else 0.0,
            "mean_local_sparsity": float(np.mean(local_sparsities)) if local_sparsities else 0.0,
        }

    return {
        "node_sris": node_sris, "effective_ranks": effective_ranks,
        "local_sparsities": local_sparsities,
        "mean_node_sri": float(np.mean(finite_sris)),
        "min_node_sri": float(np.min(finite_sris)),
        "median_node_sri": float(np.median(finite_sris)),
        "std_node_sri": float(np.std(finite_sris)),
        "p10_node_sri": float(np.percentile(finite_sris, 10)),
        "mean_effective_rank": float(np.mean(effective_ranks)),
        "mean_local_sparsity": float(np.mean(local_sparsities)),
    }

print("Helper functions defined.")

Helper functions defined.


## Phase 2: Node-Level Distinguishability Functions

Functions to compute LapPE features and pairwise node distinguishability (RWSE vs LapPE).

In [7]:
def compute_lappe_features(local_spectral: list, eigenvalues: list, num_dims: int = 10) -> np.ndarray:
    """
    Compute sign-free LapPE features for each node:
      lappe(u) = [v_1(u)^2, ..., v_d(u)^2]
    """
    n_nodes = len(local_spectral)
    n_eigs = len(eigenvalues)
    d = min(num_dims, n_eigs)

    if n_nodes == 0 or n_eigs == 0:
        return np.zeros((0, d))

    eig_array = np.array(eigenvalues)
    features = np.zeros((n_nodes, d), dtype=np.float64)
    for u in range(n_nodes):
        components = local_spectral[u]
        for comp in components:
            if len(comp) < 2:
                continue
            eig_val, weight = comp[0], comp[1]
            diffs = np.abs(eig_array - eig_val)
            idx = int(np.argmin(diffs))
            if idx < d:
                features[u, idx] = weight

    return features


def compute_node_distinguishability(
    spectral: dict,
    num_nodes: int,
    rng: np.random.RandomState,
    max_pairs: int = MAX_NODE_PAIRS,
    epsilon: float = EPSILON,
) -> dict:
    """
    Compute node-level distinguishability comparing RWSE vs sign-free LapPE.
    Returns per-node scores and the gap (LapPE_score - RWSE_score).
    """
    rwse = spectral.get("rwse", [])
    local_spectral = spectral.get("local_spectral", [])
    eigenvalues = spectral.get("eigenvalues", [])

    n_analyzable = min(len(rwse), len(local_spectral))
    if n_analyzable < 2:
        return {"node_rwse_scores": [], "node_lappe_scores": [], "node_gaps": [], "graph_gap": 0.0}

    rwse_feats = np.array(rwse[:n_analyzable], dtype=np.float64)
    lappe_feats = compute_lappe_features(local_spectral[:n_analyzable], eigenvalues, num_dims=min(10, len(eigenvalues)))

    def normalize(X: np.ndarray) -> np.ndarray:
        if X.shape[0] < 2:
            return X
        m = X.mean(axis=0)
        s = X.std(axis=0)
        s[s < 1e-10] = 1.0
        return (X - m) / s

    rwse_norm = normalize(rwse_feats)
    lappe_norm = normalize(lappe_feats)

    n = n_analyzable
    total_pairs = n * (n - 1) // 2
    n_pairs = min(max_pairs, total_pairs)

    if total_pairs <= max_pairs:
        pairs_i, pairs_j = [], []
        for i in range(n):
            for j in range(i + 1, n):
                pairs_i.append(i)
                pairs_j.append(j)
        pairs_i = np.array(pairs_i)
        pairs_j = np.array(pairs_j)
    else:
        all_pairs = rng.choice(total_pairs, size=n_pairs, replace=False)
        pairs_i = np.zeros(n_pairs, dtype=int)
        pairs_j = np.zeros(n_pairs, dtype=int)
        for idx, p in enumerate(all_pairs):
            i = int(n - 2 - math.floor(math.sqrt(-8 * p + 4 * n * (n - 1) - 7) / 2 - 0.5))
            j = int(p + i + 1 - n * (n - 1) // 2 + (n - i) * ((n - i) - 1) // 2)
            if i >= n or j >= n or i < 0 or j < 0 or i >= j:
                i = rng.randint(0, n - 1)
                j = rng.randint(i + 1, n)
            pairs_i[idx] = i
            pairs_j[idx] = j

    rwse_dists = np.linalg.norm(rwse_norm[pairs_i] - rwse_norm[pairs_j], axis=1)
    lappe_dists = np.linalg.norm(lappe_norm[pairs_i] - lappe_norm[pairs_j], axis=1)

    rwse_distinguished = rwse_dists > epsilon
    lappe_distinguished = lappe_dists > epsilon

    node_rwse_scores = np.zeros(n)
    node_lappe_scores = np.zeros(n)
    node_pair_counts = np.zeros(n)

    for idx in range(len(pairs_i)):
        i, j = pairs_i[idx], pairs_j[idx]
        node_pair_counts[i] += 1
        node_pair_counts[j] += 1
        if rwse_distinguished[idx]:
            node_rwse_scores[i] += 1
            node_rwse_scores[j] += 1
        if lappe_distinguished[idx]:
            node_lappe_scores[i] += 1
            node_lappe_scores[j] += 1

    mask = node_pair_counts > 0
    node_rwse_scores[mask] /= node_pair_counts[mask]
    node_lappe_scores[mask] /= node_pair_counts[mask]

    node_gaps = node_lappe_scores - node_rwse_scores
    graph_gap = float(np.mean(lappe_distinguished & ~rwse_distinguished))

    return {
        "node_rwse_scores": node_rwse_scores.tolist(),
        "node_lappe_scores": node_lappe_scores.tolist(),
        "node_gaps": node_gaps.tolist(),
        "graph_gap": graph_gap,
        "n_analyzable": n_analyzable,
        "n_pairs": int(len(pairs_i)),
        "frac_rwse_distinguished": float(np.mean(rwse_distinguished)),
        "frac_lappe_distinguished": float(np.mean(lappe_distinguished)),
    }

print("Phase 2 functions defined.")

Phase 2 functions defined.


## Phase 3: Correlation & Bootstrap Functions

Statistical functions for Spearman correlations with bootstrap confidence intervals and partial correlations.

In [8]:
def spearman_corr(x: np.ndarray, y: np.ndarray) -> tuple:
    """Compute Spearman correlation, handling edge cases."""
    mask = np.isfinite(x) & np.isfinite(y)
    x_clean = x[mask]
    y_clean = y[mask]
    if len(x_clean) < 3:
        return 0.0, 1.0
    if np.std(x_clean) < 1e-10 or np.std(y_clean) < 1e-10:
        return 0.0, 1.0
    rho, p = stats.spearmanr(x_clean, y_clean)
    if not np.isfinite(rho):
        return 0.0, 1.0
    return float(rho), float(p)


def bootstrap_spearman(
    x: np.ndarray,
    y: np.ndarray,
    n_resamples: int = N_BOOTSTRAP,
    seed: int = SEED,
) -> dict:
    """Bootstrap 95% CI for Spearman correlation."""
    rng = np.random.RandomState(seed)
    mask = np.isfinite(x) & np.isfinite(y)
    x_clean = x[mask]
    y_clean = y[mask]
    n = len(x_clean)

    if n < 5:
        return {"rho": 0.0, "p_value": 1.0, "ci_lower": 0.0, "ci_upper": 0.0, "n": n}

    rho, p = spearman_corr(x_clean, y_clean)

    boot_rhos = np.zeros(n_resamples)
    for i in range(n_resamples):
        idx = rng.choice(n, size=n, replace=True)
        r, _ = spearman_corr(x_clean[idx], y_clean[idx])
        boot_rhos[i] = r

    ci_lower = float(np.percentile(boot_rhos, 2.5))
    ci_upper = float(np.percentile(boot_rhos, 97.5))

    return {"rho": rho, "p_value": p, "ci_lower": ci_lower, "ci_upper": ci_upper, "n": n}


def partial_spearman(
    x: np.ndarray,
    y: np.ndarray,
    z: np.ndarray,
) -> tuple:
    """Partial Spearman correlation between x and y, controlling for z."""
    mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(z)
    x_c, y_c, z_c = x[mask], y[mask], z[mask]
    if len(x_c) < 5:
        return 0.0, 1.0

    x_r = stats.rankdata(x_c)
    y_r = stats.rankdata(y_c)
    z_r = stats.rankdata(z_c)

    slope_xz = np.polyfit(z_r, x_r, 1)
    x_resid = x_r - np.polyval(slope_xz, z_r)

    slope_yz = np.polyfit(z_r, y_r, 1)
    y_resid = y_r - np.polyval(slope_yz, z_r)

    return spearman_corr(x_resid, y_resid)

print("Correlation functions defined.")

Correlation functions defined.


## Phase 4: SRWE Benefit Prediction

Super-Resolution Walk Encoding features using Tikhonov regularization.

In [9]:
def compute_srwe_features(
    spectral: dict,
    num_nodes: int,
    k: int = K_WALK,
    alpha: float = 1e-3,
) -> np.ndarray:
    """
    Compute SRWE features for each node using Tikhonov regularization.
    Tries multiple alpha values and picks the one giving best reconstruction.
    """
    eigenvalues = np.array(spectral.get("eigenvalues", []))
    rwse = spectral.get("rwse", [])
    local_spectral = spectral.get("local_spectral", [])
    n_eigs = len(eigenvalues)
    n_nodes_available = min(len(rwse), len(local_spectral))

    if n_nodes_available == 0 or n_eigs == 0:
        return np.zeros((0, 0))

    k_use = min(k, n_eigs)

    # Build Vandermonde matrix
    V = np.zeros((k_use, n_eigs), dtype=np.float64)
    for ki in range(k_use):
        V[ki, :] = eigenvalues ** (ki + 1)

    # Try multiple alpha values, pick best per-graph
    alphas = [1e-6, 1e-4, 1e-2, 0.1]
    best_alpha = alpha

    if n_nodes_available > 0:
        test_rwse = np.array(rwse[0][:k_use], dtype=np.float64)
        if len(test_rwse) < k_use:
            test_rwse = np.pad(test_rwse, (0, k_use - len(test_rwse)))
        best_resid = float("inf")
        for a in alphas:
            try:
                VtV_test = V.T @ V + a * np.eye(n_eigs)
                w_test = np.linalg.solve(VtV_test, V.T @ test_rwse)
                w_test = np.maximum(w_test, 0)
                resid = np.linalg.norm(V @ w_test - test_rwse)
                if resid < best_resid:
                    best_resid = resid
                    best_alpha = a
            except np.linalg.LinAlgError:
                continue

    VtV = V.T @ V + best_alpha * np.eye(n_eigs)
    srwe_features = np.zeros((n_nodes_available, n_eigs), dtype=np.float64)

    for u in range(n_nodes_available):
        rwse_vec = np.array(rwse[u][:k_use], dtype=np.float64)
        if len(rwse_vec) < k_use:
            rwse_vec = np.pad(rwse_vec, (0, k_use - len(rwse_vec)))
        try:
            Vt_m = V.T @ rwse_vec
            w = np.linalg.solve(VtV, Vt_m)
            w = np.maximum(w, 0)
            srwe_features[u, :] = w
        except np.linalg.LinAlgError:
            try:
                w, _ = nnls(V.T, rwse_vec[:min(len(rwse_vec), V.shape[0])])
                if len(w) < n_eigs:
                    w = np.pad(w, (0, n_eigs - len(w)))
                srwe_features[u, :] = w[:n_eigs]
            except Exception:
                pass

    return srwe_features

print("SRWE functions defined.")

SRWE functions defined.


## Run Analysis Pipeline

Execute the full 4-phase pipeline: node SRI computation, distinguishability, correlations, and SRWE benefit prediction.

In [10]:
rng = np.random.RandomState(SEED)
start_time = time.time()

# ── PHASE 1: Node-Level SRI Computation ──
print("=" * 60)
print("PHASE 1: Node-Level SRI Computation")
print("=" * 60)

all_graph_metrics = {}

for ds_name, examples in datasets.items():
    print(f"Processing {ds_name} ({len(examples)} examples)")
    graph_metrics = []

    for ex_idx, example in enumerate(examples):
        try:
            inp = parse_input(example)
            spectral = inp.get("spectral", {})
            num_nodes = inp.get("num_nodes", 0)

            graph_sri = compute_graph_level_sri(spectral, k=K_WALK)

            node_results = {}
            for tf in THRESHOLD_FACTORS:
                node_res = compute_node_level_sri(spectral, num_nodes, threshold_factor=tf, k=K_WALK)
                node_results[f"tf_{tf}"] = node_res

            primary = node_results[f"tf_{THRESHOLD_FACTORS[0]}"]

            metric = {
                "idx": ex_idx,
                "num_nodes": num_nodes,
                "graph_sri": graph_sri,
                "delta_min": spectral.get("delta_min", 0.0),
                "mean_node_sri": primary["mean_node_sri"],
                "min_node_sri": primary["min_node_sri"],
                "median_node_sri": primary["median_node_sri"],
                "std_node_sri": primary["std_node_sri"],
                "p10_node_sri": primary["p10_node_sri"],
                "mean_effective_rank": primary["mean_effective_rank"],
                "mean_local_sparsity": primary["mean_local_sparsity"],
                "node_sris": primary["node_sris"],
            }
            graph_metrics.append(metric)
        except Exception as e:
            print(f"  Failed on example {ex_idx}: {e}")
            continue

    all_graph_metrics[ds_name] = graph_metrics
    finite_mean_sris = [m["mean_node_sri"] for m in graph_metrics if np.isfinite(m["mean_node_sri"])]
    finite_graph_sris = [m["graph_sri"] for m in graph_metrics if np.isfinite(m["graph_sri"])]
    print(f"  {ds_name}: {len(graph_metrics)} graphs processed")
    if finite_graph_sris:
        print(f"    Mean graph SRI: {np.mean(finite_graph_sris):.4f}")
    if finite_mean_sris:
        print(f"    Mean node SRI (avg): {np.mean(finite_mean_sris):.4f}")

print(f"\nPhase 1 completed in {time.time() - start_time:.1f}s")

PHASE 1: Node-Level SRI Computation
Processing ZINC-subset (50 examples)
  ZINC-subset: 50 graphs processed
    Mean graph SRI: 1.3557
    Mean node SRI (avg): 3.6060

Phase 1 completed in 0.0s


### Phase 2: Node-Level Distinguishability

Sample graph pairs and compare RWSE vs LapPE distinguishability at the node level.

In [11]:
# ── PHASE 2: Node-Level Distinguishability ──
print("=" * 60)
print("PHASE 2: Node-Level Distinguishability")
print("=" * 60)

all_sampled_metrics = {}
within_graph_rhos = []
example_graphs = []

for ds_name, examples in datasets.items():
    n_sample = min(SAMPLE_PER_DATASET, len(examples))
    sample_indices = rng.choice(len(examples), size=n_sample, replace=False) if n_sample < len(examples) else np.arange(len(examples))
    print(f"Phase 2: {ds_name} — sampling {n_sample} graphs")

    sampled_metrics = []
    ds_high_gap = None
    ds_low_gap = None
    ds_high_gap_val = -1.0
    ds_low_gap_val = 2.0

    for s_idx, orig_idx in enumerate(sample_indices):
        try:
            example = examples[orig_idx]
            inp = parse_input(example)
            spectral = inp.get("spectral", {})
            num_nodes = inp.get("num_nodes", 0)

            gm = all_graph_metrics[ds_name][orig_idx] if orig_idx < len(all_graph_metrics[ds_name]) else None

            dist_result = compute_node_distinguishability(
                spectral, num_nodes, rng, max_pairs=MAX_NODE_PAIRS, epsilon=EPSILON
            )

            node_sris = gm["node_sris"] if gm else []
            n_analyzable = dist_result.get("n_analyzable", 0)

            sampled_metric = {
                "orig_idx": int(orig_idx),
                "num_nodes": num_nodes,
                "graph_sri": gm["graph_sri"] if gm else 0.0,
                "mean_node_sri": gm["mean_node_sri"] if gm else 0.0,
                "min_node_sri": gm["min_node_sri"] if gm else 0.0,
                "median_node_sri": gm["median_node_sri"] if gm else 0.0,
                "p10_node_sri": gm["p10_node_sri"] if gm else 0.0,
                "graph_gap": dist_result["graph_gap"],
                "frac_rwse": dist_result.get("frac_rwse_distinguished", 0.0),
                "frac_lappe": dist_result.get("frac_lappe_distinguished", 0.0),
                "node_sris": node_sris[:n_analyzable],
                "node_gaps": dist_result["node_gaps"],
            }
            sampled_metrics.append(sampled_metric)

            gg = dist_result["graph_gap"]
            if gg > ds_high_gap_val and n_analyzable >= 5:
                ds_high_gap_val = gg
                ds_high_gap = {
                    "node_sris": node_sris[:n_analyzable],
                    "dataset": ds_name, "gap_type": "high-gap",
                    "n_nodes": num_nodes, "graph_gap": gg,
                }
            if gg < ds_low_gap_val and n_analyzable >= 5:
                ds_low_gap_val = gg
                ds_low_gap = {
                    "node_sris": node_sris[:n_analyzable],
                    "dataset": ds_name, "gap_type": "low-gap",
                    "n_nodes": num_nodes, "graph_gap": gg,
                }

            if n_analyzable >= 10 and len(node_sris) >= n_analyzable:
                ns = np.array(node_sris[:n_analyzable], dtype=float)
                ng = np.array(dist_result["node_gaps"], dtype=float)
                finite_mask = np.isfinite(ns) & np.isfinite(ng)
                if np.sum(finite_mask) >= 5:
                    rho_within, _ = spearman_corr(ns[finite_mask], ng[finite_mask])
                    within_graph_rhos.append(rho_within)

        except Exception as e:
            print(f"  Phase 2: Failed on {ds_name} idx {orig_idx}: {e}")
            continue

    all_sampled_metrics[ds_name] = sampled_metrics
    if ds_high_gap:
        example_graphs.append(ds_high_gap)
    if ds_low_gap:
        example_graphs.append(ds_low_gap)

    print(f"  {ds_name}: {len(sampled_metrics)} sampled graphs")

print(f"\nPhase 2 completed in {time.time() - start_time:.1f}s")

PHASE 2: Node-Level Distinguishability
Phase 2: ZINC-subset — sampling 50 graphs
  ZINC-subset: 50 sampled graphs

Phase 2 completed in 0.1s


### Phase 3: Graph-Level Correlations with Bootstrap CIs

Compare SRI variants with Spearman correlations, bootstrap confidence intervals, and partial correlations controlling for graph size.

In [12]:
# ── PHASE 3: Graph-Level Correlations ──
print("=" * 60)
print("PHASE 3: Graph-Level Correlations")
print("=" * 60)

correlation_results = {}

for ds_name, sampled in all_sampled_metrics.items():
    if len(sampled) < 5:
        print(f"Skipping {ds_name}: only {len(sampled)} samples (need >= 5)")
        continue

    graph_gap = np.array([s["graph_gap"] for s in sampled], dtype=float)
    graph_sri = np.array([s["graph_sri"] for s in sampled], dtype=float)
    mean_node_sri = np.array([s["mean_node_sri"] for s in sampled], dtype=float)
    min_node_sri = np.array([s["min_node_sri"] for s in sampled], dtype=float)
    median_node_sri = np.array([s["median_node_sri"] for s in sampled], dtype=float)
    p10_node_sri = np.array([s["p10_node_sri"] for s in sampled], dtype=float)
    log_n_nodes = np.log(np.array([s["num_nodes"] for s in sampled], dtype=float) + 1)

    metrics = {
        "graph_sri": graph_sri,
        "mean_node_sri": mean_node_sri,
        "min_node_sri": min_node_sri,
        "median_node_sri": median_node_sri,
        "p10_node_sri": p10_node_sri,
    }

    ds_corr = {}
    for metric_name, metric_vals in metrics.items():
        boot = bootstrap_spearman(metric_vals, graph_gap)
        ds_corr[metric_name] = boot

        partial_rho, partial_p = partial_spearman(metric_vals, graph_gap, log_n_nodes)
        ds_corr[f"{metric_name}_partial"] = {"rho": partial_rho, "p_value": partial_p}

    # Additional aggregations
    harmonic_sris = []
    for sm in sampled:
        ns = sm.get("node_sris", [])
        finite_ns = [v for v in ns if np.isfinite(v) and v > 0]
        if finite_ns:
            harmonic_sris.append(float(len(finite_ns) / sum(1.0 / v for v in finite_ns)))
        else:
            harmonic_sris.append(float("inf"))

    arr = np.array(harmonic_sris, dtype=float)
    boot = bootstrap_spearman(arr, graph_gap)
    ds_corr["harmonic_node_sri"] = boot

    correlation_results[ds_name] = ds_corr
    print(f"  {ds_name}: graph_sri rho={ds_corr['graph_sri']['rho']:.4f}, "
          f"mean_node_sri rho={ds_corr['mean_node_sri']['rho']:.4f}")

# Pooled correlation
all_graph_gap = []
all_graph_sri = []
all_mean_node_sri = []
all_ds_names = []

for ds_name, sampled in all_sampled_metrics.items():
    for s in sampled:
        all_graph_gap.append(s["graph_gap"])
        all_graph_sri.append(s["graph_sri"])
        all_mean_node_sri.append(s["mean_node_sri"])
        all_ds_names.append(ds_name)

if len(all_graph_gap) >= 5:
    pooled_corr = {}
    gap_arr = np.array(all_graph_gap, dtype=float)
    for name, arr in [("graph_sri", all_graph_sri), ("mean_node_sri", all_mean_node_sri)]:
        boot = bootstrap_spearman(np.array(arr, dtype=float), gap_arr)
        pooled_corr[name] = boot
    correlation_results["Pooled"] = pooled_corr
    print(f"  Pooled: graph_sri rho={pooled_corr['graph_sri']['rho']:.4f}, "
          f"mean_node_sri rho={pooled_corr['mean_node_sri']['rho']:.4f}")

print(f"\nPhase 3 completed in {time.time() - start_time:.1f}s")

PHASE 3: Graph-Level Correlations


  ZINC-subset: graph_sri rho=0.6350, mean_node_sri rho=0.5105


  Pooled: graph_sri rho=0.6350, mean_node_sri rho=0.5105

Phase 3 completed in 7.0s


### Phase 4: SRWE Benefit Prediction

Categorize node pairs by whether they are already resolved (RWSE), newly resolved (SRWE), or still unresolved.

In [13]:
# ── PHASE 4: SRWE Benefit Prediction ──
print("=" * 60)
print("PHASE 4: SRWE Benefit Prediction")
print("=" * 60)

srwe_node_sris_by_category = {"already_resolved": [], "newly_resolved": [], "still_unresolved": []}
n_srwe_processed = 0
n_datasets = len(all_sampled_metrics)
srwe_per_ds = max(5, SAMPLE_PER_DATASET)

for ds_name, sampled in all_sampled_metrics.items():
    n_use = min(len(sampled), srwe_per_ds)
    print(f"Phase 4: Processing {n_use} graphs from {ds_name}")

    for s_idx in range(n_use):
        try:
            sm = sampled[s_idx]
            example = datasets[ds_name][sm["orig_idx"]]
            inp = parse_input(example)
            spectral = inp.get("spectral", {})
            num_nodes = inp.get("num_nodes", 0)

            srwe_feats = compute_srwe_features(spectral, num_nodes, k=K_WALK, alpha=1e-3)

            if srwe_feats.shape[0] < 2:
                n_srwe_processed += 1
                continue

            rwse = np.array(spectral.get("rwse", [])[:srwe_feats.shape[0]], dtype=np.float64)
            local_spectral = spectral.get("local_spectral", [])
            eigenvalues = spectral.get("eigenvalues", [])

            n_compare = min(rwse.shape[0], srwe_feats.shape[0], len(local_spectral))
            if n_compare < 2:
                n_srwe_processed += 1
                continue

            def normalize(X):
                if X.shape[0] < 2: return X
                m = X.mean(axis=0)
                s = X.std(axis=0)
                s[s < 1e-10] = 1.0
                return (X - m) / s

            rwse_norm = normalize(rwse[:n_compare])
            srwe_norm = normalize(srwe_feats[:n_compare])

            node_sris = sm.get("node_sris", [])

            lappe_feats = compute_lappe_features(
                local_spectral[:n_compare], eigenvalues, num_dims=min(10, len(eigenvalues))
            )
            lappe_norm = normalize(lappe_feats[:n_compare])

            n_pairs_check = min(50, n_compare * (n_compare - 1) // 2)
            for _ in range(n_pairs_check):
                u = rng.randint(0, n_compare - 1)
                w = rng.randint(u + 1, n_compare)

                rwse_dist = np.linalg.norm(rwse_norm[u] - rwse_norm[w])
                srwe_dist = np.linalg.norm(srwe_norm[u] - srwe_norm[w])
                lappe_dist = np.linalg.norm(lappe_norm[u] - lappe_norm[w]) if lappe_norm.shape[0] > w else 0.0

                super_dist = max(srwe_dist, lappe_dist)

                pair_sri = min(
                    node_sris[u] if u < len(node_sris) and np.isfinite(node_sris[u]) else float("inf"),
                    node_sris[w] if w < len(node_sris) and np.isfinite(node_sris[w]) else float("inf"),
                )

                if rwse_dist > EPSILON:
                    srwe_node_sris_by_category["already_resolved"].append(pair_sri)
                elif super_dist > EPSILON:
                    srwe_node_sris_by_category["newly_resolved"].append(pair_sri)
                else:
                    srwe_node_sris_by_category["still_unresolved"].append(pair_sri)

            n_srwe_processed += 1
        except Exception as e:
            n_srwe_processed += 1
            continue

# Mann-Whitney U test
newly = [v for v in srwe_node_sris_by_category["newly_resolved"] if np.isfinite(v)]
already = [v for v in srwe_node_sris_by_category["already_resolved"] if np.isfinite(v)]
unresolved = [v for v in srwe_node_sris_by_category["still_unresolved"] if np.isfinite(v)]

print(f"  SRWE categories: already={len(already)}, newly={len(newly)}, unresolved={len(unresolved)}")

if len(newly) >= 3 and len(already) >= 3:
    try:
        u_stat, u_p = stats.mannwhitneyu(newly, already, alternative="less")
        print(f"  Mann-Whitney U: newly vs already: U={u_stat:.1f}, p={u_p:.4f}")
    except Exception as e:
        print(f"  Mann-Whitney U test failed: {e}")

print(f"\nPhase 4 completed in {time.time() - start_time:.1f}s")

PHASE 4: SRWE Benefit Prediction
Phase 4: Processing 50 graphs from ZINC-subset
  SRWE categories: already=2431, newly=30, unresolved=39
  Mann-Whitney U: newly vs already: U=60300.5, p=1.0000

Phase 4 completed in 7.1s


## Results Summary & Visualization

Print key results as a table and generate figures showing the comparison between graph-level and node-level SRI.

In [14]:
# ── Results Summary Table ──
print("=" * 70)
print("RESULTS SUMMARY: Graph-Level vs Node-Level SRI Correlations")
print("=" * 70)
print(f"{'Dataset':<20} {'Metric':<20} {'Spearman rho':>12} {'95% CI':>20} {'p-value':>10}")
print("-" * 82)

for ds_name, ds_corr in correlation_results.items():
    for metric_name in ["graph_sri", "mean_node_sri", "min_node_sri", "median_node_sri", "p10_node_sri", "harmonic_node_sri"]:
        if metric_name not in ds_corr:
            continue
        mc = ds_corr[metric_name]
        if isinstance(mc, dict) and "rho" in mc:
            ci = f"[{mc.get('ci_lower', 0):.3f}, {mc.get('ci_upper', 0):.3f}]"
            print(f"{ds_name:<20} {metric_name:<20} {mc['rho']:>12.4f} {ci:>20} {mc.get('p_value', 1.0):>10.4f}")

print()

# SRWE Summary
print("SRWE Benefit Prediction:")
print(f"  Already resolved (RWSE):  n={len(already)}, mean SRI={np.mean(already):.4f}" if already else "  Already resolved: 0")
print(f"  Newly resolved (SRWE):    n={len(newly)}, mean SRI={np.mean(newly):.4f}" if newly else "  Newly resolved: 0")
print(f"  Still unresolved:         n={len(unresolved)}, mean SRI={np.mean(unresolved):.4f}" if unresolved else "  Still unresolved: 0")

total_time = time.time() - start_time
print(f"\nTotal pipeline runtime: {total_time:.1f}s")

RESULTS SUMMARY: Graph-Level vs Node-Level SRI Correlations
Dataset              Metric               Spearman rho               95% CI    p-value
----------------------------------------------------------------------------------
ZINC-subset          graph_sri                  0.6350       [0.455, 0.772]     0.0000
ZINC-subset          mean_node_sri              0.5105       [0.281, 0.684]     0.0002
ZINC-subset          min_node_sri               0.6065       [0.419, 0.748]     0.0000
ZINC-subset          median_node_sri            0.3776       [0.089, 0.608]     0.0069
ZINC-subset          p10_node_sri               0.5074       [0.280, 0.686]     0.0002
ZINC-subset          harmonic_node_sri          0.5097       [0.268, 0.690]     0.0002
Pooled               graph_sri                  0.6350       [0.455, 0.772]     0.0000
Pooled               mean_node_sri              0.5105       [0.281, 0.684]     0.0002

SRWE Benefit Prediction:
  Already resolved (RWSE):  n=2431, mean SRI=2.5

In [15]:
# ── Visualization ──
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: Correlation comparison bar chart
ax = axes[0]
for ds_name, ds_corr in correlation_results.items():
    labels = []
    rhos = []
    for mn in ["graph_sri", "mean_node_sri", "min_node_sri"]:
        if mn in ds_corr and isinstance(ds_corr[mn], dict) and "rho" in ds_corr[mn]:
            labels.append(mn.replace("_", "\n"))
            rhos.append(abs(ds_corr[mn]["rho"]))
    if labels:
        x = np.arange(len(labels))
        colors = ["#2196F3" if "graph" in l.lower() else "#FF9800" for l in labels]
        ax.bar(x, rhos, color=colors, alpha=0.8)
        ax.set_xticks(x)
        ax.set_xticklabels(labels, fontsize=8)
        ax.set_title(f"SRI Correlations ({ds_name})")
        ax.set_ylabel("|Spearman rho|")
        ax.set_ylim(0, 1.0)
        ax.axhline(y=0.3, color="red", linestyle="--", alpha=0.5, label="rho=0.3")
        ax.legend(fontsize=7)

# Plot 2: Graph SRI vs Mean Node SRI scatter
ax = axes[1]
for ds_name, sampled in all_sampled_metrics.items():
    gs = [s["graph_sri"] for s in sampled]
    mns = [s["mean_node_sri"] for s in sampled]
    finite = [np.isfinite(g) and np.isfinite(m) for g, m in zip(gs, mns)]
    gs_f = [g for g, f in zip(gs, finite) if f]
    mns_f = [m for m, f in zip(mns, finite) if f]
    if gs_f:
        ax.scatter(gs_f, mns_f, alpha=0.6, s=20, label=ds_name)
ax.set_xlabel("Graph-Level SRI")
ax.set_ylabel("Mean Node-Level SRI")
ax.set_title("Graph SRI vs Mean Node SRI")
ax.legend(fontsize=7)

# Plot 3: Node SRI distribution for example graphs
ax = axes[2]
if example_graphs:
    for i, eg in enumerate(example_graphs[:4]):
        node_sris = eg.get("node_sris", [])
        finite_sris = [s if np.isfinite(s) else 0 for s in node_sris[:30]]
        if finite_sris:
            ax.bar(np.arange(len(finite_sris)) + i * 0.2,
                   finite_sris, width=0.2, alpha=0.7,
                   label=f"{eg.get('gap_type', '')} (n={eg.get('n_nodes', '?')})")
    ax.set_xlabel("Node Index")
    ax.set_ylabel("Node SRI")
    ax.set_title("Node SRI Distribution (Example Graphs)")
    ax.legend(fontsize=7)
else:
    ax.text(0.5, 0.5, "No example graphs available", ha="center", va="center", transform=ax.transAxes)
    ax.set_title("Node SRI Distribution")

plt.tight_layout()
plt.savefig("results_visualization.png", dpi=150, bbox_inches="tight")
plt.show()
print("Visualization saved to results_visualization.png")

Visualization saved to results_visualization.png
