In [3]:
!rm -rf /content/drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)


Mounted at /content/drive


In [7]:
# ===========================
# Model comparison + assumptions + group stats (R² & RSA)
# ===========================
import pandas as pd, numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

from scipy.stats import wilcoxon, ttest_rel, shapiro, friedmanchisquare, ttest_1samp
from itertools import combinations

# ---- paths ----
BASE = Path("/content/drive/MyDrive/algonauts_outputs")
clip_csv   = BASE / "group_clip"    / "group_summary_day3.csv"
resnet_csv = BASE / "group_resnet"  / "group_summary_day3.csv"
COMPARE_BASE = BASE / "group_compare"
PLOTS_DIR = COMPARE_BASE / "plots"
COMPARE_BASE.mkdir(parents=True, exist_ok=True)
PLOTS_DIR.mkdir(parents=True, exist_ok=True)

# ---- load ----
df_c = pd.read_csv(clip_csv)
df_r = pd.read_csv(resnet_csv)

# We only need common subjects/ROIs present in both CSVs
key = ["subject","roi"]
cols = ["r2_median","rsa_rho"]
m = df_c[key+cols].merge(df_r[key+cols], on=key, suffixes=("_clip","_resnet"))

# Convenience: ROI order
ROI_ORDER = ["EBA","FFA","PPA"]
m["roi"] = pd.Categorical(m["roi"], categories=ROI_ORDER, ordered=True)
m = m.sort_values(["roi","subject"])

# ===========================
# Helpers
# ===========================
def shapiro_normality(x, min_n=3):
    """Return (is_normal, pval). For tiny n, return (False, nan) to be safe."""
    x = np.asarray(x, float)
    x = x[np.isfinite(x)]
    if x.size < min_n:
        return (False, np.nan)
    try:
        stat, p = shapiro(x)
        return (p > 0.05, p)
    except Exception:
        return (False, np.nan)

def cohens_dz(diffs):
    """Effect size for paired t-test."""
    d = np.asarray(diffs, float)
    d = d[np.isfinite(d)]
    if d.size < 2 or np.std(d, ddof=1) == 0:
        return np.nan
    return np.mean(d) / np.std(d, ddof=1)

def rank_biserial_effect(diffs):
    """Rank-biserial effect for paired nonparametric difference."""
    from scipy.stats import rankdata
    d = np.asarray(diffs, float)
    d = d[np.isfinite(d)]
    d = d[d != 0]  # ignore ties as Wilcoxon does
    n = d.size
    if n == 0:
        return np.nan
    ranks = rankdata(np.abs(d))
    r_pos = np.sum(ranks[d > 0])
    r_neg = np.sum(ranks[d < 0])
    denom = n * (n + 1) / 2.0
    return (r_pos - r_neg) / denom  # in [-1, 1]

def fdr_bh(pvals, alpha=0.05):
    """Return (qvals, significant_mask) using Benjamini–Hochberg."""
    p = np.asarray(pvals, dtype=float)
    mask = np.isfinite(p)
    pe = p[mask]
    n = pe.size
    if n == 0:
        return np.full_like(p, np.nan), np.zeros_like(p, bool)
    order = np.argsort(pe)
    ranked = pe[order]
    q_eff = np.empty(n, float)
    prev = 1.0
    for i in range(n-1, -1, -1):
        q_i = ranked[i] * n / (i+1)
        prev = min(prev, q_i)
        q_eff[order[i]] = prev
    q = np.full_like(p, np.nan); q[mask] = q_eff
    sig = q < alpha
    return q, sig

def bootstrap_ci(x, stat_fn=np.median, n_boot=5000, ci=95, seed=42):
    """Simple bootstrap CI for a statistic (default median)."""
    x = np.asarray(x); x = x[np.isfinite(x)]
    if x.size == 0:
        return np.nan, (np.nan, np.nan)
    rng = np.random.default_rng(seed)
    stats = []
    for _ in range(n_boot):
        bs = rng.choice(x, size=x.size, replace=True)
        stats.append(stat_fn(bs))
    low = np.percentile(stats, (100-ci)/2)
    high = np.percentile(stats, 100 - (100-ci)/2)
    return stat_fn(x), (low, high)

# ===========================
# 1) Paired model comparison (CLIP − ResNet) per ROI
#    with assumptions → pick t-test vs Wilcoxon; effect sizes; BH-FDR
# ===========================
# Compute deltas
m["delta_r2"]  = m["r2_median_clip"] - m["r2_median_resnet"]
m["delta_rho"] = m["rsa_rho_clip"]   - m["rsa_rho_resnet"]

rows = []
for roi, g in m.groupby("roi"):
    # diffs across subjects
    d_r2  = g["delta_r2"].dropna().values
    d_rho = g["delta_rho"].dropna().values

    # --- R²: assumptions on paired differences ---
    norm_r2, pnorm_r2 = shapiro_normality(d_r2)
    if norm_r2 and len(d_r2) >= 2:
        tstat, p_r2 = ttest_rel(g["r2_median_clip"], g["r2_median_resnet"])
        eff_r2 = cohens_dz(d_r2)
        test_r2 = "paired t-test"
    else:
        w = wilcoxon(d_r2) if len(d_r2) >= 2 else None
        p_r2 = (w.pvalue if w else np.nan)
        eff_r2 = rank_biserial_effect(d_r2)
        test_r2 = "Wilcoxon"

    # --- RSA ρ: assumptions on paired differences ---
    norm_rho, pnorm_rho = shapiro_normality(d_rho)
    if norm_rho and len(d_rho) >= 2:
        tstat, p_rho = ttest_rel(g["rsa_rho_clip"], g["rsa_rho_resnet"])
        eff_rho = cohens_dz(d_rho)
        test_rho = "paired t-test"
    else:
        w = wilcoxon(d_rho) if len(d_rho) >= 2 else None
        p_rho = (w.pvalue if w else np.nan)
        eff_rho = rank_biserial_effect(d_rho)
        test_rho = "Wilcoxon"

    rows.append({
        "roi": str(roi),
        "n_subjects": int(g["subject"].nunique()),
        "delta_r2_mean":    float(np.nanmean(d_r2))  if d_r2.size>0 else np.nan,
        "delta_r2_median":  float(np.nanmedian(d_r2))if d_r2.size>0 else np.nan,
        "model_test_r2": test_r2,
        "normality_p_r2": pnorm_r2,
        "p_model_r2": p_r2,
        "effect_r2": eff_r2,            # Cohen's dz if t-test, rank-biserial if Wilcoxon
        "prop_subj_better_r2":  float((d_r2>0).mean())  if d_r2.size>0 else np.nan,

        "delta_rho_mean":   float(np.nanmean(d_rho)) if d_rho.size>0 else np.nan,
        "delta_rho_median": float(np.nanmedian(d_rho))if d_rho.size>0 else np.nan,
        "model_test_rho": test_rho,
        "normality_p_rho": pnorm_rho,
        "p_model_rho": p_rho,
        "effect_rho": eff_rho,          # Cohen's dz or rank-biserial
        "prop_subj_better_rho": float((d_rho>0).mean()) if d_rho.size>0 else np.nan,
    })

res_model = pd.DataFrame(rows).sort_values("roi").reset_index(drop=True)

# BH-FDR across ROIs for each metric
res_model["q_model_r2"],  _ = fdr_bh(res_model["p_model_r2"].values,  alpha=0.05)
res_model["q_model_rho"], _ = fdr_bh(res_model["p_model_rho"].values, alpha=0.05)

# Save
res_model_path = COMPARE_BASE / "model_compare_clip_vs_resnet_with_assumptions.csv"
res_model.to_csv(res_model_path, index=False)
print("Saved model compare table:", res_model_path)
display(res_model)

# ===========================
# 2) Across-subject tests vs 0 (are R² / RSA > 0?) per ROI × model
#    with assumptions
# ===========================
def one_sample_vs_zero(x):
    x = np.asarray(x, float); x = x[np.isfinite(x)]
    norm, pnorm = shapiro_normality(x)
    if norm and x.size >= 2:
        t, p = ttest_1samp(x, popmean=0.0)
        test = "one-sample t-test"
    else:
        # Wilcoxon signed-rank against 0
        try:
            w = wilcoxon(x, zero_method='wilcox', alternative='greater')
            p = w.pvalue
            test = "Wilcoxon > 0"
        except Exception:
            p = np.nan; test = "Wilcoxon > 0"
    return test, p, pnorm

rows = []
for model_name, df in [("CLIP", df_c), ("ResNet", df_r)]:
    for metric, col in [("R²", "r2_median"), ("RSA ρ", "rsa_rho")]:
        for roi in ROI_ORDER:
            vals = df.loc[df["roi"]==roi, col].values
            test, p, pnorm = one_sample_vs_zero(vals)
            rows.append({
                "model": model_name,
                "roi": roi,
                "metric": metric,
                "n_subjects": int(np.isfinite(vals).sum()),
                "normality_p": pnorm,
                "p_vs_zero": p,
            })

res_vs0 = pd.DataFrame(rows)
# FDR separately for each metric
for metric in ["R²","RSA ρ"]:
    mask = (res_vs0["metric"]==metric)
    q, _ = fdr_bh(res_vs0.loc[mask,"p_vs_zero"].values, alpha=0.05)
    res_vs0.loc[mask,"q_vs_zero"] = q

res_vs0_path = COMPARE_BASE / "across_subjects_vs_zero.csv"
res_vs0.to_csv(res_vs0_path, index=False)
print("Saved across-subjects (vs 0):", res_vs0_path)
display(res_vs0)

# ===========================
# 3) ROI effect within each model:
#    rm-ANOVA if normal across ROIs, else Friedman.
#    Then post-hoc paired tests (t or Wilcoxon) + BH-FDR.
# ===========================
def roi_matrix(df, value_col):
    """Return subj×roi matrix in ROI_ORDER; drop rows with any NaN to keep paired design."""
    tmp = df.pivot_table(index="subject", columns="roi", values=value_col)
    tmp = tmp.reindex(columns=ROI_ORDER)
    tmp = tmp.dropna(axis=0, how="any")
    return tmp

rows_global, rows_posthoc = [], []

for model_name, df in [("CLIP", df_c), ("ResNet", df_r)]:
    for metric, col in [("R²","r2_median"), ("RSA ρ","rsa_rho")]:
        mat = roi_matrix(df, col)  # subj×3
        if mat.shape[0] < 2:
            rows_global.append({"model":model_name,"metric":metric,"test":"NA","p_global":np.nan,"n_subjects":mat.shape[0]})
            continue

        # Normality check per ROI
        normals = []
        for roi in ROI_ORDER:
            is_norm, pnorm = shapiro_normality(mat[roi].values)
            normals.append(is_norm)

        if all(normals):
            # Parametric path would be rm-ANOVA; to avoid extra deps, use Friedman as robust default.
            # If you prefer rm-ANOVA, plug statsmodels AnovaRM here.
            stat, p_global = friedmanchisquare(mat["EBA"], mat["FFA"], mat["PPA"])
            test_used = "Friedman (used as robust default)"
        else:
            stat, p_global = friedmanchisquare(mat["EBA"], mat["FFA"], mat["PPA"])
            test_used = "Friedman"

        rows_global.append({
            "model": model_name,
            "metric": metric,
            "test": test_used,
            "p_global": p_global,
            "n_subjects": mat.shape[0]
        })

        # Post-hoc pairwise with assumptions + FDR later
        for (a,b) in combinations(ROI_ORDER, 2):
            diffs = mat[a].values - mat[b].values
            norm_ab, _ = shapiro_normality(diffs)
            if norm_ab:
                t, p = ttest_rel(mat[a].values, mat[b].values)
                test = "paired t-test"
                eff = cohens_dz(diffs)
            else:
                w = wilcoxon(diffs)
                p = w.pvalue
                test = "Wilcoxon"
                eff = rank_biserial_effect(diffs)
            rows_posthoc.append({
                "model": model_name, "metric": metric, "pair": f"{a} vs {b}",
                "test": test, "p_pair": p, "effect": eff, "n_subjects": mat.shape[0]
            })

res_global = pd.DataFrame(rows_global)
res_posthoc = pd.DataFrame(rows_posthoc)

# FDR for post-hoc within each (model, metric)
res_posthoc["q_pair"] = np.nan
for (model_name, metric), g in res_posthoc.groupby(["model","metric"]):
    q, _ = fdr_bh(g["p_pair"].values, alpha=0.05)
    res_posthoc.loc[g.index, "q_pair"] = q

# Save
res_global_path  = COMPARE_BASE / "roi_global_tests.csv"
res_posthoc_path = COMPARE_BASE / "roi_posthoc_tests.csv"
res_global.to_csv(res_global_path, index=False)
res_posthoc.to_csv(res_posthoc_path, index=False)
print("Saved ROI global/post-hoc tables:", res_global_path, "and", res_posthoc_path)
display(res_global.head())
display(res_posthoc.head())

# ===========================
# 4) Heatmaps of CLIP − ResNet deltas per subject × ROI (kept from your code)
# ===========================
def save_heatmap(pivot_df, title, out_png):
    plt.figure(figsize=(max(6, 0.6*pivot_df.shape[1]), max(3, 0.5*pivot_df.shape[0])))
    plt.imshow(pivot_df.values, aspect="auto")
    plt.colorbar(label="CLIP − ResNet")
    plt.xticks(range(pivot_df.shape[1]), pivot_df.columns, rotation=45, ha="right")
    plt.yticks(range(pivot_df.shape[0]), pivot_df.index)
    plt.title(title); plt.tight_layout()
    plt.savefig(out_png, dpi=200, bbox_inches="tight"); plt.close()

pivot_r2  = m.pivot_table(index="subject", columns="roi", values="delta_r2",  aggfunc="median").sort_index()
pivot_rho = m.pivot_table(index="subject", columns="roi", values="delta_rho", aggfunc="median").sort_index()
save_heatmap(pivot_r2,  "ΔR² (CLIP − ResNet) per Subject × ROI", PLOTS_DIR / "heatmap_delta_r2.png")
save_heatmap(pivot_rho, "ΔRSA ρ (CLIP − ResNet) per Subject × ROI", PLOTS_DIR / "heatmap_delta_rho.png")
print("Saved plots ->", PLOTS_DIR)

# ===========================
# 5) Bar plots with bootstrap CI + FDR stars (kept, wired to new q-values)
# ===========================
def bar_with_ci_and_stars(delta_col, q_series, title, out_png):
    rois = ROI_ORDER
    meds, lo, hi, stars = [], [], [], []
    for r in rois:
        vals = m.loc[m["roi"]==r, delta_col].values
        med, (l,h) = bootstrap_ci(vals, stat_fn=np.median, n_boot=5000)
        meds.append(med); lo.append(l); hi.append(h)
        # star if q<.05 for this ROI
        q = q_series.loc[q_series["roi"]==r, "q"].values
        stars.append(("*" if (len(q)>0 and np.isfinite(q[0]) and q[0]<0.05) else ""))

    meds, lo, hi = np.array(meds), np.array(lo), np.array(hi)
    err_low = meds - lo; err_high = hi - meds

    plt.figure(figsize=(6, 4))
    xs = np.arange(len(rois))
    plt.bar(xs, meds, yerr=[err_low, err_high], capsize=3)
    plt.xticks(xs, rois, rotation=0)
    for i, s in enumerate(stars):
        if s:
            plt.text(xs[i], meds[i] + (err_high[i] if np.isfinite(err_high[i]) else 0) + 0.005,
                     s, ha="center", va="bottom", fontsize=12)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_png, dpi=200, bbox_inches="tight"); plt.close()

# Build small series for q-values per ROI from res_model
q_r2  = res_model[["roi","q_model_r2"]].rename(columns={"q_model_r2":"q"})
q_rho = res_model[["roi","q_model_rho"]].rename(columns={"q_model_rho":"q"})

bar_with_ci_and_stars(
    delta_col="delta_r2",
    q_series=q_r2,
    title="ΔR² (CLIP − ResNet) per ROI — medians, 95% CI (★ q<.05)",
    out_png=PLOTS_DIR / "bars_delta_r2_FDR.png"
)

bar_with_ci_and_stars(
    delta_col="delta_rho",
    q_series=q_rho,
    title="ΔRSA ρ (CLIP − ResNet) per ROI — medians, 95% CI (★ q<.05)",
    out_png=PLOTS_DIR / "bars_delta_rho_FDR.png"
)

print("Saved FDR bar plots ->", PLOTS_DIR)

# ===========================
# 6) ROI medians table (unchanged, handy for paper)
# ===========================
roi_meds = (
    pd.concat([
        df_c.assign(model="CLIP")[["roi","subject","r2_median","rsa_rho","model"]],
        df_r.assign(model="ResNet")[["roi","subject","r2_median","rsa_rho","model"]],
    ])
    .groupby(["model","roi"])
    .agg(r2_median=("r2_median","median"), rsa_rho=("rsa_rho","median"))
    .reset_index()
    .sort_values(["roi","model"])
)
roi_meds.to_csv(COMPARE_BASE / "roi_medians_by_model.csv", index=False)
display(roi_meds)

# ===========================
# 7) Tiny textual summary hints (fill with your actual numbers)
# ===========================
# Example of how you might phrase results in the README once you inspect the tables:
# - "R² was significantly > 0 in all ROIs for both models (FDR-corrected)."
# - "Global ROI effect significant (Friedman) for RSA in CLIP; post-hoc EBA > PPA (q<.05)."
# - "Model differences (CLIP−ResNet) not significant after FDR in R²; RSA trended higher for CLIP in EBA."


Saved model compare table: /content/drive/MyDrive/algonauts_outputs/group_compare/model_compare_clip_vs_resnet_with_assumptions.csv


  for roi, g in m.groupby("roi"):


Unnamed: 0,roi,n_subjects,delta_r2_mean,delta_r2_median,model_test_r2,normality_p_r2,p_model_r2,effect_r2,prop_subj_better_r2,delta_rho_mean,delta_rho_median,model_test_rho,normality_p_rho,p_model_rho,effect_rho,prop_subj_better_rho,q_model_r2,q_model_rho
0,EBA,8,0.025963,0.026325,paired t-test,0.204213,2.6e-05,3.43852,1.0,-0.02651,-0.024851,paired t-test,0.372158,0.000371,-2.258746,0.0,3.9e-05,0.000557
1,FFA,8,0.023465,0.023377,paired t-test,0.661648,5.2e-05,3.083424,1.0,-0.017151,-0.019209,paired t-test,0.634284,0.012093,-1.187759,0.125,5.2e-05,0.012093
2,PPA,8,0.020092,0.019619,paired t-test,0.657775,1.5e-05,3.719466,1.0,-0.053422,-0.051636,paired t-test,0.623882,4e-06,-4.50479,0.0,3.9e-05,1.3e-05


Saved across-subjects (vs 0): /content/drive/MyDrive/algonauts_outputs/group_compare/across_subjects_vs_zero.csv


Unnamed: 0,model,roi,metric,n_subjects,normality_p,p_vs_zero,q_vs_zero
0,CLIP,EBA,R²,8,0.210432,5.161135e-05,8.559441e-05
1,CLIP,FFA,R²,8,0.689184,2.686719e-05,8.559441e-05
2,CLIP,PPA,R²,8,0.135134,7.132867e-05,8.559441e-05
3,CLIP,EBA,RSA ρ,8,0.319583,7.714582e-07,9.257498e-07
4,CLIP,FFA,RSA ρ,8,0.873653,4.398067e-07,8.280243e-07
5,CLIP,PPA,RSA ρ,8,0.211141,2.15479e-07,8.280243e-07
6,ResNet,EBA,R²,8,0.286877,6.182052e-05,8.559441e-05
7,ResNet,FFA,R²,8,0.712626,3.753095e-05,8.559441e-05
8,ResNet,PPA,R²,8,0.085507,9.147389e-05,9.147389e-05
9,ResNet,EBA,RSA ρ,8,0.213002,5.520162e-07,8.280243e-07


Saved ROI global/post-hoc tables: /content/drive/MyDrive/algonauts_outputs/group_compare/roi_global_tests.csv and /content/drive/MyDrive/algonauts_outputs/group_compare/roi_posthoc_tests.csv


Unnamed: 0,model,metric,test,p_global,n_subjects
0,CLIP,R²,Friedman (used as robust default),0.020754,8
1,CLIP,RSA ρ,Friedman (used as robust default),0.011109,8
2,ResNet,R²,Friedman (used as robust default),0.020754,8
3,ResNet,RSA ρ,Friedman (used as robust default),0.07244,8


Unnamed: 0,model,metric,pair,test,p_pair,effect,n_subjects,q_pair
0,CLIP,R²,EBA vs FFA,paired t-test,0.19429,0.507515,8,0.19429
1,CLIP,R²,EBA vs PPA,paired t-test,0.056763,-0.805567,8,0.085144
2,CLIP,R²,FFA vs PPA,paired t-test,0.034656,-0.924574,8,0.085144
3,CLIP,RSA ρ,EBA vs FFA,paired t-test,0.040507,0.886762,8,0.040507
4,CLIP,RSA ρ,EBA vs PPA,paired t-test,0.002501,1.624384,8,0.007502


  pivot_r2  = m.pivot_table(index="subject", columns="roi", values="delta_r2",  aggfunc="median").sort_index()
  pivot_rho = m.pivot_table(index="subject", columns="roi", values="delta_rho", aggfunc="median").sort_index()


Saved plots -> /content/drive/MyDrive/algonauts_outputs/group_compare/plots
Saved FDR bar plots -> /content/drive/MyDrive/algonauts_outputs/group_compare/plots


Unnamed: 0,model,roi,r2_median,rsa_rho
0,CLIP,EBA,0.202472,0.198516
3,ResNet,EBA,0.178431,0.232025
1,CLIP,FFA,0.184821,0.175794
4,ResNet,FFA,0.169054,0.191782
2,CLIP,PPA,0.217915,0.144652
5,ResNet,PPA,0.200271,0.194571


In [9]:
# ==========================================
# Key result plots for Algonauts model comparison
# ==========================================
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# -----------------------
# Paths and constants
# -----------------------
BASE = Path("/content/drive/MyDrive/algonauts_outputs")
CLIP_CSV   = BASE / "group_clip"   / "group_summary_day3.csv"
RESNET_CSV = BASE / "group_resnet" / "group_summary_day3.csv"
RES_MODEL  = BASE / "group_compare" / "model_compare_clip_vs_resnet_with_assumptions.csv"

OUT_DIR = BASE / "group_compare" / "plots_key"
OUT_DIR.mkdir(parents=True, exist_ok=True)

ROI_ORDER = ["EBA", "FFA", "PPA"]  # consistent ordering

# -----------------------
# Load and prepare data
# -----------------------
df_c = pd.read_csv(CLIP_CSV)
df_r = pd.read_csv(RESNET_CSV)
res_model = pd.read_csv(RES_MODEL)  # has q_model_r2, q_model_rho

# restrict to common subjects/ROIs
key  = ["subject","roi"]
cols = ["r2_median","rsa_rho"]
m = (
    df_c[key+cols]
    .merge(df_r[key+cols], on=key, suffixes=("_clip","_resnet"))
    .copy()
)
m = m[m["roi"].isin(ROI_ORDER)].copy()
m["roi"] = pd.Categorical(m["roi"], categories=ROI_ORDER, ordered=True)
m = m.sort_values(["roi","subject"])

# per-subject deltas
m["delta_r2"]  = m["r2_median_clip"] - m["r2_median_resnet"]
m["delta_rho"] = m["rsa_rho_clip"]   - m["rsa_rho_resnet"]

# ROI medians per model (for the summary bars)
roi_meds = (
    pd.concat([
        df_c.assign(model="CLIP")[["roi","subject","r2_median","rsa_rho","model"]],
        df_r.assign(model="ResNet")[["roi","subject","r2_median","rsa_rho","model"]],
    ])
    .groupby(["model","roi"])
    .agg(r2_median=("r2_median","median"), rsa_rho=("rsa_rho","median"))
    .reset_index()
)
roi_meds = roi_meds[roi_meds["roi"].isin(ROI_ORDER)].copy()
roi_meds["roi"] = pd.Categorical(roi_meds["roi"], categories=ROI_ORDER, ordered=True)
roi_meds = roi_meds.sort_values(["roi","model"])

# small helper: turn q-value into stars
def stars_from_q(q):
    if not np.isfinite(q): return ""
    if q < 0.001: return "***"
    if q < 0.01:  return "**"
    if q < 0.05:  return "*"
    return ""

# -----------------------
# Fig 1: ROI medians with CLIP vs ResNet + FDR stars (model difference within ROI)
# -----------------------
def plot_roi_medians_with_stars(roi_meds, res_model, out_png):
    rois = ROI_ORDER
    # gather bar heights
    clip_r2 = []; resnet_r2 = []
    clip_rho = []; resnet_rho = []
    for r in rois:
        clip_row   = roi_meds[(roi_meds["model"]=="CLIP") & (roi_meds["roi"]==r)]
        resnet_row = roi_meds[(roi_meds["model"]=="ResNet") & (roi_meds["roi"]==r)]
        clip_r2.append(  clip_row["r2_median"].iloc[0]  )
        resnet_r2.append(resnet_row["r2_median"].iloc[0])
        clip_rho.append( clip_row["rsa_rho"].iloc[0]   )
        resnet_rho.append(resnet_row["rsa_rho"].iloc[0])

    # grab q-values for model difference within each ROI
    q_r2  = [ float(res_model.loc[res_model["roi"]==r, "q_model_r2"].values[0])  for r in rois ]
    q_rho = [ float(res_model.loc[res_model["roi"]==r, "q_model_rho"].values[0]) for r in rois ]
    star_r2  = [stars_from_q(q) for q in q_r2]
    star_rho = [stars_from_q(q) for q in q_rho]

    xs = np.arange(len(rois))
    w = 0.35

    fig, axes = plt.subplots(2, 1, figsize=(7, 6), sharex=True)

    # --- Top: R^2 ---
    ax = axes[0]
    ax.bar(xs - w/2, clip_r2, width=w, label="CLIP")
    ax.bar(xs + w/2, resnet_r2, width=w, label="ResNet")
    ax.set_ylabel("R² (median across subjects)")
    ax.set_title("Encoding performance by ROI")
    ax.axhline(0, linewidth=1, color="k")
    # add stars above the higher bar per pair
    for i, s in enumerate(star_r2):
        if s:
            y_top = max(clip_r2[i], resnet_r2[i])
            ax.text(xs[i], y_top + 0.005, s, ha="center", va="bottom", fontsize=12)
    ax.legend(loc="upper right")

    # --- Bottom: RSA rho ---
    ax = axes[1]
    ax.bar(xs - w/2, clip_rho, width=w, label="CLIP")
    ax.bar(xs + w/2, resnet_rho, width=w, label="ResNet")
    ax.set_ylabel("RSA ρ (median across subjects)")
    ax.set_title("Representational similarity by ROI")
    ax.axhline(0, linewidth=1, color="k")
    # stars
    for i, s in enumerate(star_rho):
        if s:
            y_top = max(clip_rho[i], resnet_rho[i])
            ax.text(xs[i], y_top + 0.005, s, ha="center", va="bottom", fontsize=12)

    ax.set_xticks(xs); ax.set_xticklabels(rois)
    fig.tight_layout()
    fig.savefig(out_png, dpi=200, bbox_inches="tight")
    plt.close(fig)

plot_roi_medians_with_stars(
    roi_meds=roi_meds,
    res_model=res_model,
    out_png=OUT_DIR / "fig1_roi_medians_clip_vs_resnet.png"
)

# -----------------------
# Fig 2: Δ (CLIP − ResNet) bars with 95% bootstrap CI + FDR stars
# -----------------------
def bootstrap_ci(x, n_boot=5000, ci=95, seed=123):
    x = np.asarray(x, float); x = x[np.isfinite(x)]
    if x.size == 0: return np.nan, (np.nan, np.nan)
    rng = np.random.default_rng(seed)
    meds = []
    for _ in range(n_boot):
        bs = rng.choice(x, size=x.size, replace=True)
        meds.append(np.median(bs))
    low = np.percentile(meds, (100-ci)/2)
    high = np.percentile(meds, 100 - (100-ci)/2)
    return float(np.median(x)), (float(low), float(high))

def plot_delta_bars(m, res_model, out_png):
    rois = ROI_ORDER
    meds_r2, lo_r2, hi_r2, stars_r2 = [], [], [], []
    meds_rho, lo_rho, hi_rho, stars_rho = [], [], [], []

    for r in rois:
        vals_r2  = m.loc[m["roi"]==r, "delta_r2"].values
        vals_rho = m.loc[m["roi"]==r, "delta_rho"].values
        med, (l,h) = bootstrap_ci(vals_r2);  meds_r2.append(med);  lo_r2.append(l);  hi_r2.append(h)
        med, (l,h) = bootstrap_ci(vals_rho); meds_rho.append(med); lo_rho.append(l); hi_rho.append(h)

        q_r2  = float(res_model.loc[res_model["roi"]==r, "q_model_r2"].values[0])
        q_rho = float(res_model.loc[res_model["roi"]==r, "q_model_rho"].values[0])
        stars_r2.append(stars_from_q(q_r2))
        stars_rho.append(stars_from_q(q_rho))

    meds_r2, lo_r2, hi_r2 = np.array(meds_r2), np.array(lo_r2), np.array(hi_r2)
    meds_rho, lo_rho, hi_rho = np.array(meds_rho), np.array(lo_rho), np.array(hi_rho)
    err_low_r2  = meds_r2  - lo_r2
    err_high_r2 = hi_r2    - meds_r2
    err_low_rho = meds_rho - lo_rho
    err_high_rho= hi_rho   - meds_rho

    fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=False)

    # ΔR²
    ax = axes[0]
    xs = np.arange(len(rois))
    ax.bar(xs, meds_r2, yerr=[err_low_r2, err_high_r2], capsize=3)
    ax.axhline(0, color="k", linewidth=1)
    ax.set_xticks(xs); ax.set_xticklabels(rois)
    ax.set_title("ΔR² (CLIP − ResNet)")
    ax.set_ylabel("Median ΔR² (with 95% CI)")
    for i, s in enumerate(stars_r2):
        if s:
            ax.text(xs[i], meds_r2[i] + (err_high_r2[i] if np.isfinite(err_high_r2[i]) else 0) + 0.005,
                    s, ha="center", va="bottom", fontsize=12)

    # ΔRSA ρ
    ax = axes[1]
    ax.bar(xs, meds_rho, yerr=[err_low_rho, err_high_rho], capsize=3)
    ax.axhline(0, color="k", linewidth=1)
    ax.set_xticks(xs); ax.set_xticklabels(rois)
    ax.set_title("ΔRSA ρ (CLIP − ResNet)")
    ax.set_ylabel("Median Δρ (with 95% CI)")
    for i, s in enumerate(stars_rho):
        if s:
            ax.text(xs[i], meds_rho[i] + (err_high_rho[i] if np.isfinite(err_high_rho[i]) else 0) + 0.005,
                    s, ha="center", va="bottom", fontsize=12)

    fig.tight_layout()
    fig.savefig(out_png, dpi=200, bbox_inches="tight")
    plt.close(fig)

plot_delta_bars(m, res_model, OUT_DIR / "fig2_deltas_with_ci_and_stars.png")

# -----------------------
# Fig 3: Paired scatter vs identity line (per subject)
# -----------------------
def plot_paired_scatter(m, out_png):
    # One panel for R², one for RSA; color/marker by ROI
    rois = ROI_ORDER
    markers = {"EBA":"o", "FFA":"s", "PPA":"^"}
    fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False)

    # R²
    ax = axes[0]
    # set joint limits based on data
    r2_all = np.concatenate([m["r2_median_clip"].values, m["r2_median_resnet"].values])
    lo, hi = float(np.nanmin(r2_all)), float(np.nanmax(r2_all))
    pad = (hi - lo) * 0.05
    ax.plot([lo-pad, hi+pad], [lo-pad, hi+pad], linewidth=1)  # identity
    for r in rois:
        g = m[m["roi"]==r]
        ax.scatter(g["r2_median_resnet"], g["r2_median_clip"], marker=markers[r], label=r)
    ax.set_xlabel("ResNet R²")
    ax.set_ylabel("CLIP R²")
    ax.set_title("Paired subject points (R²)")
    ax.legend(frameon=False, ncol=3)

    # RSA
    ax = axes[1]
    rho_all = np.concatenate([m["rsa_rho_clip"].values, m["rsa_rho_resnet"].values])
    lo, hi = float(np.nanmin(rho_all)), float(np.nanmax(rho_all))
    pad = (hi - lo) * 0.05
    ax.plot([lo-pad, hi+pad], [lo-pad, hi+pad], linewidth=1)  # identity
    for r in rois:
        g = m[m["roi"]==r]
        ax.scatter(g["rsa_rho_resnet"], g["rsa_rho_clip"], marker=markers[r], label=r)
    ax.set_xlabel("ResNet RSA ρ")
    ax.set_ylabel("CLIP RSA ρ")
    ax.set_title("Paired subject points (RSA)")
    # legend already shown

    fig.tight_layout()
    fig.savefig(out_png, dpi=200, bbox_inches="tight")
    plt.close(fig)

plot_paired_scatter(m, OUT_DIR / "fig3_paired_scatter_identity.png")

# -----------------------
# Fig 4: Heatmaps of per-subject deltas
# -----------------------
def save_heatmap(pivot_df, title, cbar_label, out_png):
    plt.figure(figsize=(6, max(3, 0.4*pivot_df.shape[0])))
    im = plt.imshow(pivot_df.values, aspect="auto")
    cbar = plt.colorbar(im); cbar.set_label(cbar_label)
    plt.xticks(range(pivot_df.shape[1]), pivot_df.columns)
    plt.yticks(range(pivot_df.shape[0]), pivot_df.index)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_png, dpi=200, bbox_inches="tight")
    plt.close()

pivot_r2  = m.pivot_table(
    index="subject", columns="roi", values="delta_r2",
    aggfunc="median", observed=False   # explicitly set
).sort_index()

pivot_rho = m.pivot_table(
    index="subject", columns="roi", values="delta_rho",
    aggfunc="median", observed=False   # explicitly set
).sort_index()

save_heatmap(pivot_r2,  "ΔR² (CLIP − ResNet) per Subject × ROI", "ΔR²", OUT_DIR / "fig4_heatmap_delta_r2.png")
save_heatmap(pivot_rho, "ΔRSA ρ (CLIP − ResNet) per Subject × ROI", "ΔRSA ρ", OUT_DIR / "fig4_heatmap_delta_rho.png")

print("Saved all key plots to:", OUT_DIR)


Saved all key plots to: /content/drive/MyDrive/algonauts_outputs/group_compare/plots_key
