In [None]:
from glob import glob
import os
import mne
import numpy as np
from sklearn.model_selection import GroupKFold, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline  # Add this import
from sklearn.metrics import accuracy_score

def train_eeg_classifier(data_path):
    all_file_paths = glob(os.path.join(data_path, '*.edf'))
    healthy_file_paths = [i for i in all_file_paths if 'h' in i.split('\\')[1]]
    patient_file_paths = [i for i in all_file_paths if 's' in i.split('\\')[1]]

    def read_data(file_path):
        data = mne.io.read_raw_edf(file_path, preload=True)
        data.set_eeg_reference()
        data.filter(l_freq=0.5, h_freq=45)
        epochs = mne.make_fixed_length_epochs(data, duration=5, overlap=1)
        array = epochs.get_data()
        return array

    control_epochs_array = [read_data(i) for i in healthy_file_paths]
    patient_epochs_array = [read_data(i) for i in patient_file_paths]

    control_epoch_labels = [np.zeros(len(i)) for i in control_epochs_array]
    patient_epoch_labels = [np.ones(len(i)) for i in patient_epochs_array]

    data_list = control_epochs_array + patient_epochs_array
    label_list = control_epoch_labels + patient_epoch_labels
    group_list = [[i] * len(j) for i, j in enumerate(data_list)]

    data_array = np.vstack(data_list)
    label_array = np.hstack(label_list)
    group_array = np.hstack(group_list)

    # Feature extraction function
    def extract_features(x):
        # Add your feature extraction code here
        # Example: Calculate mean, variance, spectral features, etc.
        return np.mean(x, axis=-1)

    features = [extract_features(d) for d in data_array]
    features_array = np.array(features)

    # Tune hyperparameters
    clf = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42)

    gkf = GroupKFold(n_splits=5)
    pipe = Pipeline([('scaler', StandardScaler()), ('clf', clf)])

    param_grid = {
        'clf__n_estimators': [50, 100, 200],
        'clf__max_depth': [None, 10, 20],
    }

    gscv = GridSearchCV(pipe, param_grid, cv=gkf, n_jobs=12)
    gscv.fit(features_array, label_array, groups=group_array)

    return gscv

if __name__ == "__main__":
    trained_classifier = train_eeg_classifier('sampledb')

    print("Best Cross-Validation Score:", trained_classifier.best_score_)