In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import animation
from IPython.display import HTML
import seaborn as sns

from tqdm import tqdm
from scipy.spatial.distance import pdist, cdist, squareform
from scipy.stats import pearsonr, spearmanr, entropy, rankdata
from sklearn.manifold import MDS
from scipy.ndimage import gaussian_filter1d

from sklearn.svm import LinearSVC
from sklearn.preprocessing import StandardScaler

import imageio.v2 as iio
from PIL import Image

RAND = 0
RESP = (50,220)
BASE = (-50,0)
ONSET = 50
RESP = slice(ONSET + RESP[0], ONSET + RESP[1])
BASE = slice(ONSET + BASE[0], ONSET + BASE[1])

In [None]:
DATA_DIR = '../../datasets/NNN/'
dat = pd.read_pickle(os.path.join(DATA_DIR, ('face_roi_data.pkl')))
print(f'Unique face ROIs: {list(dat['roi'].unique())}')

df = pd.read_pickle('../../datasets/NNN/face_rdms.pkl') # this data is originally from pause.ipynb
ROI_LIST = ['MF1_8_F', 'Unknown_19_F', 'MF1_7_F', 'MF1_9_F']

cache = {
    row['ROI']: {k: row[k] for k in df.columns if k != 'ROI'}
    for _, row in df.iterrows()
}

In [None]:
def geo_rdm(dat, roi, mode='top', step=5, k_max=200, metric='correlation'):
    rng = np.random.default_rng(RAND)
    sig = dat[dat['p_value'] < 0.05]
    df = sig[sig['roi'] == roi]
    if len(df) == 0:
        raise ValueError(f"No data for ROI {ROI}")
    X = np.stack(df['img_psth'].to_numpy())          # (units, time, images)

    scores = np.nanmean(X[:, RESP, :], axis=(0,1)) - np.nanmean(X[:, BASE, :], axis=(0,1))
    order = np.argsort(scores)[::-1] if mode == 'top' else rng.permutation(scores.size)

    # ================= choose the image-set bins to calculate RDMs ========
    sizes = [k for k in range(step, min(k_max, X.shape[2]) + 1, step)]
    # =================== ramping step size ================================ 
    # sizes = [k for k in range(1, 2*step)] + [k for k in range(2*step, min(k_max, X.shape[2])+1, step)]
    
    rdvs = []
    for k in tqdm(sizes):
        idx = order[:k]
        Ximg = X[:, 100:400, idx] # (units, time, images)
        Xrdv = np.array([pdist(Ximg[:, t, :].T, metric='correlation') for t in range(Ximg.shape[1])])
        # spearman instead of correlation
        Xrdv = np.apply_along_axis(rankdata, 1, Xrdv)
        R = squareform(pdist(Xrdv, metric=metric))   # (time, time)
        rdvs.append(R)
    return sizes, rdvs

def static_rdm(dat, roi, mode='top', scale=30, tstart=100, tend=400, metric='correlation'):
    rng = np.random.default_rng(RAND)
    sig = dat[dat['p_value'] < 0.05]
    df = sig[sig['roi'] == roi]
    if len(df) == 0:
        raise ValueError(f"No data for ROI {roi}")
    X = np.stack(df['img_psth'].to_numpy())          # (units, time, images)

    # score images (using global RESP/BASE you already defined)
    scores = np.nanmean(X[:, RESP, :], axis=(0, 1)) - np.nanmean(X[:, BASE, :], axis=(0, 1))
    order = np.argsort(scores)[::-1] if mode == 'top' else rng.permutation(scores.size)

    # pick image subset
    idx = order[:scale]

    # restrict to desired time window
    Ximg = X[:, tstart:tend, idx]                    # (units, time, images)

    # time-by-RDV (one RDV per timepoint)
    Xrdv = np.array([
        pdist(Ximg[:, t, :].T, metric='correlation')
        for t in range(Ximg.shape[1])
    ])  # (time, n_pairs)

    # # Spearman: rank-transform rows, then use correlation distance across time
    Xrank = np.apply_along_axis(rankdata, 1, Xrdv)
    R = squareform(pdist(Xrank, metric=metric))      # (time, time)

    return R, Xrdv

In [None]:
ROI = 'Unknown_19_F'

_, top_rdms = geo_rdm(dat, roi=ROI, mode='top',   step=1)

In [None]:
for ROI in ['Unknown_19_F']:
    rdms = []
    norms = []
    window = 50 # sample window
    
    for t0 in range(ONSET, 400, 10):  
        # MOVING WINDOW (SET SIZE) OR GROW FROM T0?
        # R, _ = static_rdm(dat, ROI, mode='top', scale=30,
        #                   tstart=ONSET, tend=ONSET+t0)  
        R, _ = static_rdm(dat, ROI, mode='top', scale=30,
                          tstart=t0, tend=t0+window)  
    
        triu_vals = R[np.triu_indices_from(R, k=1)]
        norm = np.sqrt(np.sum((triu_vals**2)))
        norms.append(norm)
        
        rdms.append(R)
    
    # PLOT NORM OVER A WINDOW    
    fig,ax = plt.subplots(1,1)
    sns.lineplot(norms, ax=ax)
    ax.set_title(f'{ROI} | {window} window, {10} stride')
    ax.set_xticklabels(range(0, 400, 50))
    ax.set_xlabel('msec')
    # ax.axhline(nshuff)
    plt.show()
    
    ### NOTEBOOK IN-LINE ANIMATION
    global_vmax = 1 # np.nanmax([np.nanmax(R) for R in top_rdms])
    global_vmin = 0.0
    
    fig, ax = plt.subplots(figsize=(4,4))
    hm = sns.heatmap(rdms[0], vmin=vmin, vmax=vmax, ax=ax, cbar=False)
    def update(i):
        triu_vals = rdms[i][np.triu_indices_from(rdms[i], k=1)]
        vmax = np.nanmax(triu_vals)
    
        ax.clear()
        sns.heatmap(rdms[i], vmin=global_vmin, vmax=global_vmax, ax=ax, cbar=False)
        ax.set_title(f"Frame {i} | Max {vmax:0.2f} | Norm {norm:0.2f}")
        ax.set_axis_off()
        return ax
    
    ani = animation.FuncAnimation(fig, update, frames=len(rdms), interval=200)
    HTML(ani.to_jshtml())

In [None]:
SCALES = {
    'top': [25, 30, 60, 45],
    'shuff': [85, 30, 55, 35]
}
# MODE = 'top'
### SAVE AS A COOL GIF
for iROI, ROI in enumerate(ROI_LIST):
    rdms = []
    norms = {'top': [], 
            'shuff': []}
    window = 50   # sample window
    stride = 10
    for MODE in ['top', 'shuff']:
        scale = SCALES[MODE][iROI]
    
        # time centers / starts for each window
        t0s = list(range(ONSET, 400, stride))
    
        for t0 in t0s:
            # moving fixed window
            R, _ = static_rdm(dat, ROI, mode=MODE, scale=scale,
                              tstart=t0, tend=t0 + window)
    
            triu_vals = R[np.triu_indices_from(R, k=1)]
            norm = np.sqrt(np.sum(triu_vals**2))
            norms[MODE].append(norm)

            if MODE == 'top':
                rdms.append(R)

    # SAVE AS A GIF (RDM + lineplot with red dot)
    out_path = f"../../gifs/window_{ROI}_combined_ylim.gif"

    global_vmax = 1.0  # or np.nanmax([np.nanmax(R) for R in rdms])
    global_vmin = 0.0

    frames = []
    triu = np.triu_indices_from(rdms[0], k=1)

    for i, R in enumerate(rdms):
        triu_vals = R[triu]
        vmax = np.nanmax(triu_vals)
        norm = norms[MODE][i]
        t_now = t0s[i]

        fig, axes = plt.subplots(1, 2, figsize=(8, 3))

        # --- right: RDM heatmap ---
        ax0 = axes[1]
        sns.heatmap(R, vmin=global_vmin, vmax=global_vmax,
                    ax=ax0, square=True, cbar=False)
        ax0.set_title(f"VMax {vmax:0.2f}")#| Norm {norm:0.2f}")
        ax0.set_axis_off()

        # --- left: norm vs time with red dot at current time ---
        ax1 = axes[0]
        for MODE in ['top', 'shuff']:
            norm = norms[MODE][i]
            t_now = t0s[i]
            sns.lineplot(x=t0s, y=norms[MODE], ax=ax1)
            ax1.scatter(t_now, norm, color='red', s=60, zorder=5)
            
        ax1.set_title(f'{ROI} | window={window}, stride={stride}')
        ax1.set_ylim(top = 30, bottom=0)
        ax1.set_xlabel('time (msec)')
        ax1.set_ylabel('Euclidean norm')
        ax1.grid(alpha=0.3)

        fig.tight_layout()
        fig.canvas.draw()

        img = np.asarray(fig.canvas.buffer_rgba())
        plt.close(fig)

        frames.append(Image.fromarray(img))

    iio.mimsave(out_path, frames, duration=1, loop=0)
    print("Saved:", out_path)

In [None]:
len(norms['top'])

In [None]:
### ABS DISSIM FOR A SINGLE TIME POINT
step = 5
cols = ['ROI', 'Scale', 'Derivative', 'Mode']
diffs = pd.DataFrame(columns=cols)
for _roi in ROI_LIST:
    roi_dict = cache[_roi]
    sizes = roi_dict['sizes_top']
   
    for mode in ['top', 'shuff']:
        rdms = roi_dict[f'{mode}_rdms']
        triu = np.triu_indices_from(rdms[0], k=1)

        ## single time point RDM, or average over previous time step chunk
        # R0 = rdms[step][triu]
        R0 = np.mean(np.array([rdm[triu] for rdm in rdms[0:step]]), axis=0) #######################################
        for t in np.arange(1*step, len(rdms), step):
            prev = R0
            ## same as above
            # R0 = rdms[t-1][triu]
            R0 = np.mean(np.array([rdm[triu] for rdm in rdms[t:t+step]]), axis=0) ######################################
    
            ## difference metric for times t, t'
            # diff = np.sqrt(np.sum(R0**2))
            diff = np.nanmean(R0)
            # diff = np.sqrt(np.sum(np.abs(R0)))
            # diff = 1 - spearmanr(R0, prev).statistic
    
            diffs.loc[len(diffs)] = {'ROI': _roi, 'Scale': sizes[t-1], 'Derivative': diff, 'Mode': mode}


diffs["diff_smooth"] = diffs["Derivative"].groupby(diffs["Mode"]).transform(
    lambda v: gaussian_filter1d(v, sigma=1)
)
mins = {}
for r in ROI_LIST:
    fig, axes = plt.subplots(1, 2, figsize=(10, 3)) # change to 3 for shuffled

    d = diffs[diffs['ROI'] == r]

    ax = axes[0]
    # main lineplot
    sns.lineplot(data=d, x='Scale', y='Derivative', hue='Mode', alpha=0.5, ax=ax)
    sns.lineplot(data=d, x='Scale', y='diff_smooth', hue='Mode', ax=ax)

    # add red dot + legend label for each Mode separately
    labels  = list(ax.get_legend_handles_labels()[1])

    for i, mode in enumerate(['top', 'shuff']):
        dm = d[d['Mode'] == mode]
        idx_min = dm['diff_smooth'].idxmin()
        x_min   = dm.loc[idx_min, 'Scale']
        y_min   = dm.loc[idx_min, 'diff_smooth']

        # draw red dot
        if mode == 'top':
            mins[r] = (x_min, y_min)
        h = ax.scatter(x_min, y_min, color='red', s=60, zorder=5)

        # label for legend
        labels[i] = f'{mode} min @ {int(x_min)}'

    ax.legend(ax.get_legend_handles_labels()[0][:2], labels, frameon=False)
    ax.set_ylabel('Euclidean norm')
    ax.set_xlabel('Manifold scale')
    
    ax = axes[1]
    x_min, y_min = mins[r]
    R, _ = static_rdm(dat, r, mode='top', scale=x_min,
                      tstart=ONSET + 50, tend=ONSET + 300)

    sns.heatmap(R, square=True, cbar=False, vmax=1, ax=ax)
    ax.set_axis_off()
    ax.set_title(f'{r} | Scale: {x_min}')

    # SHOW SHUFFLED AT SAME SCALE
    # ax = axes[2]
    # R, _ = static_rdm(dat, r, mode='shuff', scale=x_min,
    #                   tstart=ONSET + 50, tend=ONSET + 300)
    # sns.heatmap(R, square=True, cbar=False, vmax=1, ax=ax)
    # ax.set_axis_off()
    
    plt.tight_layout()
    plt.show()

In [None]:
vmax = 1 # np.nanmax([np.nanmax(R) for R in top_rdms])
vmin = 0.0

fig, ax = plt.subplots(figsize=(4,4))
hm = sns.heatmap(top_rdms[0], vmin=vmin, vmax=vmax, ax=ax, cbar=False)

def update(i):
    ax.clear()
    sns.heatmap(top_rdms[i], vmin=vmin, vmax=vmax, ax=ax, cbar=False)
    ax.set_title(f"Frame {i}")
    return ax

ani = animation.FuncAnimation(fig, update, frames=len(top_rdms), interval=200)
HTML(ani.to_jshtml())