# FBCCA OFFLINE ANALYSIS 

FBCCA applied to our data, this work is adapted from the work in the following repo: https://github.com/eugeneALU/CECNL_RealTimeBCI/

In [4]:
import sys
import os
import zipfile
import numpy as np
import scipy
import pandas as pd
import numpy as np
import warnings
import itertools

from sklearn.cross_decomposition import CCA

from scipy.stats import pearsonr
from scipy.signal import butter, filtfilt, iirnotch

## Helper Functions

In [330]:
def ingest_eeg(csvname, flicker_freq):

    # Establishing what flicker frequencies are present
    flicker_freq_dict = dict()    

    path = os.path.split(os.getcwd())[0] + '/data/' + csvname + '.csv'

    df = pd.read_csv(path)

    #count value for zero cases
    count = 0

    # Adding row keys of relevant frequencies from dataframe 
    for i, freq_point in enumerate(df['Frequency']):
        if not np.isnan(freq_point) and freq_point != 0: 

            """
            #zero case
            if freq_point == 0 and count < (1 + num_stims):
                count +=1

            elif freq_point == 0 and count == (1 + num_stims):

                if freq_point not in flicker_freq_dict.keys():
                    flicker_freq_dict.update({freq_point: [i]})
                else:
                    flicker_freq_dict[freq_point].append(i)

                count = 1
            """
            #normal case  
            if freq_point not in flicker_freq_dict.keys():
                flicker_freq_dict.update({freq_point: [i]})
            else:
                flicker_freq_dict[freq_point].append(i)

    flicker_freq = np.array(list(flicker_freq_dict.keys()))
    flicker_freq.sort()

    # Formatting the eeg data -> making the appropriate matrix
    # Initializing the dimensions of the eeg matrix

    num_classes = len(flicker_freq) 
    n_ch = 8 
    total_trial_len = 1114 

    #scales to number of trials in csv for each freq
    num_trials = min(len(flicker_freq_dict[key]) for key in flicker_freq)

    #instantiates eeg data in 4 dimensional np array
    eeg = np.zeros((num_classes,n_ch,total_trial_len,num_trials))

    # Assigning the correct values to the matrix/object

    start_idx_list = []


    #grabs start and endpoints for each frequency flash
    for i, freq in enumerate(flicker_freq):
        for j in range(num_trials):
            start_idx = flicker_freq_dict[freq][j]
            start_idx_list.append(start_idx)
            end_idx = start_idx + total_trial_len

            #shaves off timestamps and markers and does a transpose, we transpose it back and cast as a np array
            eeg[i, :, :, j] = np.array(df.iloc[start_idx:end_idx, 1:9]).transpose((1,0))

    return eeg, flicker_freq

In [332]:
def iir_notch_filter(data, f0, Q, fs):
    '''
    Returns notch filtered data for frequencies specified in the input.
    Args:
        data (numpy.ndarray): array of samples. 
        fi (float): frequency to eliminate (Hz).
        fs (float): sampling rate (Hz).
        Q (int): quality factor.
    Returns:
        (numpy.ndarray): data with powerline interference removed
    '''
    b, a = iirnotch(f0, Q, fs)
    y = filtfilt(b, a, data)
    #still need to filter harmonics
    return y

In [333]:
def get_filtered_eeg(eeg, quality, sample_rate):
    
    num_classes = eeg.shape[0]
    num_chan = eeg.shape[1]
    total_trial_len = eeg.shape[2]
    num_trials = eeg.shape[3]

    #instantiate object to be sent to notch filter
    filtered_data = np.zeros((eeg.shape[0], eeg.shape[1], total_trial_len, eeg.shape[3]))

    
    for target in range(0, num_classes):
        for channel in range(0, num_chan):
            for trial in range(0, num_trials):
                
                #data to be filtered
                signal_to_filter = np.squeeze( eeg[target, channel, 0:total_trial_len, trial] )
                
                #call to notch filter
                filtered_data[target, channel, :, trial] = iir_notch_filter(signal_to_filter, 
                                                                    60, quality,
                                                                    sample_rate)
                
    return filtered_data

In [334]:
def filterbank(eeg, fs, idx_fb):    
    if idx_fb == None:
        warnings.warn('stats:filterbank:MissingInput '\
                      +'Missing filter index. Default value (idx_fb = 0) will be used.')
        idx_fb = 0
    elif (idx_fb < 0 or 9 < idx_fb):
        raise ValueError('stats:filterbank:InvalidInput '\
                          +'The number of sub-bands must be 0 <= idx_fb <= 9.')
            
    if (len(eeg.shape)==2):
        num_chans = eeg.shape[0]
        num_trials = 1
    else:
        num_chans, _, num_trials = eeg.shape
    
    # Nyquist Frequency = Fs/2N
    Nq = fs/2
    
    passband = [6, 14, 22, 30, 38, 46, 54, 62, 70, 78]
    stopband = [4, 10, 16, 24, 32, 40, 48, 56, 64, 72]
    Wp = [passband[idx_fb]/Nq, 90/Nq]
    Ws = [stopband[idx_fb]/Nq, 100/Nq]
    [N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3, 40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
    [B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency
    
    y = np.zeros(eeg.shape)
    if (num_trials == 1):
        for ch_i in range(num_chans):
            #apply filter, zero phass filtering by applying a linear filter twice, once forward and once backwards.
            # to match matlab result we need to change padding length
            y[ch_i, :] = scipy.signal.filtfilt(B, A, eeg[ch_i, :], padtype = 'odd', padlen=3*(max(len(B),len(A))-1))
        
    else:
        for trial_i in range(num_trials):
            for ch_i in range(num_chans):
                y[ch_i, :, trial_i] = scipy.signal.filtfilt(B, A, eeg[ch_i, :, trial_i], padtype = 'odd', padlen=3*(max(len(B),len(A))-1))
           
    return y

In [335]:
def cca_reference(list_freqs, fs, num_smpls, num_harms=3):
    
    num_freqs = len(list_freqs)
    tidx = np.arange(1,num_smpls+1)/fs #time index
    
    y_ref = np.zeros((num_freqs, 2*num_harms, num_smpls))
    for freq_i in range(num_freqs):
        tmp = []
        for harm_i in range(1,num_harms+1):
            stim_freq = list_freqs[freq_i]  #in HZ
            # Sin and Cos
            tmp.extend([np.sin(2*np.pi*tidx*harm_i*stim_freq),
                       np.cos(2*np.pi*tidx*harm_i*stim_freq)])
        y_ref[freq_i] = tmp # 2*num_harms because include both sin and cos
    
    return y_ref


In [336]:
def fbcca(eeg, list_freqs, fs, num_harms, num_fbs):
    
    fb_coefs = np.power(np.arange(1,num_fbs+1),(-1.25)) + 0.25
    
    num_targs, num_chan, num_smpls, num_trials = eeg.shape  
    
    y_ref = cca_reference(list_freqs, fs, num_smpls, num_harms)
    
    cca = CCA(n_components=1) #initilize CCA
    
    # result matrix
    r = np.zeros((num_fbs,num_targs))
    results = np.zeros((num_targs, num_trials))
    
    for trial in range(num_trials):

      for targ_i in range(num_targs):
          test_tmp = np.squeeze(eeg[targ_i, :, :, trial])  #deal with one target a time
          
          for fb_i in range(num_fbs):  #filter bank number, deal with different filter bank
              testdata = filterbank(test_tmp, fs, fb_i)  #data after filtering
              
              for class_i in range(num_targs):
                  refdata = np.squeeze(y_ref[class_i, :, :])   #pick corresponding freq target reference signal
                  test_C, ref_C = cca.fit_transform(testdata.T, refdata.T)
                  # len(row) = len(observation), len(column) = variables of each observation
                  # number of rows should be the same, so need transpose here
                  # output is the highest correlation linear combination of two sets
                  r_tmp, _ = pearsonr(np.squeeze(test_C), np.squeeze(ref_C)) #return r and p_value, use np.squeeze to adapt the API 
                  r[fb_i, class_i] = r_tmp
                 
          rho = np.dot(fb_coefs, r)  #weighted sum of r from all different filter banks' result
          tau = np.argmax(rho)  #get maximum from the target as the final predict (get the index)
          results[targ_i, trial] = tau #index indicate the maximum(most possible) target
    
    return results

## Parameters

In [337]:
# Setting up some completely fixed parameters
FFT_PARAMS = {
    'resolution': 0.2930,
    'start_frequency': 0.0,
    'end_frequency': 35.0,
    'sampling_rate': 250
}

flicker_freq = []

#change this depending on the number of stimuli in the data
num_stims = 4

#harmonics analyzed
num_harms=3
#filterbanks produced
num_fbs=5 

#quality of notch
Q = 100

sample_rate = FFT_PARAMS['sampling_rate']

## Data Ingestion

In [349]:
#REPLACE CSV NAME WITH CSV OF INTEREST

#csvnames = ['174_2022_159090', '174_2022_445753', '174_2022_538724'] #Bryan

#csvnames = ['174_2022_040508', '174_2022_123780', '174_2022_729377'] #Chris

#csvnames = ['173_2022_515272'] #Avery


csvnames = ['174_2022_159090', '174_2022_445753', '174_2022_538724']

print(csvnames[0])
eeg1, flicker_freq1 = ingest_eeg(str(csvnames[0]), flicker_freq)
eeg2, flicker_freq2 = ingest_eeg(str(csvnames[1]), flicker_freq)
eeg3, flicker_freq3 = ingest_eeg(str(csvnames[2]), flicker_freq)

eeg = np.concatenate((eeg1, eeg2, eeg3), axis=3)

#combines epoched data across all csvs along trials axis
eeg.shape

174_2022_159090


(4, 8, 1114, 15)

In [355]:
#validate that data is sorted properly before concatenation
list_freq1 = list(flicker_freq1)
list_freq2 = list(flicker_freq1)
list_freq3 = list(flicker_freq1)
print(list_freq1) 
print(list_freq2) 
print(list_freq3) 

[10.25, 11.75, 12.75, 14.75]
[10.25, 11.75, 12.75, 14.75]
[10.25, 11.75, 12.75, 14.75]


## Powerline Removal

In [356]:
#wrapper function for EEG data filtering with 4th order notch
filtered_data = get_filtered_eeg(eeg, Q, sample_rate)
filtered_data.shape #(classes, channels, # of samples, # of trials)

(4, 8, 1114, 15)

## FBCCA Execution and Evaluation

In [357]:
results = fbcca(filtered_data, list_freq1, sample_rate, num_harms, num_fbs)

In [358]:
"""
example of a perfect result for a 5 trial dataset:

[[0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1.]
 [2. 2. 2. 2. 2.]
 [3. 3. 3. 3. 3.]]
 
"""
print(results)
results.shape

[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [3. 3. 3. 3. 0. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]


(4, 15)