In [1]:
import pyxdf
import matplotlib.pyplot as plt
import numpy as np
import math

from scipy import signal
from scipy import linalg

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score

dr = '../Recordings/Speller/'
# file_1 = 'run_03.xdf'
# file_2 = 'run_04.xdf'
file_3 = 'run_05.xdf'

# file 'run_03' contains data for spelling 'QUICK'
# file 'run_04' contains data for spelling 'P3EEG'

In [2]:
def speller_file_to_dict(data):
    run = {
        'data': [],
        'time_stamps': [],
        'aux': [],
        'aux_time_stamps': [],
        'markers': [],
        'marker_time_stamps': [],
        'sample_rate': {},
    }
    for stream in data:
        if stream['info']['name'][0] == 'eeg_data':
            run['data'] = stream['time_series']
            run['time_stamps'] = stream['time_stamps']
            run['sample_rate']['eeg'] = float(stream['info']['nominal_srate'][0])
        elif stream['info']['name'][0] == 'aux_data':
            run['aux'] = stream['time_series']
            run['aux_time_stamps'] = stream['time_stamps']
            run['sample_rate']['aux'] = float(stream['info']['nominal_srate'][0])
        elif stream['info']['name'][0] == 'P300_Speller_Markers':
            run['markers'] = stream['time_series']
            run['marker_time_stamps'] = stream['time_stamps']
        else:
            print('Warning unmatched stream name')
            
    return run

data = {
    'run': [],
}

for file_nm in [file_3]:
    imported_data, _ = pyxdf.load_xdf(dr+file_nm)
    run = speller_file_to_dict(imported_data)
    data['run'].append(run)
    
# print(data)

In [3]:
data

{'run': [{'data': array([[32900.27027133, 26717.59894104, 34745.80910752, ...,
            9154.29105214, 42706.44995051, 37145.6712062 ],
          [32906.50640804, 26729.19949641, 34762.66232284, ...,
            9163.63408132, 42751.19814291, 37147.63815971],
          [32907.4451813 , 26704.70198449, 34735.34849111, ...,
            9144.45628458, 42642.61336835, 37142.47490674],
          ...,
          [25536.55511577, 23783.01605976, 37537.20671382, ...,
            5046.77802882, 40079.60588689, 27427.11036528],
          [25534.38699655, 23803.75847861, 37562.19596412, ...,
            5068.21335175, 40191.81164405, 27421.63418789],
          [25542.97006642, 23820.32112125, 37588.21339467, ...,
            5088.73225316, 40280.21279338, 27417.61087389]]),
   'time_stamps': array([500790.56983438, 500790.57383438, 500790.57783439, ...,
          501442.6979396 , 501442.7019396 , 501442.7059396 ]),
   'aux': array([[ 416.,   38., 1016.],
          [ 416.,   37., 1019.],
       

In [4]:
def plot_time_series(X, Y, labels, xlabel=None, ylabel=None, title=None, fig=None, ax=None):
    if fig is None and ax is None:
        fig, ax = plt.subplots(figsize = (10, 3), dpi = 90)

    for i in range(len(Y.T)):
        y = Y[:,i]
        ax.plot(X, y, label=labels[i])
        
    xlabel = xlabel if xlabel is not None else 'Time'
    ylabel = ylabel if ylabel is not None else 'Amplitude'
    title = title if title is not None else 'Time Series'
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.grid(True)
    _ = ax.legend()
    
    return fig, ax

def plot_epochs(ch_epochs, xlabel='Times (ms)', ylabel='Epochs', title='Epoch Comparison', tmin=None, tmax=None):
    fig, ax = plt.subplots(figsize = (10, 3), dpi = 90)

    if tmin == None or tmax == None:
        tmin = 0
        tmax = ch_epochs.shape[-1]
    extent = [tmin, tmax, 0, len(ch_epochs)]
    
    plot = ax.matshow(ch_epochs, interpolation='nearest', origin='lower', aspect='auto', extent=extent)

    ax.xaxis.set_ticks_position("bottom")
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.axis('auto')
    ax.axis('tight')
    ax.axvline(0, color='k', linewidth=1, linestyle='--')

    cbar = fig.colorbar(plot)
    cbar.ax.set_ylabel(r'Amplitude ($\mu$V)', rotation=270, labelpad=12)
    
    return fig, ax

def add_markers_to_plot(markers, time_stamps, fig, ax):
    for timestamp, marker in zip(time_stamps, markers):
#             if (marker[0] == '1'):
            if '[' in marker[0]:
                ax.axvline(x=timestamp, color='C1')

In [17]:
# =====================
# Re Reference all data
# =====================
def re_reference(session_data, new_ref_ch_idx):
    # returns data with the new reference channel removed... because its zeros
    new_ref = session_data[:,new_ref_ch_idx]
    re_referenced = session_data - new_ref[:,None]
    re_referenced = np.delete(re_referenced, new_ref_ch_idx, axis=1)
    
    return re_referenced

# =====================
# Epochs and windowing
# =====================
def ms_to_samples(duration, rate=250):
    return int((duration/1000.0) * rate)


def get_epoch(data, data_times, stim_time, pre_stim_ms, post_stim_ms, sample_rate):
    # great for grabbing around a stimulus event time.
    
    # TODO: is there a faster way?
    i = (np.abs(data_times - stim_time)).argmin()
    start = i - ms_to_samples(pre_stim_ms, sample_rate)
    stop = i + ms_to_samples(post_stim_ms, sample_rate) + 1
    epoch = data[start:stop]
    
    return epoch


# ==========
# Epoch Prep
# ==========
def dc_offset(epoch_data, sample_rate):
    # epoch_data should be 2D, (samples, channels)
    offset = np.mean(epoch_data, 0)

    return epoch_data - offset


def filter_eeg(epoch_data, sample_rate, f_range):
    # epoch_data should be 2D, (samples, channels)
    nyq = 0.5 * sample_rate
    low_limit = f_range[0] / nyq
    high_limit = f_range[1] / nyq
    
    sos = signal.butter(2, [low_limit, high_limit], btype='bandpass', output='sos')
    sig_filt = signal.sosfiltfilt(sos, epoch_data, axis=0)    
    
    return np.array(sig_filt)

def baseline_center(epoch_data, baseline_duration_ms, sample_rate):
    # epoch_data should be 2D, (samples, channels)
    baseline_samples = ms_to_samples(baseline_duration_ms, sample_rate)
    pre_stimulus = epoch_data[:baseline_samples]
    baseline = np.mean(pre_stimulus, 0)
    
    return epoch_data - baseline


def prepare_epochs(epochs, sample_rate, baseline_duration=100):
    # epochs should be 3D, (event, samples, channels)
    # baseline duration, in ms
    
    prepped = []
    for epoch in epochs:
        temp_data = dc_offset(epoch, sample_rate)
#         temp_data = baseline_center(temp_data, baseline_duration, sample_rate)
        prepped.append(temp_data)
        
    return np.array(prepped)

# ================
# Speller Specific
# ================
ltr_key = ['ABCDEF',
           'GHIJKL',
           'MNOPQR',
           'STUVWX',
           'YZ1234',
           '56789_']

answer_key = {
    'Q': [3, 11],
    'U': [4, 9],
    'I': [2, 9],
    'C': [1, 9],
    'K': [2, 11],
    'P': [3, 10],
    '3': [5, 11],
    'E': [1, 11],
    'G': [2, 7],    
}

ltr_key = np.array(ltr_key)
t = answer_key['Q']
ltr_key[t[0]-1][t[1]-1-len(ltr_key)]

'Q'

In [7]:
# Use mastoid as ref, then avg-ref the second mastoid
# Place the ground at AFz
channels = ['ref', 'FCz', 'Fz', 'Cz', 'Pz', 'Oz', 'CPz', 'Fp1']
aux_chs = ['aux ch_1', 'aux ch_2', 'aux ch_3']

# assigning data... a bit redundant but the hope is to help make the pipelining easier
run_1 = data['run'][0]
# run_2 = data['run'][1]

r1_eeg = run_1['data']
r1_eeg_ts = run_1['time_stamps']
r1_aux = run_1['aux']
r1_aux_ts = run_1['aux_time_stamps']
r1_markers = run_1['markers']
r1_markers_ts = run_1['marker_time_stamps']

# r2_eeg = run_2['data']
# r2_eeg_ts = run_2['time_stamps']
# r2_aux = run_2['aux']
# r2_aux_ts = run_2['aux_time_stamps']
# r2_markers = run_2['markers']
# r2_markers_ts = run_2['marker_time_stamps']

# new_ref_ch_idx = 0
# r1_eeg = re_reference(r1_eeg, new_ref_ch_idx)
# r2_eeg = re_reference(r2_eeg, new_ref_ch_idx)
# del channels[new_ref_ch_idx]
# r1_eeg = r1_eeg[:,new_ref_ch_idx+1:]
# r2_eeg = r2_eeg[:,new_ref_ch_idx+1:]
# print(r1_eeg.shape)

r1_eeg = filter_eeg(r1_eeg, 250, [1.,20.] )
# r2_eeg = filter_eeg(r2_eeg, 250, [1.,20.] )

# remove Fp1 channel
# r1_eeg = np.delete(r1_eeg, -1, axis=1)
# r2_eeg = np.delete(r2_eeg, -1, axis=1)
# del channels[-1]

# convert times to int milliseconds
r1_eeg_ts = (r1_eeg_ts*1000).astype(int)
# r2_eeg_ts = (r2_eeg_ts*1000).astype(int)
r1_aux_ts = (r1_aux_ts*1000).astype(int)
# r2_aux_ts = (r2_aux_ts*1000).astype(int)
r1_markers_ts = (r1_markers_ts*1000).astype(int)
# r2_markers_ts = (r2_markers_ts*1000).astype(int)

print(channels)
pz = channels.index('Pz')
fp1 = channels.index('Fp1')
oz = channels.index('Oz')

['ref', 'FCz', 'Fz', 'Cz', 'Pz', 'Oz', 'CPz', 'Fp1']


In [8]:
# HANDLE PHOTOSENSORS
# -------------------
# scaling for dim display
# clean and prep aux channel 0010010
# get associated markers and time stamps
# get trials of markers and time stamps

def get_ts_for_on_rise_threshold(data, data_ts, threshold):
    thr = (data > threshold) + 0
    rise = thr[1:] > thr[:-1]
    rise = np.insert(rise, 0, False) # accounts for lost position above
    return data_ts[rise]

def divide_markers(markers):
    # rc = row and column
    marker_collection = []
    target_collection = []
    y_collection = []
    for marker_arr in markers:
        marker = marker_arr[0]
        if '[' in marker:
            lbl = marker.split('[')[1].split(']')[0]
            target_collection.append(lbl)
            marker_collection.append([])
            y_collection.append([])
        elif marker.isnumeric():
            marker = int(marker)
            marker_collection[-1].append(marker)
            key = answer_key[target_collection[-1]]
            if marker in key:
                y_collection[-1].append(1)
            else:
                y_collection[-1].append(0)
                
    
    return target_collection, np.array(marker_collection), np.array(y_collection)

def sync_sensor_to_markers(sensor_ts, stim_markers, targets):
    len_a = len(sensor_ts.flatten())
    len_b = len(stim_markers.flatten())
    len_c = len(targets.flatten())
    assert len_a == len_b == len_c, 'Invalid Operation' # dimensions must match
    
    synced = np.vstack((sensor_ts.flatten(),
                        stim_markers.flatten(),
                        targets.flatten()))
    return synced.T

# scaling
# max_run1 = max(r1_aux[:,1])
# max_run2 = max(r2_aux[:,1])
# scaler = max_run2/max_run1
# r1_aux[:,1] = r1_aux[:,1] * scaler


PHOTO_SENSOR_THRESHOLD = 70 # this is based on a visual plot.
r1_photosensor_onsets = get_ts_for_on_rise_threshold(r1_aux[:,1], r1_aux_ts, PHOTO_SENSOR_THRESHOLD)
# r2_photosensor_onsets = get_ts_for_on_rise_threshold(r2_aux[:,1], r2_aux_ts, PHOTO_SENSOR_THRESHOLD)

# TODO: automate this!
# There are 2 test sensor blinks that we can get rid of and there is a
# graphical blink at the very beginning of run 2 that needs to be removed.
r1_photosensor_onsets = r1_photosensor_onsets[2:]
# r2_photosensor_onsets = r2_photosensor_onsets[3:]

r1_target_ltr, r1_stim_markers, r1_y = divide_markers(r1_markers)
# r2_target_ltr, r2_stim_markers, r2_y = divide_markers(r2_markers)

r1_events = sync_sensor_to_markers(r1_photosensor_onsets, r1_stim_markers, r1_y)
# r2_events = sync_sensor_to_markers(r2_photosensor_onsets, r2_stim_markers, r2_y)

print(r1_events[:5])
print(r1_events.shape)

KeyError: 'W'

In [None]:
# TRUNCATE and NORMALIZE
def get_session_mask(eeg_ts, markers_ts, events):
    first_photosensor_time = events[0,0]
    first_marker_time = markers_ts[1] # skip the first...its too early and not useful
    last_photosensor_time = events[-1,0]
    last_marker_time = markers_ts[-1]
    session_win_start = min(first_photosensor_time, first_marker_time) - 1000
    session_win_stop = max(last_photosensor_time, last_marker_time) + 1000
    print(f'session: {session_win_start} to {session_win_stop}')
    mask = (eeg_ts >= session_win_start) & (eeg_ts <= session_win_stop)
    return mask

def get_ts_mask(eeg_ts, start, stop, pre=0, post=0):
    win_start = start - pre
    win_stop = stop + post
    print(f'session: {win_start} to {win_stop}')
    mask = (eeg_ts >= win_start) & (eeg_ts <= win_stop)
    return mask
    
r1_mask = get_session_mask(r1_eeg_ts, r1_markers_ts, r1_events)
r2_mask = get_session_mask(r2_eeg_ts, r2_markers_ts, r2_events)

r1_eeg = r1_eeg[r1_mask]
r2_eeg = r2_eeg[r2_mask]
r1_eeg_ts = r1_eeg_ts[r1_mask]
r2_eeg_ts = r2_eeg_ts[r2_mask]

# normalize
def zero_and_normalize(eeg_data):
    eeg_data = eeg_data - np.mean(eeg_data, 0)
    eeg_data = eeg_data / np.std(eeg_data, 0)
    return eeg_data

r1_eeg = zero_and_normalize(r1_eeg)
r2_eeg = zero_and_normalize(r2_eeg)

In [None]:
# Epoch durations in ms
PRE_STIM = 100
POST_STIM = 800

SAMPLE_RATE = 250.0 # Hz

# Make Epochs
r1_epochs = []
for time in r1_photosensor_onsets:
    # get epoch based on time stamp
    epoch = get_epoch(r1_eeg, r1_eeg_ts, time, PRE_STIM, POST_STIM, SAMPLE_RATE)
    r1_epochs.append(epoch)
    
r2_epochs = []
for time in r2_photosensor_onsets:
    # get epoch based on time stamp
    epoch = get_epoch(r2_eeg, r2_eeg_ts, time, PRE_STIM, POST_STIM, SAMPLE_RATE)
    r2_epochs.append(epoch)
    

# Prep Epochs
r1_epochs = prepare_epochs(r1_epochs, SAMPLE_RATE)
r2_epochs = prepare_epochs(r2_epochs, SAMPLE_RATE)
r1_epochs_ts = r1_photosensor_onsets
r2_epochs_ts = r2_photosensor_onsets

In [None]:
epoch_runs_ts = r1_epochs_ts.reshape(r1_y.shape)

r1_runs = []
r1_runs_ts = []
for run in r1_epochs_ts.reshape(r1_y.shape):
    start_time = run[0]
    stop_time = run[-1] + POST_STIM
    mask = get_ts_mask(r1_eeg_ts, start_time, stop_time, pre=PRE_STIM, post=POST_STIM)
    run = r1_eeg[mask]
    run_ts = r1_eeg_ts[mask]
    r1_runs.append(run)
    r1_runs_ts.append(run_ts)
    
r2_runs = []
r2_runs_ts = []
for run in r2_epochs_ts.reshape(r2_y.shape):
    start_time = run[0]
    stop_time = run[-1] + POST_STIM
    mask = get_ts_mask(r2_eeg_ts, start_time, stop_time, pre=PRE_STIM, post=POST_STIM)
    run = r2_eeg[mask]
    run_ts = r2_eeg_ts[mask]
    r2_runs.append(run)
    r2_runs_ts.append(run_ts)    

In [None]:
y = r1_events[:,2]
target_epochs = r1_epochs[y == 1]
non_target_epochs = r1_epochs[y == 0]

x_for_plot = [(i*(1000/SAMPLE_RATE))-PRE_STIM for i in range(r2_epochs.shape[1])]

target_epoch_avg = np.mean(target_epochs, axis=0)
non_target_epoch_avg = np.mean(non_target_epochs, axis=0)

g_labels = [ch + ' Target' for ch in channels]
fig, ax = plot_time_series(
    x_for_plot,
    target_epoch_avg,
    labels = g_labels,
    xlabel='Time (ms)',
    ylabel=r'Amplitude ($\mu$V)',
    title=f'Target Average'
)
_ = ax.axvspan(300, 500, color='orange', alpha=0.1)

g_labels = [ch + ' Non-Target' for ch in channels]
fig, ax = plot_time_series(
    x_for_plot,
    non_target_epoch_avg,
    labels = g_labels,
    xlabel='Time (ms)',
    ylabel=r'Amplitude ($\mu$V)',
    title=f'Non-Target Average'
)
_ = ax.axvspan(300, 500, color='orange', alpha=0.1)

g_labels = channels
fig, ax = plot_time_series(
    r1_eeg_ts,
    r1_eeg,
    labels = g_labels,
    xlabel='Time (ms)',
    ylabel=r'Amplitude ($\mu$V)',
    title=f'Full Recording'
)
# _ = ax.axvspan(300, 500, color='orange', alpha=0.1)

In [None]:
g_labels = ['Fp1']
fig, ax = plot_time_series(
    r1_eeg_ts,
    r1_eeg[:,fp1, None], # None == np.newaxis, adds axis
    labels = g_labels,
    xlabel='Time (ms)',
    ylabel=r'Amplitude ($\mu$V)',
    title=f'Session 1: Fp1 Full Recording'
)

g_labels = ['Fp1']
fig, ax = plot_time_series(
    r2_eeg_ts,
    r2_eeg[:,fp1, None], # None == np.newaxis, adds axis
    labels = g_labels,
    xlabel='Time (ms)',
    ylabel=r'Amplitude ($\mu$V)',
    title=f'Session 2: Fp1 Full Recording'
)

In [None]:
def get_blinks(data, data_ts, ch, blink_dev_thresh=2.0):
    ch_data = data[:,ch]
    ch_mean = ch_data.mean()
    ch_std = ch_data.std()
    ch_dev = np.abs((ch_data - ch_mean) / ch_std)
    blink_mask = ch_dev > blink_dev_thresh
    blink_int = blink_mask.astype(int)
    blinks = []
    for i, val in enumerate(blink_int):
        if i == 0:
            continue

        prev = blink_int[i-1]
        if val != prev:
            if val == 1:
                blinks.append([data_ts[i]])
            else:
                blinks[-1].append(data_ts[i])
                
    return np.array(blinks), blink_mask

r1_blinks, _ = get_blinks(r1_eeg, r1_eeg_ts, fp1, 1.5)
r2_blinks, _ = get_blinks(r2_eeg, r2_eeg_ts, fp1, 1.5)

In [None]:
g_labels = ['Fp1']
fig, ax = plot_time_series(
    r1_eeg_ts,
    r1_eeg[:,fp1, None], # None == np.newaxis, adds axis
    labels = g_labels,
    xlabel='Time (ms)',
    ylabel=r'Amplitude ($\mu$V)',
    title=f'Blink Artifacts: session 1'
)
for blink in r1_blinks:
    _ = ax.axvspan(blink[0], blink[1], color='orange', alpha=0.5)
    
g_labels = ['Fp1']
fig, ax = plot_time_series(
    r2_eeg_ts,
    r2_eeg[:,fp1, None], # None == np.newaxis, adds axis
    labels = g_labels,
    xlabel='Time (ms)',
    ylabel=r'Amplitude ($\mu$V)',
    title=f'Blink Artifacts: session 2'
)
for blink in r2_blinks:
    _ = ax.axvspan(blink[0], blink[1], color='orange', alpha=0.5)

In [None]:
def get_bad_epochs_indices(epochs_ts, blinks):    
    count = 0
    bad_epoch_i = []
    for i, time in enumerate(epochs_ts):
        e_start = time - PRE_STIM
        e_stop = time + POST_STIM
        a = (blinks[:,0] >= e_start) & (blinks[:,0] < e_stop)
        overlap = np.count_nonzero(a)
        if overlap == 0:
            b = (blinks[:,1] > e_start) & (blinks[:,1] <= e_stop)
            overlap = np.count_nonzero(b)

        if overlap > 0:
            bad_epoch_i.append(i)
            count += 1
            
    print(f'{count} contaminated epochs!')
    return np.array(bad_epoch_i)

r1_bad_epochs = get_bad_epochs_indices(r1_epochs_ts, r1_blinks)
print(r1_bad_epochs)
print(r1_events[r1_bad_epochs, 2])

r2_bad_epochs = get_bad_epochs_indices(r2_epochs_ts, r2_blinks)
print(r2_bad_epochs)
print(r2_events[r2_bad_epochs, 2])

In [None]:
b_nontarget = r1_bad_epochs[15]
b_target = r1_bad_epochs[16]
print(b_nontarget, b_target)
g_target = 19

x_for_plot = [(i*(1000/SAMPLE_RATE))-PRE_STIM for i in range(r2_epochs.shape[1])]

g_labels = channels
fig, ax = plot_time_series(
    x_for_plot,
    r1_epochs[b_nontarget],
    labels = g_labels,
    xlabel='Time (ms)',
    ylabel=r'Amplitude ($\mu$V)',
    title=f'Non Target Artifact'
)
_ = ax.axvspan(300, 500, color='orange', alpha=0.1)

g_labels = channels
fig, ax = plot_time_series(
    x_for_plot,
    r1_epochs[b_target],
    labels = g_labels,
    xlabel='Time (ms)',
    ylabel=r'Amplitude ($\mu$V)',
    title=f'Target Artifact'
)
_ = ax.axvspan(300, 500, color='orange', alpha=0.1)

g_labels = channels
fig, ax = plot_time_series(
    x_for_plot,
    r1_epochs[g_target],
    labels = g_labels,
    xlabel='Time (ms)',
    ylabel=r'Amplitude ($\mu$V)',
    title=f'Target NO Artifact'
)
_ = ax.axvspan(300, 500, color='orange', alpha=0.1)

In [None]:
def window_and_combine_features(epochs, sub_samples_per_ch, time_window):
    # epochs is expected to be 3D, (epoch, sample, channel)
    
    t_start = ms_to_samples(time_window[0], 250)
    t_len = ms_to_samples(time_window[1] - time_window[0], 250)
    
    window_len = np.round(t_len/sub_samples_per_ch).astype(int)
    windowed = np.zeros(( len(epochs), epochs.shape[-1], sub_samples_per_ch))
    
    for i in range(sub_samples_per_ch):
        win_start = i * window_len + t_start
        win_stop = win_start + window_len
        win_avg = np.mean(epochs[:,win_start:win_stop,:], axis=1) # (epochs x ch)
        windowed[:,:,i] = win_avg

    windowed = windowed.reshape((len(epochs), epochs.shape[-1] * sub_samples_per_ch))

    return windowed

def balance_zeros(data, y):
    # truncates the zeros down to the same size as ones.
    # assumes more zeros than ones
    zeros = np.sum(y == 0)
    ones = np.sum(y == 1)
    step = int(zeros/ones)
    print(f'{100*zeros/y.shape[0]:.2f}% zeros')
    sort_order = np.argsort(y, kind='stable')
    half_zeros = [i for i in range(0,zeros,step)]
    sort_order = np.concatenate((sort_order[half_zeros], sort_order[-ones:]), axis=0)
    orig_order = np.sort(sort_order)
    y_part = y[orig_order]
    data_part = data[orig_order]
    
    return data_part, y_part


r1_ep_mask = np.ones(len(r1_epochs), np.bool)
r1_ep_mask[r1_bad_epochs] = 0

r1_epochs_clean = r1_epochs[r1_ep_mask]
r1_y_clean = r1_y.flatten()[r1_ep_mask]
r1_events_clean = r1_events[r1_ep_mask]

r2_ep_mask = np.ones(len(r2_epochs), np.bool)
r2_ep_mask[r2_bad_epochs] = 0

r2_epochs_clean = r2_epochs[r2_ep_mask]
r2_y_clean = r2_y.flatten()[r2_ep_mask]
r2_events_clean = r2_events[r2_ep_mask]

In [None]:
# session 1
X, y = balance_zeros(r1_epochs_clean[:,:,:-1], r1_y_clean) # remove fp1
print(X.shape)
X_windowed = window_and_combine_features(X, 7, (200, 600))
print(X_windowed.shape)

X_train, X_test, y_train, y_test = train_test_split(X_windowed, y, test_size=0.20, random_state=35)
LDA_clf = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto').fit(X_train, y_train)
print('Simple 20% test split accuracy:', LDA_clf.score(X_test, y_test))

In [None]:
# session 2
X, y = balance_zeros(r2_epochs_clean[:,:,:-1], r2_y_clean) # remove fp1
print(X.shape)
X_windowed = window_and_combine_features(X, 7, (200, 600))
print(X_windowed.shape)

X_train, X_test, y_train, y_test = train_test_split(X_windowed, y, test_size=0.20, random_state=40)
LDA_clf = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto').fit(X_train, y_train)
print('Simple 20% test split accuracy:', LDA_clf.score(X_test, y_test))