In [None]:
import mne
import numpy as np
from scipy.signal import butter, filtfilt
from mne.decoding import CSP
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# Define the bandpass filter functions
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    y = filtfilt(b, a, data, axis=-1)
    return y

def classification(file_paths):
    global X_combined, y_combined, raw, X_train, X_test, X, y, count, p
    # Initialize lists to hold combined data and labels
    X_combined = []
    y_combined = []
    count = 1
    for file_path in file_paths:
        # Load the .gdf file
        raw = mne.io.read_raw_gdf(file_path, preload=True, eog=['EOG:ch01', 'EOG:ch02', 'EOG:ch03'])

        # Drop EOG channels
        raw.drop_channels(['EOG:ch01', 'EOG:ch02', 'EOG:ch03'])

        # Extract events
        events, event_id = mne.events_from_annotations(raw)

        # Define event ids for left hand (class 1) and right hand (class 2) motor imagery tasks
        e = event_id
        j = 0
        k = 1
        for i in e:
            j += 1
            k += 1
            if i == '769':
                event_id = {'left_hand': j, 'right_hand': k}
                print('event_iddddd : ', event_id)
                break
            elif i == '781':
                event_id = {'Cue': j}
                print('event_iddddd : ', event_id)
                break

        epochs = mne.Epochs(raw, events, event_id, tmin=3, tmax=7, baseline=None, preload=True)

        print(epochs)

        # Get the data and labels
        X = epochs.get_data()  # EEG signals: (n_epochs, n_channels, n_times)
        y = epochs.events[:, -1]  # Labels

        # Apply a bandpass filter from 8 Hz to 30 Hz
        X_filtered = np.array([butter_bandpass_filter(epoch, 8, 30, raw.info['sfreq']) for epoch in X])

        # Append filtered data and labels to the combined lists
        X_combined.append(X_filtered)
        y_combined.append(y)

    # Concatenate all data and labels
    X_combined = np.concatenate(X_combined, axis=0)
    y_combined = np.concatenate(y_combined, axis=0)

    # Split data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X_combined, y_combined, test_size=0.3, random_state=120)

    # Ensure there are at least two classes in the training set
    while len(np.unique(y_train)) < 2:
        X_train, X_test, y_train, y_test = train_test_split(X_combined, y_combined, test_size=0.3, random_state=120)

        # After re-splitting, check again
        if len(np.unique(y_train)) >= 2:
            break

    # Check the number of unique classes in y_train after the loop
    if len(np.unique(y_train)) < 2:
        raise ValueError("Could not find at least two classes in the training set.")

    # Initialize CSP
    csp = CSP(n_components=4, reg=None, log=True, norm_trace=False)

    # Fit CSP on training data
    csp.fit(X_train, y_train)

    # Transform training and testing data with CSP
    X_train_csp = csp.transform(X_train)
    X_test_csp = csp.transform(X_test)

    # Standardize the data
    scaler = StandardScaler()
    X_train_csp = scaler.fit_transform(X_train_csp)
    X_test_csp = scaler.transform(X_test_csp)

    # Initialize and fit the KNN classifier
    knn = KNeighborsClassifier(n_neighbors=5)  # You can adjust the number of neighbors
    knn.fit(X_train_csp, y_train)

    # Test the model
    accuracy = knn.score(X_test_csp, y_test)
    return accuracy

allData = {
    "patient_1": ['data/B0101T.gdf', 'data/B0102T.gdf', 'data/B0103T.gdf', 'data/B0104E.gdf', 'data/B0105E.gdf'],
    "patient_2": ['data/B0201T.gdf', 'data/B0202T.gdf', 'data/B0203T.gdf', 'data/B0204E.gdf', 'data/B0205E.gdf'],
    "patient_3": ['data/B0301T.gdf', 'data/B0302T.gdf', 'data/B0303T.gdf', 'data/B0304E.gdf', 'data/B0305E.gdf'],
    "patient_4": ['data/B0401T.gdf', 'data/B0402T.gdf', 'data/B0403T.gdf', 'data/B0404E.gdf', 'data/B0405E.gdf'],
    "patient_5": ['data/B0501T.gdf', 'data/B0502T.gdf', 'data/B0503T.gdf', 'data/B0504E.gdf', 'data/B0505E.gdf'],
    "patient_6": ['data/B0601T.gdf', 'data/B0602T.gdf', 'data/B0603T.gdf', 'data/B0604E.gdf', 'data/B0605E.gdf'],
    "patient_7": ['data/B0701T.gdf', 'data/B0702T.gdf', 'data/B0703T.gdf', 'data/B0704E.gdf', 'data/B0705E.gdf'],
    "patient_8": ['data/B0801T.gdf', 'data/B0802T.gdf', 'data/B0803T.gdf', 'data/B0804E.gdf', 'data/B0805E.gdf'],
    "patient_9": ['data/B0901T.gdf', 'data/B0902T.gdf', 'data/B0903T.gdf', 'data/B0904E.gdf', 'data/B0905E.gdf']
}
p = 1
accuracy_list = []
for data_files in allData.values():
    accuracy_list.append(classification(data_files))
    p += 1

mean_accuracy = np.mean(accuracy_list)
# Plot the accuracies
plt.figure(figsize=(10, 5))
plt.bar(allData.keys(), accuracy_list)
plt.axhline(y=mean_accuracy, color='r', linestyle='--', label=f'Mean Accuracy: {mean_accuracy:.4f}')
plt.xlabel('Patients')
plt.ylabel('Accuracy')

for index, value in enumerate(accuracy_list):
    plt.text(index, value + 0.01, f'{value:.4f}', ha="center")
plt.title('Classification Accuracy for Each Patient')
plt.legend(loc='center', bbox_to_anchor=(0.5, 0.5))
plt.show()
