In [16]:
    import os
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    # -----------------------
    # CONFIG
    # -----------------------
    INPUT_CSV = "../../../Spectra/results/head_summary.csv"
    OUT_DIR = "analysis_outputs"
    os.makedirs(OUT_DIR, exist_ok=True)
    
    R_THRESH = 0.5
    P_THRESH = 0.01
    CONSISTENCY_FRAC = 0.30   # 30% of samples
    
    sns.set(style="whitegrid")
    
    # -----------------------
    # LOAD
    # -----------------------
    summary = pd.read_csv(INPUT_CSV)
    
    num_samples = summary["sample_id"].nunique()
    print(f"Loaded summary for {num_samples} samples")
    print(f"Total rows (sample, layer, head): {len(summary)}")
    
    # -----------------------
    # 1. SIGNIFICANT HEAD INSTANCES
    # -----------------------
    sig = summary[
        (summary["p_q_entropy"] < P_THRESH) &
        (summary["r_q_entropy"].abs() >= R_THRESH)
    ].copy()
    
    sig.to_csv(f"{OUT_DIR}/significant_heads.csv", index=False)
    
    print(f"Significant (sample,layer,head) instances: {len(sig)}")
    
    # -----------------------
    # 2. CONSISTENT HEADS ACROSS SAMPLES
    # -----------------------
    head_consistency = (
        sig
        .groupby(["layer", "head"])
        .sample_id.nunique()
        .reset_index(name="num_samples")
    )
    
    head_consistency["fraction"] = head_consistency["num_samples"] / num_samples
    
    consistent_heads = head_consistency[
        head_consistency["fraction"] >= CONSISTENCY_FRAC
    ].sort_values("fraction", ascending=False)
    
    consistent_heads.to_csv(
        f"{OUT_DIR}/consistent_heads.csv", index=False
    )
    
    print(f"Consistent heads (≥{CONSISTENCY_FRAC*100:.0f}% samples): {len(consistent_heads)}")
    
    # -----------------------
    # 3. DISTRIBUTION OF |r(Q, entropy)|
    # -----------------------
    plt.figure(figsize=(8,5))
    plt.hist(summary["r_q_entropy"].abs(), bins=60, color="steelblue")
    plt.axvline(R_THRESH, color="red", linestyle="--", label="|r| = 0.5")
    plt.xlabel("|r(Q, entropy)|")
    plt.ylabel("Count")
    plt.title("Distribution of |Correlation(Q norm, Attention Entropy)|")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{OUT_DIR}/r_entropy_distribution.png", dpi=200)
    plt.close()
    
    # -----------------------
    # 4. LAYER-WISE PROFILE
    # -----------------------
    layer_stats = (
        summary
        .groupby("layer")
        .agg(
            mean_abs_r=("r_q_entropy", lambda x: np.mean(np.abs(x))),
            std_abs_r=("r_q_entropy", lambda x: np.std(np.abs(x))),
            frac_significant=("p_q_entropy", lambda x: np.mean(
                (x < P_THRESH)
            ))
        )
        .reset_index()
    )
    
    layer_stats.to_csv(
        f"{OUT_DIR}/layer_statistics.csv", index=False
    )
    
    plt.figure(figsize=(8,5))
    plt.plot(layer_stats["layer"], layer_stats["mean_abs_r"], marker="o")
    plt.xlabel("Layer")
    plt.ylabel("Mean |r(Q, entropy)|")
    plt.title("Layer-wise Mean |Correlation|")
    plt.tight_layout()
    plt.savefig(f"{OUT_DIR}/layer_mean_abs_r.png", dpi=200)
    plt.close()
    
    # -----------------------
    # 5. HEAD CONSISTENCY HISTOGRAM
    # -----------------------
    plt.figure(figsize=(8,5))
    plt.hist(head_consistency["fraction"], bins=30, color="darkorange")
    plt.xlabel("Fraction of Samples Head is Significant")
    plt.ylabel("Number of Heads")
    plt.title("Head Consistency Across Samples")
    plt.tight_layout()
    plt.savefig(f"{OUT_DIR}/head_consistency_hist.png", dpi=200)
    plt.close()
    
    # -----------------------
    # 6. HEAD RECURRENCE HEATMAP
    # -----------------------
    heatmap_data = (
        head_consistency
        .pivot(index="layer", columns="head", values="fraction")
        .fillna(0.0)
    )
    
    plt.figure(figsize=(14,8))
    sns.heatmap(
        heatmap_data,
        cmap="viridis",
        cbar_kws={"label": "Fraction of Samples"},
    )
    plt.xlabel("Head")
    plt.ylabel("Layer")
    plt.title("Head Recurrence Heatmap (|r| ≥ 0.5, p < 0.01)")
    plt.tight_layout()
    plt.savefig(f"{OUT_DIR}/head_recurrence_heatmap.png", dpi=200)
    plt.close()
    
    # -----------------------
    # 7. CORRELATION COMPARISON (ENTROPY vs OTHERS)
    # -----------------------
    corr_means = summary[
        ["r_q_entropy", "r_q_max_attn", "r_q_k_eff"]
    ].abs().mean()
    
    plt.figure(figsize=(6,4))
    corr_means.plot(kind="bar", color=["steelblue","darkgreen","purple"])
    plt.ylabel("Mean |Correlation|")
    plt.title("Which Metric Tracks Query Norm Best?")
    plt.tight_layout()
    plt.savefig(f"{OUT_DIR}/metric_comparison.png", dpi=200)
    plt.close()
    
    # -----------------------
    # FINAL REPORT
    # -----------------------
    print("\n===== FINAL SUMMARY =====")
    print(f"Total samples: {num_samples}")
    print(f"Total heads: {32*32}")
    print(f"Significant instances: {len(sig)}")
    print(f"Consistent heads: {len(consistent_heads)}")
    
    if len(consistent_heads) > 0:
        print("\nTop consistent heads:")
        print(consistent_heads.head(10))
    else:
        print("\nNo heads met consistency threshold — consider lowering to 20% for inspection.")

Loaded summary for 64 samples
Total rows (sample, layer, head): 65536
Significant (sample,layer,head) instances: 6696
Consistent heads (≥30% samples): 118

===== FINAL SUMMARY =====
Total samples: 64
Total heads: 1024
Significant instances: 6696
Consistent heads: 118

Top consistent heads:
     layer  head  num_samples  fraction
144      8     7           61  0.953125
186      9    30           56  0.875000
187     10     4           53  0.828125
301     14    27           53  0.828125
244     12    15           50  0.781250
275     13    25           50  0.781250
442     20    27           49  0.765625
173      9    16           49  0.765625
192     10    10           46  0.718750
154      8    21           46  0.718750
