In [None]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
DATA_DIR = '../../datasets/NNN/'
dat = pd.read_pickle(os.path.join(DATA_DIR, ('face_roi_data.pkl')))
print(f'Unique face ROIs: {list(dat['roi'].unique())}')

In [None]:
ROI = 'Unknown_19_F'
PVAL = 0.05

# load in per-image psth
sig = dat[dat['p_value'] < PVAL]
df = sig[sig['roi'] == ROI]
X = np.stack(df['img_psth'].to_numpy())
print(f'Loaded unit-level data for each image. Shape:', X.shape, '---> (units, time points, images)')

# center the data per unit and per image
X_centered = X - X.mean(axis=1, keepdims=True)
print('Centered shape:', X_centered.shape)

# concatenate each image's time series
# Xsub = X_centered.reshape(X_centered.shape[0], -1) 

# average across images
Xsub = np.mean(X_centered, axis=2)

print('Final data shape:', Xsub.shape)
# stagger for DMD
X_ = Xsub[:, :-1]
Y_ = Xsub[:, 1:]

In [None]:
# Econ SVD
U, S, Vh = np.linalg.svd(X_, full_matrices=False)
print('SVD shapes:', U.shape, S.shape, Vh.shape)

# choose r that explains 95% of energy
energy = (S**2).cumsum() / (S**2).sum()
r = np.searchsorted(energy, 0.95) + 1   # 95% example

# Truncate
U_r = U[:, :r]          # (m, r)
S_r = S[:r]             # (r,)
V_r = Vh[:r, :].T       # (n, r)  <-- columns of V

print('Truncated:', U_r.shape, S_r.shape, V_r.shape)

# Atilde = U_r^T Y V_r S_r^{-1}
Atilde = (U_r.T @ Y_) @ (V_r * (1.0 / S_r))   # column-wise divide by S_r

# Eigenvalues (and optionally modes)
eigvals, W = np.linalg.eig(Atilde)
Phi = Y_ @ (V_r * (1.0 / S_r)) @ W   

# plot on unit circle
plt.scatter(eigvals.real, eigvals.imag, s=20)
th = np.linspace(0, 2*np.pi, 400)
plt.plot(np.cos(th), np.sin(th), 'k--', lw=1)
plt.axhline(0, lw=0.5); plt.axvline(0, lw=0.5)
plt.gca().set_aspect('equal', 'box')
plt.xlabel('Re(λ)'); plt.ylabel('Im(λ)')
plt.title('DMD eigenvalues')
plt.show()

# visualize the modes
sns.heatmap(np.real(Phi))
plt.xlabel('Mode'); plt.ylabel('Neuron')
plt.title('Spatial pattern of DMD modes')

In [None]:
dt = 1e-3  # adjust to your sampling interval
freqs = np.imag(np.log(eigvals)) / (2*np.pi*dt)
plt.stem(freqs)
