In [None]:
import os, sys, pickle
if '..' not in sys.path:
    sys.path.insert(0, '..')

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
RAND = 0
RESP = (50,220)
BASE = (-50,0)
ONSET = 50
RESP = slice(ONSET + RESP[0], ONSET + RESP[1])
BASE = slice(ONSET + BASE[0], ONSET + BASE[1])

CAT = 'face'
DATA_DIR = './../../datasets/NNN/'
dat = pd.read_pickle(os.path.join(DATA_DIR, (f'{CAT}_roi_data.pkl')))
print(f'Unique {CAT} ROIs: {list(dat['roi'].unique())}')

SAVE_DIR = './../../../buckets/manifold-dynamics/raw-data'
if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)

In [None]:
rng = np.random.default_rng(0)

ROI = 'MF1_8_F'
sig = dat[dat['p_value'] < 0.05]
df = sig[sig['roi'] == ROI]

unit_idx = rng.choice(range(len(df)), 10)
df['img_psth']

X = np.array([resp for resp in df.iloc[unit_idx]['img_psth']])
Xavg = np.mean(X, axis=2)
baseline = Xavg[:, :50].mean(axis=1, keepdims=True)

# subtract baseline
Xbc = Xavg - baseline

# sample units for plotting
n_units, n_time = Xbc.shape
hdf = pd.DataFrame({
    'unit_id': np.repeat(np.arange(n_units), n_time),
    'time': np.tile(np.arange(n_time), n_units),
    'response': Xbc.reshape(-1)
})

# all unit average
Y = np.stack(df['img_psth'].to_numpy())
baseline = np.nanmean(Y[:, BASE, :], axis=1, keepdims=True)
Y_bca = np.mean(Y - baseline, axis=2)
adf = pd.DataFrame({
    'unit_id': np.repeat(np.arange(len(Y_bca)), n_time),
    'time': np.tile(np.arange(n_time), len(Y_bca)),
    'response': Y_bca.reshape(-1)
})

In [None]:
customp = sns.color_palette('husl')
fig, ax = plt.subplots(1,1, figsize=(4,3))

sns.lineplot(hdf, x='time', y='response', hue='unit_id', alpha=0.5, palette=customp, legend=False, ax=ax)
# mean + error on top
sns.lineplot(adf, x='time', y='response', color='black', errorbar='se', ax=ax)

ax.set_xlim(right=500)

ax.set_ylabel('Response (Hz)')
ax.set_xlabel('Time (msec)')

ymin, ymax = ax.get_ylim()
ax.vlines(x=50, ymin=ymin, ymax=ymax, color='red', linestyle='--')

out_path = os.path.join(SAVE_DIR, f'{ROI}_response.png')
sns.despine(fig=fig, trim=True, offset=5)
plt.savefig(out_path, dpi=300, format='png', transparent=True, bbox_inches='tight')
plt.show()

In [None]:
avg_psth = np.mean(df.iloc[0]['img_psth'], axis=1)