In [1]:
# mount drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
#  Aggregate across ROIs & subjects (Algonauts)
# ----------------------------------------------------
# Scan ENCODING RSA outputs, builds per-subject/ROI summary table, and saves plots.

import os, json
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ---------- PATHS ----------
RSA_BASE = Path("/content/drive/MyDrive/algonauts_outputs/encoding_rsa_random_clip")
GROUP_BASE = Path("/content/drive/MyDrive/algonauts_outputs/group_clip")
PLOTS_DIR = GROUP_BASE / "plots"
GROUP_BASE.mkdir(parents=True, exist_ok=True)
PLOTS_DIR.mkdir(parents=True, exist_ok=True)

# ---------- DEBUG ----------
DEBUG = False          # set True to run fast while testing
DEBUG_SUBJECTS = ["subj01"]
DEBUG_ROIS     = ["FFA"]

# ------ ORDER ------
# If these names exist, we will order columns by this list; remaining ROIs appear after in alpha order.
ROI_ORDER = ["EBA","FFA","PPA"]

# ----- HELP --------
def load_json(p: Path):
    if not p.exists():
        return {}
    try:
        with open(p, "r") as f:
            return json.load(f)
    except Exception:
        return {}

def discover_rois_and_subjects(rsa_root: Path):
    roi_dirs = [d for d in rsa_root.iterdir() if d.is_dir()]
    rois = [d.name for d in roi_dirs]
    subjects = set()
    for r in roi_dirs:
        for s in r.iterdir():
            if s.is_dir():
                subjects.add(s.name)
    return sorted(rois), sorted(subjects)

def bootstrap_ci(x, stat_fn=np.median, n_boot=5000, ci=95, seed=42):
    x = np.asarray(x)
    x = x[~np.isnan(x)]
    if len(x) == 0:
        return np.nan, (np.nan, np.nan)
    rng = np.random.default_rng(seed)
    stats = []
    for _ in range(n_boot):
        bs = rng.choice(x, size=len(x), replace=True)
        stats.append(stat_fn(bs))
    low = np.percentile(stats, (100-ci)/2)
    high = np.percentile(stats, 100 - (100-ci)/2)
    return stat_fn(x), (low, high)

def ordered_rois(cols, preferred=ROI_ORDER):
    cols = list(cols)
    pref = [r for r in preferred if r in cols]
    rest = sorted([r for r in cols if r not in pref])
    return pref + rest

# ------ COLLECT -------
all_rows = []
rois_found, subs_found = discover_rois_and_subjects(RSA_BASE)
if DEBUG:
    rois_found = [r for r in rois_found if r in DEBUG_ROIS]
    subs_found = [s for s in subs_found if s in DEBUG_SUBJECTS]

for roi in rois_found:
    for subj in subs_found:
        sdir = RSA_BASE / roi / subj
        if not sdir.exists():
            continue
        enc = load_json(sdir / "encoding_summary.json")
        rsa = load_json(sdir / "rsa_summary.json")
        if not enc or not rsa:
            # skip incomplete
            continue

        all_rows.append({
            "subject": subj,
            "roi": roi,
            "feature_model": enc.get("feature_model", rsa.get("feature_model", "NA")),
            "n_images": enc.get("n_images", np.nan),
            "n_voxels": enc.get("n_voxels", np.nan),
            "r2_mean": enc.get("r2_mean", np.nan),
            "r2_median": enc.get("r2_median", np.nan),
            "r2_top10_mean": enc.get("r2_top10_mean", np.nan),
            "rsa_rho": rsa.get("rho_spearman", np.nan),
            "rsa_p": rsa.get("p_value_perm", np.nan),
        })

df = pd.DataFrame(all_rows)
if df.empty:
    raise RuntimeError("No RSA results found. Check RSA_BASE path.")

# Save  summary
group_csv = GROUP_BASE / "group_summary_day3.csv"
df.to_csv(group_csv, index=False)
print("Saved table ->", group_csv)

# ---- HEATMAPS ----
def pivot_metric(df, metric):
    # rows: subjects, cols: rois
    pivot = df.pivot_table(index="subject", columns="roi", values=metric, aggfunc="median")
    pivot = pivot.reindex(columns=ordered_rois(pivot.columns))
    pivot = pivot.sort_index()  # subjects alpha
    return pivot

def save_heatmap(array_df, title, out_png):
    plt.figure(figsize=(max(6, 0.6*array_df.shape[1]), max(3, 0.5*array_df.shape[0])))
    plt.imshow(array_df.values, aspect="auto")
    plt.colorbar()
    plt.xticks(range(array_df.shape[1]), array_df.columns, rotation=45, ha="right")
    plt.yticks(range(array_df.shape[0]), array_df.index)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_png, dpi=200, bbox_inches="tight")
    plt.close()

hm_r2 = pivot_metric(df, "r2_median")
hm_rho = pivot_metric(df, "rsa_rho")
save_heatmap(hm_r2, "R² (median) per Subject × ROI", PLOTS_DIR / "heatmap_r2_median.png")
save_heatmap(hm_rho, "RSA ρ per Subject × ROI", PLOTS_DIR / "heatmap_rsa_rho.png")
print("Saved heatmaps ->", PLOTS_DIR)

# ---------- BAR PLOTS WITH BOOTSTRAP CI ----------
def roi_bar_with_ci(df, metric, title, out_png, n_boot=5000 if not DEBUG else 1000):
    # compute median across subjects per ROI + bootstrap CI
    rois = ordered_rois(sorted(df["roi"].unique()))
    meds, lows, highs = [], [], []
    for r in rois:
        vals = df.loc[df["roi"]==r, metric].values
        m, (lo, hi) = bootstrap_ci(vals, stat_fn=np.median, n_boot=n_boot)
        meds.append(m); lows.append(lo); highs.append(hi)
    meds = np.array(meds); lows = np.array(lows); highs = np.array(highs)
    err_low = meds - lows; err_high = highs - meds

    plt.figure(figsize=(max(6, 0.6*len(rois)), 4))
    xs = np.arange(len(rois))
    plt.bar(xs, meds, yerr=[err_low, err_high], capsize=3)
    plt.xticks(xs, rois, rotation=45, ha="right")
    plt.title(title)
    plt.ylabel(metric)
    plt.tight_layout()
    plt.savefig(out_png, dpi=200, bbox_inches="tight")
    plt.close()

roi_bar_with_ci(df, "r2_median", "ROI median R² (bootstrap 95% CI)", PLOTS_DIR / "bars_roi_median_r2.png")
roi_bar_with_ci(df, "rsa_rho", "ROI median RSA ρ (bootstrap 95% CI)", PLOTS_DIR / "bars_roi_median_rsa.png")

# ----- SCATTER: ROI-wise median R² vs median RSA ρ -------
agg = df.groupby("roi").agg(
    r2_median_roi=("r2_median","median"),
    rsa_rho_roi=("rsa_rho","median")
).reset_index()
agg = agg.set_index("roi").reindex(ordered_rois(agg["roi"].tolist())).reset_index()

plt.figure(figsize=(6,5))
plt.scatter(agg["r2_median_roi"].values, agg["rsa_rho_roi"].values, s=40)
for i, row in agg.iterrows():
    plt.annotate(row["roi"], (row["r2_median_roi"], row["rsa_rho_roi"]), xytext=(3,3), textcoords="offset points")
plt.xlabel("ROI median R²")
plt.ylabel("ROI median RSA ρ")
plt.title("ROI-wise alignment: encoding vs RSA")
plt.tight_layout()
plt.savefig(PLOTS_DIR / "scatter_roi_r2_vs_rsa.png", dpi=200, bbox_inches="tight")
plt.close()

print("Saved Group figures ->", PLOTS_DIR)
print("Done.")


Saved table -> /content/drive/MyDrive/algonauts_outputs/group_clip/group_summary_day3.csv
Saved heatmaps -> /content/drive/MyDrive/algonauts_outputs/group_clip/plots
Saved Group figures -> /content/drive/MyDrive/algonauts_outputs/group_clip/plots
Done.
