In [47]:
import pandas as pd
import numpy as np
from sklearn.svm import SVC
from sklearn.svm import LinearSVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import make_scorer, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, classification_report
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedGroupKFold, cross_validate


In [48]:
data = pd.read_csv("data/processed/windows_21_10_balanced_all_powers.csv")
data

Unnamed: 0,EEG FP1-REF_a5_mean,EEG FP1-REF_a5_median,EEG FP1-REF_a5_variance,EEG FP1-REF_a5_std,EEG FP1-REF_a5_skew,EEG FP1-REF_a5_kurtosis,EEG FP1-REF_a5_rms,EEG FP1-REF_d5_mean,EEG FP1-REF_d5_median,EEG FP1-REF_d5_variance,...,EEG A2-REF_d3_kurtosis,EEG A2-REF_d3_rms,EEG A2-REF_delta,EEG A2-REF_theta,EEG A2-REF_alpha,EEG A2-REF_beta,EEG A2-REF_gamma,patient_id,asymmetry,label
0,1.718636,1.635653,0.616050,0.784889,2.946016,12.457156,1.889381,0.011480,0.010945,0.020342,...,70.424802,0.050666,0.010359,0.001685,0.000634,0.000310,0.000025,aaaaasgd,0.713307,seiz
1,2.494826,2.533521,0.765934,0.875177,-0.202479,-0.331677,2.643878,0.000950,0.010180,0.062455,...,2.000887,0.048877,0.020027,0.003068,0.000918,0.000516,0.000035,aaaaasgd,0.197368,seiz
2,2.659979,2.719828,0.745499,0.863423,-0.219236,-0.033053,2.796603,-0.004994,0.005068,0.073703,...,0.478473,0.045530,0.017030,0.003126,0.001091,0.000431,0.000034,aaaaasgd,-0.140446,seiz
3,2.779673,2.738822,0.558120,0.747074,0.107827,0.835777,2.878316,0.011395,0.002752,0.058272,...,0.550823,0.036712,0.011885,0.001920,0.000601,0.000296,0.000033,aaaaasgd,0.156857,seiz
4,2.796707,2.801029,0.166170,0.407640,0.225649,10.275338,2.826259,0.001994,0.000375,0.002450,...,6.102927,0.042822,0.006700,0.000556,0.000353,0.000309,0.000206,aaaaasgd,0.219149,seiz
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24837,2.087338,2.057246,0.300920,0.548562,1.018727,6.144104,2.158217,0.010415,-0.008610,0.032719,...,15.217666,0.024455,0.004541,0.000222,0.000124,0.000123,0.000044,aaaaanlx,0.242795,bckg
24838,2.212204,2.233945,0.432998,0.658026,0.339207,0.753278,2.307996,-0.005381,0.002326,0.038419,...,9.693751,0.033043,0.005203,0.000232,0.000138,0.000225,0.000034,aaaaanlx,0.086373,bckg
24839,3.291857,3.315987,0.493708,0.702643,-0.027024,-0.155202,3.366011,-0.004997,0.002195,0.041901,...,0.120727,0.067806,0.014783,0.004227,0.002204,0.001024,0.000213,aaaaapzi,0.054461,bckg
24840,2.723730,2.809229,0.674350,0.821188,-0.052223,-0.041784,2.844830,0.021821,0.004650,0.090626,...,0.579384,0.054626,0.017173,0.003892,0.001355,0.000835,0.000051,aaaaapzi,0.097686,bckg


In [49]:
X = data.drop(columns=["label", "patient_id"])
y = data["label"].map({'bckg': 0, 'seiz': 1})
groups = data["patient_id"]

In [50]:
y_frame = y.to_frame()
y_frame[y_frame["label"] == 0].count(), y_frame[y_frame["label"] == 1].count()

(label    12421
 dtype: int64,
 label    12421
 dtype: int64)

In [52]:
def remove_highly_correlated_features(X, threshold=0.95):
    # Calculate the correlation matrix
    corr_matrix = X.corr().abs()
    
    # Select the upper triangle of the correlation matrix
    upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
    
    # Find features with correlation greater than the threshold
    to_drop = [column for column in upper.columns if any(upper[column] > threshold)]
    
    # Drop highly correlated features
    X_reduced = X.drop(columns=to_drop)
    
    return X_reduced, to_drop

X, _ = remove_highly_correlated_features(X, 0.95)
X

Unnamed: 0,EEG FP1-REF_a5_mean,EEG FP1-REF_a5_variance,EEG FP1-REF_a5_skew,EEG FP1-REF_a5_kurtosis,EEG FP1-REF_d5_mean,EEG FP1-REF_d5_median,EEG FP1-REF_d5_variance,EEG FP1-REF_d5_skew,EEG FP1-REF_d5_kurtosis,EEG FP1-REF_d4_mean,...,EEG A2-REF_d4_kurtosis,EEG A2-REF_d3_mean,EEG A2-REF_d3_median,EEG A2-REF_d3_variance,EEG A2-REF_d3_skew,EEG A2-REF_d3_kurtosis,EEG A2-REF_delta,EEG A2-REF_beta,EEG A2-REF_gamma,asymmetry
0,1.718636,0.616050,2.946016,12.457156,0.011480,0.010945,0.020342,-0.332376,8.929693,-0.000509,...,11.229017,-0.001337,-0.000658,0.002565,-4.325576,70.424802,0.010359,0.000310,0.000025,0.713307
1,2.494826,0.765934,-0.202479,-0.331677,0.000950,0.010180,0.062455,-0.532189,1.642493,-0.000337,...,0.181988,-0.000921,-0.001643,0.002388,0.045291,2.000887,0.020027,0.000516,0.000035,0.197368
2,2.659979,0.745499,-0.219236,-0.033053,-0.004994,0.005068,0.073703,-0.380172,2.070333,-0.002392,...,0.407191,-0.000881,-0.001022,0.002072,-0.078900,0.478473,0.017030,0.000431,0.000034,-0.140446
3,2.779673,0.558120,0.107827,0.835777,0.011395,0.002752,0.058272,0.544020,1.174393,0.000788,...,0.206455,-0.001171,-0.001367,0.001346,0.075182,0.550823,0.011885,0.000296,0.000033,0.156857
4,2.796707,0.166170,0.225649,10.275338,0.001994,0.000375,0.002450,0.252442,0.407498,0.002452,...,0.065170,0.000550,0.000464,0.001833,0.140030,6.102927,0.006700,0.000309,0.000206,0.219149
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24837,2.087338,0.300920,1.018727,6.144104,0.010415,-0.008610,0.032719,1.324522,5.272375,-0.003507,...,14.242037,0.000126,-0.000318,0.000598,0.886296,15.217666,0.004541,0.000123,0.000044,0.242795
24838,2.212204,0.432998,0.339207,0.753278,-0.005381,0.002326,0.038419,0.062841,1.219399,0.005062,...,6.820851,0.001054,0.000362,0.001091,-0.124650,9.693751,0.005203,0.000225,0.000034,0.086373
24839,3.291857,0.493708,-0.027024,-0.155202,-0.004997,0.002195,0.041901,0.075979,2.132897,-0.001970,...,-0.176125,-0.001188,0.001952,0.004596,-0.064375,0.120727,0.014783,0.001024,0.000213,0.054461
24840,2.723730,0.674350,-0.052223,-0.041784,0.021821,0.004650,0.090626,0.549200,2.662391,0.008586,...,0.265363,-0.000759,0.000075,0.002983,-0.118337,0.579384,0.017173,0.000835,0.000051,0.097686


In [53]:
cv = StratifiedGroupKFold(n_splits=2)

scoring = {
    'accuracy': make_scorer(accuracy_score),
    'precision': make_scorer(precision_score, average='weighted'),
    'recall': make_scorer(recall_score, average='weighted'),
    'f1': make_scorer(f1_score, average='weighted'),
    'roc_auc': make_scorer(roc_auc_score, average='weighted', multi_class='ovr')
}

def cross_validate_pipeline(pipeline):
    results = cross_validate(pipeline, X, y, groups=groups, cv=cv, scoring=scoring, return_train_score=False)
    avg_results = {metric: np.mean(values) for metric, values in results.items()}

    return results, avg_results

In [57]:
rf_pipeline = Pipeline([
    ('scale', StandardScaler()),
    ('rf', RandomForestClassifier(n_estimators=400, random_state=42))
])

rf_results, rf_avg_results = cross_validate_pipeline(rf_pipeline)
rf_results, rf_avg_results

({'fit_time': array([144.48175621, 128.05691552]),
  'score_time': array([0.82275534, 0.85885286]),
  'test_accuracy': array([0.68840697, 0.67454575]),
  'test_precision': array([0.7031015 , 0.70404446]),
  'test_recall': array([0.68840697, 0.67454575]),
  'test_f1': array([0.68261905, 0.66241419]),
  'test_roc_auc': array([0.68831995, 0.67466776])},
 {'fit_time': 136.26933586597443,
  'score_time': 0.8408041000366211,
  'test_accuracy': 0.6814763561998243,
  'test_precision': 0.7035729774910713,
  'test_recall': 0.6814763561998243,
  'test_f1': 0.6725166181059813,
  'test_roc_auc': 0.6814938532974513})

In [55]:
svm_pipeline = Pipeline([
     ('scale', StandardScaler()),
     ('svm', SVC(kernel='rbf', C=1, gamma=0.001, probability=True, random_state=42))
])

svm_results, svm_avg_results = cross_validate_pipeline(svm_pipeline) 
svm_results, svm_avg_results

({'fit_time': array([ 94.37809467, 114.96145773]),
  'score_time': array([22.24068308, 31.04618216]),
  'test_accuracy': array([0.68671396, 0.70220293]),
  'test_precision': array([0.69800568, 0.71895961]),
  'test_recall': array([0.68671396, 0.70220293]),
  'test_f1': array([0.6821394, 0.6964389]),
  'test_roc_auc': array([0.68663667, 0.70229164])},
 {'fit_time': 104.66977620124817,
  'score_time': 26.6434326171875,
  'test_accuracy': 0.6944584448765911,
  'test_precision': 0.7084826481640003,
  'test_recall': 0.6944584448765911,
  'test_f1': 0.6892891507879642,
  'test_roc_auc': 0.6944641577553108})

In [56]:
knn_pipeline = Pipeline([
     ('scale', StandardScaler()),
     ('knn', KNeighborsClassifier(n_neighbors=10))
])

knn_results, knn_avg_results = cross_validate_pipeline(knn_pipeline)
knn_results, knn_avg_results

({'fit_time': array([0.11489701, 0.11488914]),
  'score_time': array([1.81672502, 0.73486686]),
  'test_accuracy': array([0.62294421, 0.59543335]),
  'test_precision': array([0.66193572, 0.6088158 ]),
  'test_recall': array([0.62294421, 0.59543335]),
  'test_f1': array([0.59865895, 0.58273401]),
  'test_roc_auc': array([0.62278562, 0.5955456 ])},
 {'fit_time': 0.11489307880401611,
  'score_time': 1.2757959365844727,
  'test_accuracy': 0.609188780478876,
  'test_precision': 0.6353757568523786,
  'test_recall': 0.609188780478876,
  'test_f1': 0.5906964801471435,
  'test_roc_auc': 0.6091656137170739})