In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle

import manifold_dynamics.utils_standard as sut

In [None]:
DATA_DIR = '../../datasets/NNN/'
IMG_DIR = '../../datasets/NNN/NSD1000_LOC'
SAVE_DIR = './../../../buckets/manifold-dynamics/raw-data'
if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)

CAT = 'face'
dat = pd.read_pickle(os.path.join(DATA_DIR, (f'{CAT}_roi_data.pkl')))
ROI_LIST = list(dat['roi'].unique())
print(f'Unique {CAT} ROIs: {ROI_LIST}')

In [None]:
roi = 'MF1_8_F'
n_units = 50
img_ids = [4, 29, 70]

# baseline + response windows
base = slice(0, 50)
win = slice(100, 270)

sig = dat[dat['p_value'] < 0.05]
df = sig[sig['roi'] == roi]
X = np.stack(df['img_psth'].to_numpy())

# time-averaged and baseline subtracted
Xavg = np.mean(X[:, win, 1000:], axis=1) - np.mean(X[:, base, 1000:], axis=1)
print(Xavg.shape)

rng = np.random.default_rng(0)
unit_idx = rng.choice(len(Xavg), n_units)

Xsub = Xavg[unit_idx][:, img_ids]
hdf = pd.DataFrame({
    'unit_id': np.repeat(unit_idx, len(img_ids)),
    'image_id': np.tile(img_ids, n_units),
    'response': Xsub.reshape(-1)
})

In [None]:
fig, ax = plt.subplots(figsize=(3,3))
customp = sns.color_palette('CMRmap', 3)
sns.barplot(hdf, x='image_id', y='response', palette=customp, 
            errorbar='se', err_kws={'color': 'black', 'lw': 1}, ax=ax)
sns.stripplot(hdf, x='image_id', y='response', color='black', 
            alpha=0.1, marker='.', ax=ax)

ax.set_ylim(top = 15, bottom=-5)
sns.despine(fig=fig, trim=True, offset=5)

ax.set_xticks(range(3))
ax.set_xticklabels([f'{i:02d}' for i in img_ids])
ax.set_xlabel('Image ID')
ax.set_ylabel('Avg. resp.\n(100--220 msec)')

out_path = os.path.join(SAVE_DIR, f'{roi}_img_avg_bars.png')
# plt.savefig(out_path, dpi=300, bbox_inches='tight', transparent=True)

In [None]:


fig, axes = plt.subplots(1, len(img_ids), figsize=(2 * len(img_ids), 2), squeeze=False)

for i, img_id in enumerate(img_ids):
    ax = axes[0, i]              # <-- index both dims
    col = customp[i]

    sut.load_image(img_id+1000, ax=ax)  # <-- use the actual image id

    ax.add_patch(Rectangle(
        (0, 0), 1, 1,
        transform=ax.transAxes,
        fill=False,
        edgecolor=col,
        linewidth=4,
        clip_on=False
    ))

plt.tight_layout()
out_path = os.path.join(SAVE_DIR, f'{roi}_img_avg_resp.png')
# plt.savefig(out_path, dpi=300, bbox_inches='tight', transparent=True)
plt.show()