# BLAES Units E/I Neuron Classification

This notebook contains code for classifying neuron cell type based on properties of the waveform.

---

> *Contact: Justin Campbell (justin.campbell@hsc.utah.edu)*  
> *Version: 05/27/2025*

## 1. Import Libraries

In [None]:
import os 
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec as gs
import seaborn as sns
from matplotlib.patches import Rectangle
from sklearn.cluster import KMeans
from scipy.stats import chi2_contingency, mannwhitneyu

%config InlineBackend.figure_format='retina'
pd.set_option('display.float_format', lambda x: '%.2f' % x)

## 2. Load & Organize Data

In [None]:
# Define paths
projDir = '/Users/justincampbell/Library/CloudStorage/GoogleDrive-u0815766@gcloud.utah.edu/My Drive/Research Projects/BLAESUnits/'
resultsDir = os.path.join(projDir, 'Results')

In [None]:
# Load the Waveforms and Events, create a dictionary for each session with the waveforms for each unit
wf_dict = {}
sessions = [x for x in os.listdir(resultsDir) if os.path.isdir(os.path.join(resultsDir, x)) and x != 'Group']
for session in sessions:
    session_dict = {}
    wfs = pd.read_csv(os.path.join(resultsDir, session, 'Waveforms.csv'), index_col = 0)
    events = pd.read_csv(os.path.join(resultsDir, session, 'Events.csv'), index_col = 0)
    events['Chan-Unit'] = events['Channel'].astype(str) + '-' + events['Unit'].astype(str)
    units = events['Chan-Unit'].unique()
    for unit in units:
        unit_events = events[events['Chan-Unit'] == unit].index
        unit_wfs = wfs.loc[unit_events].reset_index(drop = True)
        session_dict[unit] = unit_wfs
    wf_dict[session] = session_dict
    
    
# Loads spike stats
spike_stats = pd.read_csv(os.path.join(resultsDir, 'Group', 'SpikeStats.csv'), index_col = 0)
spike_stats = spike_stats[spike_stats['Valid'] == True] # Filter out invalid units

## 3. Extract Waveform Features

In [None]:
def WFFeats(session, unit, wfs, export = False, show = False):
    '''Calculate and plot waveform features for a given session and unit.
    
    - Parameters:
        session (str): Session identifier (e.g., 'BJH-2023-01-01').
        unit (str): Unit identifier (e.g., 'LAMY4-1').
        wfs (pd.DataFrame): DataFrame containing waveforms for the specified unit.
        export (bool): If True, export the figure as a PDF.
        show (bool): If True, display the figure.
        
    - Returns:
        pd.DataFrame: DataFrame containing waveform features including Peak, Trough, Valley-to-Peak (VP), Peak Half Width (PeakHW), number of spikes (NSpikes), and Validity (≥ FR threshold).
    '''

    f_type = session[0:3]
    # Get features
    wf_avg = wfs.mean(axis = 0)
    peak = wf_avg.idxmax()
    trough = wf_avg.idxmin()
    half_peak = wf_avg[peak] / 2
    diffs_idxs = np.abs(wf_avg - half_peak).sort_values()
    hp_idx1 = diffs_idxs[diffs_idxs.index < peak].index[0]
    hp_idx2 = diffs_idxs[diffs_idxs.index > peak].index[0]
    vp = abs(int(trough) - int(peak)) / 30000 * 1000 # convert to ms
    phw = abs(int(hp_idx2) - int(hp_idx1)) / 30000 * 1000 # convert to ms
    validity = spike_stats[(spike_stats['pID'] == session) & (spike_stats['Unit'] == unit)]['Valid'].values[0]

    if show or export:

        # Setup figure
        fig, ax = plt.subplots(figsize = (3, 2))
        feat_colors = sns.color_palette('Set1', 4)

        # Plot waveforms
        for i in range(wfs.shape[0]):
            plt.plot(wfs.iloc[i,:], color = '#e4e4e4', alpha = 0.25, lw = 1)
        plt.plot(wf_avg, color = 'k', linewidth = 4)

        # Plot features
        plt.scatter(peak, wf_avg[peak], color = feat_colors[0], s = 30, label = 'Peak', zorder = 100)
        plt.scatter(trough, wf_avg[trough], color = feat_colors[1], s = 30, label = 'Trough', zorder = 100)
        plt.plot([trough, peak], [wf_avg[peak], wf_avg[peak]], color = feat_colors[3], linestyle = '--')
        plt.plot([trough, trough], [wf_avg[trough], wf_avg[peak]], color = feat_colors[3], linestyle = '--', label = 'Valley-to-Peak')
        plt.plot([hp_idx1, hp_idx2], [wf_avg[hp_idx1], wf_avg[hp_idx2]], color = feat_colors[2], linestyle = '--', marker = 'o', markersize = 6, label = 'Peak Half Width')

        # Figure aesthetics
        plt.title(session + ', ' + unit, pad = 15, fontsize = 'large')
        plt.xlabel('Time (ms)', fontsize = 'large')
        plt.ylabel('Voltage ($\\mu$V)', fontsize = 'large')
        if f_type == 'BJH':
            tick_range = np.arange(0, 31, 15)
            time = tick_range / 30000 * 1000
            plt.xticks(tick_range, time)
            plt.xlim(0, 31)
        elif f_type == 'UIC':
            tick_range = np.arange(0, 46, 15)
            time = tick_range / 30000 * 1000
            plt.xticks(tick_range, time)
            plt.xlim(0, 48)
        sns.despine(top = True, right = True)
        handles, _ = plt.gca().get_legend_handles_labels()
        handles.append(Rectangle((0,0),1,1,fc="#e4e4e4", fill=True, edgecolor='none', linewidth=0))
        handles.append(Rectangle((0,0),1,1,fc="w", fill=False, edgecolor='none', linewidth=0))
        labels = ['Peak: %.1f $\\mu$V' %wf_avg[peak], 'Trough: %.1f $\\mu$V' %wf_avg[trough], 'V-P: %.2f ms' %vp, 'P$_{HW}$: %.2f ms' %phw, 'Spikes: %i' %wfs.shape[0], 'FR Valid: %s' %validity]
        plt.legend(handles, labels, title = 'WF Features', title_fontsize = 'small', fontsize = 'x-small', bbox_to_anchor = (1.05, 1))

        # Export & Display
        if export:
            plt.savefig('/Users/justincampbell/Desktop/TEST.pdf', dpi = 1200, bbox_inches = 'tight')
        if show:
            plt.show()
        else:
            plt.close()
    
    return pd.DataFrame({'pID': session, 'Unit': unit, 'Peak': peak, 'Trough': trough, 'VP': vp, 'PeakHW': phw, 'NSpikes': wfs.shape[0], 'Valid': validity}, index = [0])

In [None]:
feature_dfs = []

for session in spike_stats['pID'].unique():
    for unit in spike_stats[spike_stats['pID'] == session]['Unit'].unique():
        wfs = wf_dict[session][unit]
        df = WFFeats(session, unit, wfs, export = False, show = False)
        feature_dfs.append(df)
        
feature_df = pd.concat(feature_dfs, ignore_index = True)

## 4. K-Means Clustering
Methods based on [Peyrache et al. 2011, *PNAS*](https://www.pnas.org/doi/full/10.1073/pnas.1109895109).

In [None]:
# Do k-means clustering using the PHW and VP features
kmeans = KMeans(n_clusters=2, random_state=0)
feature_df['Cluster'] = kmeans.fit_predict(feature_df[['PeakHW', 'VP']])

In [None]:
# Add cluster labels to the spike stats DataFrame
spike_stats_clust = spike_stats.merge(feature_df[['pID', 'Unit', 'Cluster']], on=['pID', 'Unit'], how='left')

# Isolate the two clusters
df_e = spike_stats_clust[spike_stats_clust['Cluster'] == 0]
df_i = spike_stats_clust[spike_stats_clust['Cluster'] == 1]
e_pct = spike_stats_clust[spike_stats_clust['Cluster'] == 0].shape[0] / spike_stats_clust.shape[0] * 100
i_pct = spike_stats_clust[spike_stats_clust['Cluster'] == 1].shape[0] / spike_stats_clust.shape[0] * 100

In [None]:
# Print number of E and I units
print(f'Excitatory units: {df_e.shape[0]} ({e_pct:.1f}%)')
print(f'Inhibitory units: {df_i.shape[0]} ({i_pct:.1f}%)')

In [None]:
# Get X,Y position of cluster centroids
centroids = kmeans.cluster_centers_
centroid_df = pd.DataFrame(centroids, columns=['PeakHW', 'VP'])
centroid_df['Cluster'] = ['E', 'I']
centroid_df

## 5. Vizualization

In [None]:
plotOrder = spike_stats_clust['Region'].value_counts().index
palette = ['#6d9dc5', '#ff6e61']
sns.palplot(palette, size = 0.75)
plt.show()

In [None]:
# Use gridspec to create a 2x2 grid of subplots
fig = plt.figure(figsize=(8, 3))
gs = fig.add_gridspec(3, 2, width_ratios=[1, 2], height_ratios=[1, 1, 1])
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[1:, 0])
ax3 = fig.add_subplot(gs[:, 1])
fig.subplots_adjust(wspace=0.5, hspace=0.25)


# Plotting
ex_session = ['UIC20221701', 'UIC20230701']
ex_units = ['mLOFC3-1', 'mLHC7-1']

for i in range(2):
    session = ex_session[i]
    unit = ex_units[i]
    f_type = session[0:3]
    wfs = wf_dict[session][unit]
    wf_avg = wfs.mean(axis=0)
    melted_wfs = wfs.melt(var_name='Time', value_name='Voltage')
    sns.lineplot(data=melted_wfs, x='Time', y='Voltage', ax=ax2, lw=4, color = palette[i], errorbar = 'sd')
    # wf_avg = wfs.mean(axis = 0)
    # # Plot WFs
    # for ii in range(wfs.shape[0]):
    #     ax2.plot(wfs.iloc[ii,:], color = '#e4e4e4', alpha = 0.1, lw = 1)
    # ax2.plot(wf_avg, color = palette[i], linewidth = 4)

ax1.plot(wf_avg, color = '#7E7E7E', linewidth = 3)
trough = wf_avg.idxmin()
peak = wf_avg.idxmax()
half_peak = wf_avg[peak] / 2
diffs_idxs = np.abs(wf_avg - half_peak).sort_values()
hp_idx1 = diffs_idxs[diffs_idxs.index < peak].index[0]
hp_idx2 = diffs_idxs[diffs_idxs.index > peak].index[0]
ax1.scatter(peak, wf_avg[peak], color = 'k', s = 10, label = 'Peak', zorder = 100)
ax1.scatter(trough, wf_avg[trough], color = 'k', s = 10, label = 'Trough', zorder = 100)
ax1.plot([trough, peak], [wf_avg[peak], wf_avg[peak]], color = 'k', linestyle = ':', lw = 1)
ax1.plot([trough, trough], [wf_avg[trough], wf_avg[peak]], color = 'k', linestyle = ':', label = 'Valley-to-Peak', lw = 1)
ax1.plot([hp_idx1, hp_idx2], [wf_avg[hp_idx1], wf_avg[hp_idx2]], color = 'k', linestyle = ':', lw = 1, marker = 'o', markersize = 3, label = 'Peak Half Width')

# Add text annotations for PHW and VP
ax1.text(0.45, 0.35, 'PHW', transform=ax1.transAxes, fontsize='small', ha='center')
ax1.text(0.25, 0.725, 'VP', transform=ax1.transAxes, fontsize='small', ha='center')
    
sns.scatterplot(data=feature_df, x='PeakHW', y='VP', ax=ax3, hue = 'Cluster', palette = palette, s=50, alpha = 0.75)
ex1_df = feature_df[(feature_df['pID'] == ex_session[0]) & (feature_df['Unit'] == ex_units[0])]
ex2_df = feature_df[(feature_df['pID'] == ex_session[1]) & (feature_df['Unit'] == ex_units[1])]
sns.scatterplot(data=ex1_df, x='PeakHW', y='VP', ax=ax3, color=palette[0], s=50, edgecolor='k', linewidth=1)
sns.scatterplot(data=ex2_df, x='PeakHW', y='VP', ax=ax3, color=palette[1], s=50, edgecolor='k', linewidth=1)

# Add marginal plots

# Aesthetics
sns.despine(ax=ax1, top=True, right=True, left=True, bottom=True)
ax1.set_ylim(-100, 100)
ax1.set_xticks([])
ax1.set_yticks([])

sns.despine(ax = ax2, top = True, right = True)
ax2.set_xlim(0, 45)
ax2.set_ylim(-100, 50)
ax2.set_yticks(np.arange(-100, 51, 50), ['-100', '-50', '0', '50'], fontsize='medium')
ax2.set_xlabel('Time (ms)', fontsize='large', labelpad=10)
ax2.set_ylabel('Voltage ($\\mu$V)', fontsize='large', labelpad=10)
if f_type == 'BJH':
    tick_range = np.arange(0, 31, 15)
    time = tick_range / 30000 * 1000
    ax2.set_xticks(tick_range, time)
elif f_type == 'UIC':
    tick_range = np.arange(0, 46, 15)
    time = tick_range / 30000 * 1000
    ax2.set_xticks(tick_range, time)

sns.despine(ax = ax3, top = True, right = True)
ax3.set_xlim(0, 0.9)
ax3.set_ylim(0, 0.9)
ax3.set_xticks(np.arange(0, 1.2, 0.3), ['0', '0.3', '0.6', '0.9'], fontsize='medium')
ax3.set_yticks(np.arange(0, 1.2, 0.3), ['0', '0.3', '0.6', '0.9'], fontsize='medium')
ax3.set_xlabel('Peak Half Width (ms)', fontsize='large', labelpad=10)
ax3.set_ylabel('Valley-to-Peak (ms)', fontsize='large', labelpad=10)
handles, labels = ax3.get_legend_handles_labels()
labels = ['E (%.1f%%)' % (e_pct), 'I (%.1f%%)' % (i_pct)]
for handle in handles:
    handle.set_alpha(1)
ax3.legend(title='Cluster', handles = handles, labels = labels, title_fontsize='medium', fontsize='small', loc = 'upper left')


plt.savefig((os.path.join(resultsDir, 'Group', 'Figures', 'EIWFs.pdf')), dpi = 1200, bbox_inches = 'tight')
plt.show()

In [None]:
# Created separately to paste over the scatter plot above (cannot use jointplot with axes)
sns.jointplot(data=feature_df, x='PeakHW', y='VP', hue='Cluster', palette=palette, alpha=0.75, height=5, marginal_kws={'common_norm': False}, s = 50)

ex1_df = feature_df[(feature_df['pID'] == ex_session[0]) & (feature_df['Unit'] == ex_units[0])]
ex2_df = feature_df[(feature_df['pID'] == ex_session[1]) & (feature_df['Unit'] == ex_units[1])]

# get the count of overlapping points

sns.scatterplot(data=ex1_df, x='PeakHW', y='VP', color=palette[0], s=50, edgecolor='k', linewidth=1)
sns.scatterplot(data=ex2_df, x='PeakHW', y='VP', color=palette[1], s=50, edgecolor='k', linewidth=1)

# resize figure
fig = plt.gcf()
fig.set_size_inches(4.5, 3)

sns.despine(ax = ax3, top = True, right = True)
plt.xlim(0, 0.9)
plt.ylim(0, 0.9)
plt.xticks(np.arange(0, 1.2, 0.3), ['0', '0.3', '0.6', '0.9'], fontsize='medium')
plt.yticks(np.arange(0, 1.2, 0.3), ['0', '0.3', '0.6', '0.9'], fontsize='medium')
plt.xlabel('Peak Half Width (ms)', fontsize='large', labelpad=10)
plt.ylabel('Valley-to-Peak (ms)', fontsize='large', labelpad=10)
handles, labels = ax3.get_legend_handles_labels()
labels = ['E (%.1f%%)' % (e_pct), 'I (%.1f%%)' % (i_pct)]
for handle in handles:
    handle.set_alpha(1)
plt.legend(title='Cluster', handles = handles, labels = labels, title_fontsize='medium', fontsize='small', loc = 'upper left')

plt.savefig((os.path.join(resultsDir, 'Group', 'Figures', 'EIWScatterV2.pdf')), dpi = 1200, bbox_inches = 'tight')
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize = (1, 2), sharey = True)

# Parse data, convert to percentage
mod_e_pct = 100 * df_e['StimSig'].mean()
mod_i_pct = 100 * df_i['StimSig'].mean()

# Create barplot
sns.barplot([mod_e_pct, mod_i_pct], palette = palette, saturation = 0.8)

# Figure aeshtetics
plt.setp(ax.patches, linewidth = 0.5, edgecolor = 'k')
ax.set_ylabel('% Modulated Units', fontsize = 'large', labelpad = 10)
tick_spacing = 9
yticks = np.arange(0, 36 + .1, tick_spacing)
ax.set_yticks(yticks)
ax.set_xticks([0, 1], ['E', 'I'])
ax.set_ylim(0, yticks[-1])
ax.set_xlabel('Cluster', fontsize = 'large', labelpad = 10)
sns.despine(top = True, right = True)

plt.savefig((os.path.join(resultsDir, 'Group', 'Figures', 'EIModulation.pdf')), dpi = 1200, bbox_inches = 'tight')
plt.show()

# Do a chi-square test for the two clusters
contingency_table = pd.crosstab(spike_stats_clust['Cluster'], spike_stats_clust['StimSig'])
chi2, p, dof, expected = chi2_contingency(contingency_table)
print(f"Chi-square test results:\nChi2: {chi2:.2f}, p-value: {p:.4f}, DF: {dof}")

# Print percentage of modulated units in each cluster
mod_e_pct = 100 * df_e['StimSig'].mean()
mod_i_pct = 100 * df_i['StimSig'].mean()
print(f"Percentage of modulated units in E cluster: {mod_e_pct:.1f}%")
print(f"Percentage of modulated units in I cluster: {mod_i_pct:.1f}%")

In [None]:
fig, ax = plt.subplots(1, 1, figsize = (2.5, 2))
sns.countplot(x = 'Region', data = spike_stats_clust, order = plotOrder, color = palette[0])
sns.countplot(x = 'Region', data = spike_stats_clust[spike_stats_clust['Cluster'] == 1], order = plotOrder, color = palette[1])
e_pct = spike_stats_clust[spike_stats_clust['Cluster'] == 0].shape[0] / spike_stats_clust.shape[0] * 100
i_pct = spike_stats_clust[spike_stats_clust['Cluster'] == 1].shape[0] / spike_stats_clust.shape[0] * 100

# Figure aeshtetics
plt.setp(ax.patches, linewidth = 0.5, edgecolor = 'k')
plt.legend(title = 'Cluster', labels = ['E', 'I'], title_fontsize = 'medium', fontsize = 'small', bbox_to_anchor = (1, 1))
ax.set_xlabel('Region', fontsize = 'large', labelpad = 10)
ax.set_ylabel('# Units', fontsize = 'large', labelpad = 10)
ax.set_yticks(np.arange(0, 91, 30), ['0', '30', '60', '90'], fontsize='medium')
ax.set_ylim(0, 90)
sns.despine(top = True, right = True)

plt.savefig((os.path.join(resultsDir, 'Group', 'Figures', 'EIRegion.pdf')), dpi = 1200, bbox_inches = 'tight')
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize = (1, 2))

# Parse data, convert to percentage
sns.boxplot(x = 'Cluster', y = 'FR_ISI', hue = 'Cluster', data = spike_stats_clust, ax = ax, palette = palette, saturation = 0.8, legend = False, linecolor = 'k', width = 0.75, fliersize=2.5)

# Aesthetics
sns.despine(top = True, right = True)
ax.set_xlabel('Cluster', fontsize = 'large', labelpad = 10)
ax.set_ylabel('Baseline FR (Hz)', fontsize = 'large', labelpad = 10)
ax.set_xticks([0, 1], ['E', 'I'], fontsize='medium')
ax.set_yticks(np.arange(0, 21, 5), ['0', '5', '10', '15', '20'], fontsize='medium')
ax.set_ylim(0, 20)

plt.savefig((os.path.join(resultsDir, 'Group', 'Figures', 'EIFR.pdf')), dpi = 1200, bbox_inches = 'tight')
plt.show()

# Do a Mann-Whitney U test for the two clusters
e, p_value = mannwhitneyu(df_e['FR_ISI'], df_i['FR_ISI'])
print(f"Mann-Whitney U test results:\nU-statistic: {e:.2f}, p-value: {p_value:.4f}")