In [None]:
import os
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import animation
from IPython.display import HTML
import seaborn as sns
from PIL import Image

from tqdm import tqdm
from scipy.spatial.distance import pdist, cdist, squareform
from scipy.stats import pearsonr, spearmanr, entropy, rankdata
from sklearn.manifold import MDS
from scipy.ndimage import gaussian_filter1d

RAND = 0
RESP = (50,220)
BASE = (-50,0)
ONSET = 50
RESP = slice(ONSET + RESP[0], ONSET + RESP[1])
BASE = slice(ONSET + BASE[0], ONSET + BASE[1])

In [None]:
DATA_DIR = '../../datasets/NNN/'
IMG_DIR = '../../datasets/NNN/NSD1000_LOC'

_r = 'body'
dat = pd.read_pickle(os.path.join(DATA_DIR, (f'{_r}_roi_data.pkl')))
print(f'Unique face ROIs: {list(dat['roi'].unique())}')

with open(f'../../datasets/NNN/{_r}_mins.pkl', 'rb') as f:
    mins = pickle.load(f)
    
SCALES = {k: v[0] for k,v in mins.items()}

In [None]:
def geo_rdm(dat, roi, mode='top', step=5, k_max=200, metric='correlation', random_state=RAND):
    rng = np.random.default_rng(random_state)
    sig = dat[dat['p_value'] < 0.05]
    df = sig[sig['roi'] == roi]
    if len(df) == 0:
        raise ValueError(f"No data for ROI {ROI}")
    X = np.stack(df['img_psth'].to_numpy())          # (units, time, images)

    scores = np.nanmean(X[:, RESP, :], axis=(0,1)) - np.nanmean(X[:, BASE, :], axis=(0,1))
    order = np.argsort(scores)[::-1] if mode == 'top' else rng.permutation(scores.size)

    # ================= choose the image-set bins to calculate RDMs ========
    sizes = [k for k in range(step, min(k_max, X.shape[2]) + 1, step)]
    # =================== ramping step size ================================ 
    # sizes = [k for k in range(1, 2*step)] + [k for k in range(2*step, min(k_max, X.shape[2])+1, step)]
    
    rdvs = []
    for k in tqdm(sizes):
        idx = order[:k]
        Ximg = X[:, 100:400, idx] # (units, time, images)
        Xrdv = np.array([pdist(Ximg[:, t, :].T, metric='correlation') for t in range(Ximg.shape[1])])
        # spearman instead of correlation
        Xrdv = np.apply_along_axis(rankdata, 1, Xrdv)
        R = squareform(pdist(Xrdv, metric=metric))   # (time, time)
        rdvs.append(R)
    return sizes, rdvs

def static_rdm(dat, roi, mode='top', scale=30, tstart=100, tend=400, metric='correlation', random_state=RAND):
    rng = np.random.default_rng(random_state)
    sig = dat[dat['p_value'] < 0.05]
    df = sig[sig['roi'] == roi]
    if len(df) == 0:
        raise ValueError(f"No data for ROI {roi}")
    X = np.stack(df['img_psth'].to_numpy())          # (units, time, images)

    # score images (using global RESP/BASE you already defined)
    scores = np.nanmean(X[:, RESP, :], axis=(0, 1)) - np.nanmean(X[:, BASE, :], axis=(0, 1))
    order = np.argsort(scores)[::-1] if mode == 'top' else rng.permutation(scores.size)

    # pick image subset
    idx = order[:scale]

    # restrict to desired time window
    Ximg = X[:, tstart:tend, idx]                    # (units, time, images)

    # time-by-RDV (one RDV per timepoint)
    Xrdv = np.array([
        pdist(Ximg[:, t, :].T, metric='correlation')
        for t in range(Ximg.shape[1])
    ])  # (time, n_pairs)

    # # Spearman: rank-transform rows, then use correlation distance across time
    Xrank = np.apply_along_axis(rankdata, 1, Xrdv)
    R = squareform(pdist(Xrank, metric=metric))      # (time, time)

    return R, Xrdv

def local_scale(dat, roi, mode='top', scale=30, tstart=100, tend=400, metric='correlation', random_state=RAND):
    rng = np.random.default_rng(random_state)
    sig = dat[dat['p_value'] < 0.05]
    df = sig[sig['roi'] == roi]
    if len(df) == 0:
        raise ValueError(f"No data for ROI {roi}")
    X = np.stack(df['img_psth'].to_numpy())          # (units, time, images)

    # score images (using global RESP/BASE you already defined)
    scores = np.nanmean(X[:, RESP, :], axis=(0, 1)) - np.nanmean(X[:, BASE, :], axis=(0, 1))
    order = np.argsort(scores)[::-1] if mode == 'top' else rng.permutation(scores.size)

    return order

def rdv(X):
    ind = np.triu_indices_from(X, k=1)
    return(X[ind])

def l2(X):
    return np.sqrt(np.sum((X)**2))

def ED1(R):
    S = -0.5 * R**2
    lam = np.linalg.eigvalsh(S)
    lam = np.clip(lam, 0, None)
    return (lam.sum()**2) / (lam**2).sum()

def ED2(R):
    # R = distance matrix
    n = R.shape[0]
    J = np.eye(n) - np.ones((n, n))/n
    B = -0.5 * J @ (R**2) @ J
    lam = np.linalg.eigvalsh(B)
    lam = np.clip(lam, 0, None)
    return (lam.sum()**2) / (lam**2).sum()


def entropy(V):
    v = np.abs(V)
    return v / v.sum()

def idx_to_fname(idx):
    # idx is 0-based
    if idx < 1000:
        # 0–999 → DDDD.bmp, 1-indexed with leading zeros
        return f'{idx + 1:04d}.bmp'
    else:
        # 1000–1071 → MFOBDDD.bmp, DDD = (idx - 999)
        return f'MFOB{idx - 999:03d}.bmp'

In [None]:
ROI = 'Unknown_19_F' # Unknown_19_F, MF1_7_F, MF1_8_F, MF1_9_F, AF3_18_F, MB1_3_B
MODE = 'top'
scale = SCALES[ROI]
# 100 t0, 200 window for medial
# 150 to, 300 window for anterior
t0 = 100 
window = 200 #

Rt, _ = static_rdm(dat, ROI, mode=MODE, scale=scale,
                  tstart=t0, tend=t0 + window)

Ra, _ = static_rdm(dat, ROI, mode=MODE, scale=1072,
                  tstart=t0, tend=t0 + window)

Rss = []
for i in tqdm(range(500)):
    Rs, _ = static_rdm(dat, ROI, mode='shuff', scale=scale,
                      tstart=t0, tend=t0 + window, random_state=i)
    Rss.append(Rs)

In [None]:
for _r, scale in SCALES.items():
    order = local_scale(dat, _r, mode=MODE, scale=scale,
                      tstart=t0, tend=t0 + window)
    
    n = np.min([scale, 20])
    ncols = 5
    nrows = int(np.ceil(n / ncols))
    
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols+1,nrows))
    axes = np.atleast_1d(axes).ravel()
    
    for ax, idx in zip(axes, order[:n]):
        fname = idx_to_fname(int(idx))
        img = Image.open(os.path.join(IMG_DIR, fname))
        ax.imshow(img, cmap='gray')
        ax.set_title(f'image {idx}', fontsize=8)
        ax.axis('off')
    
    # hide unused axes
    for ax in axes[n:]:
        ax.axis('off')
    
    fig.suptitle(f'{_r} top images')
    plt.tight_layout()
    plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10,4))

ax = axes[0]
sns.heatmap(Rt, vmax=1, 
            cbar=False, square=True, 
            ax=ax)
ax.set_axis_off()
ax.set_title(f'Top {scale}')

ax = axes[1]
sns.heatmap(Ra, vmax=1, 
            cbar=False, square=True, 
            ax=ax)
ax.set_axis_off()
ax.set_title('Global Span')

In [None]:
Vt = rdv(Rt)
Va = rdv(Ra)

# collect effective dimensionalities
Rs_EDs = np.array([ED2(Rs) for Rs in Rss])

conditions = [f'T{scale}', 'S30', 'All']

# build long-form dataframe so seaborn can handle stats + errorbars
rows = []

rows.append({'cond': conditions[0], 'ed': ED2(Rt), 'kind': 'test'})
for v in Rs_EDs:
    rows.append({'cond': conditions[1], 'ed': v, 'kind': 'shuffle'})
rows.append({'cond': conditions[2], 'ed': ED2(Ra), 'kind': 'all'})

df = pd.DataFrame(rows)

fig, ax = plt.subplots(1, 1, figsize=(5, 5))

# bars (median + sd errorbars)
sns.barplot(
    data=df,
    x='cond',
    y='ed',
    order=conditions,
    estimator=np.median,
    errorbar='sd', errcolor='black',
    palette=['red', 'blue', 'gray'],
    ax=ax,
    zorder=1
)

# dots for the shuffle distribution (middle bar only)
sns.stripplot(
    data=df[df['cond'] == conditions[1]],
    x='cond',
    y='ed',
    order=conditions,
    color='black',
    size=2,
    jitter=0.2,
    alpha=0.1,
    ax=ax,
    zorder=2
)

ax.set_title(f'{ROI} dynamics ({t0}-{t0+window} msec)')
ax.set_xlabel('image manifold scale')
ax.set_ylabel('effective dimensionality')

sns.despine(fig, trim=True, offset=5)
plt.tight_layout()

In [None]:
for ROI, SC in mins.items():
# for ROI in ['MB3_12_B', 'AB3_12_B', 'AB3_17_B', 'Unknown_27_B']: # not converging for: 'MB1_8_B', 'AB3_18_B',
    # SC = mins[ROI]
    scale = SC[0]
    
    MODE = 'top'
    # 100 t0, 200 window for medial
    # 150 to, 300 window for anterior
    t0 = 100 
    window = 200 #
    
    Rt, _ = static_rdm(dat, ROI, mode=MODE, scale=scale,
                      tstart=t0, tend=t0 + window)
    
    Ra, _ = static_rdm(dat, ROI, mode=MODE, scale=1072,
                      tstart=t0, tend=t0 + window)
    
    Rss = []
    for i in tqdm(range(500)):
        Rs, _ = static_rdm(dat, ROI, mode='shuff', scale=scale,
                          tstart=t0, tend=t0 + window, random_state=i)
        Rss.append(Rs)
    
    Vt = rdv(Rt)
    Va = rdv(Ra)

    try:
        # collect effective dimensionalities
        Rs_EDs = np.array([ED2(Rs) for Rs in Rss])
    except:
        Rs_EDs = [np.nan for _ in Rss]
        print(f'eigenvalues do not converge for ROI: {ROI}')
        
    conditions = [f'T{scale}', 'S30', 'All']
    
    # build long-form dataframe so seaborn can handle stats + errorbars
    rows = []
    
    rows.append({'cond': conditions[0], 'ed': ED2(Rt), 'kind': 'test'})
    for v in Rs_EDs:
        rows.append({'cond': conditions[1], 'ed': v, 'kind': 'shuffle'})
    rows.append({'cond': conditions[2], 'ed': ED2(Ra), 'kind': 'all'})
    
    df = pd.DataFrame(rows)
    
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    
    # bars (median + sd errorbars)
    sns.barplot(
        data=df,
        x='cond',
        y='ed',
        order=conditions,
        estimator=np.median,
        errorbar='sd', errcolor='black',
        palette=['red', 'blue', 'gray'],
        ax=ax,
        zorder=1
    )
    
    # dots for the shufsfle distribution (middle bar only)
    sns.stripplot(
        data=df[df['cond'] == conditions[1]],
        x='cond',
        y='ed',
        order=conditions,
        color='black',
        size=2,
        jitter=0.2,
        alpha=0.1,
        ax=ax,
        zorder=2
    )
    
    ax.set_title(f'{ROI} dynamics ({t0}-{t0+window} msec)')
    ax.set_xlabel('image manifold scale')
    ax.set_ylabel('effective dimensionality')
    
    sns.despine(fig, trim=True, offset=5)
    plt.tight_layout()
    plt.show()