In [1]:
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
import mne
import seaborn as sns
import nolds
from scipy import stats
from scipy.signal import welch
from sklearn.model_selection import StratifiedKFold
from sklearn.feature_selection import SelectFromModel
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.svm import SVC
from sklearn.impute import KNNImputer
from sklearn.metrics import f1_score, balanced_accuracy_score, confusion_matrix, accuracy_score
from sklearn.preprocessing import QuantileTransformer, StandardScaler
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from lightgbm import LGBMClassifier
from xgboost import XGBClassifier
from tqdm import tqdm

In [2]:
SAMPLING_RATE = 256
SEED = 42
labels_folder = "D:/Repos/reading_comprehension_EEG/our_data/labels"
data_folder = "D:/Repos/reading_comprehension_EEG/our_data"

In [3]:
SEED = 42
subjects = ['lea','finn','sarah', 'aurora', 'bjoern', 'derek', 'dimi', 'ronan']

In [4]:
def list_to_nested_numpy(lst):
    if isinstance(lst, list):
        return np.array([list_to_nested_numpy(item) for item in lst])
    else:
        return lst

In [5]:
subj_data = {}
for subj in subjects:
    print(subj)
    df = pd.read_csv(labels_folder+"/events_" + subj + ".txt", delim_whitespace=True)
    df = df[(df.number != "condition")]
    subj_data[subj] = {}
    subj_data[subj]["labels"] = df["number"].to_numpy().astype(float)
    subj_data[subj]["timestamps"] = df["type"].to_numpy().astype(float)
    if subj == 'aurora': # aurora is another format
        df = pd.read_csv(data_folder+"/" + subj + "_pre_processed_data.txt", delim_whitespace=True)
    else:
        df = pd.read_csv(data_folder+"/" + subj + "_pre_processed_data.txt", delim_whitespace=False)
    subj_data[subj]["data"] = df

lea
finn
sarah
aurora
bjoern
derek
dimi
ronan


In [6]:
print(subj_data['lea']['timestamps'])

[ 50844.     62614.625  70370.5    86809.875  96244.75  111613.375
 120614.25  125534.875 133507.75  144811.5   154439.25  159898.625
 212116.375 225500.5   235378.125 250900.375 265721.5   282789.5
 294260.125 302628.75  316782.875 336008.625 343165.    353894.25
 386458.375 400239.25  406785.25  419892.75  425830.    436434.5
 451762.75  463859.   ]


In [7]:
for x in subjects:
    if subj_data[x]['labels'][0] != 100:
        raise Exception("Something wrong with labels for " + x)

In [8]:
subj_data['lea']['data'].tail(10)

Unnamed: 0,Time,FP1,AF7,AF3,F1,F3,F5,F7,FT7,FC5,...,CP4,CP2,P2,P4,P6,P8,P10,PO8,PO4,O2
490230,1914961.0,-14.2626,-21.5872,-12.1773,-21.4252,-27.7623,-38.848,-9.1665,-39.2277,-31.2169,...,-7.4959,-10.5052,-7.9905,-1.5212,10.301,17.8715,12.1262,9.6484,0.3445,9.3032
490231,1914965.0,-12.8294,-21.8563,-17.7713,-23.7763,-28.8121,-32.8357,-3.5858,-35.3107,-32.8769,...,-8.0893,-11.9919,-8.2936,-1.8956,10.8764,17.9528,12.1024,9.5728,-0.3633,9.5258
490232,1914969.0,-6.914,-12.4566,-14.0455,-19.0623,-22.7313,-22.2133,7.7825,-24.7967,-27.1025,...,-4.5978,-9.0009,-4.7314,0.8325,13.6789,19.4881,13.2136,11.5986,1.0361,11.7327
490233,1914973.0,-2.1341,-4.5824,-5.8755,-12.5076,-15.1894,-13.7007,16.4557,-17.6653,-19.6339,...,-1.7404,-6.1776,-1.8143,2.5572,14.5142,18.604,12.8774,11.4908,1.1766,12.6361
490234,1914977.0,-3.7242,-7.0971,-3.0301,-11.2754,-13.6877,-13.4158,15.9852,-20.7935,-18.3197,...,-4.0619,-8.1852,-4.0551,-0.6734,9.8897,12.2785,8.2582,5.8749,-3.045,9.5219
490235,1914980.0,-9.8035,-17.0388,-8.746,-16.1411,-18.8107,-20.6804,8.4378,-30.516,-23.4937,...,-10.1869,-13.9156,-10.405,-7.4408,1.6816,3.2085,0.7237,-2.7314,-9.7797,3.9201
490236,1914984.0,-12.7859,-23.5844,-16.2914,-20.3241,-23.14,-27.7637,2.221,-35.7802,-27.1619,...,-13.311,-16.6934,-14.1044,-11.1779,-3.734,-1.7606,-4.2365,-7.1249,-12.65,1.0525
490237,1914988.0,-7.0342,-19.2117,-16.4715,-16.7578,-19.0342,-26.2726,4.4678,-28.5108,-21.6623,...,-8.5386,-11.565,-9.8327,-6.9721,-1.9565,1.549,-2.0529,-2.7778,-7.4266,4.3755
490238,1914992.0,4.2465,-7.4741,-8.221,-6.2167,-7.2919,-15.5059,13.5452,-12.1964,-8.6511,...,1.5839,-0.9063,0.6295,3.0963,4.8157,9.9836,5.6307,7.2299,3.266,11.9744
490239,1914996.0,11.94,1.5536,0.3865,3.1666,3.4148,-3.2525,20.9219,1.6156,2.3341,...,8.9047,7.0886,9.2479,11.1826,9.5938,15.8107,11.8401,14.5463,11.4808,17.7376


In [9]:
def get_bandwidths(data):
    data = data.drop('Time', axis='columns').values.transpose()
    # Define bandpass filter cutoff frequencies for each band
    freq_ranges = {'Delta': (0.5, 4),
                   'Theta': (4, 8),
                   'Alpha': (8, 12),
                   'Beta': (12, 30),
                   'Gamma': (30, 100)}
    bandwidths = []
    for i, channel_data in enumerate(data):
        for j, (band_name, (low_freq, high_freq)) in enumerate(freq_ranges.items()):
            # Apply bandpass filter
            bandwidths.append(mne.filter.filter_data(channel_data, SAMPLING_RATE, low_freq, high_freq,verbose=False))
    return bandwidths

In [10]:
for s in subjects:
    subj_data[s]['data'] = get_bandwidths(subj_data[s]['data'])

In [11]:
for s in subjects:
    subj_data[s]['data'] = list_to_nested_numpy(subj_data[s]['data'])

In [12]:
print(subj_data['aurora']['data'].shape)

(320, 621056)


In [13]:
def split_data(data, labels, timestamps):
    def to_true_label(label):
        if label == 100:
            raise Exception("Must skip labels with value 100!")
        if label == 195:
            return 1
        if label == 196:
            return 2
        return 0
    
    x = []
    y = []
    start = timestamps[0]
    for i, label in enumerate(labels):
        if i == 0: continue
        end = timestamps[i]
        if label != 100:
            x.append(data[:,int(start):int(end)])
            y.append(to_true_label(label))
        start = timestamps[i]
    return (x,y)

In [14]:
x = []
y = []
subj_indices = {}
for subj in subjects:
    subj_indices[subj] = {}
    subj_indices[subj]['s'] = len(x)
    x_s, y_s = split_data(subj_data[subj]['data'], subj_data[subj]['labels'], subj_data[subj]['timestamps'])
    x += x_s
    y += y_s
    subj_indices[subj]['e'] = len(x)-1

In [15]:
print(subj_indices)

{'lea': {'s': 0, 'e': 28}, 'finn': {'s': 29, 'e': 58}, 'sarah': {'s': 59, 'e': 87}, 'aurora': {'s': 88, 'e': 117}, 'bjoern': {'s': 118, 'e': 147}, 'derek': {'s': 148, 'e': 176}, 'dimi': {'s': 177, 'e': 206}, 'ronan': {'s': 207, 'e': 235}}


In [16]:
print(len(x))

236


In [17]:
def get_outlier_ratio(values, m):
    outliers = np.array(values)[abs(values - np.mean(values)) > m * np.std(values)]
    return len(outliers)/len(values)

def hjorth_mobility(signal):
    diff1 = np.diff(signal)
    var_signal = np.var(signal)
    var_diff1 = np.var(diff1)
    return np.sqrt(var_diff1 / var_signal)

def hjorth_complexity(signal):
    diff1 = np.diff(signal)
    diff1_mobility = hjorth_mobility(diff1)
    mobility = hjorth_mobility(signal)
    return diff1_mobility / mobility

def higuchi_fd(X, kmax=SAMPLING_RATE):
    """
    Compute Higuchi Fractal Dimension of a time series X.
    
    Parameters:
    X (array-like): The input time series data.
    kmax (int): The maximum interval size.
    
    Returns:
    float: The Higuchi Fractal Dimension.
    """
    L = []
    x = []
    N = len(X)
    for k in range(1, kmax + 1):
        Lk = 0
        for m in range(k):
            # Empty list x
            x = []
            for i in range(1, int(np.floor((N - m) / k))):
                x.append(X[m + i * k] - X[m + (i - 1) * k])
            # Compute the length of the curve
            Lk += np.sum(np.abs(x)) / ((N - 1) / ((N - m) / k) * k)
        L.append(np.log(Lk / (m + 1)))
    # Fit a line to the curve and return its slope
    return np.polyfit(np.log(np.arange(1, kmax + 1)), L, 1)[0]

def binarize_signal(signal):
    """
    Binarizes a time series signal by taking the median value of the entire signal.
    
    Parameters:
        signal (list or numpy array): The time series signal.
        
    Returns:
        numpy array: The binarized signal.
    """
    median_value = np.median(signal)
    binarized_signal = np.where(signal >= median_value, 1, 0)
    return binarized_signal


def lempel_ziv_complexity(signal):
    """
    Calculates the Lempel-Ziv Complexity (LZC) of a time series signal.
    
    Parameters:
        signal (list or numpy array): The time series signal.
        
    Returns:
        float: The Lempel-Ziv Complexity of the signal.
    """
    binarized_signal = binarize_signal(signal)
    unique_patterns = set()
    lzc = 0
    
    i = 0
    while i < len(binarized_signal):
        j = i + 1
        while j <= len(binarized_signal):
            pattern = tuple(binarized_signal[i:j])
            if pattern not in unique_patterns:
                unique_patterns.add(pattern)
                lzc += 1
                i = j - 1
                break
            j += 1
        i += 1
    
    return lzc

def extract_features(signal):
    #signal /= np.mean(signal)
    features = []

    features.append(hjorth_mobility(signal))
    features.append(hjorth_complexity(signal))
    features.append(lempel_ziv_complexity(signal))
    #features.append(higuchi_fd(signal, 100))

    # Time-domain features
    features.append(np.mean(signal))
    features.append(np.var(signal))
    features.append(stats.skew(signal))
    features.append(stats.kurtosis(signal))
    features.append(np.max(np.abs(signal)))  # Peak amplitude

    # Frequency-domain features
    f, psd = welch(signal)
    features.append(np.mean(psd))  # Mean power spectral density
    features.append(np.std(psd))
    features.append(np.max(psd))
    features.append(np.min(psd))
    features.append(stats.entropy(psd))  # Spectral entropy
    features.append(f[np.argmax(psd)])  # Spectral edge frequency

    # Statistical features
    features.append(np.median(signal))
    features.append(stats.iqr(signal))  # Interquartile range
    features.append(stats.kurtosis(np.diff(signal)))  # Kurtosis of first differences
    features.append(stats.entropy(np.abs(signal)))  # Signal entropy
    features.append(get_outlier_ratio(signal, 1.5))

    return features

In [18]:
features = []
for i in tqdm(range(len(x))):
    feature = []
    for bw in x[i]:
        feature += extract_features(bw)
    features.append(feature)
X = features

100%|██████████| 236/236 [16:23<00:00,  4.17s/it]


In [19]:
X = list_to_nested_numpy(X)
y = list_to_nested_numpy(y)

In [20]:
print(X.shape)

(236, 6080)


In [21]:
import collections

In [27]:
collections.Counter(y)

Counter({1: 91, 2: 23, 0: 122})

In [47]:

# Iterate over each fold
f1s = []
accs = []
balanced_accs = []
rands_f1 = []
rands_acc = []
balanced_rands_acc = []
splits = []
for (i, s) in enumerate(subjects):
    for j in range(i+1, len(subjects)):
        subj = subjects[i]
        subj2 = subjects[j]
        if subj != subj2:
            indices1 = subj_indices[subj]
            indices2 = subj_indices[subj2]
            if indices1['s'] > indices2['s']:
                indices2, indices1 = indices1, indices2
                subj, subj2 = subj2, subj
            splits.append(
                (
                    subj + ' and ' + subj2,
                    (
                        [x for x in range(indices1['s'])] + [x for x in range(indices1['e']+1,indices2['s'])] + [x for x in range(indices2['e']+1, len(X))],
                        [x for x in range(indices1['s'], indices1['e']+1)] + [x for x in range(indices2['s'], indices2['e']+1)]
                    )
                )
            )

for (i, (subjs, (train_ind, test_ind))) in enumerate(splits):

    # Split data into train and validation sets for this fold
    X_train, X_val = X[train_ind], X[test_ind]
    y_train, y_val = y[train_ind], y[test_ind]
    print(subjs + ':')
    print('train distribution:', collections.Counter(y_train))
    print('eval distribution:', collections.Counter(y_val))
    scaler = StandardScaler()
    imputer = KNNImputer(n_neighbors=50)
    X_train = imputer.fit_transform(X_train)
    X_val = imputer.transform(X_val)
    scaler.fit_transform(X_train)
    scaler.transform(X_val)

    sample_weight = compute_sample_weight(class_weight="balanced", y=y_train)
    # Initialize SVM classifier
    classifier = XGBClassifier(**{
                    'objective': 'multi:softmax',
                    'tree_method': 'auto',
                    'random_state': SEED,
                    'lambda': 44.95002293426431,
                    'alpha': 0.004154520917950766,
                    'colsample_bytree': 0.11602557907294059,
                    'colsample_bylevel': 0.6792114962764053,
                    'subsample': 0.9,
                    'learning_rate': 0.13725312872605422,
                    'n_estimators': 3099,
                    'max_depth': 23,
                    'min_child_weight': 1,
                })
    # Train the classifier on the training data
    classifier.fit(X_train, y_train, sample_weight=sample_weight)

    # Make predictions on the validation data
    y_pred = classifier.predict(X_val)

    # Evaluate the classifier
    rands = [random.randint(0,2) for _ in range(len(y_val))]
    f1 = f1_score(y_val, y_pred, average='micro')
    acc = accuracy_score(y_val, y_pred)
    bal_acc = balanced_accuracy_score(y_val, y_pred)
    f1s.append(f1)
    balanced_accs.append(bal_acc)
    accs.append(acc)
    print(f"F1: {f1:.2f}, acc: {acc: .2f}, balanced_acc: {bal_acc: .2f}")
    f1 = f1_score(y_val, rands, average='micro')
    acc = accuracy_score(y_val, rands)
    bal_acc = balanced_accuracy_score(y_val, rands)
    rands_f1.append(f1)
    rands_acc.append(acc)
    balanced_rands_acc.append(bal_acc)
    print(f"RANDOM: F1: {f1:.2f}, acc: {acc: .2f}, balanced_acc: {bal_acc: .2f}")
    print()

accs = list_to_nested_numpy(accs)
f1s = list_to_nested_numpy(f1s)
print("f1:", np.mean(f1s), np.std(f1s))
print("acc:", np.mean(accs), np.std(accs))
print("balanced_acc:", np.mean(balanced_accs), np.std(balanced_accs))
print("rands_f1:", np.mean(rands_f1), np.std(rands_f1))
print("rands_acc:", np.mean(rands_acc), np.std(rands_acc))
print("balanced_rands_acc:", np.mean(rands_acc), np.std(balanced_rands_acc))

lea and finn:
train distribution: Counter({0: 108, 1: 57, 2: 12})
eval distribution: Counter({1: 34, 0: 14, 2: 11})
F1: 0.37, acc:  0.37, balanced_acc:  0.41
RANDOM: F1: 0.36, acc:  0.36, balanced_acc:  0.40

lea and sarah:
train distribution: Counter({0: 93, 1: 70, 2: 15})
eval distribution: Counter({0: 29, 1: 21, 2: 8})
F1: 0.55, acc:  0.55, balanced_acc:  0.39
RANDOM: F1: 0.28, acc:  0.28, balanced_acc:  0.22

lea and aurora:
train distribution: Counter({0: 90, 1: 72, 2: 15})
eval distribution: Counter({0: 32, 1: 19, 2: 8})
F1: 0.44, acc:  0.44, balanced_acc:  0.37
RANDOM: F1: 0.36, acc:  0.36, balanced_acc:  0.42

lea and bjoern:
train distribution: Counter({0: 104, 1: 63, 2: 10})
eval distribution: Counter({1: 28, 0: 18, 2: 13})
F1: 0.41, acc:  0.41, balanced_acc:  0.38
RANDOM: F1: 0.29, acc:  0.29, balanced_acc:  0.32

lea and derek:
train distribution: Counter({0: 100, 1: 63, 2: 15})
eval distribution: Counter({1: 28, 0: 22, 2: 8})
F1: 0.40, acc:  0.40, balanced_acc:  0.34
RANDO



F1: 0.59, acc:  0.59, balanced_acc:  0.56
RANDOM: F1: 0.34, acc:  0.34, balanced_acc:  0.43

sarah and bjoern:
train distribution: Counter({0: 89, 1: 70, 2: 18})
eval distribution: Counter({0: 33, 1: 21, 2: 5})
F1: 0.58, acc:  0.58, balanced_acc:  0.37
RANDOM: F1: 0.37, acc:  0.37, balanced_acc:  0.33

sarah and derek:
train distribution: Counter({0: 85, 1: 70, 2: 23})
eval distribution: Counter({0: 37, 1: 21})




F1: 0.62, acc:  0.62, balanced_acc:  0.54
RANDOM: F1: 0.33, acc:  0.33, balanced_acc:  0.33

sarah and dimi:
train distribution: Counter({1: 80, 0: 74, 2: 23})
eval distribution: Counter({0: 48, 1: 11})
F1: 0.76, acc:  0.76, balanced_acc:  0.61
RANDOM: F1: 0.34, acc:  0.34, balanced_acc:  0.35

sarah and ronan:
train distribution: Counter({0: 91, 1: 71, 2: 16})
eval distribution: Counter({0: 31, 1: 20, 2: 7})




F1: 0.60, acc:  0.60, balanced_acc:  0.44
RANDOM: F1: 0.31, acc:  0.31, balanced_acc:  0.30

aurora and bjoern:
train distribution: Counter({0: 86, 1: 72, 2: 18})
eval distribution: Counter({0: 36, 1: 19, 2: 5})
F1: 0.53, acc:  0.53, balanced_acc:  0.39
RANDOM: F1: 0.42, acc:  0.42, balanced_acc:  0.41

aurora and derek:
train distribution: Counter({0: 82, 1: 72, 2: 23})
eval distribution: Counter({0: 40, 1: 19})




F1: 0.53, acc:  0.53, balanced_acc:  0.50
RANDOM: F1: 0.25, acc:  0.25, balanced_acc:  0.24

aurora and dimi:
train distribution: Counter({1: 82, 0: 71, 2: 23})
eval distribution: Counter({0: 51, 1: 9})
F1: 0.47, acc:  0.47, balanced_acc:  0.69
RANDOM: F1: 0.40, acc:  0.40, balanced_acc:  0.37

aurora and ronan:
train distribution: Counter({0: 88, 1: 73, 2: 16})
eval distribution: Counter({0: 34, 1: 18, 2: 7})




F1: 0.41, acc:  0.41, balanced_acc:  0.34
RANDOM: F1: 0.49, acc:  0.49, balanced_acc:  0.43

bjoern and derek:
train distribution: Counter({0: 96, 1: 63, 2: 18})
eval distribution: Counter({1: 28, 0: 26, 2: 5})
F1: 0.53, acc:  0.53, balanced_acc:  0.39
RANDOM: F1: 0.34, acc:  0.34, balanced_acc:  0.30

bjoern and dimi:
train distribution: Counter({0: 85, 1: 73, 2: 18})
eval distribution: Counter({0: 37, 1: 18, 2: 5})
F1: 0.53, acc:  0.53, balanced_acc:  0.38
RANDOM: F1: 0.22, acc:  0.22, balanced_acc:  0.23

bjoern and ronan:
train distribution: Counter({0: 102, 1: 64, 2: 11})
eval distribution: Counter({1: 27, 0: 20, 2: 12})
F1: 0.39, acc:  0.39, balanced_acc:  0.37
RANDOM: F1: 0.32, acc:  0.32, balanced_acc:  0.31

derek and dimi:
train distribution: Counter({0: 81, 1: 73, 2: 23})
eval distribution: Counter({0: 41, 1: 18})




F1: 0.56, acc:  0.56, balanced_acc:  0.56
RANDOM: F1: 0.31, acc:  0.31, balanced_acc:  0.28

derek and ronan:
train distribution: Counter({0: 98, 1: 64, 2: 16})
eval distribution: Counter({1: 27, 0: 24, 2: 7})
F1: 0.52, acc:  0.52, balanced_acc:  0.40
RANDOM: F1: 0.31, acc:  0.31, balanced_acc:  0.38

dimi and ronan:
train distribution: Counter({0: 87, 1: 74, 2: 16})
eval distribution: Counter({0: 35, 1: 17, 2: 7})
F1: 0.58, acc:  0.58, balanced_acc:  0.42
RANDOM: F1: 0.29, acc:  0.29, balanced_acc:  0.28

f1: 0.5168590492889148 0.10202420107298332
acc: 0.5168590492889148 0.10202420107298334
balanced_acc: 0.4312472392231991 0.085461270129565
rands_f1: 0.33517610141660403 0.05454262690661691
rands_acc: 0.33517610141660403 0.054542626906616896
balanced_rands_acc: 0.33517610141660403 0.0690111098049621


In [None]:
def plot_confusion_matrix(y_true, y_pred, classes):
    """
    This function prints and plots the confusion matrix.
    """
    cm = confusion_matrix(y_true, y_pred)
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion matrix')
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], fmt),
                     ha="center", va="center",
                     color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()