In [1]:
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
import mne
import seaborn as sns
import nolds
from scipy import stats
from scipy.signal import welch
from sklearn.model_selection import StratifiedKFold
from sklearn.feature_selection import SelectFromModel
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.svm import SVC
from sklearn.impute import KNNImputer
from sklearn.metrics import f1_score, balanced_accuracy_score, confusion_matrix, accuracy_score
from sklearn.preprocessing import QuantileTransformer, StandardScaler
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from lightgbm import LGBMClassifier
from xgboost import XGBClassifier
from tqdm import tqdm

In [2]:
SAMPLING_RATE = 256
SEED = 42
labels_folder = "D:/Repos/reading_comprehension_EEG/our_data/labels"
data_folder = "D:/Repos/reading_comprehension_EEG/our_data"

In [3]:
SEED = 42
subjects = ['lea','finn','sarah', 'aurora', 'bjoern', 'derek', 'dimi', 'ronan']

In [4]:
def list_to_nested_numpy(lst):
    if isinstance(lst, list):
        return np.array([list_to_nested_numpy(item) for item in lst])
    else:
        return lst

In [5]:
subj_data = {}
for subj in subjects:
    print(subj)
    df = pd.read_csv(labels_folder+"/events_" + subj + ".txt", delim_whitespace=True)
    df = df[(df.number != "condition")]
    subj_data[subj] = {}
    subj_data[subj]["labels"] = df["number"].to_numpy().astype(float)
    subj_data[subj]["timestamps"] = df["type"].to_numpy().astype(float)
    if subj == 'aurora': # aurora is another format
        df = pd.read_csv(data_folder+"/" + subj + "_pre_processed_data.txt", delim_whitespace=True)
    else:
        df = pd.read_csv(data_folder+"/" + subj + "_pre_processed_data.txt", delim_whitespace=False)
    subj_data[subj]["data"] = df

lea
finn
sarah
aurora
bjoern
derek
dimi
ronan


In [6]:
print(subj_data['aurora']['labels'])

[100. 101. 102. 103. 104. 105. 106. 107. 108. 109. 110. 111. 112. 100.
 101. 102. 103. 195. 195. 106. 195. 108. 195. 110. 111. 100. 101. 102.
 103. 104. 105. 195. 107.]


In [7]:
for x in subjects:
    if subj_data[x]['labels'][0] != 100 or subj_data[x]['labels'][1] == 100:
        raise Exception("Something wrong with labels for " + x)

In [8]:
subj_data['lea']['data'][10:20]

Unnamed: 0,Time,FP1,AF7,AF3,F1,F3,F5,F7,FT7,FC5,...,CP4,CP2,P2,P4,P6,P8,P10,PO8,PO4,O2
10,39.0625,31.411,32.0673,22.3352,25.0382,23.5137,35.112,29.3931,22.0631,23.7075,...,27.3832,24.2128,19.9883,21.8572,40.9739,33.9711,26.3451,25.9169,24.9449,27.6925
11,42.9688,33.7772,34.1328,24.6102,27.0248,24.8361,35.6534,32.2594,22.5779,25.0118,...,28.8048,25.0436,20.5898,23.3805,42.628,35.6002,28.7708,30.272,26.4836,29.2275
12,46.875,43.6999,42.6404,34.5724,36.4731,33.9406,42.2067,42.7711,33.0325,34.5852,...,39.8251,36.8006,33.0467,35.3716,52.3633,42.8707,42.8806,43.4715,38.6105,41.5642
13,50.7812,47.7583,45.0808,38.0574,39.5245,37.3788,43.492,47.8495,39.8416,39.3467,...,45.8703,44.1305,41.8964,42.8888,55.8803,43.9029,53.7123,49.9642,46.667,50.3401
14,54.6875,40.277,37.2935,29.6998,30.8815,29.6544,35.6481,42.0118,35.8378,33.3848,...,39.9302,39.132,38.2728,38.107,45.9031,35.0073,52.2798,42.6238,42.3187,46.5558
15,58.5938,29.7054,28.2113,18.7481,19.7947,19.3899,27.0491,32.9142,27.6715,24.2342,...,30.108,29.6817,29.2025,28.4947,30.7703,25.1686,44.2137,30.1952,32.2482,35.792
16,62.5,27.8484,28.9108,17.4641,18.5662,18.356,27.7724,31.2602,26.2626,22.8085,...,28.5257,28.429,27.4537,26.3551,24.3644,24.4255,40.5826,25.8036,28.6333,30.061
17,66.4062,35.2774,38.1997,25.8704,27.5399,26.8358,36.3582,36.7867,32.3949,29.622,...,36.6044,37.2835,35.9167,33.9725,30.8759,31.8548,44.2875,31.9227,34.749,34.0319
18,70.3125,41.7216,44.8413,32.8964,35.6618,34.1918,41.8751,39.8603,36.6848,35.2766,...,43.9096,45.5853,44.4238,41.368,41.2876,37.224,48.0081,38.6357,41.7653,40.7277
19,74.2188,39.0086,41.1422,30.0883,33.9687,32.0526,37.3685,34.2636,32.0706,32.8132,...,41.3217,43.4583,42.825,39.1924,45.1635,34.2447,45.0698,37.0833,40.6852,41.9102


In [9]:
def split_data(data, labels, timestamps):
    def to_true_label(label):
        if label == 100:
            raise Exception("Must skip labels with value 100!")
        if label == 195:
            return 1
        if label == 196:
            return 2
        return 0
    
    texts = []
    x = []
    y = []
    start = timestamps[0]
    for i, label in enumerate(labels):
        if i == 0: continue
        end = timestamps[i]
        if label != 100:
            x.append(data[int(start):int(end)])
            y.append(to_true_label(label))
        else:
            texts.append((x,y))
            x = []
            y = []
        start = timestamps[i]
    texts.append((x,y))
    if len(texts) != 3:
        raise Exception("Texts must be 3, not " + str(len(texts)))
    return texts

In [10]:
X1 = []
X2 = []
X3 = []
y1 = []
y2 = []
y3 = []
for subj in subjects:
    print(subj)
    texts = split_data(subj_data[subj]['data'], subj_data[subj]['labels'], subj_data[subj]['timestamps'])
    X1 += texts[0][0]
    y1 += texts[0][1]
    X2 += texts[1][0]
    y2 += texts[1][1]
    X3 += texts[2][0]
    y3 += texts[2][1]

lea
finn
sarah
aurora
bjoern
derek
dimi
ronan


In [11]:
texts = [X1, X2, X3]
labels = [y1, y2, y3]

In [12]:
def get_bandwidths(data):
    data = data.drop('Time', axis='columns').values.transpose()
    # Define bandpass filter cutoff frequencies for each band
    freq_ranges = {'Theta': (4, 8),
                   'Alpha': (8, 12),
                   'Beta': (12, 30),
                   'Gamma': (30, 100)}
    bandwidths = []
    for i, channel_data in enumerate(data):
        for j, (band_name, (low_freq, high_freq)) in enumerate(freq_ranges.items()):
            # Apply bandpass filter
            bandwidths.append(mne.filter.filter_data(channel_data, SAMPLING_RATE, low_freq, high_freq,verbose=False))
    return bandwidths

In [13]:
X = []
for x in texts:
    text = []
    for i in tqdm(range(len(x))):
        text.append(get_bandwidths(x[i]))
    X.append(text)

  0%|          | 0/92 [00:00<?, ?it/s]

100%|██████████| 92/92 [00:35<00:00,  2.60it/s]
100%|██████████| 88/88 [00:32<00:00,  2.67it/s]
100%|██████████| 56/56 [00:21<00:00,  2.61it/s]


In [14]:
texts = X

In [15]:
def get_outlier_ratio(values, m):
    outliers = np.array(values)[abs(values - np.mean(values)) > m * np.std(values)]
    return len(outliers)/len(values)

def hjorth_mobility(signal):
    diff1 = np.diff(signal)
    var_signal = np.var(signal)
    var_diff1 = np.var(diff1)
    return np.sqrt(var_diff1 / var_signal)

def hjorth_complexity(signal):
    diff1 = np.diff(signal)
    diff1_mobility = hjorth_mobility(diff1)
    mobility = hjorth_mobility(signal)
    return diff1_mobility / mobility

def higuchi_fd(X, kmax=SAMPLING_RATE):
    """
    Compute Higuchi Fractal Dimension of a time series X.
    
    Parameters:
    X (array-like): The input time series data.
    kmax (int): The maximum interval size.
    
    Returns:
    float: The Higuchi Fractal Dimension.
    """
    L = []
    x = []
    N = len(X)
    for k in range(1, kmax + 1):
        Lk = 0
        for m in range(k):
            # Empty list x
            x = []
            for i in range(1, int(np.floor((N - m) / k))):
                x.append(X[m + i * k] - X[m + (i - 1) * k])
            # Compute the length of the curve
            Lk += np.sum(np.abs(x)) / ((N - 1) / ((N - m) / k) * k)
        L.append(np.log(Lk / (m + 1)))
    # Fit a line to the curve and return its slope
    return np.polyfit(np.log(np.arange(1, kmax + 1)), L, 1)[0]

def binarize_signal(signal):
    """
    Binarizes a time series signal by taking the median value of the entire signal.
    
    Parameters:
        signal (list or numpy array): The time series signal.
        
    Returns:
        numpy array: The binarized signal.
    """
    median_value = np.median(signal)
    binarized_signal = np.where(signal >= median_value, 1, 0)
    return binarized_signal


def lempel_ziv_complexity(signal):
    """
    Calculates the Lempel-Ziv Complexity (LZC) of a time series signal.
    
    Parameters:
        signal (list or numpy array): The time series signal.
        
    Returns:
        float: The Lempel-Ziv Complexity of the signal.
    """
    binarized_signal = binarize_signal(signal)
    unique_patterns = set()
    lzc = 0
    
    i = 0
    while i < len(binarized_signal):
        j = i + 1
        while j <= len(binarized_signal):
            pattern = tuple(binarized_signal[i:j])
            if pattern not in unique_patterns:
                unique_patterns.add(pattern)
                lzc += 1
                i = j - 1
                break
            j += 1
        i += 1
    
    return lzc

def extract_features(signal):
    #signal /= np.mean(signal)
    features = []

    features.append(hjorth_mobility(signal))
    features.append(hjorth_complexity(signal))
    features.append(lempel_ziv_complexity(signal))
    #features.append(higuchi_fd(signal, 100))

    # Time-domain features
    features.append(np.mean(signal))
    features.append(np.var(signal))
    features.append(stats.skew(signal))
    features.append(stats.kurtosis(signal))
    features.append(np.max(np.abs(signal)))  # Peak amplitude

    # Frequency-domain features
    f, psd = welch(signal)
    features.append(np.mean(psd))  # Mean power spectral density
    features.append(np.std(psd))
    features.append(np.max(psd))
    features.append(np.min(psd))
    features.append(stats.entropy(psd))  # Spectral entropy
    features.append(f[np.argmax(psd)])  # Spectral edge frequency

    # Statistical features
    features.append(np.median(signal))
    features.append(stats.iqr(signal))  # Interquartile range
    features.append(stats.kurtosis(np.diff(signal)))  # Kurtosis of first differences
    features.append(stats.entropy(np.abs(signal)))  # Signal entropy
    features.append(get_outlier_ratio(signal, 1.5))

    return features

In [16]:
X = []
for t in texts:
    features = []
    for i in tqdm(range(len(t))):
        feature = []
        for bw in t[i]:
            feature += extract_features(bw)
        features.append(feature)
    X.append(features)
print(len(X))

100%|██████████| 92/92 [04:02<00:00,  2.64s/it]
100%|██████████| 88/88 [04:32<00:00,  3.09s/it]
100%|██████████| 56/56 [04:43<00:00,  5.06s/it]

3





In [17]:
texts = X

In [18]:
import collections

In [19]:

# Iterate over each fold
f1s = []
accs = []
balanced_accs = []
rands_f1 = []
rands_acc = []
balanced_rands_acc = []
for i in range(3):
    X_val = texts[i]
    y_val = labels[i]
    X_train = np.vstack([texts[j] for j in range(3) if j != i])
    y_train = np.hstack([labels[j] for j in range(3) if j != i])
    print("TEXT " + str(i) + ":")
    print("Train distribution: " + str(collections.Counter(y_train)))
    print("Test distribution: " + str(collections.Counter(y_val)))
    scaler = StandardScaler()
    imputer = KNNImputer(n_neighbors=50)
    X_train = imputer.fit_transform(X_train)
    X_val = imputer.transform(X_val)
    scaler.fit_transform(X_train)
    scaler.transform(X_val)

    sample_weight = compute_sample_weight(class_weight="balanced", y=y_train)
    # Initialize SVM classifier
    classifier = XGBClassifier(**{
                    'objective': 'multi:softmax',
                    'tree_method': 'auto',
                    'random_state': SEED,
                    'lambda': 44.95002293426431,
                    'alpha': 0.004154520917950766,
                    'colsample_bytree': 0.11602557907294059,
                    'colsample_bylevel': 0.6792114962764053,
                    'subsample': 0.9,
                    'learning_rate': 0.13725312872605422,
                    'n_estimators': 3099,
                    'max_depth': 23,
                    'min_child_weight': 1,
                })
    # Train the classifier on the training data
    classifier.fit(X_train, y_train, sample_weight=sample_weight)

    # Make predictions on the validation data
    y_pred = classifier.predict(X_val)

    # Evaluate the classifier
    rands = [random.randint(0,2) for _ in range(len(y_val))]
    f1 = f1_score(y_val, y_pred, average='micro')
    acc = accuracy_score(y_val, y_pred)
    bal_acc = balanced_accuracy_score(y_val, y_pred)
    f1s.append(f1)
    balanced_accs.append(bal_acc)
    accs.append(acc)
    print(f"F1: {f1:.2f}, acc: {acc: .2f}, balanced_acc: {bal_acc: .2f}")
    f1 = f1_score(y_val, rands, average='micro')
    acc = accuracy_score(y_val, rands)
    bal_acc = balanced_accuracy_score(y_val, rands)
    rands_f1.append(f1)
    rands_acc.append(acc)
    balanced_rands_acc.append(bal_acc)
    print(f"RANDOM: F1: {f1:.2f}, acc: {acc: .2f}, balanced_acc: {bal_acc: .2f}")
    print()

accs = list_to_nested_numpy(accs)
f1s = list_to_nested_numpy(f1s)
print("f1:", np.mean(f1s), np.std(f1s))
print("acc:", np.mean(accs), np.std(accs))
print("balanced_acc:", np.mean(balanced_accs), np.std(balanced_accs))
print("rands_f1:", np.mean(rands_f1), np.std(rands_f1))
print("rands_acc:", np.mean(rands_acc), np.std(rands_acc))
print("balanced_rands_acc:", np.mean(rands_acc), np.std(balanced_rands_acc))

TEXT 0:
Train distribution: Counter({1: 66, 0: 65, 2: 13})
Test distribution: Counter({0: 57, 1: 25, 2: 10})
F1: 0.59, acc:  0.59, balanced_acc:  0.43
RANDOM: F1: 0.29, acc:  0.29, balanced_acc:  0.21

TEXT 1:
Train distribution: Counter({0: 83, 1: 50, 2: 15})
Test distribution: Counter({1: 41, 0: 39, 2: 8})
F1: 0.55, acc:  0.55, balanced_acc:  0.44
RANDOM: F1: 0.35, acc:  0.35, balanced_acc:  0.39

TEXT 2:
Train distribution: Counter({0: 96, 1: 66, 2: 18})
Test distribution: Counter({0: 26, 1: 25, 2: 5})
F1: 0.66, acc:  0.66, balanced_acc:  0.48
RANDOM: F1: 0.32, acc:  0.32, balanced_acc:  0.34

f1: 0.5977084509693205 0.047664836219963806
acc: 0.5977084509693205 0.047664836219963806
balanced_acc: 0.44863231405593407 0.023508797229436888
rands_f1: 0.3223931865236213 0.024012429859267142
rands_acc: 0.3223931865236213 0.024012429859267142
balanced_rands_acc: 0.3223931865236213 0.07631153490684266


In [20]:
def plot_confusion_matrix(y_true, y_pred, classes):
    """
    This function prints and plots the confusion matrix.
    """
    cm = confusion_matrix(y_true, y_pred)
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion matrix')
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], fmt),
                     ha="center", va="center",
                     color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()