# CCEP Analyses

Load in preprocessed data (created in `CCEPPrepro.ipynb`) and performs X analyses.


---
> Justin Campbell & Krista Wahlstrom  
> Version: 3/05/2024

## 1. Setup

In [1]:
# Import libraries
import os
import mne
import sys
import glob
import scipy.io
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.signal import iirnotch, lfilter, sosfilt, butter

# Notebook settings
%matplotlib inline
%config InlineBackend.figure_format='retina'

## 2. Data Cleaning

### 2.1 Load Preprocessed Data

In [2]:
# Session to analyze
pID = 'UIC202208'
stimPair = 'RAMG2-RAMG3'

In [3]:
# Define paths and find datafiles
#rootDir = '/Users/justincampbell/Library/CloudStorage/Box-Box/INMANLab/BCI2000/BLAES Aim 2.1/CCEPs/' # Justin's path
rootDir = '/Users/inmanlab/Library/CloudStorage/Box-Box/INMANLab/BCI2000/BLAES Aim 2.1/CCEPs/' #Krista's path

datapath = os.path.join(rootDir, 'Data')

if pID[0:3] == 'UIC':
    fileType = 'Utah'
    datapath = os.path.join(datapath, 'Utah_Data')
else:
    fileType = 'WashU'
    datapath = os.path.join(datapath, 'WashU_Data')

preproDataPath = os.path.join(datapath, 'Prepro', (pID + '_' + stimPair))
savepath = os.path.join(rootDir, 'Results', (pID + '_' + stimPair))

# Create results folder if it doesn't exist
if not os.path.exists(savepath):
    os.mkdir(savepath)

try:
    # Load preprocessed data
    data = np.load(os.path.join(preproDataPath, 'PreproData.npy'))
    events = np.load(os.path.join(preproDataPath, 'Events.npy'))
    chans = pd.read_csv(os.path.join(preproDataPath, 'ChanLabels.csv'), index_col=0)['Chan'].to_list()
    bad_chans = pd.read_csv(os.path.join(preproDataPath, 'DroppedChans.csv'), index_col=0)['Dropped Chans'].to_list()
    # Set Fs
    if fileType == 'Utah':
        fs = 1000
    elif fileType == 'WashU':
        fs = 2000
except:
    print('Error loading data')
    raise

### 2.1 Remove Bad Channels
Stored in `DroppedChans.csv`

In [4]:
# get indices of bad channels in chans
bad_chan_inds = [chans.index(chan) for chan in bad_chans]

# remove bad channels from data
data = np.delete(data, bad_chan_inds, axis=1)

# remove bad channels from chans
chans = [chan for chan in chans if chan not in bad_chans]

In [5]:
print('# Trials: %i' % len(events))
print('# Chans: %i' % len(chans))
print('Epoch Length: %.2fs' % (data.shape[2]/fs))

# Trials: 27
# Chans: 86
Epoch Length: 1.80s


### 2.2 Z-Score Post-Stim Response

\begin{equation}
Z_{trial} = \frac{x_{trial} - \mu_{pre}}{\sigma_{pre}}
\end{equation}

In [6]:
# Seperate data into pre and post stim
midpoint = int(data.shape[2] / 2)
preStim = data[:, :, 0:midpoint]
postStim = data[:, :, midpoint:]

# Define baseline (in pre-stim data)
baselineEnd = int(fs / 100)
baselineStart = int((data.shape[2] / 2) - baselineEnd - (fs/2))
baseline = preStim[:, :, baselineStart:-baselineEnd]

# Get mean and SD of baseline (over time)
baselineMean = np.mean(baseline, axis=2)
baselineSD = np.std(baseline, axis=2)

# Normalize data
postStimZ = np.zeros(postStim.shape)
for trial in range(postStim.shape[0]):
    for channel in range(postStim.shape[1]):
        postStimZ[trial, channel, :] = (postStim[trial, channel, :] - baselineMean[trial, channel]) / baselineSD[trial, channel]
        
# Get trial-averaged responses
postStimAvg = np.mean(postStimZ, axis=0)

# Calculate SEM
postStimSEM = np.std(postStimZ, axis=0) / np.sqrt(postStimZ.shape[0])

## 3. Feature Extraction

### 3.1 Amplitude-Based Features

In [7]:
# Define windows of interest
N1Window = [10, 50]
N2Window = [50, 400]
N1Start = int(N1Window[0] * fs / 1000) #Convert milliseconds to samples
N1End = int(N1Window[1] * fs / 1000)
N2Start = int(N2Window[0] * fs / 1000)
N2End = int(N2Window[1] * fs / 1000)
ccepStart = int(10 * fs / 1000)
midAmpEnd = int(100 * fs / 1000)

In [8]:
def plotCCEP(chanIdx, showFeatures = True, showPlot = False, export = False):
    '''
    This function plots the CCEP waveform for a given channel index. It also calculates the following features:
    - N1: The peak amplitude between 10 and 50ms
    - N2: The peak amplitude between 50 and 400ms
    - Overall Peak: The overall peak amplitude in the full window
    - Mid Amp: The peak amplitude between 10 and 100ms
    - AUC: The area under the curve of the full window
    
    Inputs:
    - chanIdx: The index of the channel to plot
    - showFeatures: Whether or not to display the features on the plot (default = True)
    - export: Whether or not to export the plot (default = False)
    
    Outputs:
    - Plot of the CCEP waveform with features (if showFeatures = True)
    '''

    # Get features
    N1Idx = np.argmax(np.abs(postStimAvg[chanIdx, N1Start:N1End])) + N1Start # Find the index of max value between 10 and 50ms
    N1Val = np.abs(postStimAvg[chanIdx, N1Idx])
    N1Lat = np.round((N1Idx * 1000) / fs, 2) #Convert samples to milliseconds

    N2Idx = np.argmax(np.abs(postStimAvg[chanIdx, N2Start:N2End])) + N2Start # Find the index of max value between 50 and 400ms
    N2Val = np.abs(postStimAvg[chanIdx, N2Idx])
    N2Lat = np.round((N2Idx * 1000) / fs, 2)

    overallPeakIdx = np.argmax(np.abs(postStimAvg[chanIdx, ccepStart:])) + ccepStart # Find the overall peak in full window
    overallPeakVal = np.abs(postStimAvg[chanIdx, overallPeakIdx])
    overallPeakLat = np.round((overallPeakIdx * 1000) / fs, 2)

    midAmpIdx = np.argmax(np.abs(postStimAvg[chanIdx, ccepStart:midAmpEnd])) + ccepStart # Find the peak between 10 and 100ms
    midAmpVal = np.abs(postStimAvg[chanIdx, midAmpIdx])
    midAmpLat = np.round((midAmpIdx * 1000) / fs, 2)

    ccepAUCVal = np.trapz(np.abs(postStimAvg[chanIdx, ccepStart:])) # Calculate AUC

    # Create time vector
    time = np.arange(0, postStimAvg.shape[1])

    # Plotting
    fig, ax = plt.subplots(figsize=(8, 4))
    plt.plot(time, postStimAvg[chanIdx, :], color = 'blue', zorder = 10, lw = 2)
    plt.fill_between(time, postStimAvg[chanIdx, :] + postStimSEM[chanIdx, :], postStimAvg[chanIdx, :] - postStimSEM[chanIdx, :], color = 'blue', alpha = 0.125)

    if showFeatures:
        plt.axvline(N1Idx, color='r', lw = 2, linestyle = ':', label = 'N1: ' + str(round(N1Val, 2)) + ' (' + str(N1Lat) + 'ms)', zorder = 20)
        plt.axvline(N2Idx, color='purple', lw = 2, linestyle = ':', label = 'N2: ' + str(round(N2Val, 2)) + ' (' + str(N2Lat) + 'ms)', zorder = 20)
        plt.axvline(overallPeakIdx, color='g', lw = 2, linestyle = ':', label = 'Overall: ' + str(round(overallPeakVal, 2)) + ' (' + str(overallPeakLat) + 'ms)', zorder = 20)
        plt.axvline(midAmpIdx, color='orange', lw = 2, linestyle = ':', label = 'Mid Amp: ' + str(round(midAmpVal, 2)) + ' (' + str(midAmpLat) + 'ms)', zorder = 20)
        plt.axvline(-1, color='k', lw = 2, linestyle = ':', label = 'AUC: ' + str(round(ccepAUCVal, 2)))

    # Figure aesthetics
    if showFeatures:
        plt.legend(title = 'Features', title_fontsize = 'small', fontsize = 'x-small', bbox_to_anchor=(1.3, 1))
    else:
        featureStr = 'Features:\n' + 'N1: ' + str(round(N1Val, 2)) + ' (' + str(N1Lat) + 'ms)' + '\nN2: ' + str(round(N2Val, 2)) + ' (' + str(N2Lat) + 'ms)' + '\nOverall: ' + str(round(overallPeakVal, 2)) + ' (' + str(overallPeakLat) + 'ms)' + '\nMid Amp: ' + str(round(midAmpVal, 2)) + ' (' + str(midAmpLat) + 'ms)' + '\nAUC: ' + str(round(ccepAUCVal, 2))
        plt.text(1.05, .75, featureStr, fontsize = 'x-small', bbox = {'boxstyle': 'round', 'ec': (.5, 0.5, 0.5), 'fc': (1., 1., 1.)}, transform=ax.transAxes)

    plt.axvspan(0, 10, color='r', alpha=0.1)
    plt.text(1.05, 1.025, pID + '\n' + stimPair, fontsize= 'x-small', verticalalignment='center', transform=ax.transAxes)
    plt.xlim(0, int(900 * fs / 1000)) #Label the x-axis based on samples/sampling rate
    plt.xticks(ax.get_xticks(),labels=['0','100','200','300','400','500','600','700','800','900']) #Manually re-label the x-axis tick marks
    plt.title(chans[chanIdx])
    sns.despine(top=True, right=True)
    plt.xlabel('Time (ms)')
    plt.ylabel('Z-Scored (Amplitude)')

    # Export
    if export:
        magDir = os.path.join(savepath, 'Magnitude')
        if not os.path.exists(magDir):
            os.mkdir(magDir)
        plt.savefig(os.path.join(magDir, chans[chanIdx] + '.png'), dpi=1200, bbox_inches='tight')
    if showPlot:
        plt.show()
    else:
        plt.close()


In [None]:
def createFeatureDF(export = False):
    '''
    This function creates a dataframe of features for all channels in the CCEP data.
    
    Inputs:
    - export: Whether or not to export the dataframe to a CSV file (default = False)
    
    Outputs:
    - featureDF: A dataframe of features for all channels in the CCEP data
    '''

    holder = []
    for chanIdx in range(len(chans)):

        # Get features
        N1Idx = np.argmax(np.abs(postStimAvg[chanIdx, N1Start:N1End])) + N1Start # Find the index of max value between 10 and 50ms
        N1Val = np.abs(postStimAvg[chanIdx, N1Idx])
        N1Lat = np.round((N1Idx * 1000) / fs, 2) #Convert samples to milliseconds

        N2Idx = np.argmax(np.abs(postStimAvg[chanIdx, N2Start:N2End])) + N2Start # Find the index of max value between 50 and 400ms
        N2Val = np.abs(postStimAvg[chanIdx, N2Idx])
        N2Lat = np.round((N2Idx * 1000) / fs, 2)

        overallPeakIdx = np.argmax(np.abs(postStimAvg[chanIdx, ccepStart:])) + ccepStart # Find the overall peak in full window
        overallPeakVal = np.abs(postStimAvg[chanIdx, overallPeakIdx])
        overallPeakLat = np.round((overallPeakIdx * 1000) / fs, 2)

        midAmpIdx = np.argmax(np.abs(postStimAvg[chanIdx, ccepStart:midAmpEnd])) + ccepStart # Find the peak between 10 and 100ms
        midAmpVal = np.abs(postStimAvg[chanIdx, midAmpIdx])
        midAmpLat = np.round((midAmpIdx * 1000) / fs, 2)

        ccepAUCVal = np.trapz(np.abs(postStimAvg[chanIdx, ccepStart:])) # Calculate AUC

        featureDF = pd.DataFrame({'N1': [N1Val], 'N1_Lat': [N1Lat], 'N2': [N2Val], 'N2_Lat': N2Lat, 'Overall Peak': [overallPeakVal], 'Overall Peak_Lat': overallPeakLat, 'Mid Amp': [midAmpVal], 'Mid Amp_Lat': midAmpLat, 'AUC': [ccepAUCVal], 'Chan': [chans[chanIdx]], 'pID': [pID], 'StimPair': [stimPair]})
        
        holder.append(featureDF)
        
    featureDF = pd.concat(holder, axis=0)
    featureDF.reset_index(drop=True, inplace=True)
    
    if export:
        magDir = os.path.join(savepath, 'Magnitude')
        if not os.path.exists(magDir):
            os.mkdir(magDir)
        featureDF.to_csv(os.path.join(magDir, 'FeatureDF.csv'), index=False)
        
    return featureDF

#### 3.1.1 Plot CCEPs & Features

In [9]:
for i in range(len(chans)):
    plotCCEP(i, showFeatures = True, showPlot = False, export = True)
    
# for single-channel plotting:
# chanIdx = 0
# plotCCEP(chanIdx, showFeatures = True, showPlot = True, export = False)

#### 3.1.2 Create `FeatureDF` File

In [None]:
createFeatureDF(export = True)

### 3.2 Spectral Features
- HFA (70 - 150Hz)