In [None]:
import os, fsspec, pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import manifold_dynamics.tuning_utils as tut
import manifold_dynamics.paths as pth
fs = fsspec.filesystem('s3')

In [None]:
### set up roi_data.pkl on s3
datadir = os.path.join(pth.PROCESSED, 'single-session-raster')
fnames = [f for f in fs.ls(datadir) if 'F' in f.split('.')[-2]]

fnm = fnames[1] # i think this file is particularly small
with fs.open(os.path.join('s3://', fnm), 'rb') as f:
    out = np.load(f)
    print(out.shape)
    avg = np.nanmean(out, axis=3)
    print(avg.shape)

x = pd.read_pickle('../../datasets/NNN/face_roi_data.pkl')

this_roi = x[x['roi']=='Unknown_19_F']
print(len(this_roi))

raster = np.stack(this_roi['img_psth'])
print(raster.shape)

In [None]:
DATA_DIR = '../../datasets/NNN/'
IMG_DIR = '../../datasets/NNN/NSD1000_LOC'

CAT = 'face'
dat = pd.read_pickle(os.path.join(DATA_DIR, (f'{CAT}_roi_data.pkl')))
ROI_LIST = list(dat['roi'].unique())
print(f'Unique face ROIs: {ROI_LIST}')

with open(f'../../datasets/NNN/{CAT}_mins.pkl', 'rb') as f:
    mins = pickle.load(f)
    
SCALES = {k: v[0] for k,v in mins.items()}

In [None]:
ROI = 'Unknown_19_F' # Unknown_19_F, MF1_7_F, MF1_8_F, MF1_9_F, AF3_18_F, MB1_3_B
MODE = 'top'
scale = SCALES[ROI]
vmin, vmax= 0, 1
di = {-1: 'shuff', 1072: 'all'}

Rt, _ = tut.static_rdm(dat, ROI, mode=MODE, scale=scale,
                          tstart=0, tend=450)
Rs, _ = tut.static_rdm(dat, ROI, mode=MODE, scale=-1,
                          tstart=0, tend=450)
# Ra, _ = tut.static_rdm(dat, ROI, mode=MODE, scale=1072,
#                           tstart=0, tend=450)
# R = [Rt, Rs, Ra]

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(6, 2))
for i, Rx in enumerate(R):
    ax = axes[i]
    sns.heatmap(Rx[100:300, 100:300], ax=ax, square=True, vmin=vmin, vmax=vmax, cbar=False)
# ax.set_axis_off()

triu = np.triu_indices_from(Rt, k=1)
Rv = Rt[triu]
print(Rv.shape)

# (101025,)

In [None]:
from sklearn.manifold import MDS
for Rx in R:
    rx = Rx[100:300, 100:300]
    mds = MDS(dissimilarity='precomputed', n_components=3, random_state=0)
    T_embed = mds.fit_transform(rx)  # (450, 3)

    # MDS plot
    sns.scatterplot(x=T_embed[:,0], y=T_embed[:,1])
    plt.title('temporal trajectory')
    plt.show()

    fig = plt.figure(figsize=(6, 5))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot(T_embed[:, 0], T_embed[:, 1], T_embed[:, 2], linewidth=2)
    
    # color by time for clarity
    sc = ax.scatter(
        T_embed[:, 0],
        T_embed[:, 1],
        T_embed[:, 2],
        c=np.arange(T_embed.shape[0]),
        cmap='viridis',
        s=15
    )
    plt.show()

    # eigenspectra plot
    S = -0.5 * rx**2
    lam = np.linalg.eigvalsh(S)
    plt.plot(lam[::-1][5:20])
    plt.show()

    # diag plot
    plt.plot(np.diag(rx, k=1))
    plt.show()

In [None]:
ROI = 'MF1_9_F' # Unknown_19_F, MF1_7_F, MF1_8_F, MF1_9_F, AF3_18_F, MB1_3_B
MODE = 'top'
scale = SCALES[ROI]

scores = tut.landscape(dat, ROI)
order = np.argsort(scores)[::-1]

rand = np.random.choice(order[scale:], scale)

Rt, _ = tut.specific_static_rdm(dat, ROI, order[:scale], tstart=100, tend=350)
fig, ax = plt.subplots(1, 1, figsize=(2,2))
sns.heatmap(Rt, ax=ax, square=True, vmin=vmin, vmax=vmax, cbar=False)
ed = tut.ED2(Rt)
ax.set_title(f'First {scale}, {ed:.02f}')

Rt, _ = tut.specific_static_rdm(dat, ROI, order[scale:2*scale], tstart=100, tend=350)
fig, ax = plt.subplots(1, 1, figsize=(2,2))
sns.heatmap(Rt, ax=ax, square=True, vmin=vmin, vmax=vmax, cbar=False)
ed = tut.ED2(Rt)
ax.set_title(f'Second {scale}, {ed:.02f}')

Rt, _ = tut.specific_static_rdm(dat, ROI, order[2*scale:3*scale], tstart=100, tend=350)
fig, ax = plt.subplots(1, 1, figsize=(2,2))
sns.heatmap(Rt, ax=ax, square=True, vmin=vmin, vmax=vmax, cbar=False)
ed = tut.ED2(Rt)
ax.set_title(f'Third {scale}, {ed:.02f}')

Rt, _ = tut.specific_static_rdm(dat, ROI, rand, tstart=100, tend=350)
fig, ax = plt.subplots(1, 1, figsize=(2,2))
sns.heatmap(Rt, ax=ax, square=True, vmin=vmin, vmax=vmax, cbar=False)
ed = tut.ED2(Rt)
ax.set_title(f'Random {scale}, {ed:.02f}')