In [None]:
import os
import re
import h5py
import random
import numpy as np
import pandas as pd
import scipy.io
from scipy.cluster.hierarchy import linkage, dendrogram
from scipy.spatial.distance import pdist, squareform
from sklearn.manifold import MDS
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm

import utils

In [None]:
datadir = '../datasets/NNN/'
fnames = utils.fnames(datadir)
raster_data = pd.read_pickle(os.path.join(datadir, ('unit_data_full.pkl')))

In [None]:
labels = pd.read_csv(os.path.join(datadir, 'shared1000_HED.tsv'), sep='\t')
targets = ['animal']

mask = labels["HED_short"].str.lower().apply(
    lambda x: all(t in x for t in targets) if isinstance(x, str) else False
)
to_idx = labels[mask].index.tolist()

metric = 'correlation'
roi = 'MF1_9_F' #'MF1_9_F'
roi_data = raster_data[(raster_data['roi']==roi)]
stacked = np.stack(roi_data['img_psth'])

substacked = stacked[:, :, to_idx]
# average over units
averaged = np.mean(substacked, axis=0)

fig,ax = plt.subplots(1,1)

rdm = squareform(pdist(averaged, metric=metric))
sns.heatmap(rdm, cmap=sns.color_palette('Greys_r', as_cmap=True), vmax=1.5)
ax.set_title(f'{roi}: {'.'.join(targets)}')

ax.axvline(x=50, color='red', linestyle='--', linewidth=1)
ax.axhline(y=50, color='red', linestyle='--', linewidth=1)
plt.show()

In [None]:
raster_data['session'].unique()

In [None]:
single_units = raster_data[raster_data['unit_type']==1]
stacked = np.stack(single_units['img_psth'])
mean_per_unit = stacked.mean(axis=2)

df_long = pd.DataFrame(mean_per_unit).melt(var_name='time', value_name='response')
df_long['unit'] = np.repeat(np.arange(mean_per_unit.shape[0]), mean_per_unit.shape[1])

plt.subplots(1,1)
sns.lineplot(data=df_long, x='time', y='response', errorbar='sd', color='black')
plt.title('Single unit average time trace')
plt.tight_layout()
plt.show()

In [None]:
roi = 'MF1_8_F' #'MF1_9_F'
roi_data = raster_data[(raster_data['roi']==roi)]

stacked = np.stack(roi_data['img_psth'])
averaged = np.mean(stacked, axis=0) # shape (450,1072)
df = pd.DataFrame(averaged)  # shape (450, 1072)
df_long = df.melt(var_name='image', value_name='response')
df_long['time'] = np.tile(np.arange(df.shape[0]), df.shape[1])



fig,ax=plt.subplots(1,1)
sns.lineplot(
    data=df_long, x='time', y='response',
    errorbar='sd',
    color='black',
    ax=ax
)
ax.axvline(x=50, color='red', linestyle='--', linewidth=1)
plt.title('Average PSTH')
plt.tight_layout()
plt.show()

In [None]:
np.setdiff1d(np.arange(1000, 1072), np.concatenate([np.arange(1000,1024), np.arange(1025,1031), np.arange(1043,1049), np.arange(1051,1062)]))

In [None]:
img_sets = {'all images': np.arange(1000,1072), 
           'all faces': np.arange(1000,1024),
           'monkey faces':  np.concatenate([np.arange(1000,1006), np.arange(1009,1016)]),
           'human faces': np.concatenate([np.arange(1006,1009), np.arange(1016,1025)])}

# img_sets = {'all images': np.arange(1000,1072), 
#            'all nonfaces': np.arange(1025,1072),
#            'monkey bodies': np.concatenate([np.arange(1026,1031), np.arange(1043,1049)]),
#            'animal bodies': np.concatenate([np.arange(1026,1031), np.arange(1043,1049), np.arange(1051,1062)]),
#            'all objects': np.setdiff1d(np.arange(1000, 1072), np.concatenate([np.arange(1000,1024), np.arange(1025,1031), np.arange(1043,1049), np.arange(1051,1062)]))}

metric = 'correlation'
roi = 'MF1_8_F' #'MF1_9_F'
roi_data = raster_data[(raster_data['roi']==roi)]
stacked = np.stack(roi_data['img_psth'])

for k,v in img_sets.items():
    substacked = stacked[:, :, v]
    # average over units
    averaged = np.mean(substacked, axis=0)
    
    fig,ax = plt.subplots(1,1)

    # plot rdm
    rdm = squareform(pdist(averaged, metric=metric))
    sns.heatmap(rdm, cmap=sns.color_palette('Greys_r', as_cmap=True), vmax=1.5)
    ax.set_title(f'{k}')

    ax.axvline(x=50, color='red', linestyle='--', linewidth=1)
    ax.axhline(y=50, color='red', linestyle='--', linewidth=1)
    plt.show()

    # Embed into 3D and plot trajectory
    mds = MDS(n_components=3, dissimilarity='precomputed', random_state=0)
    coords = mds.fit_transform(rdm)               # [time, 3]
    T = coords.shape[0]
    ts = np.arange(T)

    # 3D plot
    fig = plt.figure(figsize=(6,5))
    # ax = fig.add_subplot()
    ax = fig.add_subplot(111, projection='3d')

    # Draw trajectory line
    ax.plot(coords[:,0], coords[:,1], coords[:,2], 
            linewidth=1)

    # Scatter with time-coded colors
    sc = ax.scatter(coords[:,0], coords[:,1], coords[:,2],
                    c=ts, cmap='viridis', s=12)

    # Optional: highlight a particular time (e.g., 50)
    t_mark = np.arange(90, 140)
    if (t_mark < T).all:
        ax.scatter(coords[t_mark,0], coords[t_mark,1], coords[t_mark,2],
                   s=80, edgecolor='red', facecolor='none')

    ax.set_title(f"3D trajectory – {k}")
    ax.set_xlabel("MDS-1"); ax.set_ylabel("MDS-2"); ax.set_zlabel("MDS-3")
    cbar = plt.colorbar(sc, ax=ax, pad=0.1)
    cbar.set_label("Time index")

    plt.tight_layout()
    plt.show()

In [None]:
metric = 'correlation'
roi = 'MF1_8_F' #'MF1_9_F'
roi_data = raster_data[(raster_data['roi']==roi)]
stacked = np.stack(roi_data['img_psth'])

for k,v in img_sets.items():
    substacked = stacked[:, :, v]
    # average over units
    X = np.mean(substacked, axis=0)
    
    Xz = X - X.mean(axis=1, keepdims=True)
    Xz /= (np.linalg.norm(Xz, axis=1, keepdims=True) + 1e-8)
    pcaZ = PCA(n_components=3, random_state=0).fit(Xz)
    print(f"{k} - After demeaning+unit-norm EVR: {pcaZ.explained_variance_ratio_[:3]} | cumulative: {np.sum(pcaZ.explained_variance_ratio_[:3])}")
    coords = pcaZ.fit_transform(Xz)
    
#     # do pca on the data
#     pca = PCA(n_components=3, random_state=0)
#     coords = pca.fit_transform(averaged) 
    
    evr = pca.explained_variance_ratio_
    cum = evr.cumsum()
    print(f"{k} — EVR per comp: {np.round(evr, 3)} | cumulative: {np.round(cum, 3)}")

    T = coords.shape[0]
    ts = np.arange(T)

    # 3D plot
    fig = plt.figure(figsize=(6,5))
    ax = fig.add_subplot(111, projection='3d')

    # Draw trajectory line
    ax.plot(coords[:,0], coords[:,1], coords[:,2], 
            linewidth=1)

    # Scatter with time-coded colors
    sc = ax.scatter(coords[:,0], coords[:,1], coords[:,2],
                    c=ts, cmap='viridis', s=12)

    # Optional: highlight a particular time (e.g., 50)
    t_mark = np.arange(90, 140)
    if (t_mark < T).all:
        ax.scatter(coords[t_mark,0], coords[t_mark,1], coords[t_mark,2],
                   s=80, edgecolor='red', facecolor='none')

    ax.set_title(f"3D trajectory – {k}")
    ax.set_xlabel("PC-1"); ax.set_ylabel("PC-2"); ax.set_zlabel("PC-3")
    cbar = plt.colorbar(sc, ax=ax, pad=0.1)
    cbar.set_label("Time index")

    plt.tight_layout()
    plt.show()

In [None]:
from IPython.display import clear_output
import time

metric = 'correlation'
roi = 'MF1_7_F' #'MF1_9_F'
roi_data = raster_data[(raster_data['roi']==roi)]
stacked = np.stack(roi_data['img_psth'])

for i in range(200):
    fig, ax = plt.subplots(1,1)
    substacked = stacked[:, i, 1000:1024]
    rdm = squareform(pdist(substacked.T, metric=metric))
    sns.heatmap(rdm, cmap=sns.color_palette('Greys_r', as_cmap=True), vmax=0.75, ax=ax)
    ax.set_title(f't={i}')
    
    ax.axvline(x=15, color='red', linestyle='--', linewidth=1)
    ax.axhline(y=15, color='red', linestyle='--', linewidth=1)
    ax.axvline(x=24, color='red', linestyle='--', linewidth=1)
    ax.axhline(y=24, color='red', linestyle='--', linewidth=1)
    
    plt.show()
    time.sleep(0.01)
    clear_output(wait=True)

In [None]:
savedir = '../gifs/localizer_set/'
if not os.path.exists(savedir):
    os.makedirs(savedir)
    
metric = 'correlation'
    
roi = 'MF1_7_F' #'MF1_9_F'
roi_data = raster_data[(raster_data['roi']==roi)]
stacked = np.stack(roi_data['img_psth'])
    
for i in tqdm(range(200)):
    fig, ax = plt.subplots()
    substacked = stacked[:, i, 1000:1024]
    rdm = squareform(pdist(substacked.T, metric=metric))
    sns.heatmap(rdm, cmap=sns.color_palette('Greys_r', as_cmap=True), vmax=0.75, ax=ax)
    ax.set_title(f't={i}')
    plt.tight_layout()
    plt.savefig(os.path.join(savedir,f'frame_{i:03d}.png'), dpi=300, transparent=False, bbox_inches='tight')
    plt.close(fig)
    
print('done')