In [1]:
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
import mne
import seaborn as sns
import nolds
from scipy import stats
from scipy.signal import welch
from sklearn.model_selection import StratifiedKFold
from sklearn.feature_selection import SelectFromModel
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.svm import SVC
from sklearn.impute import KNNImputer
from sklearn.metrics import f1_score, balanced_accuracy_score, confusion_matrix, accuracy_score
from sklearn.preprocessing import QuantileTransformer, StandardScaler
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from lightgbm import LGBMClassifier
from xgboost import XGBClassifier
from tqdm import tqdm

In [2]:
SAMPLING_RATE = 256
SEED = 42
labels_folder = "D:/Repos/reading_comprehension_EEG/our_data/labels" # change to your folders
data_folder = "D:/Repos/reading_comprehension_EEG/our_data"

In [3]:
SEED = 42
subjects = ['lea','finn','sarah', 'aurora', 'bjoern', 'derek'] # add all subjects here

In [4]:
subj_data = {}
for subj in subjects:
    print(subj)
    df = pd.read_csv(labels_folder+"/events_" + subj + ".txt", delim_whitespace=True)
    df = df[(df.number != "condition")]
    subj_data[subj] = {}
    subj_data[subj]["labels"] = df["number"].to_numpy().astype(float)
    subj_data[subj]["timestamps"] = df["type"].to_numpy().astype(float)
    if subj == 'aurora': # aurora is another format
        df = pd.read_csv(data_folder+"/" + subj + "_pre_processed_data.txt", delim_whitespace=True)
    else:
        df = pd.read_csv(data_folder+"/" + subj + "_pre_processed_data.txt", delim_whitespace=False)
    subj_data[subj]["data"] = df

lea
finn
sarah
aurora
bjoern
derek


In [5]:
for x in subjects:
    if subj_data[x]['labels'][0] != 100 or subj_data[x]['labels'][1] == 100:
        raise Exception("Something wrong with labels for " + x)

In [6]:
def split_data(data, labels, timestamps):
    def to_true_label(label):
        if label == 100:
            raise Exception("Must skip labels with value 100!")
        if label == 195:
            return 1
        if label == 196:
            return 2
        return 0
    
    texts = []
    x = []
    y = []
    start = timestamps[0]
    for i, label in enumerate(labels):
        if i == 0: continue
        end = timestamps[i]
        if label != 100:
            x.append(data[int(start):int(end)])
            y.append(to_true_label(label))
        else:
            texts.append((x,y))
            x = []
            y = []
        start = timestamps[i]
    texts.append((x,y))
    if len(texts) != 3:
        raise Exception("Texts must be 3, not " + str(len(texts)))
    return texts

In [7]:
X1 = []
X2 = []
X3 = []
y1 = []
y2 = []
y3 = []
for subj in subjects:
    print(subj)
    texts = split_data(subj_data[subj]['data'], subj_data[subj]['labels'], subj_data[subj]['timestamps'])
    X1 += texts[0][0]
    y1 += texts[0][1]
    X2 += texts[1][0]
    y2 += texts[1][1]
    X3 += texts[2][0]
    y3 += texts[2][1]

lea
finn
sarah
aurora
bjoern
derek
