In [None]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns

from scipy.stats import wasserstein_distance
from sklearn.metrics.pairwise import cosine_distances
from sklearn.manifold import MDS
from scipy.spatial.distance import pdist, squareform, cdist

### from https://humaticlabs.com/blog/mrdmd-python/
import numpy as np
import matplotlib.pyplot as plt
from numpy import dot, multiply, diag, power
from numpy import pi, exp, sin, cos
from numpy.linalg import inv, eig, pinv, solve
from scipy.linalg import svd, svdvals
from math import floor, ceil # python 3.x

In [None]:
DATA_DIR = '../../datasets/NNN/'
dat = pd.read_pickle(os.path.join(DATA_DIR, ('face_roi_data.pkl')))
print(f'Unique face ROIs: {list(dat['roi'].unique())}')

In [None]:
ROI = 'Unknown_19_F'
PVAL = 0.05

# load in per-image psth
sig = dat[dat['p_value'] < PVAL]
df = sig[sig['roi'] == ROI]
X = np.stack(df['img_psth'].to_numpy())
print(f'Loaded unit-level data for each image. Shape:', X.shape, '---> (units, time points, images)')

# center the data per unit and per image
X_centered = X - X.mean(axis=1, keepdims=True)
print('Centered shape:', X_centered.shape)

# concatenate each image's time series
# Xsub = X_centered.reshape(X_centered.shape[0], -1) 

# average across images
# Xsub = np.mean(X_centered, axis=2)

# look only at localizer image responses
Xsub = X_centered[:, :, 1000:]

print('Final data shape:', Xsub.shape)
# stagger for DMD
# X_ = Xsub[:, :-1]
# Y_ = Xsub[:, 1:]

In [None]:
img_idx = 1010
_x = X[:, :, img_idx]
fig,ax = plt.subplots(1,1, figsize=(10,5))
sns.heatmap(_x, cbar=False, square=True, ax=ax)
ax.set_title(f'Unit responses for image {img_idx}')
print(_x.shape)

In [None]:
img_sets = {
    "all_faces":  np.arange(1000, 1024),
    "all_bodies": np.concatenate([
        np.arange(1025, 1031),
        np.arange(1042, 1048),
        np.arange(1049, 1061)
    ]),
    "all_objects": np.concatenate([
        np.arange(1024,1025),
        np.arange(1031, 1042),
        np.arange(1048,1049),
        np.arange(1061, 1072)
    ])
}

IMG_DIR = '../../datasets/NNN/NSD1000_LOC/'


for cat, indices in img_sets.items():
    print(len(indices))
    fig,axes = plt.subplots(1, min(len(indices), 33), figsize=(20,5))
    axes = axes.ravel()
    for idx, imgid in enumerate(indices):
        subpath = f'MFOB{(imgid-1000+1):03d}.bmp'
    
        ax = axes[idx]
        img_pth = os.path.join(IMG_DIR, subpath)
        img = mpimg.imread(img_pth)
        ax.imshow(img)
        ax.set_title(imgid)
        ax.set_axis_off()
    plt.show()

# build per-category local index arrays
cat_local = {
    k: (v - 1000)[(v - 1000 >= 0) & (v - 1000 < X.shape[2])]
    for k, v in img_sets.items()
}
cat_local

for cat, indices in img_sets.items():
    Xs = X_centered[:, :, indices].mean(axis=0)
    print(Xs.shape)
    fig,ax = plt.subplots(1,1)
    sns.lineplot(Xs, ax=ax, legend=False)
    ax.set_title(cat)
    plt.show()

In [None]:
def dmd(X_, Y_, r):
    # X: (units, time)
    # X_ = X[:, :-1]
    # Y_ = X[:,  1:]
    U,S,Vt = np.linalg.svd(X_, full_matrices=False)
    U_r, S_r, V_r = U[:,:r], S[:r], Vt[:r,:].T
    Atilde = (U_r.T @ Y_) @ (V_r * (1.0/S_r))
    lam, W = np.linalg.eig(Atilde)
    Phi = Y_ @ (V_r * (1.0/S_r)) @ W
    return lam, Phi, (U_r,S_r,V_r,Atilde)

def delay_embed(Z, L):
    # Z: (k, T), returns (k*L, T-L+1) stacked windows within an image
    T = Z.shape[1]
    return np.vstack([Z[:, s:T-L+s+1] for s in range(L)])

def mode_amplitudes(Phi, W_img):
    # initial coefficients for image i (at the first embedded time)
    return np.linalg.lstsq(Phi, W_img[:, 0], rcond=None)[0]

def dmd(X, Y, truncate=None):
    '''
    version created by Robert Taylor
    for more info: https://humaticlabs.com/blog/dmd-python/

    in this funciton, truncate == r
    '''
    if truncate == 0:
        # return empty vectors
        mu = np.array([], dtype='complex')
        Phi = np.zeros([X.shape[0], 0], dtype='complex')
    else:
        U2,Sig2,Vh2 = svd(X, False) # SVD of input matrix
        r = len(Sig2) if truncate is None else truncate # rank truncation
        U = U2[:,:r]
        Sig = diag(Sig2)[:r,:r]
        V = Vh2.conj().T[:,:r]
        Atil = dot(dot(dot(U.conj().T, Y), V), inv(Sig)) # build A tilde
        mu,W = eig(Atil)
        Phi = dot(dot(dot(Y, V), inv(Sig)), W) # build DMD modes
    return mu, Phi

In [None]:
# ALL IMAGE (NSD) DMD
Xsub = X_centered[:, :, :1000]
Xi = Xsub.reshape(Xsub.shape[0], -1) 
Xi_ = Xi[:, :-1]
Yi_ = Xi[:,  1:]
lam, Phi, _ = dmd(Xi_, Yi_, r=20)

# plot real and imaginary spaital modes
fig,axes = plt.subplots(1,2, figsize=(10,4))

ax = axes[0]

# plot unit circle
p = np.linspace(0, np.pi/2, 100)
ax.plot(np.cos(p), np.sin(p), c="k")

ax.scatter(lam.real, lam.imag, color='black')
ax.set_aspect("equal")
ax.set_title('wrt unit circle')

ax = axes[1]
ax.scatter(lam.real, lam.imag, color='black')
ax.set_title('DMD on all 1000 NSD images')

plt.show()

In [None]:
# ALL IMAGE (LOCALIZER) DMD
Xsub = X_centered[:, :, 1000:]
Xi = Xsub.reshape(Xsub.shape[0], -1) 
Xi_ = Xi[:, :-1]
Yi_ = Xi[:,  1:]
lam, Phi, _ = dmd(Xi_, Yi_, r=20)

# plot real and imaginary spaital modes
fig,axes = plt.subplots(1,2, figsize=(10,4))

ax = axes[0]

# plot unit circle
p = np.linspace(0, np.pi/2, 100)
ax.plot(np.cos(p), np.sin(p), c="k")

ax.scatter(lam.real, lam.imag, color='black')
ax.set_aspect("equal")
ax.set_title('wrt unit circle')

ax = axes[1]
ax.scatter(lam.real, lam.imag, color='black')
ax.set_title('DMD on all LOCALIZER images')

plt.show()

In [None]:
### TUTORIAL FROM https://humaticlabs.com/blog/mrdmd-python/
def svht(X, sv=None):
    # svht for sigma unknown
    m,n = sorted(X.shape) # ensures m <= n
    beta = m / n # ratio between 0 and 1
    if sv is None:
        sv = svdvals(X)
    sv = np.squeeze(sv)
    omega_approx = 0.56 * beta**3 - 0.95 * beta**2 + 1.82 * beta + 1.43
    return np.median(sv) * omega_approx

Xsub = X_centered[:, :, 1010]
Xi_ = Xsub[:, :-1]
Yi_ = Xsub[:,  1:]

# determine rank-reduction
sv = svdvals(Xi_)
tau = svht(Xi_, sv=sv)
r = sum(sv > tau)

fig,ax = plt.subplots(1,1)
sns.scatterplot(x=range(len(sv)), y=sv, ax=ax)
ax.axhline(tau, color='red', linestyle='dashed', linewidth=0.75)
ax.set_title(f'Optimal rank for rank reduction: r={r}')
plt.show()

In [None]:
t = np.arange(X_centered.shape[1])

# do dmd
mu,Phi = dmd(Xi_, Yi_, r)

# compute time evolution
b = dot(pinv(Phi), Xi_[:,0])
Vand = np.vander(mu, len(t), True)
Psi = (Vand.T * b).T


# 1) initial amplitudes (b): solve Phi @ b ≈ X[:,0]
b, *_ = np.linalg.lstsq(Phi, Xi_[:, 0], rcond=None)  # (r,)

# 2) time dynamics (Psi): Psi[k,t] = b[k] * mu[k]**t
Psi = (b[:, None]) * np.power(mu[:, None], t[None, :])  # shape (r, Trec)

# 3) reconstruction
Xhat = Phi @ Psi

Xhat = Xhat.real

fig,axes = plt.subplots(2,1, figsize=(12,10))
ax = axes[0]
sns.heatmap(Xi_, label="data", cbar=False, square=True, ax=ax)
ax.set_axis_off()
ax.set_title('Original')

ax = axes[1]
sns.heatmap(Xhat, label="DMD recon", cbar=False, square=True, ax=ax)
ax.set_axis_off()
ax.set_title('Reconstructed')
plt.show()

In [None]:
def mrdmd(D, level=0, bin_num=0, offset=0, max_levels=7, max_cycles=2, do_svht=True):
    """
    Compute the multi-resolution DMD on the dataset `D`, returning a list of nodes
    in the hierarchy. Each node represents a particular "time bin" (window in time) at
    a particular "level" of the recursion (time scale). The node is an object consisting
    of the various data structures generated by the DMD at its corresponding level and
    time bin. The `level`, `bin_num`, and `offset` parameters are for record keeping 
    during the recursion and should not be modified unless you know what you are doing.
    The `max_levels` parameter controls the maximum number of levels. The `max_cycles`
    parameter controls the maximum number of mode oscillations in any given time scale 
    that qualify as "slow". The `do_svht` parameter indicates whether or not to perform
    optimal singular value hard thresholding.

    More info here: https://humaticlabs.com/blog/mrdmd-python/
    """

    # 4 times nyquist limit to capture cycles
    nyq = 8 * max_cycles

    # time bin size
    bin_size = D.shape[1]
    if bin_size < nyq:
        return []

    # extract subsamples 
    step = floor(bin_size / nyq) # max step size to capture cycles
    _D = D[:,::step]
    X = _D[:,:-1]
    Y = _D[:,1:]

    # determine rank-reduction
    if do_svht:
        _sv = svdvals(_D)
        tau = svht(_D, sv=_sv)
        r = sum(_sv > tau)
    else:
        r = min(X.shape)

    # compute dmd
    mu,Phi = dmd(X, Y, r)

    # frequency cutoff (oscillations per timestep)
    rho = max_cycles / bin_size

    # consolidate slow eigenvalues (as boolean mask)
    slow = (np.abs(np.log(mu) / (2 * pi * step))) <= rho
    n = sum(slow) # number of slow modes

    # extract slow modes (perhaps empty)
    mu = mu[slow]
    Phi = Phi[:,slow]

    if n > 0:

        # vars for the objective function for D (before subsampling)
        Vand = np.vander(power(mu, 1/step), bin_size, True)
        P = multiply(dot(Phi.conj().T, Phi), np.conj(dot(Vand, Vand.conj().T)))
        q = np.conj(diag(dot(dot(Vand, D.conj().T), Phi)))

        # find optimal b solution
        b_opt = solve(P, q).squeeze()

        # time evolution
        Psi = (Vand.T * b_opt).T

    else:

        # zero time evolution
        b_opt = np.array([], dtype='complex')
        Psi = np.zeros([0, bin_size], dtype='complex')

    # dmd reconstruction
    D_dmd = dot(Phi, Psi)   

    # remove influence of slow modes
    D = D - D_dmd

    # record keeping
    node = type('Node', (object,), {})()
    node.level = level            # level of recursion
    node.bin_num = bin_num        # time bin number
    node.bin_size = bin_size      # time bin size
    node.start = offset           # starting index
    node.stop = offset + bin_size # stopping index
    node.step = step              # step size
    node.rho = rho                # frequency cutoff
    node.r = r                    # rank-reduction
    node.n = n                    # number of extracted modes
    node.mu = mu                  # extracted eigenvalues
    node.Phi = Phi                # extracted DMD modes
    node.Psi = Psi                # extracted time evolution
    node.b_opt = b_opt            # extracted optimal b vector
    nodes = [node]

    # split data into two and do recursion
    if level < max_levels:
        split = ceil(bin_size / 2) # where to split
        nodes += mrdmd(
            D[:,:split],
            level=level+1,
            bin_num=2*bin_num,
            offset=offset,
            max_levels=max_levels,
            max_cycles=max_cycles,
            do_svht=do_svht
            )
        nodes += mrdmd(
            D[:,split:],
            level=level+1,
            bin_num=2*bin_num+1,
            offset=offset+split,
            max_levels=max_levels,
            max_cycles=max_cycles,
            do_svht=do_svht
            )

    return nodes

def stitch(nodes, level):
    
    # get length of time dimension
    start = min([nd.start for nd in nodes])
    stop = max([nd.stop for nd in nodes])
    t = stop - start

    # extract relevant nodes
    nodes = [n for n in nodes if n.level == level]
    nodes = sorted(nodes, key=lambda n: n.bin_num)
    
    # stack DMD modes
    Phi = np.hstack([n.Phi for n in nodes])
    
    # allocate zero matrix for time evolution
    nmodes = sum([n.n for n in nodes])
    Psi = np.zeros([nmodes, t], dtype='complex')
    
    # copy over time evolution for each time bin
    i = 0
    for n in nodes:
        _nmodes = n.Psi.shape[0]
        Psi[i:i+_nmodes,n.start:n.stop] = n.Psi
        i += _nmodes
    
    return Phi,Psi

In [None]:
Xsub = X_centered[:, :, 1010]
D = Xsub
nodes = mrdmd(D)

Phi0,Psi0 = stitch(nodes, 0)
Phi1,Psi1 = stitch(nodes, 1)
Phi2,Psi2 = stitch(nodes, 2)

levels = sorted({nd.level for nd in nodes})
D_hat_full = sum(stitch(nodes, l)[0] @ stitch(nodes, l)[1] for l in levels)

D_iter = None
for l in levels:
    fig, axes = plt.subplots(1,2, figsize=(15,5))
    ax = axes[1]
    _d = stitch(nodes, l)[0] @ stitch(nodes, l)[1]
    sns.heatmap(_d.real, cbar=False, square=True, ax=ax)
    ax.set_axis_off()
    
    if D_iter is None:
        D_iter = _d.real
    else:
        D_iter = D_iter + _d.real
        
    ax = axes[0]
    sns.heatmap(D_iter, cbar=False, square=True, ax=ax)
    ax.set_axis_off()
    plt.show()

i = 0
plt.plot(D[i], label='data')
plt.plot(D_hat_full[i].real, '--', label='MRDMD recon')
plt.legend(); plt.show()

In [None]:
fix,ax = plt.subplots(1,1)

# plot unit circle
p = np.linspace(-np.pi/2, np.pi/2, 100)
ax.plot(np.cos(p), np.sin(p), c="k")

for l in levels:
    ns = [n for n in nodes if n.level == l]
    eigs = []
    for n in ns:
        eigs.extend(n.mu)
    eigs = np.array(eigs)
    ax.scatter(eigs.real, eigs.imag, alpha=0.25, label=l)

ax.set_aspect("equal")
ax.set_title('wrt unit circle')
ax.legend()

In [None]:
ns = [n for n in nodes if n.level==1]
ns[0].mu

In [None]:
# # PER IMAGE DMD for LOCALIZER IMAGES
# fig,ax = plt.subplots(1,1)
# for i in range(72):
#     Xi = Xsub[:, :, i]
#     X_ = Xi[:, :-1]; Y_ = Xi[:,  1:]
#     lam, Phi, _ = dmd(X_, Y_, r=20)
#     ax.scatter(lam.real, lam.imag, alpha=0.5)

# Xi = Xsub.reshape(Xsub.shape[0], -1) 
# X_ = Xi[:, :-1]; Y_ = Xi[:,  1:]
# lam, Phi, _ = dmd(X_, Y_, r=20)
# ax.scatter(lam.real, lam.imag, color='black')
# plt.show()

# =================================== SHARED BASIS ===========================
k = 40 # PCA components
L = 7 # window size
r = 20 # DMD modes

Xc = X_centered[:, :, 1000:]
Xpool = Xc.reshape(Xc.shape[0], -1)          # (units, time*images)

# dim reduction using PCA
U, S, Vt = np.linalg.svd(Xpool, full_matrices=False)
P = U[:, :k].T

W_list = []
for i in range(Xc.shape[2]):
    Zi = P @ Xc[:, :, i]                      # (k, T)
    Wi = delay_embed(Zi, L)                   # (k*L, T_eff)
    W_list.append(Wi)

Wpool = np.concatenate(W_list, axis=1)        # (k*L, sum T_eff)
X_, Y_ = Wpool[:, :-1], Wpool[:, 1:]

# shared-basis DMD
lam, Phi, cache = dmd(X_, Y_, r) 

# look at shared mode amplitudes across images
B = np.stack([mode_amplitudes(Phi, Wi) for Wi in W_list], axis=1)  # (r, 72)

F = np.abs(B).T        # (images × r)
D = squareform(pdist(F, metric='correlation'))
sns.heatmap(D)

# category centroids in amplitude space
def centroid(F, indices):
    mask = np.isin(range(72), indices)
    return np.median(F[mask], axis=0), np.where(mask)[0]

centroids = {}
members = {}
for name, ids in cat_local.items():
    c, m = centroid(F, ids)
    centroids[name] = c
    members[name] = m

cat_lams = {}
for name, ids in cat_local.items():
    if len(ids) == 0: 
        continue
    W_cat = np.concatenate([W_list[np.where(local_idx == i)[0][0]] for i in ids], axis=1)
    Xc_, Yc_ = W_cat[:, :-1], W_cat[:, 1:]
    lam_c, _, _ = dmd(Xc_, Yc_, r)
    cat_lams[name] = lam_c

In [None]:
fig, ax = plt.subplots(figsize=(6,6))
theta = np.linspace(0, 2*np.pi, 400)
ax.plot(np.cos(theta), np.sin(theta), linestyle="--", linewidth=1)  # unit circle
for name, lam in cat_lams.items():
    ax.scatter(lam.real, lam.imag, label=name)
ax.axhline(0, linewidth=0.5); ax.axvline(0, linewidth=0.5)
ax.set_aspect("equal", "box"); ax.set_xlabel("Re(λ)"); ax.set_ylabel("Im(λ)")
ax.set_title("Category-wise DMD spectra")
ax.legend()

plt.figure(figsize=(6,6))
# 2) Category structure in shared amplitude space (centroids + images via MDS)
D = cosine_distances(F)
Z = MDS(n_components=2, dissimilarity="precomputed", random_state=0).fit_transform(D)
plt.scatter(Z[:,0], Z[:,1], alpha=0.6)
for name, idxs in members.items():
    cx, cy = Z[idxs].mean(axis=0)
    plt.scatter([cx],[cy], marker="X", s=120, label=name)
plt.title("Images embedded by shared-mode usage (MDS)")
plt.xlabel("dim 1"); plt.ylabel("dim 2")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
all_ids = np.concatenate(list(img_sets.values()))
# Suppose X is ordered by global IDs 1000..1071; map global IDs to local indices:
# local index = id - 1000
local_idx = all_ids - 1000
local_idx = local_idx[(local_idx >= 0) & (local_idx < X.shape[2])]
local_idx

In [None]:
# MEAN RESPONSE SVD
U, S, Vh = np.linalg.svd(X_, full_matrices=False)
print('SVD shapes:', U.shape, S.shape, Vh.shape)

# choose r that explains 95% of energy
energy = (S**2).cumsum() / (S**2).sum()
r = np.searchsorted(energy, 0.95) + 1   # 95% example

# Truncate
U_r = U[:, :r]          # (m, r)
S_r = S[:r]             # (r,)
V_r = Vh[:r, :].T       # (n, r)  <-- columns of V

print('Truncated:', U_r.shape, S_r.shape, V_r.shape)

# Atilde = U_r^T Y V_r S_r^{-1}
Atilde = (U_r.T @ Y_) @ (V_r * (1.0 / S_r))   # column-wise divide by S_r

# Eigenvalues (and optionally modes)
eigvals, W = np.linalg.eig(Atilde)
Phi = Y_ @ (V_r * (1.0 / S_r)) @ W   

# plot on unit circle
plt.scatter(eigvals.real, eigvals.imag, s=20)
th = np.linspace(0, 2*np.pi, 400)
plt.plot(np.cos(th), np.sin(th), 'k--', lw=1)
plt.axhline(0, lw=0.5); plt.axvline(0, lw=0.5)
plt.gca().set_aspect('equal', 'box')
plt.xlabel('Re(λ)'); plt.ylabel('Im(λ)')
plt.title('DMD eigenvalues')
plt.show()

# visualize the modes
sns.heatmap(np.real(Phi))
plt.xlabel('Mode'); plt.ylabel('Neuron')
plt.title('Spatial pattern of DMD modes')

# I DONT REALLY SEE ANYTHING?

In [None]:
# PER IMAGE DMD
r = 20
dt = 1/1000

LAM, FREQ = [], []
for imgid in range(72):
    Xi = Xsub[:, :, imgid]
    
    # PCA on data
    k = 50
    U,S,Vt = np.linalg.svd(Xi, full_matrices=False)
    
    Xiproj = U[:,:k] @ np.diag(S[:k]) @ Vt[:k]
    # print(Xi.shape, Xiproj.shape)
    
    # DMD with r modes
    lam, Phi, _ = dmd(Xiproj, r)
    
    # transform parameters into continuous variables
    omega = np.log(lam)/dt
    freq = np.imag(omega)/(2*np.pi)
    decay = np.real(omega)

    LAM.append(lam)
    FREQ.append(freq)


In [None]:
# VISUALISE MODE SIMILARITY ACROSS IMAGES

I = len(LAM)
D = np.zeros((I,I))
for i in range(I):
    fi = np.sort(FREQ[i])
    wi = None # if b_list is None else np.abs(b_list[i])/np.sum(np.abs(b_list[i]))
    for j in range(i+1, I):
        fj = np.sort(FREQ[j])
        wj = None # if b_list is None else np.abs(b_list[j])/np.sum(np.abs(b_list[j]))
        D[i,j] = D[j,i] = wasserstein_distance(fi, fj, u_weights=wi, v_weights=wj)

sns.heatmap(D, cmap="viridis", square=True)
plt.title("Spectral distance between images"); plt.show()

embed = MDS(n_components=2, dissimilarity='precomputed', random_state=0).fit_transform(D)
plt.scatter(embed[:,0], embed[:,1])
# add labels if you have categories
plt.title("Images embedded by DMD spectral similarity"); plt.show()