In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns

from tqdm import tqdm
from scipy.spatial.distance import pdist, cdist, squareform

RAND = 0
RESP = (50,220)
BASE = (-50,0)
ONSET = 50
RESP = slice(ONSET + RESP[0], ONSET + RESP[1])
BASE = slice(ONSET + BASE[0], ONSET + BASE[1])

In [None]:
DATA_DIR = '../../datasets/NNN/'
dat = pd.read_pickle(os.path.join(DATA_DIR, ('face_roi_data.pkl')))
print(f'Unique face ROIs: {list(dat['roi'].unique())}')

In [None]:
def geo_rdm(dat, roi, mode='top', step=5, k_max=200, metric='correlation'):
    rng = np.random.default_rng(RAND)
    sig = dat[dat['p_value'] < 0.05]
    df = sig[sig['roi'] == roi]
    if len(df) == 0:
        raise ValueError(f"No data for ROI {ROI}")
    X = np.stack(df['img_psth'].to_numpy())          # (units, time, images)

    scores = np.nanmean(X[:, RESP, :], axis=(0,1)) - np.nanmean(X[:, BASE, :], axis=(0,1))
    order = np.argsort(scores)[::-1] if mode == 'top' else rng.permutation(scores.size)

    # ================= choose the image-set bins to calculate RDMs ========
    sizes = [k for k in range(step, min(k_max, X.shape[2]) + 1, step)]
    # =================== ramping step size ================================ 
    # sizes = [k for k in range(1, 2*step)] + [k for k in range(2*step, min(k_max, X.shape[2])+1, step)]
    
    rdvs = []
    for k in tqdm(sizes):
        idx = order[:k]
        Ximg = X[:, :, idx] # (units, time, images)
        Xrdv = np.array([pdist(Ximg[:, t, :].T, metric='correlation') for t in range(Ximg.shape[1])])
        R = squareform(pdist(Xrdv, metric=metric))   # (time, time)
        rdvs.append(R)
    return sizes, rdvs

In [None]:
_roi = 'MF1_8_F'
_mode = 'top'
_step = 1

sizes, rdms_top = geo_rdm(dat, _roi, mode=_mode, step=_step)

In [None]:
fig,axes = plt.subplots(10,5, figsize = (20,30), sharex=True, sharey=True)
axes = axes.ravel()

multiplier = 0
for i in range(50):
    ax = axes[i]
    
    rdm = rdms_top[i + 50*multiplier]
    sns.heatmap(rdm, cbar=False, square=True, ax=ax)
    
    ax.set_axis_off()
    ax.set_title(f'{i + 50*multiplier}')
plt.show()

In [None]:
_roi = 'Unknown_19_F'
_mode = 'top'
_step = 1

sizes2, rdms_top2 = geo_rdm(dat, _roi, mode=_mode, step=_step)

In [None]:
fig,axes = plt.subplots(10,10, figsize = (10,10), sharex=True, sharey=True)
axes = axes.ravel()

multiplier = 0
for i in range(100):
    ax = axes[i]
    
    rdm = rdms_top2[i + 100*multiplier]
    sns.heatmap(rdm, cbar=False, square=True, ax=ax)
    
    ax.set_axis_off()
    ax.set_title(f'{i + 100*multiplier}')
plt.show()

In [None]:
roi = 'Unknown_19_F'
mode = 'top'

rng = np.random.default_rng(RAND)
sig = dat[dat['p_value'] < 0.05]
df = sig[sig['roi'] == roi]
if len(df) == 0:
    raise ValueError(f"No data for ROI {ROI}")
X = np.stack(df['img_psth'].to_numpy())          # (units, time, images)

scores = np.nanmean(X[:, RESP, :], axis=(0,1)) - np.nanmean(X[:, BASE, :], axis=(0,1))
order = np.argsort(scores)[::-1] if mode == 'top' else rng.permutation(scores.size)


IMG_DIR = '../../datasets/NNN/NSD1000_LOC/'

fig,axes = plt.subplots(10,10, figsize = (10,10))
axes = axes.ravel()
for idx, imgid in enumerate(order[0:100]):
    if imgid>1000:
        subpath = f'MFOB{(imgid-1000):03d}.bmp'
    else: 
        subpath = f'{imgid:04d}.bmp'
    ax = axes[idx]
    img_pth = os.path.join(IMG_DIR, subpath)
    img = mpimg.imread(img_pth)
    ax.imshow(img)
    ax.set_axis_off()
plt.show()

In [None]:
print(f'{order[0]:04d}')