In [None]:
import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt
import scipy.stats as stt
from sklearn.decomposition import PCA

In [None]:
DATA_DIR = '../../datasets/NNN/face_roi_data.pkl'
dat = pd.read_pickle(DATA_DIR)

print(f'Unique face ROIs: {list(dat['monkey'].unique())}')

In [None]:
ROI = 'MF1_7_F'
roi_dat = dat[(dat['roi']==ROI) & (dat['p_value']<0.05)].reset_index(drop=True)
X = np.stack(roi_dat['img_psth'])
# X = X[:, :, 1000:]
# X = stt.zscore(X, axis=1)
# X = np.nan_to_num(X)

# Find patterns shared across *time points*
U, T, I = X.shape
Y = X.transpose(0, 2, 1).reshape(U * I, T)          # (samples=unit×image, features=time)
Yc = Y - Y.mean(axis=0, keepdims=True)              # center across time features

pca = PCA().fit(Yc)

plt.figure(figsize=(10,5))
plt.plot(np.cumsum(pca.explained_variance_ratio_) * 100, linewidth=2)
plt.xlabel("Number of time-PCs"); plt.ylabel("Cumulative explained variance (%)")
plt.title("Explained variance by number of time PCs"); plt.grid(True, alpha=0.3)
plt.tight_layout(); plt.show()

In [None]:
k = 24
C = pca.components_[:k]          # (k, T)
fig, axes = plt.subplots(6, 4, figsize=(12, 6), sharex=True, sharey=True)
axes = axes.ravel()
for i in range(k):
    axes[i].plot(C[i])
    axes[i].set_title(f"PC{i+1}")
plt.tight_layout(); plt.show()

In [None]:
# 2) Reconstruction quality vs #PCs (overall and per-timepoint)
T = pca.components_.shape[1]
evr = pca.explained_variance_ratio_
V = pca.components_                 # (ncomp, T)

def recon_stats(n):
    S = Yc @ V[:n].T               # scores (samples × n)
    Yh = S @ V[:n]                  # recon (samples × T)
    ss_tot = np.sum(Yc**2, axis=0)  # per-time total var
    ss_res = np.sum((Yc - Yh)**2, axis=0)
    r2_time = 1 - ss_res/ss_tot     # per-time R^2
    r2_overall = np.sum(evr[:n])    # cumulative variance explained
    return r2_overall, r2_time

fig,ax = plt.subplots(1,1)
for r in range(1, 11):
    r2_overall_12, r2_time_12 = recon_stats(r)
    # plt.figure(figsize=(6,3)); plt.plot(np.cumsum(evr)*100); plt.axvline(12, ls='--', c='k'); plt.ylabel('% var'); plt.xlabel('#PCs'); plt.tight_layout(); plt.show()
    ax.plot(r2_time_12, label=r); 
ax.set_title(f'Per-time R² with {r} PCs'); 
ax.legend()
plt.xlabel('Time'); 
plt.ylabel('R²'); 
plt.tight_layout(); 
plt.show()
