In [1]:
import numpy as np
import scipy.io as io
import random
import scipy

from sklearn.utils import shuffle
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegressionCV
from utils import *
from extract_eeg_features import

In [2]:
def get_eeg_freqs(task:str, sbj:int, freq_domain:str, merge:str):
    """
        Args: Task (NR vs. TSR), Test subject number, EEG frequency domain (theta, alpha, beta, gamma), Binning strategy
        Return: NumPy matrix of respective EEG features on word level
    """
    files = get_matfiles(task)
    data = io.loadmat(files[sbj], squeeze_me=True, struct_as_record=False)['sentenceData']
    n_words = sum([len(sent.word) for sent in data if not isinstance(sent.word, float)])
    n_electrodes = 105
    word2eeg = np.zeros((n_words, n_electrodes))
    
    if freq_domain == 'theta':
        fields = ['TRT_t1', 'TRT_t2']
    elif freq_domain == 'alpha':
        fields = ['TRT_a1', 'TRT_a2']
    elif freq_domain == 'beta':
        fields = ['TRT_b1', 'TRT_b2']    
    elif freq_domain == 'gamma':
        fields = ['TRT_g1', 'TRT_g2']
        
    fixated = 0
    for sent in data:
        # if there is no data, skip sentence (most probably due to technical issues)
        if isinstance(sent.word, float):
            continue
        else:
            for word in sent.word:
                # if there was no fixation, skip word
                if isinstance(word.nFixations, np.ndarray):
                    continue
                else:
                    if merge == 'avg':
                        eeg_freq = np.mean(np.vstack([getattr(word, field) if hasattr(word, field) and len(getattr(word, field)) 
                                                      > 0 else 0 for field in fields]), axis = 0)
                    elif merge == 'max':
                        eeg_freq = np.amax(np.vstack([getattr(word, field) if hasattr(word, field) and len(getattr(word, field))
                                                      > 0 else 0 for field in fields]), axis = 0)
                    else:
                        raise ValueError('Binning strategy must be one of {max-pool, average}')
                    eeg_freq[np.isnan(eeg_freq)] = 0
                    word2eeg[fixated] += eeg_freq
                    fixated += 1
    word2eeg = word2eeg[:fixated, :]
    return word2eeg

In [3]:
def mean_freq_per_sbj(task:str, freq_domain:str, merge:str):
    sbjs_to_skip = [6, 11] if task == 'task2' else [3, 7, 11]
    X = []
    for i in range(12):
        if i not in sbjs_to_skip:
            X.append(get_eeg_freqs(task, i, freq_domain, merge))
    X_mean = np.zeros((X[0].shape[0], 105))
    if task == 'task2':
        D_0, D_1, D_2, D_3, D_4, D_5, D_7, D_8, D_9, D_10 = X 
        for i, (sbj_0, sbj_1, sbj_2, sbj_3, sbj_4, sbj_5, sbj_7, sbj_8, sbj_9, sbj_10) in enumerate(zip(D_0, D_1, D_2, D_3, D_4, D_5, D_7, D_8, D_9, D_10)):
            X_mean[i] += np.mean((sbj_0, sbj_1, sbj_2, sbj_3, sbj_4, sbj_5, sbj_7, sbj_8, sbj_9, sbj_10), axis=0)
    elif task == 'task3':
        D_0, D_1, D_2, D_4, D_5, D_6, D_8, D_9, D_10 = X
        for i, (sbj_0, sbj_1, sbj_2, sbj_4, sbj_5, sbj_6, sbj_8, sbj_9, sbj_10) in enumerate(zip(D_0, D_1, D_2, D_4, D_5, D_6, D_8, D_9, D_10)):
             X_mean[i] += np.mean((sbj_0, sbj_1, sbj_2, sbj_4, sbj_5, sbj_6, sbj_8, sbj_9, sbj_10), axis=0)
    return X_mean

In [46]:
def clf_fit(X_train, X_test, y_train, y_test, clf, rnd_state=42):
    if clf == 'RandomForest':
        model = RandomForestClassifier(n_estimators=100, criterion='gini', bootstrap=False, random_state=rnd_state)
    elif clf == 'LogReg':
        model = LogisticRegressionCV(cv=5, max_iter=1000, random_state=rnd_state,)
    model.fit(X_train, y_train)
    y_hat = model.predict(X_test)
    print(model.score(X_test, y_test))
    if clf == 'LogReg':
        return model.coef_
    else:
        return model.feature_importances_

In [5]:
X_NR = mean_freq_per_sbj('task2', 'alpha', 'avg')
Y_NR = np.zeros((X_NR.shape[0], 1))

X_AR = mean_freq_per_sbj('task3', 'alpha', 'avg')
Y_AR = np.ones((X_AR.shape[0], 1))

X, y = np.vstack((X_NR, X_AR)), np.vstack((Y_NR, Y_AR))
X, y = shuffle(X, y, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
y_train, y_test = y_train.ravel(), y_test.ravel()

In [39]:
feature_weights_RF = clf_fit(X_train, X_test, y_train, y_test, 'RandomForest')

np.argsort(feature_weights_RF)[::-1]

0.8986657050126217


array([ 30,  93,  44,  19,  88,   5,  94,  99,  35,  72,  36,  45,  50,
        71,  89,  49,  98,  51, 100,   0,  40,  43,  21,  90,  95,  22,
        11,  91,   4,  24,   3,  10,  41,  82,  74,   9,  75,  54,  79,
        23,  33,  83,  66,  26,  47,  61, 103,  68,  15,  29,  64,  73,
        46,  69,  59,  18,  78,  32,  16, 101,  17,  58,  14,  60,  31,
        92,   2,  20,  80,  55,   8,  37,  12,  97,  48,  53,   1,  57,
        28,  27,  25,   7,  86,  96,  87,  52,   6,  84, 102,  38,  34,
        85,  76,  13,  39,  65,  77,  42,  81,  70,  63,  67,  56,  62,
       104], dtype=int64)

In [41]:
X_NR = mean_freq_per_sbj('task2', 'theta', 'avg')
Y_NR = np.zeros((X_NR.shape[0], 1))

X_AR = mean_freq_per_sbj('task3', 'theta', 'avg')
Y_AR = np.ones((X_AR.shape[0], 1))

X, y = np.vstack((X_NR, X_AR)), np.vstack((Y_NR, Y_AR))
X, y = shuffle(X, y, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
y_train, y_test = y_train.ravel(), y_test.ravel()

In [47]:
feature_weights_RF = clf_fit(X_train, X_test, y_train, y_test, 'RandomForest')

np.argsort(feature_weights_RF)[::-1]

0.9065993508835196


array([ 93,  30,  98,  99,  94,  88,  35,  44,  72, 100,  40,  43,  19,
         5,  90,   0,  21,   3,  89,  49,  71,  22,  36,   2,  23,  28,
        50,  29,  91,  51,  24,  41,  18,  32,  33,  60,  97,  31,  11,
       103,  45,  92,  15,   8,  66,  16,  38,   6,   7,  47,  77,  14,
        61,  48,  25,  78, 102,  10,  54,  84,  56,  95,  83,  20,  81,
        52,  42,  12,  80,  87,  34,  85,  59,   4,  13,  17,  68, 101,
        79,  96,  82,   1,  26,   9,  37,  64,  65,  86,  39,  46,  27,
        69,  74,  55,  53,  73,  76,  70,  67,  63,  57,  62,  58,  75,
       104], dtype=int64)

In [48]:
X_NR = mean_freq_per_sbj('task2', 'beta', 'avg')
Y_NR = np.zeros((X_NR.shape[0], 1))

X_AR = mean_freq_per_sbj('task3', 'beta', 'avg')
Y_AR = np.ones((X_AR.shape[0], 1))

X, y = np.vstack((X_NR, X_AR)), np.vstack((Y_NR, Y_AR))
X, y = shuffle(X, y, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
y_train, y_test = y_train.ravel(), y_test.ravel()

In [49]:
feature_weights_RF = clf_fit(X_train, X_test, y_train, y_test, 'RandomForest')

np.argsort(feature_weights_RF)[::-1]

0.9228272628921745


array([ 30,  19,  35,  99,  44,   5,  94,  51,  50,  40, 100,  90,  36,
        21,  89,  49,  43,  93,  22,   0,  88,  11,  10,  45,  20,   1,
        26,  66,  95,  18,  91,  79,   9,  15,  74,  61,  29, 101,  57,
        12, 103,  53,  32,  60,  17,  48,  54,  24,  98,  47,  13,  82,
        58,  16,  84,   6,  28,   2,  52,   8,  59,  87,  96,  67,  39,
        41,  31, 102,  80,  42,  46,  63,  72,  33,  25,   4,  86,  27,
        64,  14,  73,  68,  77,  23,  65,  55,  37,  76,  83,   7,  62,
        85,  92,  38,  81,  56,  75,  70,  97,  69,  71,  34,  78,   3,
       104], dtype=int64)

In [None]:
X_NR = mean_freq_per_sbj('task2', 'gamma', 'avg')
Y_NR = np.zeros((X_NR.shape[0], 1))

X_AR = mean_freq_per_sbj('task3', 'gamma', 'avg')
Y_AR = np.ones((X_AR.shape[0], 1))

X, y = np.vstack((X_NR, X_AR)), np.vstack((Y_NR, Y_AR))
X, y = shuffle(X, y, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
y_train, y_test = y_train.ravel(), y_test.ravel()

In [None]:
feature_weights_RF = fit_clf(X_train, X_test, y_train, y_test, 'RandomForest')

np.argsort(feature_weights_RF)[::-1]