In [1]:
import pandas as pd
import numpy as np
from sklearn.svm import SVC
from sklearn.svm import LinearSVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import make_scorer, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, classification_report
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedGroupKFold, cross_validate


In [2]:
data = pd.read_csv("data/processed/windows_21_10_balanced_avg_energy.csv")
data

Unnamed: 0,EEG FP1-REF_a5_mean,EEG FP1-REF_a5_median,EEG FP1-REF_a5_variance,EEG FP1-REF_a5_std,EEG FP1-REF_a5_skew,EEG FP1-REF_a5_kurtosis,EEG FP1-REF_a5_rms,EEG FP1-REF_a5_energy,EEG FP1-REF_d5_mean,EEG FP1-REF_d5_median,...,theta,alpha,beta,gamma,alpha_beta_ratio,theta_beta_ratio,theta_alpha_beta_ratio,theta_alpha_beta_alpha_ratio,alpha_theta_ratio,theta_alpha_ratio
0,2.981401,3.060055,0.534266,0.730935,-0.252371,0.558448,3.069693,1601.912705,-0.001281,-0.012881,...,0.001994,0.000961,0.000460,0.000073,2.086576,4.331259,6.417835,2.079273,0.481748,2.075773
1,2.807923,2.814208,0.533092,0.730132,-0.339051,0.866265,2.901297,1430.979319,-0.003017,0.012290,...,0.001662,0.000641,0.000462,0.000066,1.387719,3.597425,4.985144,2.087827,0.385753,2.592329
2,2.709073,2.715860,0.386459,0.621658,0.188169,0.799525,2.779484,1313.340551,-0.007780,0.002819,...,0.002669,0.001337,0.000621,0.000090,2.152151,4.294149,6.446300,2.045048,0.501182,1.995283
3,2.201688,2.199628,0.380018,0.616456,0.500788,1.757253,2.286361,888.665778,0.017664,-0.000250,...,0.002148,0.000889,0.000514,0.000069,1.727822,4.176825,5.904647,2.164601,0.413669,2.417393
4,2.398083,2.438879,0.391035,0.625328,0.065136,0.587532,2.478273,1044.112085,-0.005077,-0.007493,...,0.003299,0.001523,0.000669,0.000082,2.275867,4.929678,7.205544,2.199584,0.461666,2.166066
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24837,2.806724,2.846410,0.290900,0.539351,0.293917,2.919968,2.858076,1388.661699,0.025440,0.019509,...,0.003056,0.002869,0.001686,0.000305,1.702429,1.813357,3.515786,1.300973,0.938827,1.065159
24838,2.874045,2.884819,0.346355,0.588519,0.131109,-0.333775,2.933682,1463.103063,0.011056,0.005731,...,0.003490,0.005281,0.002356,0.000194,2.240951,1.481144,3.722095,1.148458,1.512986,0.660945
24839,2.658772,2.628194,0.407355,0.638244,1.431019,5.818943,2.734305,1270.992304,0.042588,0.013024,...,0.003181,0.002110,0.001204,0.000408,1.752798,2.642184,4.394982,1.596551,0.663390,1.507409
24840,2.971984,2.975812,0.161001,0.401249,-0.047986,1.678566,2.998948,1528.927629,0.023938,0.002973,...,0.000637,0.001391,0.003218,0.002610,0.432168,0.197878,0.630046,0.439925,2.184012,0.457873


In [3]:
X = data.drop(columns=["label", "patient_id"])
y = data["label"].map({'bckg': 0, 'seiz': 1})
groups = data["patient_id"]

In [4]:
y_frame = y.to_frame()
y_frame[y_frame["label"] == 0].count(), y_frame[y_frame["label"] == 1].count()

(label    12421
 dtype: int64,
 label    12421
 dtype: int64)

In [5]:
def remove_highly_correlated_features(X, threshold=0.95):
    # Calculate the correlation matrix
    corr_matrix = X.corr().abs()
    
    # Select the upper triangle of the correlation matrix
    upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
    
    # Find features with correlation greater than the threshold
    to_drop = [column for column in upper.columns if any(upper[column] > threshold)]
    
    # Drop highly correlated features
    X_reduced = X.drop(columns=to_drop)
    
    return X_reduced, to_drop

X, _ = remove_highly_correlated_features(X, 0.95)
X

Unnamed: 0,EEG FP1-REF_a5_mean,EEG FP1-REF_a5_variance,EEG FP1-REF_a5_skew,EEG FP1-REF_a5_kurtosis,EEG FP1-REF_d5_mean,EEG FP1-REF_d5_median,EEG FP1-REF_d5_variance,EEG FP1-REF_d5_skew,EEG FP1-REF_d5_kurtosis,EEG FP1-REF_d4_mean,...,theta,alpha,beta,gamma,alpha_beta_ratio,theta_beta_ratio,theta_alpha_beta_ratio,theta_alpha_beta_alpha_ratio,alpha_theta_ratio,theta_alpha_ratio
0,2.981401,0.534266,-0.252371,0.558448,-0.001281,-0.012881,0.051580,0.006228,3.536087,0.003456,...,0.001994,0.000961,0.000460,0.000073,2.086576,4.331259,6.417835,2.079273,0.481748,2.075773
1,2.807923,0.533092,-0.339051,0.866265,-0.003017,0.012290,0.077742,0.274656,3.802841,-0.000826,...,0.001662,0.000641,0.000462,0.000066,1.387719,3.597425,4.985144,2.087827,0.385753,2.592329
2,2.709073,0.386459,0.188169,0.799525,-0.007780,0.002819,0.070420,-0.251142,3.234767,-0.000575,...,0.002669,0.001337,0.000621,0.000090,2.152151,4.294149,6.446300,2.045048,0.501182,1.995283
3,2.201688,0.380018,0.500788,1.757253,0.017664,-0.000250,0.053857,0.539557,2.679757,-0.000332,...,0.002148,0.000889,0.000514,0.000069,1.727822,4.176825,5.904647,2.164601,0.413669,2.417393
4,2.398083,0.391035,0.065136,0.587532,-0.005077,-0.007493,0.111201,0.127460,2.854489,0.004078,...,0.003299,0.001523,0.000669,0.000082,2.275867,4.929678,7.205544,2.199584,0.461666,2.166066
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24837,2.806724,0.290900,0.293917,2.919968,0.025440,0.019509,0.097798,0.245740,1.577897,0.002293,...,0.003056,0.002869,0.001686,0.000305,1.702429,1.813357,3.515786,1.300973,0.938827,1.065159
24838,2.874045,0.346355,0.131109,-0.333775,0.011056,0.005731,0.137951,-0.276828,0.961426,-0.007523,...,0.003490,0.005281,0.002356,0.000194,2.240951,1.481144,3.722095,1.148458,1.512986,0.660945
24839,2.658772,0.407355,1.431019,5.818943,0.042588,0.013024,0.100872,1.858001,9.379107,0.001540,...,0.003181,0.002110,0.001204,0.000408,1.752798,2.642184,4.394982,1.596551,0.663390,1.507409
24840,2.971984,0.161001,-0.047986,1.678566,0.023938,0.002973,0.028034,0.164423,0.457677,-0.036690,...,0.000637,0.001391,0.003218,0.002610,0.432168,0.197878,0.630046,0.439925,2.184012,0.457873


In [10]:
cv = StratifiedGroupKFold(n_splits=5)

scoring = {
    'accuracy': make_scorer(accuracy_score),
    'precision': make_scorer(precision_score, average='weighted'),
    'recall': make_scorer(recall_score, average='weighted'),
    'f1': make_scorer(f1_score, average='weighted'),
    'roc_auc': make_scorer(roc_auc_score, average='weighted', multi_class='ovr')
}

def cross_validate_pipeline(pipeline):
    results = cross_validate(pipeline, X, y, groups=groups, cv=cv, scoring=scoring, return_train_score=False)
    avg_results = {metric: np.mean(values) for metric, values in results.items()}

    return results, avg_results

In [11]:
rf_pipeline = Pipeline([
    ('scale', StandardScaler()),
    ('rf', RandomForestClassifier(n_estimators=400, random_state=42))
])

rf_results, rf_avg_results = cross_validate_pipeline(rf_pipeline)
rf_results, rf_avg_results

({'fit_time': array([216.16130328, 215.42270494, 220.80689049, 224.71254206,
         223.94002724]),
  'score_time': array([0.42522621, 0.43571401, 0.42062736, 0.4227562 , 0.43347239]),
  'test_accuracy': array([0.70177866, 0.73242454, 0.75900672, 0.69575957, 0.76374077]),
  'test_precision': array([0.71947987, 0.73350158, 0.76647439, 0.71575503, 0.77667472]),
  'test_recall': array([0.70177866, 0.73242454, 0.75900672, 0.69575957, 0.76374077]),
  'test_f1': array([0.69785213, 0.73201864, 0.75705641, 0.68753306, 0.7606742 ]),
  'test_roc_auc': array([0.70588863, 0.7321832 , 0.75827647, 0.69378576, 0.76291928])},
 {'fit_time': 220.20869359970092,
  'score_time': 0.42755923271179197,
  'test_accuracy': 0.7305420506904095,
  'test_precision': 0.7423771178844463,
  'test_recall': 0.7305420506904095,
  'test_f1': 0.7270268876570205,
  'test_roc_auc': 0.7306106680134278})

In [8]:
svm_pipeline = Pipeline([
     ('scale', StandardScaler()),
     ('svm', SVC(kernel='rbf', C=1, gamma=0.001, probability=True, random_state=42))
])

svm_results, svm_avg_results = cross_validate_pipeline(svm_pipeline) 
svm_results, svm_avg_results

({'fit_time': array([92.16159606, 91.01945186]),
  'score_time': array([26.7877748 , 26.85553145]),
  'test_accuracy': array([0.65510582, 0.71703584]),
  'test_precision': array([0.68149959, 0.72362189]),
  'test_recall': array([0.65510582, 0.71703584]),
  'test_f1': array([0.64166014, 0.71506641]),
  'test_roc_auc': array([0.65446859, 0.71731286])},
 {'fit_time': 91.59052395820618,
  'score_time': 26.821653127670288,
  'test_accuracy': 0.6860708308572001,
  'test_precision': 0.7025607413649252,
  'test_recall': 0.6860708308572001,
  'test_f1': 0.6783632737726832,
  'test_roc_auc': 0.685890728752975})

In [9]:
knn_pipeline = Pipeline([
     ('scale', StandardScaler()),
     ('knn', KNeighborsClassifier(n_neighbors=10))
])

knn_results, knn_avg_results = cross_validate_pipeline(knn_pipeline)
knn_results, knn_avg_results

({'fit_time': array([0.10662866, 0.10548067]),
  'score_time': array([0.69288206, 0.70506382]),
  'test_accuracy': array([0.59837451, 0.64115989]),
  'test_precision': array([0.63549262, 0.65460999]),
  'test_recall': array([0.59837451, 0.64115989]),
  'test_f1': array([0.56792637, 0.63355966]),
  'test_roc_auc': array([0.59750011, 0.64163718])},
 {'fit_time': 0.10605466365814209,
  'score_time': 0.6989729404449463,
  'test_accuracy': 0.6197671971773879,
  'test_precision': 0.6450513075445099,
  'test_recall': 0.6197671971773879,
  'test_f1': 0.6007430160404523,
  'test_roc_auc': 0.6195686457433487})