In [None]:
import os, fsspec, pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.cluster.hierarchy import linkage, leaves_list
from scipy.spatial.distance import pdist

import manifold_dynamics.tuning_utils as tut
import manifold_dynamics.paths as pth
fs = fsspec.filesystem('s3')

In [None]:
### set up roi_data.pkl on s3
datadir = os.path.join(pth.PROCESSED, 'single-session-raster')
fnames = [f for f in fs.ls(datadir) if 'F' in f.split('.')[-2]]

fnm = fnames[1] # i think this file is particularly small
with fs.open(os.path.join('s3://', fnm), 'rb') as f:
    out = np.load(f)
    print(out.shape)
    avg = np.nanmean(out, axis=3)
    print(avg.shape)

x = pd.read_pickle('../../datasets/NNN/face_roi_data.pkl')

this_roi = x[x['roi']=='Unknown_19_F']
print(len(this_roi))

raster = np.stack(this_roi['img_psth'])
print(raster.shape)

In [None]:
DATA_DIR = '../../datasets/NNN/'
IMG_DIR = '../../datasets/NNN/NSD1000_LOC'

CAT = 'body'
dat = pd.read_pickle(os.path.join(DATA_DIR, (f'{CAT}_roi_data.pkl')))
ROI_LIST = list(dat['roi'].unique())
print(f'Unique face ROIs: {ROI_LIST}')

with open(f'../../datasets/NNN/{CAT}_mins.pkl', 'rb') as f:
    mins = pickle.load(f)
    
SCALES = {k: v[0] for k,v in mins.items()}

In [None]:
ROI = 'Unknown_19_F' # Unknown_19_F, MF1_7_F, MF1_8_F, MF1_9_F, AF3_18_F, MB1_3_B
MODE = 'top'
scale = SCALES[ROI]
vmin, vmax= 0, 1
di = {-1: 'shuff', 1072: 'all'}

Rt, _ = tut.static_rdm(dat, ROI, mode=MODE, scale=scale,
                          tstart=0, tend=450)
Rs, _ = tut.static_rdm(dat, ROI, mode=MODE, scale=-1,
                          tstart=0, tend=450)
# Ra, _ = tut.static_rdm(dat, ROI, mode=MODE, scale=1072,
#                           tstart=0, tend=450)
R = [Rt, Rs] #, Ra]

fig, axes = plt.subplots(1, 3, figsize=(6, 2))
for i, Rx in enumerate(R):
    ax = axes[i]
    sns.heatmap(Rx[100:300, 100:300], ax=ax, square=True, vmin=vmin, vmax=vmax, cbar=False)
# ax.set_axis_off()

triu = np.triu_indices_from(Rt, k=1)
Rv = Rt[triu]
print(Rv.shape)

# (101025,)

In [None]:
### MDS, EIGENSPECTRA PLOTS (NOT VERY USEFUL)

from sklearn.manifold import MDS
for Rx in R:
    rx = Rx[100:350, 100:350]
    mds = MDS(dissimilarity='precomputed', n_components=3, random_state=0)
    T_embed = mds.fit_transform(rx)  # (450, 3)

    # MDS plot
    
    t = np.arange(T_embed.shape[0])
    
    sns.scatterplot(
        x=T_embed[:, 0],
        y=T_embed[:, 1],
        hue=t,
        palette='viridis',
        legend=False
    )
    plt.title('temporal trajectory')
    plt.show()

    fig = plt.figure(figsize=(6, 5))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot(T_embed[:, 0], T_embed[:, 1], T_embed[:, 2], color='black', linewidth=1)
    
    # color by time for clarity
    sc = ax.scatter(
        T_embed[:, 0],
        T_embed[:, 1],
        T_embed[:, 2],
        c=np.arange(T_embed.shape[0]),
        cmap='viridis',
        s=15
    )
    plt.show()

    # eigenspectra plot
    S = -0.5 * rx**2
    lam = np.linalg.eigvalsh(S)
    plt.plot(lam[::-1][5:20])
    plt.title('Eigenspectra')
    plt.show()

    # diag plot
    plt.plot(np.diag(rx, k=1))
    plt.title('diag k=1')
    plt.show()

In [None]:
### PLOT CONSECUTIVE 'TOP-k' RDMS
# ['MB1_3_B', 'MB2_20_B', 'Unknown_20_B', 'Unknown_6_B', 'MB2_21_B', 'Unknown_9_B', 'Unknown_23_B', 'MB1_8_B', 'AB3_18_B', 'MB3_12_B', 'AB3_12_B', 'AB3_17_B', 'Unknown_27_B']
ROI = 'MB1_8_B' # Unknown_19_F, MF1_7_F, MF1_8_F, MF1_9_F, AF3_18_F, MB1_3_B, #### MB2_20_B, MB3_12_B, AB3_18_B, MB1_8_B, Unknown_6_B
scale = 20 #SCALES[ROI]
tstart=100
tend=350
vmin = 0
vmax=1

rng = np.random.default_rng(0)
scores = tut.landscape(dat, ROI)
order = np.argsort(scores)[::-1]

# pick bins
bins = [
    ('first',  order[:scale]),
    ('second', order[scale:2 * scale]),
    ('third',  order[2 * scale:3 * scale]),
]

# random matched-size subset from the remaining images (avoid top bins)
pool = order[3*scale:]
rand_idx = rng.choice(pool, size=scale, replace=False)
bins.append(('random', rand_idx))

# compute rdms + eds
results = []
for name, idx in bins:
    R, _ = tut.specific_static_rdm(dat, ROI, idx, tstart=tstart, tend=tend)
    ed = tut.ED2(R)
    results.append((name, idx, R, ed))

# plot
fig, axes = plt.subplots(1, 4, figsize=(9, 2.5), constrained_layout=True)
for ax, (name, idx, R, ed) in zip(axes, results):
    sns.heatmap(R, ax=ax, square=True, vmin=vmin, vmax=vmax, cbar=False)
    ax.set_title(f'{name} {scale}\ned={ed:.2f}')
    ax.set_xlabel('')
    ax.set_ylabel('')
    xticks = np.arange(0, tend - tstart + 1, 50)
    ax.set_xticks(xticks)
    ax.set_xticklabels(np.arange(tstart, tend + 1, 50))
    ax.set_yticks([])

fig.suptitle(f'{ROI} | win ({tstart}-{tend})', y=1.1)

In [None]:
idx = rand_idx # order[:scale]
windows = [
    ('first',  slice(100, 150)),
    ('second', slice(150, 250)),
    ('third', slice(300, 350)),
]

# compute rdms for each window
results = []
vmin = -np.inf
vmax = np.inf
for name, win in windows:
    R, _ = tut.time_avg_rdm(dat, ROI, window=win, images=idx)
    triu = np.triu_indices_from(R, k=1)
    vmin = max(vmin, np.nanmin(R[triu]))
    vmax = min(vmax, np.nanmax(R[triu]))
    ed = tut.ED2(R)
    results.append({'name': name, 'win': win, 'R': R, 'ed': ed})
W = len(results)

# compute an ordering per window (each as a reference)
for r in results:
    _, r['order'], _ = reorder_by_rowpattern(r['R'])

fig, axes = plt.subplots(W, W, figsize=(2.2 * W, 2.2 * W), constrained_layout=True)
axes = np.atleast_2d(axes)

for i in range(W):
    oi = results[i]['order']
    ref_name = results[i]['name']
    t0i, t1i = results[i]['win'].start, results[i]['win'].stop

    for j in range(W):
        ax = axes[i, j]
        Rj = results[j]['R']
        name_j = results[j]['name']
        t0j, t1j = results[j]['win'].start, results[j]['win'].stop

        R_re = Rj[np.ix_(oi, oi)]
        sns.heatmap(R_re, ax=ax, square=True, cbar=False)# vmin=0, vmax=1)

        # annotate edges only (otherwise it's a wall of text)
        if i == 0:
            ax.set_title(f'{name_j}\n{t0j}:{t1j}', fontsize=10)
        if j == 0:
            ax.set_ylabel(f'ref: {ref_name}\n{t0i}:{t1i}', fontsize=10)

        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel('')
        ax.set_ylabel(ax.get_ylabel())  # keep ylabel on left edge only
fig.suptitle(f'{ROI}')#, y=1.1)

In [None]:
def reorder_by_rowpattern(M, *, method='average', metric='correlation', tie_break='stable'):
    '''
    cluster rows by their off-diagonal pattern and return a deterministic leaf order.

    tie_break:
      - 'stable': add a tiny deterministic jitter to break exact ties (recommended)
      - None: no tie-breaking (can yield different valid orders when distances tie)
    '''
    M = np.asarray(M)
    n = M.shape[0]
    if M.ndim != 2 or n != M.shape[1]:
        raise ValueError(f'expected square matrix; got {M.shape}')

    mask = ~np.eye(n, dtype=bool)
    X = M[mask].reshape(n, n - 1)  # each row without diagonal

    # replace nans so pdist doesn't propagate them
    # (if you have lots of nans, consider imputing more thoughtfully)
    X = np.where(np.isfinite(X), X, np.nanmean(X, axis=1, keepdims=True))

    if tie_break == 'stable':
        # deterministic jitter: same each run, breaks exact ties so leaves_list is stable
        eps = 1e-12
        jitter = eps * np.arange(X.size, dtype=float).reshape(X.shape)
        X = X + jitter

    Z = linkage(pdist(X, metric=metric), method=method, optimal_ordering=True)
    order = leaves_list(Z)
    M_re = M[np.ix_(order, order)]
    return M_re, order, Z