In [None]:
import os
import re
import h5py
import random
import numpy as np
import pandas as pd
import scipy.io
from scipy.cluster.hierarchy import linkage, dendrogram
from scipy.spatial.distance import pdist, squareform

import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm

import utils

datadir = '../datasets/NNN/'
fnames = utils.fnames(datadir)
# data = pd.read_pickle('../datasets/NNN/all_unit_data.pkl')
raster_data = pd.read_pickle('../datasets/NNN/unit_data_full.pkl')

In [None]:
for i, pair in enumerate(fnames):
    if i>0:
        break
    gus_fname = os.path.join(datadir, pair[0])
    proc_fname = os.path.join(datadir, pair[1])
    
    gus_data = utils.load_mat(gus_fname)
    proc_data = scipy.io.loadmat(proc_fname)
    
    print('unit types:', len(proc_data['UnitType'][0]))
    print('psth:', proc_data['mean_psth'].shape)
    print('avg firing rate to all images:', proc_data['response_basic'].shape)
    
    mean_psth = proc_data['mean_psth']
    unit_type = proc_data['UnitType'][0]
    single_units = np.where(unit_type==1)
    snr = proc_data['snr'].T
    response = np.stack(gus_data['GoodUnitStrc']['response_matrix_img'])
    
    print(f'number of single units: {np.sum(unit_type==1)}')
    
    sns.heatmap(mean_psth[single_units], vmax=20)
    plt.show()
    
    sns.boxplot(snr[single_units])
    plt.show()

In [None]:
# df = data[data['unit_type']==1]
fig,ax = plt.subplots(1,1,figsize=(10,5))
df = data

sns.boxplot(df, x='monkey', y='snr_max', hue='unit_type', fliersize=1, ax=ax)
plt.title('SNR max')
plt.ylim(bottom = -1, top=50)

In [None]:
df.groupby(['unit_type']).describe()

In [None]:
fig,ax = plt.subplots(1,1,figsize=(20,8))

df = data[(data['unit_type']==1) & (data['snr']>1)]
dat = np.stack(df['avg_firing_rate'])

sns.heatmap(dat, ax=ax, vmax=10)

ax.set_ylabel('unit number')
ax.set_xlabel('image idx')
ax.set_xticks([], [])
ax.set_yticks([], [])


In [None]:
df = data[(data['unit_type']==1) & (data['snr']>1)]

fig,axes = plt.subplots(5,1,figsize=(10,20))

for i in range(5):
    num = round(random.random()*len(df))
    dat = df.iloc[num]['img_psth']
    sns.heatmap(dat, square=True, vmax=10, ax=axes[i])
    axes[i].axis("off")
    axes[i].set_title(f'{df.iloc[i]['snr']}')

In [None]:
# sample
# df = pd.DataFrame({'F_SI':..., 'B_SI':..., 'O_SI':...})

# 1. "Melt" the dataframe to long form (for lineplotting with seaborn)
df = data.copy()

df_long = df.reset_index().melt(id_vars='index', value_vars=['F_SI', 'B_SI', 'O_SI'],
                                var_name='SI_type', value_name='SI_value')
df_long = df_long.rename(columns={"index": "unit"})  # index serves as unique unit id

# 2. For grouping: For each unit, find which SI_type is maximal AND above cutoff
cutoff = 0.5

# Get for each unit the max SI value and its column
def assign_group(row):
    vals = row[['F_SI', 'B_SI', 'O_SI']]
    maxval = vals.max()
    if maxval < cutoff:
        return 'None'
    maxcol = vals.idxmax()
    return maxcol  # will be 'F_SI', 'B_SI', or 'O_SI'

df['group'] = data.apply(assign_group, axis=1)

# 3. Add 'group' label to long version
df_long['group'] = df_long['unit'].map(df['group'])

# Filter out units not meeting any cutoff
df_long_filtered = df_long[df_long['group'] != 'None']

In [None]:
# fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, sharex=True)

groups = ['F_SI', 'B_SI', 'O_SI']
for i, g in enumerate(groups):
    fig,ax = plt.subplots(1,1,figsize=(8,5))
    sub_df = df_long_filtered[df_long_filtered['group'] == g]
    sns.lineplot(sub_df, x='SI_type', y='SI_value', sort=False,
                 hue='unit', marker='o', palette = ['black'], alpha=0.05)
    ax.set_title(f"Group: {g}, {len(sub_df)}")
    ax.legend().remove()
    ax.set_xlabel('category')
    if i == 0:
        ax.set_ylabel('Selectivity Index')
    plt.show()

In [None]:
unit = 'F_SI'
face_df = raster_data[(raster_data[unit]>0.5) & (raster_data['unit_type'] == 1)]
face_df = face_df[~(face_df['img_raster'].isna())] # some sessions do not have raster data?
psth = face_df.iloc[13]['img_psth']
sns.heatmap(psth[:, :1000].T, cmap=sns.color_palette(palette='Greys'))
plt.show()

x = np.stack(face_df['img_psth'].values)
avg_face = np.mean(x, axis=0)
sns.heatmap(avg_face[:, :1000].T, cmap=sns.color_palette(palette='Greys'))
plt.title('Face responses to NSD images')
plt.axis('off')
plt.show()

In [None]:
M = avg_face[:, :1000].T  # (1000, time)

# 1. Compute row linkage (using, e.g., Euclidean or correlation distance)
dists = pdist(M, metric='correlation')  # or 'euclidean', 'cosine', etc
row_linkage = linkage(dists, method='average')  # method can be 'ward', 'average', etc

# 2. Get the optimal leaf order from dendrogram
dendro = dendrogram(row_linkage, no_plot=True)
row_order = dendro['leaves']  # this is the order for the rows (images)

sns.heatmap(M[row_order, :], cmap=sns.color_palette('Greys', as_cmap=True))
plt.title('Face responses to NSD images (clustered)')
plt.axis('off')
plt.show()

In [None]:
# 1. Compute mean activity for each image
row_means = M.mean(axis=1)  # shape: (num_images,)

# 2. Get sort order (indices of rows), from highest to lowest mean
row_order = np.argsort(row_means)[::-1]  # descending order

# 3. Plot using the sorted order
sns.heatmap(M[row_order, :], cmap=sns.color_palette('Greys', as_cmap=True))
plt.title('NSD images (sorted by mean unit activity)')
plt.axis('off')
plt.show()

In [None]:
all_raster = np.stack(face_df['img_raster'].values)
raster = np.mean(all_raster, axis=0)

R = raster.T[:1000]
row_means = R.mean(axis=1) 
row_order = np.argsort(row_means)[::-1]

sns.heatmap(R[row_order, :], cmap=sns.color_palette('Greys', as_cmap=True))
plt.title('NSD images (Raster plot)')
plt.axis('off')

In [None]:
all_raster = np.stack(face_df['img_raster'].values)
raster = np.mean(all_raster, axis=0)

R = raster.T[:1000]

early_raster = all_raster[:, 100:170, :1000] # 70 - 150 msec
late_raster = all_raster[:, 200:270, :1000] # 170-250 msec
latelate_raster = all_raster[:, 300:370, :1000]
fig, axes = plt.subplots(3, 2, figsize=(10, 12))

time_windows = [
    ('spikes from 50-120 msec', early_raster),
    ('spikes from 150-220 msec', late_raster),
    ('spikes from 250-320 msec', latelate_raster)
]

for row, (title, raster) in enumerate(time_windows):
    # Left: histogram
    ax = axes[row, 0]
    sspikes = np.sum(raster, axis=1)        # (units, images)
    x = np.mean(sspikes, axis=0)            # mean population activity per image
    sns.histplot(x, binwidth=0.01, ax=ax)
    ax.set_ylim(top=150)
    ax.set_xlim(right=1)
    ax.set_title(title)
    ax.set_xlabel('Mean spike count per image')
    ax.set_ylabel('Count')

    # Right: RDM
    ax = axes[row, 1]
    rdm = squareform(pdist(sspikes, metric='correlation'))  # shape (units, units)
    sns.heatmap(rdm, ax=ax, cmap='mako', square=True)  # only cbar for top row for less clutter
    ax.set_title('RDM (%s)' % title)
    ax.set_xlabel('Unit')
    ax.set_ylabel('Unit')
    ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
unit = 'F_SI'
face_df = raster_data[(raster_data[unit]>0.5) & (raster_data['unit_type'] == 1) & (raster_data['snr_max'] > 20) & (raster_data['session'].isin([21,24,25,28,29,30,36,40]))]
face_df = face_df[~(face_df['img_raster'].isna())] # some sessions do not have raster data?
face_df['session'].unique(), len(face_df)

In [None]:
img_idx = 254
N = 50  # window for moving avg

random.seed(5)

example_units = random.sample(range(len(face_df)), 5)   # choose units you want to visualize

fig, axes = plt.subplots(2, 5, figsize=(18, 6), sharex='col', sharey='row')

for col, unit_idx in enumerate(example_units):
    # Get normalized spike history for this unit and image
    spike_history = face_df.iloc[unit_idx]['img_raster'][:, img_idx]
    spike_history = spike_history / spike_history.max()

    # --- Plot cumulative sum ---
    ax = axes[0, col]
    sns.lineplot(x=np.arange(len(spike_history)), y=np.cumsum(spike_history), ax=ax)
    ymin, ymax = ax.get_ylim()
    ax.vlines(x=50, ymin=ymin, ymax=ymax, colors='red', linestyle='dashed')
    ax.set_title(f'Unit {unit_idx} - cumulative')
    ax.set_ylabel('spike count')

    # --- Plot moving average ---
    ax = axes[1, col]
    moving_avg = np.convolve(spike_history, np.ones(N)/N, mode='valid')
    sns.lineplot(x=np.arange(len(moving_avg)), y=moving_avg, ax=ax)
    ymin, ymax = ax.get_ylim()
    ax.vlines(x=50, ymin=ymin, ymax=ymax, colors='red', linestyle='dashed')
    ax.set_title(f'Unit {unit_idx} - moving avg')
    ax.set_ylabel('firing rate')

plt.suptitle(f'Example units for image {img_idx}', fontsize=16)
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.show()

In [None]:
datadir = '../datasets/NNN/'
fnames = utils.fnames(datadir)

cols = ['session', 'monkey', 'F_SI', 'B_SI', 'O_SI']
df = pd.DataFrame(columns=cols)

total_units = 0
for i, pair in tqdm(enumerate(fnames)):
    gus_fname = os.path.join(datadir, pair[0])
    proc_fname = os.path.join(datadir, pair[1])
    m = re.match(r'Processed_ses(\d+)_(\d{6})_M(\d+)_(\d+)\.mat', os.path.basename(proc_fname))
    if i == 28:
        print(f'skipping {proc_fname}...')
        continue
    if not m:
        print(f"Could not parse {proc_fname}")
        continue
    try:
        proc_data = scipy.io.loadmat(proc_fname)
        
        session_num = int(m.group(1))
        monkey = int(m.group(3))
        unit_types = proc_data['UnitType'][0]
        num_units = len(proc_data['UnitType'][0])

        bsi = proc_data['B_SI'].T.squeeze(); assert bsi.shape[0] == num_units
        osi = proc_data['O_SI'].T.squeeze(); assert osi.shape[0] == num_units
        fsi = proc_data['F_SI'].T.squeeze(); assert fsi.shape[0] == num_units
        
        for unit_idx in range(num_units):
            df.loc[len(df)] = {
                'session': session_num,
                'monkey': monkey,
                'F_SI': fsi[unit_idx],
                'B_SI': bsi[unit_idx],
                'O_SI': osi[unit_idx],
            }
        total_units += num_units

    except AssertionError as e:
        print(f"Assertion failed for {proc_fname or gus_fname}: {e}")
        continue
    except Exception as e:
        print(f"Error processing {proc_fname or gus_fname}: {e}")
        continue


In [None]:
data['F_SI'] = df['F_SI']
data['B_SI'] = df['B_SI']
data['O_SI'] = df['O_SI']
data.to_pickle('../datasets/NNN/all_unit_data.pkl')