In [6]:
# ============================================================
# COMPARISON SCRIPT / NOTEBOOK (ENTIRE UPDATED CODE)
# - Loads all runs from results/runs.csv
# - Builds comparable pairs in two modes:
#     A) same_case_diff_time     (phi, lat_size, isolevel, post fixed)
#     B) same_phi_isolevel_diff_case (phi, isolevel, post fixed)
# - Computes:
#     - Spearman rank correlation on union of top-K features (K configurable)
#     - Top-K Jaccard overlap
#     - Mean centroid distance after cluster matching (Hungarian if SciPy present)
# - Saves results to results/comparisons_pairs.csv
# ============================================================

from __future__ import annotations

import json
import numpy as np
import pandas as pd
from pathlib import Path


# -------------------------
# SETTINGS
# -------------------------
RESULTS_ROOT = Path("../results/gmm_stability")
RUNS_CSV = RESULTS_ROOT / "runs.csv"

# Compute stability for these top-k values
TOPK_LIST = [5, 10, 20]  # edit as desired

# Pair-building modes
KEY_COLS_TIME = ["phi", "lat_size", "isolevel", "post"]     # same case, different timesteps
KEY_COLS_CASE = ["phi", "isolevel", "post"]                 # same phi/isolevel, different cases (lat_size)


# -------------------------
# LOAD ARTIFACTS
# -------------------------
def load_run_artifacts(run_id: str):
    run_dir = RESULTS_ROOT / f"run_{run_id}"
    cfg = json.loads((run_dir / "config.json").read_text(encoding="utf-8"))
    summary = pd.read_csv(run_dir / "summary_clusters.csv")
    imp_agg = pd.read_csv(run_dir / "importance_agg.csv")
    cent = pd.read_csv(run_dir / "centroids_scaled.csv")
    return cfg, summary, imp_agg, cent


# -------------------------
# METRICS
# -------------------------
def topk_jaccard(imp_a: pd.DataFrame, imp_b: pd.DataFrame, k: int = 20) -> float:
    """
    Jaccard overlap of top-k feature sets based on importance_weighted.
    """
    if imp_a.empty or imp_b.empty:
        return np.nan

    a = set(imp_a.sort_values("importance_weighted", ascending=False).head(k)["feature"])
    b = set(imp_b.sort_values("importance_weighted", ascending=False).head(k)["feature"])
    if not a and not b:
        return np.nan
    return float(len(a & b) / len(a | b))


def spearman_rank_corr_topk(imp_a: pd.DataFrame, imp_b: pd.DataFrame, k: int = 20) -> float:
    """
    Spearman correlation of feature ranks restricted to the union of top-k features
    from each run (based on importance_weighted). This is more interpretable than
    correlating all features when the tail is noisy.
    """
    if imp_a.empty or imp_b.empty:
        return np.nan

    a = imp_a.set_index("feature")["importance_weighted"]
    b = imp_b.set_index("feature")["importance_weighted"]

    feats_common = a.index.intersection(b.index)
    if len(feats_common) < 3:
        return np.nan

    a_common = a.loc[feats_common]
    b_common = b.loc[feats_common]

    top_a = set(a_common.sort_values(ascending=False).head(k).index)
    top_b = set(b_common.sort_values(ascending=False).head(k).index)
    feats = list(top_a | top_b)

    if len(feats) < 3:
        return np.nan

    ra = a_common.loc[feats].rank(ascending=False)
    rb = b_common.loc[feats].rank(ascending=False)
    return float(ra.corr(rb, method="spearman"))


def match_clusters_by_centroids(cent_a: pd.DataFrame, cent_b: pd.DataFrame) -> pd.DataFrame:
    """
    Match clusters between run A and B by minimizing total centroid distance.
    cent_* must have columns: cluster, z_<feature1>, z_<feature2>, ...
    Returns DataFrame: cluster_A, cluster_B, centroid_distance
    """
    A = cent_a.sort_values("cluster").reset_index(drop=True)
    B = cent_b.sort_values("cluster").reset_index(drop=True)

    Za = A.drop(columns=["cluster"]).to_numpy()
    Zb = B.drop(columns=["cluster"]).to_numpy()

    if Za.size == 0 or Zb.size == 0:
        return pd.DataFrame(columns=["cluster_A", "cluster_B", "centroid_distance"])

    # distance matrix
    D = np.linalg.norm(Za[:, None, :] - Zb[None, :, :], axis=2)

    try:
        from scipy.optimize import linear_sum_assignment  # type: ignore
        r, c = linear_sum_assignment(D)
        pairs = [(int(A.loc[i, "cluster"]), int(B.loc[j, "cluster"]), float(D[i, j])) for i, j in zip(r, c)]
    except Exception:
        # greedy fallback
        pairs = []
        used_j = set()
        for i in range(D.shape[0]):
            j = int(np.argmin([D[i, jj] if jj not in used_j else np.inf for jj in range(D.shape[1])]))
            used_j.add(j)
            pairs.append((int(A.loc[i, "cluster"]), int(B.loc[j, "cluster"]), float(D[i, j])))

    return pd.DataFrame(pairs, columns=["cluster_A", "cluster_B", "centroid_distance"])


# -------------------------
# BUILD PAIRS
# -------------------------
def build_pairs(runs_df: pd.DataFrame, key_cols: list[str], comparison_type: str) -> pd.DataFrame:
    pairs = []

    for _, g in runs_df.groupby(key_cols):
        ids = g["run_id"].tolist()
        if len(ids) < 2:
            continue

        for i in range(len(ids)):
            for j in range(i + 1, len(ids)):
                ra, rb = ids[i], ids[j]

                cfg_a, _, imp_a, cent_a = load_run_artifacts(ra)
                cfg_b, _, imp_b, cent_b = load_run_artifacts(rb)

                # centroid matching and its mean distance
                mapping = match_clusters_by_centroids(cent_a, cent_b)
                mean_centroid_dist = float(mapping["centroid_distance"].mean()) if not mapping.empty else np.nan

                row = {
                    "comparison_type": comparison_type,
                    "run_a": ra,
                    "run_b": rb,
                    "phi": cfg_a.get("PHI", cfg_a.get("phi")),
                    "lat_a": cfg_a.get("LAT_SIZE", cfg_a.get("lat_size")),
                    "lat_b": cfg_b.get("LAT_SIZE", cfg_b.get("lat_size")),
                    "time_a": cfg_a.get("TIME_STEP", cfg_a.get("time_step")),
                    "time_b": cfg_b.get("TIME_STEP", cfg_b.get("time_step")),
                    "isolevel": cfg_a.get("ISOLEVEL", cfg_a.get("isolevel")),
                    "post": cfg_a.get("POST", cfg_a.get("post")),
                    "mean_centroid_dist": mean_centroid_dist,
                }

                # add top-k metrics for each K
                for K in TOPK_LIST:
                    row[f"spearman_top{K}"] = spearman_rank_corr_topk(imp_a, imp_b, k=K)
                    row[f"jaccard_top{K}"] = topk_jaccard(imp_a, imp_b, k=K)

                pairs.append(row)

    return pd.DataFrame(pairs)


# -------------------------
# MAIN
# -------------------------
if not RUNS_CSV.exists():
    raise FileNotFoundError(f"Missing: {RUNS_CSV}. Run at least two experiments first.")

runs_df = pd.read_csv(RUNS_CSV)
print("Total runs:", len(runs_df))

pairs_time = build_pairs(runs_df, KEY_COLS_TIME, comparison_type="same_case_diff_time")
pairs_case = build_pairs(runs_df, KEY_COLS_CASE, comparison_type="same_phi_isolevel_diff_case")

pairs_all = pd.concat([pairs_time, pairs_case], ignore_index=True)
out_csv = RESULTS_ROOT / "comparisons_pairs.csv"
pairs_all.to_csv(out_csv, index=False)

print("Pairs (same case, different time):", len(pairs_time))
print("Pairs (same phi/isolevel, different case):", len(pairs_case))
print(f"Saved: {out_csv}")

if pairs_all.empty:
    print("No comparable pairs found yet. Add more runs for the same grouping key.")
else:
    # Show best pairs by Spearman for a chosen K
    K_show = TOPK_LIST[-1]  # show the largest K by default
    s_col = f"spearman_top{K_show}"
    j_col = f"jaccard_top{K_show}"

    print(f"\nTop comparisons by {s_col} (higher is better):")
    display(pairs_all.sort_values(s_col, ascending=False).head(10))

    print(f"\nTop comparisons by {j_col} (higher is better):")
    display(pairs_all.sort_values(j_col, ascending=False).head(10))

    print("\nTop comparisons by mean_centroid_dist (lower is better):")
    display(pairs_all.sort_values("mean_centroid_dist", ascending=True).head(10))


Total runs: 7
Pairs (same case, different time): 4
Pairs (same phi/isolevel, different case): 21
Saved: ..\results\gmm_stability\comparisons_pairs.csv

Top comparisons by spearman_top20 (higher is better):


Unnamed: 0,comparison_type,run_a,run_b,phi,lat_a,lat_b,time_a,time_b,isolevel,post,mean_centroid_dist,spearman_top5,jaccard_top5,spearman_top10,jaccard_top10,spearman_top20,jaccard_top20
24,same_phi_isolevel_diff_case,20251216_151517_phi0.40_lat050_t335_iso4.5_k5_...,20251216_151757_phi0.40_lat100_t211_iso4.5_k5_...,0.4,50,100,335,211,4.5,True,4.474157,0.5,0.428571,0.461538,0.538462,0.732355,0.818182
9,same_phi_isolevel_diff_case,20251216_150031_phi0.40_lat200_t230_iso4.5_k5_...,20251216_151757_phi0.40_lat100_t211_iso4.5_k5_...,0.4,200,100,230,211,4.5,True,1.019721,0.942857,0.666667,0.415385,0.428571,0.498462,0.6
20,same_phi_isolevel_diff_case,20251216_150607_phi0.40_lat025_t100_iso4.5_k5_...,20251216_151517_phi0.40_lat050_t335_iso4.5_k5_...,0.4,25,50,100,335,4.5,True,5.571431,0.238095,0.25,0.089286,0.333333,0.485217,0.666667
18,same_phi_isolevel_diff_case,20251216_150557_phi0.40_lat025_t238_iso4.5_k5_...,20251216_151757_phi0.40_lat100_t211_iso4.5_k5_...,0.4,25,100,238,211,4.5,True,3.306177,-0.428571,0.25,-0.027473,0.538462,0.420855,0.538462
8,same_phi_isolevel_diff_case,20251216_150031_phi0.40_lat200_t230_iso4.5_k5_...,20251216_151517_phi0.40_lat050_t335_iso4.5_k5_...,0.4,200,50,230,335,4.5,True,4.395643,0.5,0.428571,0.375824,0.428571,0.403077,0.538462
17,same_phi_isolevel_diff_case,20251216_150557_phi0.40_lat025_t238_iso4.5_k5_...,20251216_151517_phi0.40_lat050_t335_iso4.5_k5_...,0.4,25,50,238,335,4.5,True,3.391447,-0.357143,0.25,-0.049451,0.538462,0.392137,0.538462
21,same_phi_isolevel_diff_case,20251216_150607_phi0.40_lat025_t100_iso4.5_k5_...,20251216_151757_phi0.40_lat100_t211_iso4.5_k5_...,0.4,25,100,100,211,4.5,True,3.007985,0.214286,0.25,0.017857,0.333333,0.316206,0.73913
6,same_phi_isolevel_diff_case,20251216_150031_phi0.40_lat200_t230_iso4.5_k5_...,20251216_150607_phi0.40_lat025_t100_iso4.5_k5_...,0.4,200,25,230,100,4.5,True,3.133644,0.25,0.428571,0.207143,0.333333,0.283883,0.481481
1,same_case_diff_time,20251216_150226_phi0.40_lat025_t335_iso4.5_k5_...,20251216_150607_phi0.40_lat025_t100_iso4.5_k5_...,0.4,25,25,335,100,4.5,True,3.424184,-0.333333,0.111111,-0.225,0.333333,0.277167,0.481481
11,same_phi_isolevel_diff_case,20251216_150226_phi0.40_lat025_t335_iso4.5_k5_...,20251216_150607_phi0.40_lat025_t100_iso4.5_k5_...,0.4,25,25,335,100,4.5,True,3.424184,-0.333333,0.111111,-0.225,0.333333,0.277167,0.481481



Top comparisons by jaccard_top20 (higher is better):


Unnamed: 0,comparison_type,run_a,run_b,phi,lat_a,lat_b,time_a,time_b,isolevel,post,mean_centroid_dist,spearman_top5,jaccard_top5,spearman_top10,jaccard_top10,spearman_top20,jaccard_top20
24,same_phi_isolevel_diff_case,20251216_151517_phi0.40_lat050_t335_iso4.5_k5_...,20251216_151757_phi0.40_lat100_t211_iso4.5_k5_...,0.4,50,100,335,211,4.5,True,4.474157,0.5,0.428571,0.461538,0.538462,0.732355,0.818182
21,same_phi_isolevel_diff_case,20251216_150607_phi0.40_lat025_t100_iso4.5_k5_...,20251216_151757_phi0.40_lat100_t211_iso4.5_k5_...,0.4,25,100,100,211,4.5,True,3.007985,0.214286,0.25,0.017857,0.333333,0.316206,0.73913
20,same_phi_isolevel_diff_case,20251216_150607_phi0.40_lat025_t100_iso4.5_k5_...,20251216_151517_phi0.40_lat050_t335_iso4.5_k5_...,0.4,25,50,100,335,4.5,True,5.571431,0.238095,0.25,0.089286,0.333333,0.485217,0.666667
3,same_case_diff_time,20251216_151306_phi0.40_lat050_t100_iso4.5_k5_...,20251216_151517_phi0.40_lat050_t335_iso4.5_k5_...,0.4,50,50,100,335,4.5,True,6.00505,-0.583333,0.111111,-0.481319,0.428571,0.078462,0.6
22,same_phi_isolevel_diff_case,20251216_151306_phi0.40_lat050_t100_iso4.5_k5_...,20251216_151517_phi0.40_lat050_t335_iso4.5_k5_...,0.4,50,50,100,335,4.5,True,6.00505,-0.583333,0.111111,-0.481319,0.428571,0.078462,0.6
9,same_phi_isolevel_diff_case,20251216_150031_phi0.40_lat200_t230_iso4.5_k5_...,20251216_151757_phi0.40_lat100_t211_iso4.5_k5_...,0.4,200,100,230,211,4.5,True,1.019721,0.942857,0.666667,0.415385,0.428571,0.498462,0.6
2,same_case_diff_time,20251216_150557_phi0.40_lat025_t238_iso4.5_k5_...,20251216_150607_phi0.40_lat025_t100_iso4.5_k5_...,0.4,25,25,238,100,4.5,True,4.191072,-0.45,0.111111,-0.371429,0.333333,0.18,0.6
15,same_phi_isolevel_diff_case,20251216_150557_phi0.40_lat025_t238_iso4.5_k5_...,20251216_150607_phi0.40_lat025_t100_iso4.5_k5_...,0.4,25,25,238,100,4.5,True,4.191072,-0.45,0.111111,-0.371429,0.333333,0.18,0.6
8,same_phi_isolevel_diff_case,20251216_150031_phi0.40_lat200_t230_iso4.5_k5_...,20251216_151517_phi0.40_lat050_t335_iso4.5_k5_...,0.4,200,50,230,335,4.5,True,4.395643,0.5,0.428571,0.375824,0.428571,0.403077,0.538462
17,same_phi_isolevel_diff_case,20251216_150557_phi0.40_lat025_t238_iso4.5_k5_...,20251216_151517_phi0.40_lat050_t335_iso4.5_k5_...,0.4,25,50,238,335,4.5,True,3.391447,-0.357143,0.25,-0.049451,0.538462,0.392137,0.538462



Top comparisons by mean_centroid_dist (lower is better):


Unnamed: 0,comparison_type,run_a,run_b,phi,lat_a,lat_b,time_a,time_b,isolevel,post,mean_centroid_dist,spearman_top5,jaccard_top5,spearman_top10,jaccard_top10,spearman_top20,jaccard_top20
9,same_phi_isolevel_diff_case,20251216_150031_phi0.40_lat200_t230_iso4.5_k5_...,20251216_151757_phi0.40_lat100_t211_iso4.5_k5_...,0.4,200,100,230,211,4.5,True,1.019721,0.942857,0.666667,0.415385,0.428571,0.498462,0.6
0,same_case_diff_time,20251216_150226_phi0.40_lat025_t335_iso4.5_k5_...,20251216_150557_phi0.40_lat025_t238_iso4.5_k5_...,0.4,25,25,335,238,4.5,True,1.833785,-0.633333,0.111111,-0.346429,0.333333,0.115017,0.333333
10,same_phi_isolevel_diff_case,20251216_150226_phi0.40_lat025_t335_iso4.5_k5_...,20251216_150557_phi0.40_lat025_t238_iso4.5_k5_...,0.4,25,25,335,238,4.5,True,1.833785,-0.633333,0.111111,-0.346429,0.333333,0.115017,0.333333
4,same_phi_isolevel_diff_case,20251216_150031_phi0.40_lat200_t230_iso4.5_k5_...,20251216_150226_phi0.40_lat025_t335_iso4.5_k5_...,0.4,200,25,230,335,4.5,True,2.616286,-0.404762,0.25,-0.041176,0.25,-0.030242,0.290323
14,same_phi_isolevel_diff_case,20251216_150226_phi0.40_lat025_t335_iso4.5_k5_...,20251216_151757_phi0.40_lat100_t211_iso4.5_k5_...,0.4,25,100,335,211,4.5,True,2.638795,-0.309524,0.25,-0.015385,0.428571,0.18906,0.538462
21,same_phi_isolevel_diff_case,20251216_150607_phi0.40_lat025_t100_iso4.5_k5_...,20251216_151757_phi0.40_lat100_t211_iso4.5_k5_...,0.4,25,100,100,211,4.5,True,3.007985,0.214286,0.25,0.017857,0.333333,0.316206,0.73913
5,same_phi_isolevel_diff_case,20251216_150031_phi0.40_lat200_t230_iso4.5_k5_...,20251216_150557_phi0.40_lat025_t238_iso4.5_k5_...,0.4,200,25,230,238,4.5,True,3.128709,-0.45,0.111111,-0.082418,0.538462,0.252747,0.481481
6,same_phi_isolevel_diff_case,20251216_150031_phi0.40_lat200_t230_iso4.5_k5_...,20251216_150607_phi0.40_lat025_t100_iso4.5_k5_...,0.4,200,25,230,100,4.5,True,3.133644,0.25,0.428571,0.207143,0.333333,0.283883,0.481481
18,same_phi_isolevel_diff_case,20251216_150557_phi0.40_lat025_t238_iso4.5_k5_...,20251216_151757_phi0.40_lat100_t211_iso4.5_k5_...,0.4,25,100,238,211,4.5,True,3.306177,-0.428571,0.25,-0.027473,0.538462,0.420855,0.538462
17,same_phi_isolevel_diff_case,20251216_150557_phi0.40_lat025_t238_iso4.5_k5_...,20251216_151517_phi0.40_lat050_t335_iso4.5_k5_...,0.4,25,50,238,335,4.5,True,3.391447,-0.357143,0.25,-0.049451,0.538462,0.392137,0.538462
