In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

from gsn.perform_gsn import perform_gsn

from scipy.ndimage import gaussian_filter1d, uniform_filter1d

# raw: (units, time, images, trials)
def psth_gaussian(raw, sigma=3):
    # smooth along the time axis
    psth = gaussian_filter1d(raw, sigma=sigma, axis=1)
    return psth

def psth_window(raw, window_ms=50):
    # raw: (units, time, images, trials)
    # convert NaN to 0 so windowing behaves properly
    X = np.nan_to_num(raw, nan=0.0)
    # sliding window average over time axis
    psth = uniform_filter1d(X, size=window_ms, axis=1, mode='constant')
    return psth

def summary(x):
    print(f'''Summary statistics {x.shape}:
    \tMedian: {np.nanmedian(x)}
    \tMean: {np.nanmean(x)}
    \tMinimum: {np.nanmin(x)}
    \tMaximum:{np.nanmax(x)}\n''')

def segregate(df, method='truncate'):
    roi_arrays = {}
    for roi, df_roi in df.groupby('roi'):
        arrays = df_roi['raster'].to_list()
        if method == 'truncate':
            # count valid trials per unit
            valid_counts = []
            valid_masks  = []

            # minimum trials with no NaNs
            for a in arrays:
                vm = ~np.isnan(a).any(axis=(0,1))
                valid_masks.append(vm)
                valid_counts.append(vm.sum())
            min_T = min(valid_counts)
        
            total = []
            for a, vm in zip(arrays, valid_masks):
                idx = np.where(vm)[0][:min_T]   # first min_T valid trials
                total.append(a[:, :, idx])
        elif method == 'pad':
            # find max number of trials *within this ROI*
            max_T = max(a.shape[2] for a in arrays)
            total = []
            for a in arrays:
                T = a.shape[2]
                if T < max_T:
                    pad_width = ((0, 0), (0, 0), (0, max_T - T))
                    a = np.pad(a, pad_width, constant_values=np.nan)
                total.append(a)

        x_roi = np.array(total)   # (units, 450, 1072, min_T)
        roi_arrays[roi] = x_roi
        print(roi, x_roi.shape)

    return roi_arrays

In [None]:
ras_df = pd.read_pickle('../../datasets/NNN/raw_raster_data_batch000.pkl')
print(f'Succesfully loaded data for {len(ras_df)} units.')

roi_arrays = segregate(ras_df, method='truncate')

In [None]:
### CLUSTER UNITS USING K-MEANS
ROI = 'Unknown_19_F'
x = roi_arrays[ROI] # shape: (608, 450, 1072, reps)
unit_tc = np.nanmean(x, axis=(2, 3))

k = 3  # number of clusters you want
kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto')
labels = kmeans.fit_predict(unit_tc)

cluster_data = {}  # cluster_id -> (n_units_in_cluster, 450, 1072, 7)
for c in range(k):
    idx = np.where(labels == c)[0]
    cluster_data[c] = x[idx]   # keeps full (time, images, trials) for those units
    print(f"Cluster {c}: {cluster_data[c].shape}")

In [None]:
### VISUALIZE CLUSTER TIME COURSES

# first for a single unit
uidx = 8
img_avg = np.mean(cluster_data[2], axis=2)

unit_tc = img_avg[uidx]                      # (time, trials)
unit_tc = unit_tc / np.max(unit_tc, axis=0, keepdims=True)
# now each column (trial) peaks at 1

fig,ax = plt.subplots(1,1)
for trial in range(unit_tc.shape[-1]):
    trial_tc = unit_tc[:, trial]
    sns.lineplot(trial_tc, label=trial, alpha=0.5, ax=ax)
ax.legend()
ax.set_title('Single unit data across 5 trials')
plt.show()

# now for the entire cluster
trial_avg = np.mean(img_avg, axis=2)                     # (units, time)
unit_norm = trial_avg / np.max(trial_avg, axis=1, keepdims=True)
# axis=1 so each unitâ€™s own max over time is 1

fig,ax = plt.subplots(1,1)
for unit in unit_norm[:50]:
    sns.lineplot(unit, alpha=0.5, label = unit, ax=ax)
ax.legend().remove()
ax.set_title('Each line is a unit, averaged across 5 trials')
plt.show()