# Ablation Stability Analysis

? notebook ????????????
- `R2` / `PCC`
- `pred mean/std`
- `agpe_graph_log` ?? `R_mean/R_std`
- ???? `PCC < 0`
- `pred mean/std` ?????????


In [1]:
import csv
import numpy as np
from pathlib import Path
from openpyxl import load_workbook

# ===== ???? =====
CASE_NAME = "skel_ref"   # ??: grid / glat / skel_ref
LAST_N_RUNS = 3           # ???? ablation ??
RESULTS_ROOT = Path("results")


In [2]:
def find_case_dir(run_root: Path, case_name: str) -> Path | None:
    cands = sorted([d for d in run_root.glob(f"*_{case_name}") if d.is_dir()])
    if cands:
        return cands[0]
    return None


def read_metrics_from_summary(summary_xlsx: Path, case_name: str):
    if not summary_xlsx.exists():
        return np.nan, np.nan
    wb = load_workbook(summary_xlsx, data_only=True)
    ws = wb.active
    if ws.max_row < 2:
        return np.nan, np.nan

    header = [c.value for c in ws[1]]
    idx = {str(h): i for i, h in enumerate(header)}
    if "case" not in idx:
        return np.nan, np.nan

    for r in range(2, ws.max_row + 1):
        vals = [c.value for c in ws[r]]
        case_val = str(vals[idx["case"]]) if vals[idx["case"]] is not None else ""
        if case_val == case_name:
            r2 = vals[idx["r2"]] if "r2" in idx else np.nan
            pcc = vals[idx["pcc"]] if "pcc" in idx else np.nan
            return float(r2) if r2 is not None else np.nan, float(pcc) if pcc is not None else np.nan

    return np.nan, np.nan


def read_last_graph_stats(case_dir: Path):
    logs = sorted(case_dir.glob("*_agpe_graph_log.csv"))
    if not logs:
        return np.nan, np.nan
    with open(logs[0], "r", newline="", encoding="utf-8") as f:
        rows = list(csv.DictReader(f))
    if not rows:
        return np.nan, np.nan
    last = rows[-1]
    r_mean = float(last.get("R_mean", np.nan))
    r_std = float(last.get("R_std", np.nan))
    return r_mean, r_std


def read_pred_stats(case_dir: Path):
    preds = sorted(case_dir.glob("*_pred_AI.npy"))
    if not preds:
        return np.nan, np.nan
    arr = np.load(preds[0])
    return float(arr.mean()), float(arr.std())


In [3]:
run_roots = sorted(RESULTS_ROOT.glob("ablation_*"), key=lambda p: p.stat().st_mtime)
run_roots = run_roots[-LAST_N_RUNS:]

rows = []
for rr in run_roots:
    case_dir = find_case_dir(rr, CASE_NAME)
    if case_dir is None:
        continue

    r2, pcc = read_metrics_from_summary(rr / "ablation_metrics_summary.xlsx", CASE_NAME)
    pred_mean, pred_std = read_pred_stats(case_dir)
    r_mean_last, r_std_last = read_last_graph_stats(case_dir)

    rows.append({
        "run_root": rr.name,
        "case": CASE_NAME,
        "r2": r2,
        "pcc": pcc,
        "pred_mean": pred_mean,
        "pred_std": pred_std,
        "R_mean_last": r_mean_last,
        "R_std_last": r_std_last,
    })

print("run_root,case,r2,pcc,pred_mean,pred_std,R_mean_last,R_std_last")
for x in rows:
    def fmt(v):
        if isinstance(v, float) and np.isnan(v):
            return "nan"
        if isinstance(v, float):
            return f"{v:.6f}"
        return str(v)
    print(",".join([
        x["run_root"], x["case"], fmt(x["r2"]), fmt(x["pcc"]), fmt(x["pred_mean"]),
        fmt(x["pred_std"]), fmt(x["R_mean_last"]), fmt(x["R_std_last"])
    ]))

if len(rows) == 0:
    print("[WARN] ??????????? CASE_NAME / LAST_N_RUNS?")


run_root,case,r2,pcc,pred_mean,pred_std,R_mean_last,R_std_last
ablation_20260301_164637_174640,skel_ref,0.320962,0.726795,-0.402788,0.939003,0.011659,0.063578
ablation_20260301_165442_813639,skel_ref,0.320962,0.726795,-0.402788,0.939003,0.011659,0.063578
ablation_20260301_170251_482171,skel_ref,0.320962,0.726795,-0.402788,0.939003,0.011659,0.063578


In [None]:
if len(rows) >= 2:
    r2s = np.array([x["r2"] for x in rows], dtype=float)
    pccs = np.array([x["pcc"] for x in rows], dtype=float)
    pred_means = np.array([x["pred_mean"] for x in rows], dtype=float)
    pred_stds = np.array([x["pred_std"] for x in rows], dtype=float)

    r2_std = float(np.nanstd(r2s, ddof=1))
    pcc_has_neg = bool(np.any(pccs < 0))

    pm_mean = float(np.nanmean(pred_means))
    ps_mean = float(np.nanmean(pred_stds))
    pred_mean_drift = float(np.nanmax(pred_means) - np.nanmin(pred_means))
    pred_std_drift = float(np.nanmax(pred_stds) - np.nanmin(pred_stds))
    pred_mean_rel_drift = float(pred_mean_drift / (abs(pm_mean) + 1e-12))
    pred_std_rel_drift = float(pred_std_drift / (abs(ps_mean) + 1e-12))

    print(f"R2_std = {r2_std:.6f}")
    print(f"R2_std < 0.03 ? {'YES' if r2_std < 0.03 else 'NO'}")
    print(f"Any PCC < 0 ? {'YES' if pcc_has_neg else 'NO'}")
    print(f"pred_mean drift (abs / rel) = {pred_mean_drift:.6f} / {pred_mean_rel_drift:.2%}")
    print(f"pred_std  drift (abs / rel) = {pred_std_drift:.6f} / {pred_std_rel_drift:.2%}")

    converged = (not pcc_has_neg) and (pred_mean_rel_drift < 0.05) and (pred_std_rel_drift < 0.05)
    print(f"Pred mean/std drift ????? {'YES' if converged else 'NO'}")
else:
    print("????? < 2?????????")

R2_std = 0.000000
R2_std < 0.03 ? YES
Any PCC < 0 ? NO
pred_mean drift (abs / rel) = 0.000000 / 0.00%
pred_std  drift (abs / rel) = 0.000000 / 0.00%
Pred mean/std drift ????? YES
