In [None]:
import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt
import scipy.stats as stt

In [None]:
DATA_DIR = '../../datasets/NNN/face_roi_data.pkl'
dat = pd.read_pickle(DATA_DIR)
print(f'Unique face ROIs: {list(dat['roi'].unique())}')

In [None]:
ROI = 'MF1_9_F'
p_val = 0.05
roi_dat = dat[(dat['roi']==ROI) & (dat['p_value']<p_val)].reset_index(drop=True)
X = np.stack(roi_dat['img_psth'])
# X = X[:, :, 1000:]
# X = stt.zscore(X, axis=1)
# X = np.nan_to_num(X)
print(f'ROI data size: {X.shape} (units, time points, images)')

U, T, I = X.shape
X_flat = X.reshape(U, T * I).T  # (samples, features)
X_centered = X_flat - X_flat.mean(axis=0, keepdims=True)
print(f'Centered data size: {X_centered.shape} (time points*images, units)')

X_3d = X_centered.reshape(T, I, U)
print(f'Centered 3d shape: {X_3d.shape} (time points, images, units)')

In [None]:
group_colors = {
    'all faces': 'red',
    'monkey faces': 'pink',
    'human faces': 'darkred',
    'all nonfaces': 'blue',
    'all objects': 'green',
    'monkey bodies': 'purple',
    'animal bodies': 'black'
}
img_sets = {'all images': np.arange(1000,1072), 
           'all faces': np.arange(1000,1024),
           'monkey faces':  np.concatenate([np.arange(1000,1006), np.arange(1009,1016)]),
           'human faces': np.concatenate([np.arange(1006,1009), np.arange(1016,1025)]),
           'all nonfaces': np.arange(1025,1072),
            'all objects': np.setdiff1d(np.arange(1000, 1072), np.concatenate([np.arange(1000,1024), np.arange(1025,1031), np.arange(1043,1049), np.arange(1051,1062)])),
           'monkey bodies': np.concatenate([np.arange(1026,1031), np.arange(1043,1049)]),
            'animal bodies': np.concatenate([np.arange(1026,1031), np.arange(1043,1049), np.arange(1051,1062)]),
           }

fig, ax = plt.subplots(1, 1, figsize=(6,4))
for name, idxs in img_sets.items():
    if name not in group_colors:
        continue
    img_resp = X_3d[:, idxs, :]                # (time, images, units)
    mean_over_imgs = img_resp.mean(axis=1)     # (time, units)
    df = pd.DataFrame(mean_over_imgs)
    df_melt = df.melt(var_name="unit", value_name="response")
    df_melt["time"] = np.tile(np.arange(df.shape[0]), df.shape[1])
    sns.lineplot(x="time", y="response", data=df_melt,
                 color=group_colors[name], label=name, errorbar="se", ax=ax)

ax.set_title("Mean response Â± SE per group")
ax.set_xlabel("Time"); ax.set_ylabel("Response (a.u.)")
plt.tight_layout(); plt.show()