In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import ShuffleSplit
from mne.decoding import CSP

current_directory = os.path.abspath('')

project_root = os.path.abspath(os.path.join(current_directory, '..', '..', '..', '..'))

sys.path.append(project_root)

print("ROOT:", project_root)
from Early_predict_UQ.data.make_dataset import make_data



def early_pred(probabilities, predict, numTimesBelowThreshold, patience, confidence_type, threshold):
    probabilities = probabilities.flatten()
    sorted_probs = sorted(probabilities, reverse=True)
    if confidence_type == 'highest_prob':
        confidence = sorted_probs[0]
    else:
        confidence = 1 - (1 / (1 + (sorted_probs[0] + (sorted_probs[0] - sorted_probs[1]))))
    if confidence > threshold and not predict:
        print("confindence:", confidence)
        sorted_probs[0]
        numTimesBelowThreshold += 1
        if numTimesBelowThreshold == patience:
            predict = True
    return predict, confidence, numTimesBelowThreshold

def run_sliding_classification(subjects, threshold, patience, confidence_type, w_length, w_step, sfreq):
    scores_across_subjects = []
    prediction_time_across_subjects = []
    current_person = 0
    for person in subjects:
        current_person += 1
        print("Person %d" % (person))
        subject= [person]
        epochs, labels = make_data(subject)
        epochs_train = epochs.copy()
        labels = epochs.events[:, -1] - 4
        epochs_data = epochs.get_data(copy=False)
        epochs_data_train = epochs_train.get_data(copy=False)

        cv = ShuffleSplit(2, test_size=0.2, random_state=42)
        cv_split = cv.split(epochs_data_train)
        scores_cv_splits = []
        predict_time_cv_splits = []

        lda = LinearDiscriminantAnalysis()
        csp = CSP(n_components=5, reg=None, log=True, norm_trace=False)
        current_cv = 0 
        for train_idx, test_idx in cv_split:
            current_cv += 1
            y_train, y_test = labels[train_idx], labels[test_idx]
            X_train = csp.fit_transform(epochs_data_train[train_idx], y_train)
            lda.fit(X_train, y_train)
            w_start = np.arange(0, epochs_data.shape[2] - w_length, w_step) 
            scores_across_epochs = []
            predict_time_across_epochs = []

            for epoch_idx in range(len(test_idx)):
                predict = False
                numTimesBelowThreshold = 0
                for n in w_start:
                    X_test_window = csp.transform(epochs_data_train[test_idx][:, :, n:(n + w_length)])
                    X_test_epoch_window = X_test_window[epoch_idx]
                    probabilities = lda.predict_proba([X_test_epoch_window])
                    probabilities = np.array(probabilities)
                    probabilities = probabilities.flatten()
                    predict, confidence, numTimesBelowThreshold = early_pred(
                        probabilities, predict, numTimesBelowThreshold, patience, confidence_type, threshold
                    )
                    if predict:
                        #IF WE DIDNT PREDICT EARLY, MAYBE PREDICT ON THE WHOLE EPOCH?
                        predict_time = n
                        score = lda.score(X_test_epoch_window.reshape(1, -1), [y_test[epoch_idx]])
                        break
                else:
                    predict_time = n
                    score = lda.score(X_test_epoch_window.reshape(1, -1), [y_test[epoch_idx]])
                predict_time = (predict_time + w_length / 2.0) / sfreq + epochs.tmin
                scores_across_epochs.append(score)
                predict_time_across_epochs.append(predict_time)

            if current_cv == 1:
                scores_cv_splits = np.array(scores_across_epochs)
                predict_time_cv_splits = np.array(predict_time_across_epochs)
            else:
                scores_cv_splits = np.vstack((scores_cv_splits,np.array(scores_across_epochs)))
                predict_time_cv_splits = np.vstack((predict_time_cv_splits,np.array(predict_time_across_epochs)))

        mean_scores_across_cv = np.mean(scores_cv_splits, axis=0)
        mean_predict_time_across_cv = np.mean(predict_time_cv_splits, axis=0)
        if current_person == 1:
            scores_across_subjects  = np.array(mean_scores_across_cv)
            prediction_time_across_subjects = np.array(mean_predict_time_across_cv)
        else:
            scores_across_subjects = np.vstack((scores_across_subjects,np.array(mean_scores_across_cv)))
            prediction_time_across_subjects = np.vstack((predict_time_cv_splits,np.array(mean_predict_time_across_cv)))

        mean_scores_across_subjects = np.mean(scores_across_subjects, axis=0)
        accuracy = np.mean(mean_scores_across_subjects)

        mean_prediction_time_across_subjects = np.mean(prediction_time_across_subjects, axis=0)
        mean_prediction_time = np.mean(mean_prediction_time_across_subjects)
    return accuracy, mean_prediction_time, epochs, labels

if __name__ == "__main__":
    #threshold = 0.6  # values - {0,1}
    #patience = 4 # values - {1,36}
    subjects = [1, 2, 3, 4, 5, 6, 7, 8, 9]

    confidence_type = 'highest_prob' # 'highest_prob' or the cost function. # hyperparameter - maybe compare in different files
    sfreq = 250      
    w_length = int(sfreq * 0.5)  
    w_step = int(sfreq * 0.5)   

    #csp components #hyperparameter
    #cross validation #hyperparmater
accuracy_array = []
prediction_time_array = []

#MIGHT BE TOO INTENSIVE FOR THE KERNEL, maybe tune the other hyperparameters first, then with the best values, loop over patience, and trehsold
# over threshold values
for threshold in np.arange(0, 1, 0.2):
    accuracy_row = []
    prediction_time_row = []
    # over patience values
    for patience in np.arange(1, 36, 4):
        accuracy, mean_prediction_time, epochs, labels = run_sliding_classification(subjects, threshold, patience, confidence_type, w_length, w_step, sfreq)
        accuracy_row.append(accuracy)
        prediction_time_row.append(mean_prediction_time)
    accuracy_array.append(accuracy_row)
    prediction_time_array.append(prediction_time_row)

accuracy_array = np.array(accuracy_array)
prediction_time_array = np.array(prediction_time_array)

print("accuracy_array: ", accuracy_array)
print("prediction_time_array: ",  prediction_time_array)


<h2> Plotting and evaluation 

In [None]:

import pandas as pd

accuracy_array = np.array(accuracy_array)
prediction_time_array = np.array(prediction_time_array)

accuracy_df = pd.DataFrame(accuracy_array, 
                           index=np.arange(0, 1, 0.2),
                           columns=np.arange(1, 36, 4))

prediction_time_df = pd.DataFrame(prediction_time_array, 
                                  index=np.arange(0, 1, 0.2),
                                  columns=np.arange(1, 36, 4))

# Plotting accuracy
fig = plt.figure()
ax = fig.add_subplot(121, projection='3d')
ax.plot_surface(accuracy_df.columns, accuracy_df.index, accuracy_df.values, cmap='viridis')
ax.set_xlabel('Patience')
ax.set_ylabel('Threshold')
ax.set_zlabel('Accuracy')
ax.set_title('Accuracy vs Threshold and Patience')

# Plotting prediction time
ax = fig.add_subplot(122, projection='3d')
ax.plot_surface(prediction_time_df.columns, prediction_time_df.index, prediction_time_df.values, cmap='viridis')
ax.set_xlabel('Patience')
ax.set_ylabel('Threshold')
ax.set_zlabel('Prediction Time')
ax.set_title('Prediction Time vs Threshold and Patience')

plt.show()

In [None]:

# A formality as classes are balanced
class_balance = np.zeros(4)
for i in range(4):
    class_balance[i] = np.mean(labels == i)
class_balance = np.max(class_balance)

plt.figure()
plt.plot(np.arange(0, 1, 0.2), accuracy_across_threshold_values, label="Accuracy")
#plt.axvline(2, linestyle="--", color="k", label="Onset")
#plt.axvline(w_times[predict_time], linestyle="-", color="k", label="Stopping")
plt.axhline(class_balance, linestyle="-", color="k", label="Chance")
plt.xlabel("time (s)")
plt.ylabel("classification accuracy")
plt.title("Classification score accross threshold values")
plt.legend(loc="lower right")
plt.show()

print("Classification accuracy: %f / Chance level: %f" % (np.mean(accuracy_across_threshold_values), class_balance))

#Subtracting time before the cue at 2s
mean_prediction_time -= 2
max_time = epochs.tmax - 2
print("Mean prediction time: %f / Full time: %f / Percentage of time: %f" % (np.mean(mean_prediction_time_across_threshold_values), max_time,np.mean(mean_prediction_time_across_threshold_values) / max_time))
