In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import ShuffleSplit
from mne.decoding import CSP

current_directory = os.path.abspath('')

project_root = os.path.abspath(os.path.join(current_directory, '..', '..', '..', '..'))

sys.path.append(project_root)

print("ROOT:", project_root)
from Early_predict_UQ.data.make_dataset import make_data



def early_pred(probabilities, predict, numTimesBelowThreshold, patience, confidence_type, threshold):
    probabilities = probabilities.flatten()
    sorted_probs = sorted(probabilities, reverse=True)
    if confidence_type == 'highest_prob':
        confidence = sorted_probs[0]
    else:
        confidence = 1 - (1 / (1 + (sorted_probs[0] + (sorted_probs[0] - sorted_probs[1]))))
    if confidence > threshold and not predict:
        print("confindence:", confidence)
        sorted_probs[0]
        numTimesBelowThreshold += 1
        if numTimesBelowThreshold == patience:
            predict = True
    return predict, confidence, numTimesBelowThreshold

def run_sliding_classification(subjects, threshold, patience, confidence_type, initial_window_length, expansion_rate, sfreq):
    scores_across_subjects = []
    prediction_time_across_subjects = []
    current_person = 0
    for person in subjects:
        current_person += 1
        print("Person %d" % (person))
        subject= [person]
        epochs, labels = make_data(subject)
        epochs_train = epochs.copy()
        labels = epochs.events[:, -1] - 4
        epochs_data = epochs.get_data(copy=False)
        epochs_data_train = epochs_train.get_data(copy=False)

        cv = ShuffleSplit(2, test_size=0.2, random_state=42)
        cv_split = cv.split(epochs_data_train)
        scores_cv_splits = []
        predict_time_cv_splits = []

        lda = LinearDiscriminantAnalysis()
        csp = CSP(n_components=5, reg=None, log=True, norm_trace=False)
        current_cv = 0 
        for train_idx, test_idx in cv_split:
            current_cv += 1
            y_train, y_test = labels[train_idx], labels[test_idx]
            X_train = csp.fit_transform(epochs_data_train[train_idx], y_train)
            lda.fit(X_train, y_train)
            w_start = np.arange(0, epochs_data.shape[2] - initial_window_length, expansion_rate) 
            scores_across_epochs = []
            predict_time_across_epochs = []

            for epoch_idx in range(len(test_idx)):
                predict = False
                numTimesBelowThreshold = 0
                for n, window_start in enumerate(w_start):
                    window_length = initial_window_length + n * expansion_rate
                    X_test_window = csp.transform(epochs_data_train[test_idx][:, :,  window_start:(window_start + window_length)])
                    X_test_epoch_window = X_test_window[epoch_idx]
                    probabilities = lda.predict_proba([X_test_epoch_window])
                    probabilities = np.array(probabilities)
                    probabilities = probabilities.flatten()
                    predict, confidence, numTimesBelowThreshold = early_pred(
                        probabilities, predict, numTimesBelowThreshold, patience, confidence_type, threshold
                    )
                    if predict:
                        #IF WE DIDNT PREDICT EARLY, MAYBE PREDICT ON THE WHOLE EPOCH?
                        predict_time = n
                        score = lda.score(X_test_epoch_window.reshape(1, -1), [y_test[epoch_idx]])
                        break
                else:
                    predict_time = n
                    score = lda.score(X_test_epoch_window.reshape(1, -1), [y_test[epoch_idx]])
                predict_time = (predict_time + initial_window_length / 2.0) / sfreq + epochs.tmin
                scores_across_epochs.append(score)
                predict_time_across_epochs.append(predict_time)

            if current_cv == 1:
                scores_cv_splits = np.array(scores_across_epochs)
                predict_time_cv_splits = np.array(predict_time_across_epochs)
            else:
                scores_cv_splits = np.vstack((scores_cv_splits,np.array(scores_across_epochs)))
                predict_time_cv_splits = np.vstack((predict_time_cv_splits,np.array(predict_time_across_epochs)))

        mean_scores_across_cv = np.mean(scores_cv_splits, axis=0)
        mean_predict_time_across_cv = np.mean(predict_time_cv_splits, axis=0)
        if current_person == 1:
            scores_across_subjects  = np.array(mean_scores_across_cv)
            prediction_time_across_subjects = np.array(mean_predict_time_across_cv)
        else:
            scores_across_subjects = np.vstack((scores_across_subjects,np.array(mean_scores_across_cv)))
            prediction_time_across_subjects = np.vstack((predict_time_cv_splits,np.array(mean_predict_time_across_cv)))

        mean_scores_across_subjects = np.mean(scores_across_subjects, axis=0)
        accuracy = np.mean(mean_scores_across_subjects)

        mean_prediction_time_across_subjects = np.mean(prediction_time_across_subjects, axis=0)
        mean_prediction_time = np.mean(mean_prediction_time_across_subjects)
    return accuracy, mean_prediction_time, epochs, labels

if __name__ == "__main__":
    #threshold = 0.6  # values - {0,1}   #hyperparameter
    #patience = 4 # values - {1,36}      #hyperparameter

    #csp components #hyperparameter
    #cross validation #hyperparmater
    subjects = [1, 2, 3, 4, 5, 6, 7, 8, 9]

    confidence_type = 'highest_prob' # 'highest_prob' or the cost function    #hyperparameter
    sfreq = 250      
    initial_window_length = int(sfreq * 0.5) #hyperparameter
    expansion_rate = int(sfreq * 0.1) #hyperparameter

accuracy_array = []
prediction_time_array = []

# Loop over threshold values
for threshold in np.arange(0, 1, 0.2):
    accuracy_row = []
    prediction_time_row = []
    # Loop over patience values
    for patience in np.arange(1, 36, 4):
        # Calculate accuracy and prediction time
        accuracy, mean_prediction_time, epochs, labels = run_sliding_classification(subjects, threshold, patience, confidence_type, initial_window_length,expansion_rate, sfreq)
        accuracy_row.append(accuracy)
        prediction_time_row.append(mean_prediction_time)
    # Append accuracy and prediction time rows to the respective arrays
    accuracy_array.append(accuracy_row)
    prediction_time_array.append(prediction_time_row)


print("accuracy_array: ", accuracy_array)
print("prediction_time_array: ",  prediction_time_array)

accuracy_array = np.array(accuracy_array)
prediction_time_array = np.array(prediction_time_array)

print("accuracy_array: ", accuracy_array)
print("prediction_time_array: ",  prediction_time_array)
