In [9]:
import numpy as np                                      
import matplotlib.pyplot as plt                         
import matplotlib.patches as patches
import seaborn as sns
import scipy.signal as signal 
from scipy.io import loadmat
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import cross_val_score
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import StratifiedKFold
from pathlib import Path
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
import pickle

In [10]:
# Define our filter variables
fs = 512                      # Hz; sampling rate
dt = 1000. / fs                 # ms; time between samples
sdt = dt#np.round(dt).astype(int); # rounded dt so that we can index samples
hp = 1                        # Hz; our low cut for our bandpass
lp = 24.                        # Hz; our high cut for our bandpass
num_taps = 31                   # Number of taps/coefficients of FIR filter

# Create our filter coefficients
# Note: by defining 'fs' we don't divide our windows by the Nyquist
# Note: for FIR filters, a is always 1
b = signal.firwin(numtaps=num_taps, cutoff=[hp, lp], pass_zero='bandpass', fs=fs)
a = 1

# Define ERP-related variables
epoch_start = 0    # ms
epoch_end = 800    # ms
baseline_start = 0 # ms
baseline_end = 100 # ms
erp_start = 200    # ms
erp_end = 800      # ms

# Let's translate these from time into index space to save time later
e_s = np.round(epoch_start / sdt).astype(int)     # epoch start
e_e = np.round(epoch_end / sdt).astype(int)       # epoch end
bl_s = np.round(baseline_start / sdt).astype(int) # baseline start
bl_e = np.round(baseline_end / sdt).astype(int)   # baseline end
erp_s = np.round(erp_start / sdt).astype(int)     # ERP component window start
erp_e = np.round(erp_end / sdt).astype(int)       # ERP component window end

In [11]:
import util as myUtil
def load_file_from_bi2015a(filename):
    headerNames = pd.read_csv('./datasets/bi2015a/Header.csv', header=None)
    headerNames = np.array(headerNames.iloc[0]).flatten()
    if not Path(filename).exists():
        raise ValueError("File does not exist   " + filename)
    df = pd.read_csv(filename, header=None)
    df.columns = headerNames
    del headerNames
    #timestamps = df['Time'].values
    sample_rate = 512
    x = df.iloc[:, 1:33].values
    df = df.iloc[:, -2:]
    df['y'] = 0
    df.loc[df['Trigger'] == 1, 'y'] = -1
    df.loc[df['Target'] == 1, 'y'] = 1
    y = df.y.values
    del df
    # x = myUtil.resample_x(x, rate=sample_rate, target_rate=fs)
    # y = myUtil.resample_y(y, rate=sample_rate, target_rate=fs)
    x = x.T
    x = signal.filtfilt(b, a, x, axis=1)
    x = x[:, 2*fs:]
    y = y[2*fs:]
    
    return x, y

In [12]:
# load the data
subject = range(1,44)
session = [1, 2, 3]
X_train = None
y_train = np.array([])
out_dir = "./datasets/2015pre/"
for i in subject:
    for j in session:
        x, y = load_file_from_bi2015a(f'./datasets/bi2015a/subject_{i:02}_session_{j:02}.csv')
        x, y = myUtil.epoch_wrt_event_chanFirst(x, y, e_s, e_e)
        X_train = np.concatenate((X_train, x)) if X_train is not None else x
        y_train = np.concatenate((y_train, y))
    print(f"subject {i}", X_train.shape, y_train.shape)
    
    with open(out_dir + f"x{i}.pkl", "wb") as f:
        pickle.dump(X_train, f)
    with open(out_dir + f"y{i}.pkl", "wb") as f:
        pickle.dump(y_train, f)
    X_train = None
    y_train = np.array([])
        

# print(X_train.shape, y_train.shape)

# linesI_no_nan = np.where(np.isnan(X_train).any(axis=1).any(axis=1) == False)[0]
# X_train = X_train[linesI_no_nan]
# y_train = y_train[linesI_no_nan]

# print(X_train.shape, y_train.shape)

subject 1 (4956, 32, 410) (4956,)
subject 2 (1512, 32, 410) (1512,)
subject 3 (1044, 32, 410) (1044,)
subject 4 (1440, 32, 410) (1440,)
subject 5 (1188, 32, 410) (1188,)
subject 6 (1584, 32, 410) (1584,)
subject 7 (1224, 32, 410) (1224,)
subject 8 (1512, 32, 410) (1512,)
subject 9 (1332, 32, 410) (1332,)
subject 10 (1116, 32, 410) (1116,)
subject 11 (1440, 32, 410) (1440,)
subject 12 (2232, 32, 410) (2232,)
subject 13 (1368, 32, 410) (1368,)
subject 14 (1404, 32, 410) (1404,)
subject 15 (1260, 32, 410) (1260,)
subject 16 (1260, 32, 410) (1260,)
subject 17 (1620, 32, 410) (1620,)
subject 18 (1620, 32, 410) (1620,)
subject 19 (1800, 32, 410) (1800,)
subject 20 (1620, 32, 410) (1620,)
subject 21 (1260, 32, 410) (1260,)
subject 22 (1584, 32, 410) (1584,)
subject 23 (1368, 32, 410) (1368,)
subject 24 (1296, 32, 410) (1296,)
subject 25 (2088, 32, 410) (2088,)
subject 26 (1116, 32, 410) (1116,)
subject 27 (2844, 32, 410) (2844,)
subject 28 (1800, 32, 410) (1800,)
subject 29 (1836, 32, 410) (1