In [None]:
import os, math, numpy as np, pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from scipy.stats import zscore

# Time window (ms)
BIN_MS   = 1
ONSET_MS = 0
WINDOW_MS = 350

t_start = int(ONSET_MS / BIN_MS)
t_end   = t_start + int(WINDOW_MS / BIN_MS)

# Global image ranges
GLOBAL_START = 0
LOCALIZER_START = 1000
LOCALIZER_LEN = 72
GLOBAL_END   = LOCALIZER_START + LOCALIZER_LEN  # 1072 (end-exclusive) -> indices 0..1071

# Clustering config
RANDOM_STATE = 0
N_INIT = 10
K_RANGE = list(range(2, 9))
EXAMPLES_PER_CLUSTER = 8  # thumbnails per cluster row

# ROI filter: set to 'F', 'O', 'B', or None for all
ROI_SUFFIX_FILTER = 'F'  # e.g., 'F' to restrict to F-ROIs

# --- Set directories and filename patterns ---
IMAGE_DIR_BASE = "../../datasets/NNN/NSD1000_LOC"  # 0001.bmp..1000.bmp
IMAGE_DIR_LOC  = "../../datasets/NNN/NSD1000_LOC"  # MFOB001.bmp..MFOB072.bmp

# Patterns:
BASE_PATTERN = "{idx:04d}.bmp"     # files 0001.bmp..1000.bmp  (global 0..999 -> idx=global+1)
LOC_PATTERN  = "MFOB{idx:03d}.bmp" # files MFOB001.bmp..072   (global 1000..1071 -> idx=(global-1000)+1)

DATA_DIR = '../../datasets/NNN/unit_data_full.pkl'

In [None]:
dat = pd.read_pickle(DATA_DIR)

In [None]:
def path_for_global_image(global_idx):
    """Return filesystem path for a given global image index (0..1071)."""
    gi = int(global_idx)
    if 0 <= gi <= 999:
        file_idx = gi + 1               # 1..1000
        fname = BASE_PATTERN.format(idx=file_idx)
        return os.path.join(IMAGE_DIR_BASE, fname)
    elif 1000 <= gi <= (LOCALIZER_START + LOCALIZER_LEN - 1):
        local_idx = gi - LOCALIZER_START + 1  # 1..72
        fname = LOC_PATTERN.format(idx=local_idx)
        return os.path.join(IMAGE_DIR_LOC, fname)
    else:
        raise ValueError(f"Global image index out of range: {gi}")

def roi_suffix(roi_label):
    if isinstance(roi_label, str) and roi_label:
        tail = roi_label[-1]
        if tail in {"B","O","F"}:
            return tail
    return "Unknown"

def compute_image_time_traces_for_roi_all(dat_roi, t_start=None, t_end=None,
                                          g_start=GLOBAL_START, g_end=GLOBAL_END):
    """Average per-image timecourses across units in ROI for the full global range [g_start:g_end)."""
    if len(dat_roi) == 0:
        raise ValueError("dat_roi has no rows.")
    # infer total time
    T_total = None
    for _, row in dat_roi.iterrows():
        A = np.asarray(row.get('img_psth', None))
        if A is not None and A.ndim == 2:
            T_total = A.shape[0]
            break
    if T_total is None:
        raise ValueError("No valid img_psth found.")
    if t_start is None: t_start = 0
    if t_end   is None: t_end   = T_total
    if t_end <= t_start:
        raise ValueError(f"t_end must be > t_start (got {t_start}, {t_end}).")

    n_images = int(g_end - g_start)
    T = int(t_end - t_start)
    accum = np.zeros((n_images, T), dtype=float)
    count = 0

    for _, row in dat_roi.iterrows():
        A = np.asarray(row.get('img_psth', None))
        if A is None or A.ndim != 2:
            continue
        B = A[t_start:t_end, g_start:g_end]  # (T, n_images)
        if B.shape[1] != n_images:
            continue
        accum += B.T
        count += 1

    if count == 0:
        raise ValueError("No valid/compatible units in ROI.")
    X_img = accum / float(count)
    global_img_indices = list(range(g_start, g_end))
    return X_img, global_img_indices, count

def choose_k_and_cluster(X, k_range, random_state=0, n_init=10):
    n_samples = X.shape[0]
    valid_ks = [k for k in k_range if 1 < k <= n_samples]
    inertias, sils, labels_by_k = {}, [], {}
    if len(valid_ks) == 0 or n_samples < 2:
        return np.zeros(n_samples, dtype=int), 1, {1: 0.0}, [(1, np.nan)]
    for k in valid_ks:
        km = KMeans(n_clusters=k, random_state=random_state, n_init=n_init)
        lab = km.fit_predict(X)
        inertias[k] = km.inertia_
        sil = silhouette_score(X, lab) if len(np.unique(lab)) > 1 else np.nan
        sils.append((k, sil))
        labels_by_k[k] = lab
    finite = [(k, s) for (k, s) in sils if not (s is None or (isinstance(s, float) and math.isnan(s)))]
    best_k = valid_ks[0] if len(finite) == 0 else max(finite, key=lambda kv: kv[1])[0]
    return labels_by_k[best_k], best_k, inertias, sils

def plot_all_clusters_one_figure(
    X_img_z, labels, global_img_indices,
    title_prefix="", examples_per_cluster=8, onset_ms=None, bin_ms=1
):
    """Overlay all cluster means on one axes; show rows of tiny thumbnails with border color matching the line."""
    clusters = sorted(np.unique(labels))
    t = np.arange(X_img_z.shape[1]) * bin_ms

    means, counts = {}, {}
    for i, c in enumerate(clusters):
        idxs = np.where(labels == c)[0]
        means[c] = X_img_z[idxs].mean(axis=0)
        counts[c] = len(idxs)

    default_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    color_for = {c: default_colors[i % len(default_colors)] for i, c in enumerate(clusters)}

    nrows_img = len(clusters)
    ncols_img = max(1, int(examples_per_cluster))

    fig_height = 3.2 + 0.9 * nrows_img
    fig = plt.figure(figsize=(14, fig_height))
    gs = fig.add_gridspec(nrows=1 + nrows_img, ncols=ncols_img,
                          height_ratios=[2.2] + [1.0]*nrows_img,
                          hspace=0.45, wspace=0.06)

    # Top axes
    ax_top = fig.add_subplot(gs[0, :])
    for c in clusters:
        ax_top.plot(t, means[c], lw=2, color=color_for[c], label=f"Cluster {c} (n={counts[c]})")
    if onset_ms is not None:
        ax_top.axvline(onset_ms, ls="--", lw=1, color="k", alpha=0.6)
    ax_top.set_title(f"{title_prefix} â€” all clusters (all images)")
    ax_top.set_xlabel("Time (ms)")
    ax_top.set_ylabel("Z-scored response")
    ax_top.legend(frameon=False, ncol=min(len(clusters), 5), fontsize=9)

    # Bottom thumbnail rows
    for row_i, c in enumerate(clusters, start=1):
        idxs = np.where(labels == c)[0]
        rng = np.random.default_rng(RANDOM_STATE)
        idxs = rng.permutation(np.where(labels == c)[0])
        take = idxs[:ncols_img]
        for col_j in range(ncols_img):
            ax_im = fig.add_subplot(gs[row_i, col_j])
            if col_j < len(take):
                ex_idx = take[col_j]
                gi = int(global_img_indices[ex_idx])
                try:
                    path = path_for_global_image(gi)
                    im = Image.open(path).convert("RGB")
                    ax_im.imshow(im)
                except Exception:
                    ax_im.text(0.5, 0.5, f"missing\n{gi}", ha="center", va="center", fontsize=8)
            ax_im.axis("off")
            for spine in ax_im.spines.values():
                spine.set_edgecolor(color_for[c])
                spine.set_linewidth(2)

    plt.tight_layout()
    plt.show()

In [None]:
# Prepare ROI list
dat_valid = dat.copy()
dat_valid['roi'] = dat_valid['roi'].fillna('Unknown')
dat_valid['sel'] = dat_valid['roi'].map(roi_suffix)

if ROI_SUFFIX_FILTER in {'F','O','B'}:
    rois = sorted(dat_valid.loc[dat_valid['sel'] == ROI_SUFFIX_FILTER, 'roi'].unique())
else:
    rois = sorted(dat_valid['roi'].unique())
print(f"Found {len(rois)} ROIs")    

for roi_name in rois:
    dat_roi = dat_valid.loc[dat_valid['roi'] == roi_name]
    try:
        X_img, global_img_indices, n_units = compute_image_time_traces_for_roi_all(
            dat_roi, t_start=t_start, t_end=t_end, g_start=GLOBAL_START, g_end=GLOBAL_END
        )
    except Exception as e:
        print(f"Skipping ROI {roi_name}: {e}")
        continue

    X_img_z = zscore(X_img, axis=1)
    X_img_z = np.nan_to_num(X_img_z, nan=0.0, posinf=0.0, neginf=0.0)

    labels, best_k, inertias, sils = choose_k_and_cluster(
        X_img_z, K_RANGE, random_state=RANDOM_STATE, n_init=N_INIT
    )
    print(f"ROI {roi_name}: images={X_img.shape[0]}, units={n_units}, best_k={best_k}")

    plot_all_clusters_one_figure(
        X_img_z, labels, global_img_indices,
        title_prefix=f"{roi_name}", examples_per_cluster=EXAMPLES_PER_CLUSTER,
        onset_ms=ONSET_MS, bin_ms=BIN_MS
    )