# BLAES Units Multiunit Activity (MUA) Analysis

This notebook contains code to perform an analysis of modulation in multiunit activity. The threshold crossings are calculated in the script `MUAPrepro.ipynb`.

---

> *Contact: Justin Campbell (justin.campbell@hsc.utah.edu)*  
> *Version: 6/6/2025*

## 1. Import Libraries

In [None]:
# Import necessary libraries
import os
import glob
import warnings
import numpy as np
import pandas as pd
from scipy.stats import wilcoxon, norm, fisher_exact
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
from matplotlib.legend_handler import HandlerBase

# Display params
%matplotlib inline
%config InlineBackend.figure_format='retina'
pd.set_option('display.float_format', lambda x: '%.2f' % x)
warnings.filterwarnings('ignore')

## 2. Data Processing

In [None]:
# Define paths
results_dir = '/Users/justincampbell/Library/CloudStorage/GoogleDrive-u0815766@gcloud.utah.edu/My Drive/Research Projects/BLAESUnits/Results'

# Find all files named 'MUACounts.csv' in the results directory
muacounts_files = glob.glob(os.path.join(results_dir, '**', 'MUACounts.csv'), recursive=True)

# Load and concatenate all MUACounts.csv files into a single DataFrame
mua_df = []
for file in muacounts_files:
    df = pd.read_csv(file, index_col=0)
    mua_df.append(df)
mua_df = pd.concat(mua_df, axis=0)
mua_df = mua_df.reset_index(drop=True)

In [None]:
# Add region labels by getting letters 3-5 from Chan column
# mua_df['Region'] = mua_df['Chan'].str[2:5]
# mua_df['Region'].value_counts()

# strip first two characters from Chan column
mua_df['Region'] = mua_df['Chan'].str[2:]
mua_df['Region'] = mua_df['Region'].str[:-1]

# Simplify labels using a mapping
region_mapping = {
    'HIP': 'HIP',
    'HCA': 'HIP',
    'AHIP': 'HIP',
    'AMY': 'AMY',
    'OFC': 'OFC',
    'ACC': 'ACC',
    'VCG': 'ACC',
    'DAC': 'ACC',
    'VCC': 'ACC',
    'DCG': 'ACC',
    'FUG': 'FUG'}

mua_df['Region'] = mua_df['Region'].map(region_mapping)

# Exclude FUG region (too few units/chans)
mua_df = mua_df[mua_df['Region'] != 'FUG']

## 3. Statistical Testing

In [None]:
def runFRContrast_MUA(mua_cond_df):
    '''
    This function runs a Wilcoxon signed-rank test on the MUA threshold crossings for each channel for pre-/post- (ISI vs. Post) stim. It also runs permutation testing to determine whether the observed test statistic is in the 5% of most extreme values in the permutation distribution (p < 0.05).
    
    Inputs:
    - mua_cond_df (DataFrame): A DataFrame of the mua threshold crossings for each channel, generated by 'MUAPrepro.ipnyb'
    
    Outputs:
    - statsDF (DataFrame): A DataFrame with the results of the statistical tests.
    '''
    
    # Holders for unit characteristics
    chanLabels = []
    pIDLabels = []
    regions = []
    conditions = []
    
    # Holders for empirical and permuted Wilcoxon stats
    stat_real = []
    p_real = []
    res_real = []
    directions = []
    stat_perm = []
    p_perm = []
    res_perm = []
    
    # Permutation params
    nPerms = 1000
    np.random.seed(0)
    
    # Loop through each session
    for pID in mua_cond_df['pID'].unique():
        
        pIDChans = mua_cond_df[mua_cond_df['pID'] == pID]
        for chan in pIDChans['Chan'].unique():
            chanDF = pIDChans[pIDChans['Chan'] == chan]

            try:
                stat, p = wilcoxon(chanDF['Pre_Spikes'], chanDF['Post_Spikes'])
                if p < 0.05:
                    res = True
                else:
                    res = False
                    
                # determine whether pre or post spikes are greater
                if chanDF['Pre_Spikes'].mean() > chanDF['Post_Spikes'].mean():
                    direction = 'Dec'
                elif chanDF['Pre_Spikes'].mean() < chanDF['Post_Spikes'].mean():
                    direction = 'Inc'
                    
            except:
                stat = np.nan
                p = np.nan
                res = np.nan
                direction = np.nan
                
            # Permutation testing
            statPerms = []
            for i in range(nPerms):
                permDF = chanDF[['Pre_Spikes', 'Post_Spikes']].copy().melt(var_name='Epoch', value_name='Spike')
                permDF['Trial'] = np.tile(np.arange(1, chanDF.shape[0] + 1), 2)
                epochLabels = permDF['Trial']
                epochLabels = np.random.permutation(epochLabels)
                permDF['Trial'] = epochLabels
                isiSpikes = permDF[permDF['Epoch'] == 'Pre_Spikes'].sort_values(by='Trial')['Spike']
                postSpikes = permDF[permDF['Epoch'] == 'Post_Spikes'].sort_values(by='Trial')['Spike']
                statPerm, _ = wilcoxon(postSpikes.values, isiSpikes.values)
                statPerms.append(statPerm)
                
            # Compute p-values using z-score method (from permutation distribution)
            zContrast = (stat - np.mean(statPerms)) / np.std(statPerms)
            pPerm = norm.sf(abs(zContrast)) * 2
            if pPerm < 0.05:
                resPerm = True
            else:
                resPerm = False
                
                
            # Append contrast statistics
            stat_real.append(stat)
            p_real.append(p)
            res_real.append(res)
            directions.append(direction)
            stat_perm.append(statPerms)
            p_perm.append(pPerm)
            res_perm.append(resPerm)
            chanLabels.append(chan)
            pIDLabels.append(pID)
            regions.append(chanDF['Region'].iloc[0])
            conditions.append(chanDF['Condition'].iloc[0])        
            
    # Create a DataFrame with the results
    muastatsDF = pd.DataFrame({
        'Chan': chanLabels,
        'pID': pIDLabels,
        'Region': regions,
        'Condition': conditions,
        'Stat_Real': stat_real,
        'P_Real': p_real,
        'Res_Real': res_real,
        'Direction': directions,
        'Stat_Perm': stat_perm,
        'P_Perm': p_perm,
        'Res_Perm': res_perm
    })
    
    return muastatsDF

In [None]:
# Run the contrast analysis for the 'Stim' condition
mua_stim = mua_df[mua_df['Condition'] == 'Stim']
mua_stats_stim = runFRContrast_MUA(mua_stim)

# mua_nostim = mua_df[mua_df['Condition'] == 'NoStim']
# mua_stats_nostim = runFRContrast_MUA(mua_nostim)

# Combine the results into a single DataFrame
# mua_stats = pd.concat([mua_stats_stim, mua_stats_nostim], axis=0)
mua_stats = mua_stats_stim.copy()

# Export the results to a CSV file
mua_stats.to_csv(os.path.join(results_dir, 'Group', 'MUAStatsStim.csv'), index=False)

### 3.1 Summarize Results

In [None]:
# Find number of unique pID-Chan combinations
n_chans = mua_stats.groupby(['pID', 'Chan']).size().reset_index(name='Count')
n_chans = n_chans['Count'].sum()

print('Number of MUA channels: %i' %n_chans)

In [None]:
# Percent of significant channels
stim_sig = mua_stats[mua_stats['Condition'] == 'Stim']['Res_Perm'].value_counts(normalize=True)

# Print the percentage of significant channels
print('Sig channels (Stim): %i (%.1f%%)' % (stim_sig[True] * n_chans, stim_sig[True] * 100))

In [None]:
# Direction of modulation
stim_inc = mua_stats[(mua_stats['Condition'] == 'Stim') & (mua_stats['Res_Perm'] == True) & (mua_stats['Direction'] == 'Inc')].shape[0]
stim_dec = mua_stats[(mua_stats['Condition'] == 'Stim') & (mua_stats['Res_Perm'] == True) & (mua_stats['Direction'] == 'Dec')].shape[0]

# Print proportion of significant channels that are increasing or decreasing
print('Sig channels inc/dec (Stim): %.1f%% / %.1f%%' % ((stim_inc / (stim_inc + stim_dec)) * 100, (stim_dec / (stim_inc + stim_dec)) * 100))

## 4. Visualization

In [None]:
# Region order
nRegions = len(mua_stats['Region'].unique())
plotOrder = mua_stats['Region'].value_counts().index

# 90's Anime Color Palette (https://colormagic.app/palette/671d8e4ae9810edeaccc9737)
regionPal = ['#ff6e61', '#ffb84d', '#6d9dc5', '#5e4b8b']
pairPal = [regionPal[0], regionPal[2]]
saturation = 0.9
sns.palplot(sns.color_palette(pairPal, n_colors=len(pairPal), desat=saturation), size = .75)

### 4.1 Direction of Modulation in MUA

In [None]:
fig, ax = plt.subplots(1, 1, figsize = (1.5, 3), sharey = True)
plotPal = pairPal

# Parse data, convert to percentage
stim_inc_pct = (stim_inc / n_chans) * 100
stim_dec_pct = (stim_dec / n_chans) * 100 

# Create barplot
sns.barplot([stim_inc_pct, stim_dec_pct], palette = plotPal, saturation = saturation)

# Figure aeshtetics
plt.setp(ax.patches, linewidth = 0.5, edgecolor = 'k')
ax.set_ylabel('% Modulated Chans', fontsize = 'large', labelpad = 10)
yticks = np.arange(0, 16 + .1, 4)
ax.set_yticks(yticks)
ax.set_ylim(0, yticks[-1])
ax.set_xticks([0, 1], ['Inc', 'Dec'])
ax.set_xlabel('Direction\n(Post- vs. Pre-)', fontsize = 'large', labelpad = 10)
sns.despine(top = True, right = True)

# Export & Display
plt.savefig((os.path.join(results_dir, 'Group', 'Figures', 'MUASummary.pdf')), dpi = 1200, bbox_inches = 'tight')
plt.show()

# Test for significance using fisher exact test
fisher_table = np.array([[stim_inc, stim_dec], [n_chans - stim_inc, n_chans - stim_dec]])
odds, p = fisher_exact(fisher_table)
print("Odds Ratio (Inc/Dec): %.2f, p = %.3f" %(odds, p))

### 4.2 Modulation Direction x Region

In [None]:
# Parse data
sig_inc = mua_stats[(mua_stats['Condition'] == 'Stim') & (mua_stats['Res_Perm'] == True) & (mua_stats['Direction'] == 'Inc')]
sig_dec = mua_stats[(mua_stats['Condition'] == 'Stim') & (mua_stats['Res_Perm'] == True) & (mua_stats['Direction'] == 'Dec')]

region_inc = sig_inc['Region'].value_counts()
region_dec = sig_dec['Region'].value_counts()

# Scale values by n_chans
region_inc_pct = (region_inc / n_chans) * 100
region_dec_pct = (region_dec / n_chans) * 100

region_inc_pct = region_inc_pct.reindex(plotOrder, fill_value=0)
region_dec_pct = region_dec_pct.reindex(plotOrder, fill_value=0)
plotDF = pd.DataFrame({
    'Region': plotOrder,
    'Inc': region_inc_pct,
    'Dec': region_dec_pct})
plotDF = plotDF.melt(id_vars='Region', var_name='Direction', value_name='Percent')

fig, ax = plt.subplots(1, 1, figsize = (4, 3))
plotPal = pairPal

sns.barplot(x = 'Region', y = 'Percent', data = plotDF, order = plotOrder, hue = 'Direction', palette = [plotPal[0], plotPal[1]], saturation = saturation, ax = ax)

# Figure aeshtetics
handles = [Rectangle((0, 0), 1, 1, color = plotPal[0]), Rectangle((0, 0), 1, 1, color = plotPal[1])]
plt.legend(title = 'Direction', handles = handles, labels = ['Inc', 'Dec'], title_fontsize = 'medium', fontsize = 'small', bbox_to_anchor = (1, 1))
plt.setp(ax.patches, linewidth = 0.5, edgecolor = 'k')
ax.set_ylabel('% Modulated Chans', fontsize = 'large', labelpad = 10)
yticks = np.arange(0, 6 + 1, 2)
ax.set_yticks(yticks)
ax.set_ylim(0, yticks[-1])
sns.despine(top = True, right = True)
ax.set_xlabel('Region', fontsize = 'large', labelpad = 10)

# Export & Display
plt.savefig((os.path.join(results_dir, 'Group', 'Figures', 'MUARegion.pdf')), dpi = 1200, bbox_inches = 'tight')
plt.show()

# For each region, run a Fisher's exact test to compare the proportion of inc/dec channels
for region in plotOrder:
    region_inc = sig_inc[sig_inc['Region'] == region].shape[0]
    region_dec = sig_dec[sig_dec['Region'] == region].shape[0]
    
    # Create contingency table
    fisher_table = np.array([[region_inc, region_dec], [n_chans - region_inc, n_chans - region_dec]])
    
    # Run Fisher's exact test
    odds, p = fisher_exact(fisher_table)
    
    # Print results
    print("Region: %s, Odds Ratio (Inc/Dec): %.2f, p = %.3f" %(region, odds, p))