In [None]:
import os, sys, pickle
if '..' not in sys.path:
    sys.path.insert(0, '..')

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, cdist, squareform
from IPython.display import display, clear_output

import utils_txt as tut

In [None]:
RESP = (100,220)
ONSET = 50
RESP = slice(ONSET + RESP[0], ONSET + RESP[1])
RAND = 0

CAT = 'face'
DATA_DIR = '../../datasets/NNN/'
dat = pd.read_pickle(os.path.join(DATA_DIR, (f'{CAT}_roi_data.pkl')))
print(f'Unique {CAT} ROIs: {list(dat['roi'].unique())}')

SAVE_DIR = '../../../buckets/manifold-dynamics/time-averaged/'
if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)

In [None]:
def time_avg_rdm(dat, roi, window=RESP, metric='correlation', random_state=RAND):
    rng = np.random.default_rng(random_state)

    sig = dat[dat['p_value'] < 0.05]
    df = sig[sig['roi'] == roi]
    if len(df) == 0:
        raise ValueError(f"No data for ROI {ROI}")
    X = np.stack(df['img_psth'].to_numpy())          # (units, time, images)

    # average unit responses over time window
    Xw = np.nanmean(X[:, window, 1000:], axis=1)
    Xrdv = pdist(Xw.T, metric=metric)
    R = squareform(Xrdv)

    return R, Xrdv

In [None]:
ROI = "MF1_8_F" # MF1_8_F, Unknown_19_F
start=130
WIN = (start, start+120)

R, _ = tut.time_avg_rdm(dat, roi=ROI, images='localizer', window=WIN)

np.fill_diagonal(R, np.nan)
customp = sns.color_palette('PiYG')

fig, ax = plt.subplots(figsize=(2, 2))
sns.heatmap(R, ax=ax, square=True, 
            cmap=customp, cbar=False)
# ax.set_title('Averaged time window')
# ax.set_axis_off()
ax.set_xticks([])
ax.set_yticks([])

ymin, ymax = ax.get_ylim()
# ax.vlines(x=24, ymin=ymin, ymax=ymax, color='red', linestyle='--')

n = R.shape[1]

# centers in data coordinates
x_faces = (0 + 24) / 2
x_nonfaces = (24 + n) / 2

ymin, ymax = ax.get_ylim()
y_text = ymax + 0.02 * (ymax - ymin)  # slightly above heatmap

ax.text(x_faces, y_text, 'Faces',
        ha='center', va='bottom', style='italic')

ax.text(x_nonfaces, y_text, 'Non-faces',
        ha='center', va='bottom', style='italic')

ax.set_xlabel('Image')
ax.set_ylabel('Image')

plt.savefig(os.path.join(SAVE_DIR, f'{ROI}_{int(WIN[0])}-{int(WIN[1])}.png'), format='png', dpi=300, transparent=True, bbox_inches='tight')
plt.show()

In [None]:
ROI = "MF1_8_F"
WIN = (150, 270)
metric='correlation'
start, stop = 0, 450

sig = dat[dat['p_value'] < 0.05]
df = sig[sig['roi'] == ROI]
if len(df) == 0:
    raise ValueError(f"No data for ROI {ROI}")
X = np.stack(df['img_psth'].to_numpy())          # (units, time, images)

# restrict to desired time window
Ximg = X[:, start:stop, 1000:]                    # (units, time, images)

# time-by-RDV (one RDV per timepoint)
Xrdv = np.array([
    pdist(Ximg[:, t, :].T, metric='correlation')
    for t in range(Ximg.shape[1])
])  # (time, n_pairs)

# # Spearman: rank-transform rows, then use correlation distance across time
Xrank = np.apply_along_axis(rankdata, 1, Xrdv)
R = squareform(pdist(Xrank, metric=metric))      # (time, time)

fig, ax = plt.subplots(figsize=(6, 5))

sns.heatmap(R, ax=ax, square=True, cbar=False)
ax.set_title('Time x Time')
ax.set_axis_off()

plt.show()

In [None]:
ROI = "Unknown_19_F"

# window params
start, stop = 0, 450
window_len = 120
step = 5

windows = [(t, t + window_len) for t in range(start, stop - window_len, step)]

# one figure, redraw onto same axes
fig, ax = plt.subplots(figsize=(6, 5))

for t0, t1 in windows:
    R, Xrdv = tut.time_avg_rdm(dat, roi=ROI, images='localizer', window=(t0, t1))
    np.fill_diagonal(R, np.nan)

    ax.clear()
    sns.heatmap(R, ax=ax, square=True, 
                cmap=customp, cbar=False)

    ax.set_title(f"{ROI}  window: [{t0}, {t1}), vmax: {np.nanmax(R):.03f}, vmin: {np.nanmin(R):.03f}")
    ax.set_xlabel("")
    ax.set_ylabel("")

    clear_output(wait=True)
    display(fig)
    plt.pause(0.1)  # controls how long each frame is shown

plt.close(fig)