In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import sys

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import ShuffleSplit, cross_val_score
from sklearn.pipeline import Pipeline

from scipy.stats import entropy

from mne.decoding import CSP


# Get the current directory
current_directory = os.path.abspath('')

# Move two directories up
project_root = os.path.abspath(os.path.join(current_directory,  '..', '..', '..', '..'))

# Append the project root to sys.path
sys.path.append(project_root)

print("ROOT:", project_root)
from Early_predict_UQ.data.make_dataset import make_data
from Early_predict_UQ.data.plots import plot_accuracy_over_time_and_epochs, plot_confidence_over_time_and_epochs #, plot_cost_over_time_and_epochs

def early_pred(probabilities, predict, numTimesBelowThreshold, patience, confidence_type):
                probabilities = probabilities.flatten()
                sorted_probs = sorted(probabilities, reverse=True)
                #cost1 = 1/(1+(sorted_probs[0] - sorted_probs[1]))
                if confidence_type == 'highest_prob':
                    confidence = sorted_probs[0]
                else:
                    # confidence_type is 'two_highest_difference'
                    #Based on the stopping rule described in DOI: 10.1109/TNNLS.2017.2764939
                    confidence = 1/(1+(sorted_probs[0] + (sorted_probs[0] - sorted_probs[1])))
                if confidence > threshold and predict == False:
                    #print("Confidence reached")
                    numTimesBelowThreshold +=1
                    #print("numTimesBelowThreshold: ",numTimesBelowThreshold)
                    if numTimesBelowThreshold == patience:
                        predict = True
                return predict, confidence, numTimesBelowThreshold

# Setting hyper-parameters
threshold = 0.5 # need to be a list with values between 0 and 1 to loop over
patience = 4 # numTimesBelowThreshold - potential hyperparameter
current_person = 0
subjects = [1,2,3,4,5,6,7,8,9]  #all sucbjects
scores_across_subjects = [] 
prediction_time_across_subjects = []
confidence_across_subjects = []
for person in subjects:
    current_person += 1
    subject= [person] # Choosing the subject 

    # Preprocessed epochs
    epochs, labels = make_data(subject)

    # Asserting the epochs and labels (last row of the events matrix) to be used for the classification
    epochs_train = epochs.copy()
    labels = epochs.events[:, -1] - 4

    currentcv = 0
    current_epoch = 0
    current_window = 0
    # Cross validation 
    ## (Might need to do cross session - session 1 as train, and session 2 as test. See dataset_structure.ipynb)
    scores = []
    epochs_data = epochs.get_data(copy=False)
    epochs_data_train = epochs_train.get_data(copy=False)
    
    #Cross validation  (CV)
    cv = ShuffleSplit(10, test_size=0.2, random_state=42) #Potential hyperparameter.
    cv_split = cv.split(epochs_data_train)

    # Linear discriminant analysis (LDA) and Common Spatial Pattern (CSP)
    lda = LinearDiscriminantAnalysis() # classifier
    csp = CSP(n_components=5, reg=None, log=True, norm_trace=False) # why 4 components or 5 ? Potential hyperparameter.

    confidence_type = ['highest_prob', 'two_highest_difference']
    confidence_type = 'highest_prob' #Hyperparameter

    # Class balance between the 4 classes. 
    #This is a formality, as the dataset is balanced, with equal number epochs for each of the 4 classes.
    class_balance = np.zeros(4)
    for i in range(4):
        class_balance[i] = np.mean(labels == i)

    class_balance = np.max(class_balance)

    class_names = {
            1: "Left hand",
            2: "Right hand",
            3: "Both feet",
            4: "Tongue"
    }

    sfreq = 250 # Sampling frequency of 250 Hz as per the BCI competion dataset 2a

    # Classify the signal using a sliding window
    w_length = int(sfreq * 0.5)  # Window length - Hyperparameter.
    w_step = int(sfreq * 0.5)  # window step size - Hyperparameter.

    #if w_length = sfreq * 0.5 and w_step = sfreq * 0.1. Theres 36 starting points
    # Set of starting positions in the signal(Note! the signal is 2s to 4s) 
    w_start = np.arange(0, epochs_data.shape[2] - w_length, w_step) 
    scores_cv_splits = [] 
    predict_time_cv_splits = []
    confidence_cv_split = []
    # Running classification across the signal
    for train_idx, test_idx in cv_split:
        currentcv+=1
        y_train, y_test = labels[train_idx], labels[test_idx] # Get the current labels and data

        # Exatract spatial filters and transform the data 
        X_train = csp.fit_transform(epochs_data_train[train_idx], y_train)
        X_test = csp.transform(epochs_data_train[test_idx]) #  why define and transform it here, and then do it later as well!

        # Fit the classifier on the training data
        lda.fit(X_train, y_train)
        w_times = (w_start + w_length / 2.0) / sfreq + epochs.tmin
        scores_across_epochs = []
        confidences_across_epochs = []
        predict_time_across_epochs = []
        current_epoch = 0     
        #The testset is 20% 116 epochs of the whole data 576 epochs(trials) for each subject
        for epoch_idx in range(len(test_idx)):  #for each epoch
            current_n = 0
            predict_time = 0
            current_epoch+=1
            predict = False 
            current_window = 0
            numTimesBelowThreshold = 0
            confidences_across_windows =[]
            for n in w_start: #for each sliding window
                current_window +=1 
                print(f" Subject {current_person} CV {currentcv}, epoch: {current_epoch}, and window:{current_window}")

                X_test_window = csp.transform(epochs_data[test_idx][:, :, n:(n + w_length)])
                X_test_epoch_window = X_test_window[epoch_idx]

                #Early prediction
                probabilities = lda.predict_proba([X_test_epoch_window])
                probabilities = np.array(probabilities)
                probabilities = probabilities.flatten()

                # predict becomes true to predict early then go to the next epoch
                predict, confidence, numTimesBelowThreshold = early_pred(probabilities, predict, numTimesBelowThreshold, patience, confidence_type)
                confidences_across_windows.append(confidence)
                if predict:
                    predict_time = n
                    print("early prediction")
                    score = lda.score(X_test_epoch_window.reshape(1, -1), [y_test[epoch_idx]])
                    break # predicting early
                current_n+=1
            else:
                predict_time = n #if not predicted early, we still predict as we have reached the end of the signal 
                score = lda.score(X_test_epoch_window.reshape(1, -1), [y_test[epoch_idx]])
                print("no early prediction in this epoch, numTimesBelowThreshold:", numTimesBelowThreshold)
            predict_time = (predict_time + w_length / 2.0) / sfreq + epochs.tmin
            scores_across_epochs.append(score)
            predict_time_across_epochs.append(predict_time)
    
        if currentcv == 1:
            scores_cv_splits = np.array(scores_across_epochs)
            predict_time_cv_splits = np.array(predict_time_across_epochs)
        else:
            scores_cv_splits = np.vstack((scores_cv_splits,np.array(scores_across_epochs)))
            predict_time_cv_splits = np.vstack((predict_time_cv_splits,np.array(predict_time_across_epochs)))

    mean_scores_across_cv = np.mean(scores_cv_splits, axis=0)
    mean_predict_time_across_cv = np.mean(predict_time_cv_splits, axis=0)
    if person == 1:
        scores_across_subjects  = np.array(mean_scores_across_cv)
        prediction_time_across_subjects = np.array(mean_predict_time_across_cv)
    else:
        scores_across_subjects = np.vstack((scores_across_subjects,np.array(mean_scores_across_cv)))
        prediction_time_across_subjects = np.vstack((predict_time_cv_splits,np.array(mean_predict_time_across_cv)))

mean_scores_across_subjects = np.mean(scores_across_subjects, axis=0)
accuracy = np.mean(mean_scores_across_subjects)
print("Classification accuracy: %f / Chance level: %f" % (accuracy, class_balance))
mean_prediction_time_across_subjects = np.mean(prediction_time_across_subjects, axis=0)
mean_prediction_time = np.mean(mean_prediction_time_across_subjects) -2
max_time = epochs.tmax-2
print("Mean prediction time: %f / full time: %f /  Percentage of time: %f" % (mean_prediction_time, max_time, mean_prediction_time/max_time))

''''
plt.plot(len(test_idx), mean_scores_across_cv, label="Score")
plt.xlabel("Time (s)")
plt.ylabel("Accuracy")
plt.axvline(2, linestyle="--", color="k", label="Onset")
plt.axhline(class_balance, linestyle="-", color="k", label="Chance")
plt.title("Classification accuracy over Time")
plt.legend()
plt.show()
'''
#print("scores_cv_splits (10X116?) shape:",scores_across_epochs
#print("predict_time_cv_splits (10X116?) shape: ", predict_time_across_epochs).shape

'''
To do - dynamic stop:
- sliding
    - make the for loops work and contain and provide the mean predict time and score correctly X
    - then advance to for all subjects X
    - modularize
    - the start the hyperparameter tuning to maximize classification accuracy, and minimize predict_time 
    - then loop across all threshold values 
    - make it take into account all the subjects
    - make it work using svm
    - provide the plots for all the subjects for all subjects for each condition, let it just save the plots to a folder automatically (potentially also the values to make plots somewhere else)
    - nb: watch the memory and time usage for codespaces
- expanding:
    - make a new file, adjust to use expanding window
    - save its plots into another folder automaically
            
To do - static:
- make a new file and adjust the dynamic to just use a specific predict times using the cost function
- save the plots

to do - whole:
- already did that lol

'''



'''
print("len w-times: ", w_times)
print("len w-start: ", w_start)
print("len w-times[:numberOfNs]: ", w_times[:numberOfNs])
print("len w-start[:numberOfNs]: ", w_start[:numberOfNs])
print("number n's ", numberOfNs)
'''

''' ##Costs for each of the classes for each window
plt.plot(w_times, confidences, label='Cost')
plt.xlabel("Time (s)")
plt.ylabel("Cost")
plt.axvline(w_times[predict_time], linestyle="-", color="b", label="Stopping")
plt.axvline(2, linestyle="--", color="k", label="Onset")
plt.axhline(0.5, linestyle="-", color="k", label="Threshold")
plt.title("Cost over Time")
plt.legend()
plt.show()


##Accuracy for each window
plt.plot(w_times, score_this_window, label="Score")
plt.xlabel("Time (s)")
plt.ylabel("Accuracy")
plt.axvline(w_times[predict_time], linestyle="-", color="b", label="Stopping")
plt.axvline(2, linestyle="--", color="k", label="Onset")
plt.axhline(class_balance, linestyle="-", color="k", label="Chance")
plt.title("Classification accuracy over Time")
plt.legend()
plt.show()

plt.plot(w_times, confidence_this_window, label="Score")
plt.xlabel("Time (s)")
plt.ylabel("Confidence")
plt.ylim(0,1)
plt.axvline(w_times[predict_time], linestyle="-", color="b", label="Stopping")
plt.axvline(2, linestyle="--", color="k", label="Onset")
#plt.axhline(threshold, linestyle="-", color="k", label="Threshold")
plt.title("Model confindence over Time")
plt.legend()
plt.show()'''

'''
        
    #sanity check
    for epoch_idx in range(len(test_idx)): 
        current_n = 0
        current_epoch+=1
        predict = False 
        numTimesBelowThreshold = 0
        confidences_across_windows_full =[]
        probs_across_windows = []
        ##Earl pred
        for n in w_start:
            X_test_window = csp.transform(epochs_data[test_idx][:, :, n:(n + w_length)])
            X_test_epoch_window = X_test_window[epoch_idx+1]

            #Early prediction
            probabilities = lda.predict_proba([X_test_epoch_window])

            if len(probs_across_windows) == 0:
               probs_across_windows = probabilities
            else:
                probs_across_windows = np.vstack((probs_across_windows, probabilities))

            probabilities = np.array(probabilities)
            probabilities = probabilities.flatten()
            # predict becomes true to predict ealrly then go to the next epoch
            predict, confidence, numTimesBelowThreshold= early_pred(probabilities, predict, numTimesBelowThreshold, patience)
            confidences_across_windows_full.append(confidence)
            score = lda.score(X_test_epoch_window.reshape(1, -1), [y_test[epoch_idx+1]])
        plt.plot(w_times, confidences_across_windows_full, label='confidences_across_windows_full')
        plt.xlabel("Time (s)")
        plt.ylabel("confidence")
        plt.axvline(predict_time, linestyle="-", color="b", label="Stopping")
        plt.axvline(2, linestyle="--", color="k", label="Onset")
        plt.axhline(0.5, linestyle="-", color="k", label="Threshold")
        plt.title("Cost over Time")
        plt.legend()
        plt.show()
        y_test = y_test+4
        print("right label:", class_names[y_test[epoch_idx+1]])
            ##Probabiltiies for each of the classes for each window
        plt.plot(w_times, probs_across_windows, label=[class_names[label] for label in [1, 2, 3, 4]])
        plt.xlabel("Time (s)")
        plt.ylabel("Probabilities")
        plt.axvline(predict_time, linestyle="-", color="b", label="Stopping")
        plt.axvline(2, linestyle="--", color="k", label="Onset")
        plt.axhline(0.5, linestyle="-", color="k", label="Threshold")
        plt.title("Classification probabilities over Time")
        plt.legend()
        plt.show()
        break
    break
'''

ROOT: /workspaces/UQ_Early_prediction_MI_BCI


/workspaces/UQ_Early_prediction_MI_BCI/.conda/lib/python3.11/site-packages/moabb/pipelines/__init__.py:26: ModuleNotFoundError: Tensorflow is not installed. You won't be able to use these MOABB pipelines if you attempt to do so.
  warn(
Choosing from all possible events


To use the get_shape_from_baseconcar, InputShapeSetterEEG, BraindecodeDatasetLoaderyou need to install `braindecode`.`pip install braindecode` or Please refer to `https://braindecode.org`.


KeyboardInterrupt: 