In [1]:
import pickle
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

In [2]:
def load_in_data(path):
    with open(path, 'rb') as f:
        data = pickle.load(f)
    return data

In [3]:
def get_selected_features(SearchFeatures, include_segments=2):
    all_selected_eeg = []
    all_selected_ids = []
    for index in range(len(SearchFeatures[1])):
        sentence_eeg = SearchFeatures[0][index]
        sentence_ids = SearchFeatures[1][index]

        selected_ids = sentence_ids[-include_segments:]
        selected_eeg = sentence_eeg[-include_segments:]

        all_selected_eeg.append(selected_eeg)
        all_selected_ids.append(selected_ids)

    return all_selected_eeg

In [4]:
def get_search_x_y(SearchFeatures, label):
    X_data = []
    Y_Data = []

    for index in range(len(SearchFeatures)):
        sentence_eeg = SearchFeatures[index]
        for eeg in sentence_eeg:
            #print(eeg.shape)
            X_data.append(eeg)
            Y_Data.append(label)
    return X_data, Y_Data

In [5]:
def combine_data(NeedToSearch_X, CorrectSearch_X, IncorrectSearch_X, NeedToSearch_Y, CorrectSearch_Y, IncorrectSearch_Y):
    X_data = NeedToSearch_X + CorrectSearch_X[:len(NeedToSearch_X)] + IncorrectSearch_X[:len(NeedToSearch_X)]
    Y_data = NeedToSearch_Y + CorrectSearch_Y[:len(NeedToSearch_X)] + IncorrectSearch_Y[:len(NeedToSearch_X)]
    return X_data, Y_data

In [72]:
def get_all_subject_x_y(data, include_segments=2):
    X_data_all = []
    Y_data_all = []

    if include_segments < 2:
        raise ValueError("include_segments must be greater than 1")

    for key in data.keys():
        subject = data[key]


        NeedToSearchFeatures, CorrectSearchFeatures, IncorrectSearchFeatures = subject
        Selected_NeedToSearchFeatures = get_selected_features(NeedToSearchFeatures, include_segments)
        Selected_CorrectSearchFeatures = get_selected_features(CorrectSearchFeatures, include_segments)
        Selected_IncorrectSearchFeatures = get_selected_features(IncorrectSearchFeatures, include_segments)

        NeedToSearch_X, NeedToSearch_Y = get_search_x_y(Selected_NeedToSearchFeatures, label=0)
        CorrectSearch_X, CorrectSearch_Y = get_search_x_y(Selected_CorrectSearchFeatures, label=1)
        IncorrectSearch_X, IncorrectSearch_Y = get_search_x_y(Selected_IncorrectSearchFeatures, label=1)

        X_data, Y_data = combine_data(NeedToSearch_X, CorrectSearch_X, IncorrectSearch_X, NeedToSearch_Y, CorrectSearch_Y, IncorrectSearch_Y)
        X_data_all += X_data
        Y_data_all += Y_data
    return X_data_all, Y_data_all

In [73]:

def get_metrics(model, X_Test, Y_Test):
    y_pred = model.predict(X_Test)
    precision, recall, f1, support = precision_recall_fscore_support(Y_Test, y_pred, average='weighted')
    accuracy = accuracy_score(Y_Test, y_pred)
    print(precision, recall, f1, accuracy)
    metrics = [precision, recall, f1, accuracy]
    return accuracy

In [74]:
path = r"C:\Users\gxb18167\PycharmProjects\SIGIR_EEG_GAN\Development\Information-Need\Data\stat_features\Participant_Features.pkl"

In [75]:
data = load_in_data(path)

In [76]:
X, Y = get_all_subject_x_y(data, include_segments=5)

In [77]:
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=1)

In [78]:
clf = RandomForestClassifier(max_depth=20, random_state=0)
clf.fit(X_train, y_train)

In [79]:
get_metrics(clf, X_test, y_test)

0.6470810758724445 0.6557077625570776 0.5721182352022074 0.6557077625570776


0.6557077625570776

216