In [None]:
import mne
import numpy as np
from scipy.signal import butter, filtfilt
from mne.decoding import CSP
from sklearn.neighbors import KNeighborsClassifier  # Import KNN classifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import itertools
 
# Define the bandpass filter function
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    y = filtfilt(b, a, data, axis=-1)
    return y
 
def classification(file_paths):
    global X_combined, y_combined, raw, X_train, X_test, X, y
 
    # Initialize lists to hold combined data and labels
    X_combined = []
    y_combined = []
 
    for file_path in file_paths:
        # Load the .gdf file
        raw = mne.io.read_raw_gdf(file_path, preload=True, eog=['EOG:ch01', 'EOG:ch02', 'EOG:ch03'])
 
        # Drop EOG channels
        raw.drop_channels(['EOG:ch01', 'EOG:ch02', 'EOG:ch03'])
 
        # Extract events
        events, event_id = mne.events_from_annotations(raw)
       
        e = event_id
       
        event_id = {'left_hand': e['769'], 'right_hand': e['770']}
 
        epochs = mne.Epochs(raw, events, event_id, tmin=3, tmax=7, baseline=None, preload=True)
 
        # Get the data and labels
        X = epochs.get_data()  # EEG signals: (n_epochs, n_channels, n_times)
        y = epochs.events[:, -1]  # Labels
 
        # Apply a bandpass filter from 8 Hz to 30 Hz
        X_filtered = np.array([butter_bandpass_filter(epoch, 8, 30, raw.info['sfreq']) for epoch in X])
 
        # Append filtered data and labels to the combined lists
        X_combined.append(X_filtered)
        y_combined.append(y)
 
    # Concatenate all data and labels
    X_combined = np.concatenate(X_combined, axis=0)
    y_combined = np.concatenate(y_combined, axis=0)
 
    # Split data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X_combined, y_combined, test_size=0.3, random_state=120)
 
    # Ensure there are at least two classes in the training set
    while len(np.unique(y_train)) < 2:
        X_train, X_test, y_train, y_test = train_test_split(X_combined, y_combined, test_size=0.3, random_state=120)
       
        # After re-splitting, check again
        if len(np.unique(y_train)) >= 2:
            break
 
    # Check the number of unique classes in y_train after the loop
    if len(np.unique(y_train)) < 2:
        raise ValueError("Could not find at least two classes in the training set.")
 
    # Initialize CSP
    csp = CSP(n_components=4, reg=None, log=True, norm_trace=False)
 
    # Fit CSP on training data
    csp.fit(X_train, y_train)
 
    # Transform training and testing data with CSP
    X_train_csp = csp.transform(X_train)
    X_test_csp = csp.transform(X_test)
 
    # Standardize the data
    scaler = StandardScaler()
    X_train_csp = scaler.fit_transform(X_train_csp)
    X_test_csp = scaler.transform(X_test_csp)
 
    # Initialize and fit the KNN classifier
    knn = KNeighborsClassifier(n_neighbors=5)  # You can adjust the number of neighbors
    knn.fit(X_train_csp, y_train)
 
    # Test the model
    accuracy = knn.score(X_test_csp, y_test)
 
    # Compute confusion matrix
    predictions = knn.predict(X_test_csp)
    cm = confusion_matrix(y_test, predictions)
 
    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    classes = np.unique(y_test)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)
 
    fmt = 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
 
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()
 
    return accuracy
 
allData = {
    "patient_1": ['data/B0401T.gdf', 'data/B0402T.gdf', 'data/B0403T.gdf']
}
 
accuracy_list = []
for data_files in allData.values():
    accuracy_list.append(classification(data_files))
 
print("Accuracy List:", accuracy_list)
 