# Epileptic Seizure Classification with Random Forest
This notebook contains the classification of time series EEG data for the detection of epileptic seizures based on the preprocessed CHB-MIT Scalp EEG Database using a Random Forest classifier. <br>
The codes is structured as followed:
1. [Imports](#1-imports)
2. [Load Dataset](#2-load-dataset)
3. [Split Dataset](#3-split-dataset)
4. [Define Space & Optimization Function](#4-define-space--optimization-function)
5. [Train Optimized Classifier](#5-optimize-classifier)
6. [Validate Results](#6-validate-results)
7. [Conclusions](#7-conclusion)

## 1. Imports
Import requiered libraries. <br>
External packages can be installed via the `pip install -r requirements.txt` command.

In [None]:
! pip install -r ../requirements.txt

In [1]:
# Import built-in libraries
import time

# Import datascience libraries
import numpy as np

# Import preprocessing-libraries, classifier & metrics
from sklearn.ensemble import RandomForestClassifier
import joblib
from sklearn.model_selection import train_test_split, cross_validate, StratifiedKFold
from sklearn.metrics import f1_score, roc_auc_score, precision_score, recall_score, make_scorer, classification_report
from imblearn.metrics import geometric_mean_score

# Import visualization libraries
import plotly.graph_objects as go
from prettytable import PrettyTable

# Import optimization library
from hyperopt import fmin, hp, tpe, STATUS_OK, Trials, STATUS_STRINGS
from hyperopt.pyll import scope

## 2. Load Dataset
In order to load the preprocessed dataset, that was created with the notebook `00_Preprocessing.ipynb`, is loaded and the numpy Arrays for the features and labels are extracted. <br>
To enshure a functional distribution of the classes in the dataset, the classes with the respective amounts are plotted.

In [12]:
dataset = np.load('../00_Data/Processed-Data/classification_dataset_mean.npz')
X = dataset["features"]
y = dataset["labels"]
channels = ['F8-T8', 'T7-FT9', 'F4-C4', 'C3-P3', 'P7-T7', 'P7-O1', 'T8-P8', 'FP1-F7', 'P8-O2', 'T7-P7', 'C4-P4', 'FT10-T8', 'P4-O2', 'F7-T7', 'CZ-PZ', 'FP2-F8', 'P3-O1', 'FP1-F3','FP2-F4', 'FZ-CZ', 'F3-C3', 'FT9-FT10', 'age', 'gender']

In [3]:
print("Shapes: \n X:", X.shape, "y:", y.shape)
print("Unique Values:", np.unique(y, return_counts=True))

Shapes: 
 X: (16999, 1000, 24) y: (16999, 1)
Unique Values: (array([0, 1], dtype=int8), array([10403,  6596]))


In order to classify the time series with the Random Forest Classifier, the data must be reshaped into two dimensions and the features flattend.

In [4]:
n_samples, n_timesteps, n_features = X.shape
X_reshaped = np.reshape(X, (n_samples, (n_timesteps * n_features)))

## 3. Split Dataset
In order to validate and test the trained classifier, the dataset must be split into a `train` and `validation` subset. Due to the applied cross validation, a `test` subset is not needed. <br>
To preserve an equal distribution within each split, the `stratify`-option is enabled.

In [5]:
X_train, X_val, y_train, y_val = train_test_split(X_reshaped, y, test_size=0.3, shuffle=True, stratify=np.ravel(y), random_state=34)

## 4. Define Space & Optimization Function
To get the best possible predictions, the hyperparameters of the classifier are optimized with the bayesian optimization library `hyperopt`. <br>
First, the space for each hyperparameter is defined and stored as an dictionary. <br>
The `objective()`-function contains the definition, training and evaluation of the classifier, which is done by a five-fold cross-validation split. <br>
Last, the metrics are returned to enable a correct optimization.

In [14]:
max_features_values = ['sqrt','log2', None]

space={
    'n_estimators': scope.int(hp.quniform('n_estimators', 100, 600, 10)),
    'max_depth': hp.quniform('max_depth', 100, 400, 10),
    'min_samples_split' : hp.uniform ('min_samples_split', 0, 0.15),
    'min_samples_leaf': hp.uniform('min_samples_leaf', 0, 0.25),
    'max_features': hp.choice('max_features', max_features_values),
}

gm_scorer = make_scorer(geometric_mean_score, greater_is_better=True, average='macro') #Create Scorer for G-Mean

In [15]:
def objective(space):
    global X_train, y_train, X_test, y_test

    # Create classifier
    rf_classifier = RandomForestClassifier(
        n_estimators = int(space["n_estimators"]),
        max_depth = int(space["max_depth"]),
        min_samples_split = space["min_samples_split"],
        min_samples_leaf = space["min_samples_leaf"],
        max_features=space["max_features"],
        random_state=456,
        n_jobs=-1,
        verbose=0
    )

    # Cross Validation
    splits = StratifiedKFold(n_splits=5, shuffle=True)
    cross_val = cross_validate(rf_classifier, X_train, np.ravel(y_train), cv=splits, scoring={'f1_macro': 'f1_macro', 'f1_weighted': 'f1_weighted', 'auc': 'roc_auc_ovr', 'gmean': gm_scorer, 'precision': 'precision_macro', 'recall': 'recall_macro', 'waccuracy': 'balanced_accuracy'})
    try:
        cv_f1_macro = np.mean(cross_val.get('test_f1_macro')[~np.isnan(cross_val.get('test_f1_macro'))])
        cv_f1_weighted = np.mean(cross_val.get('test_f1_weighted')[~np.isnan(cross_val.get('test_f1_weighted'))])
        cv_auc = np.mean(cross_val.get('test_auc')[~np.isnan(cross_val.get('test_auc'))])
        cv_gmean = np.mean(cross_val.get('test_gmean')[~np.isnan(cross_val.get('test_gmean'))])
        cv_precision = np.mean(cross_val.get('test_precision')[~np.isnan(cross_val.get('test_precision'))])
        cv_recall = np.mean(cross_val.get('test_recall')[~np.isnan(cross_val.get('test_recall'))])
        cv_acc_weighted = np.mean(cross_val.get('test_waccuracy')[~np.isnan(cross_val.get('test_waccuracy'))])
    except Exception as e:
        print(e)
        return {
            'loss': 1, 
            'status': STATUS_STRINGS[4], 
            'metrics': {
                'cv_f1_macro': -1,
                'cv_f1_weighted': -1,
                'cv_auc': -1,
                'cv_gmean': -1,
                'cv_precision': -1,
                'cv_recall': -1,
                'cv_acc_weighted': -1
            },
            'eval_time': time.time()
        }

    return {
        'loss': -cv_f1_macro, 
        'status': STATUS_OK, 
        'metrics': {
            'cv_f1_macro': cv_f1_macro,
            'cv_f1_weighted': cv_f1_weighted,
            'cv_auc': cv_auc,
            'cv_gmean': cv_gmean,
            'cv_precision': cv_precision,
            'cv_recall': cv_recall,
            'cv_acc_weighted': cv_acc_weighted
        },
        'eval_time': time.time()
    }

## 5. Optimize Classifier

In [16]:
trials = Trials()

# Execution >5min!

# best_param = fmin(
#     fn=objective,
#     space=space,
#     algo=tpe.suggest,
#     max_evals=5,
#     trials=trials
# )

  0%|          | 0/5 [00:00<?, ?trial/s, best loss=?]

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.



 20%|██        | 1/5 [00:45<03:02, 45.62s/trial, best loss: -0.3796465179109455]

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.



 40%|████      | 2/5 [02:27<03:55, 78.42s/trial, best loss: -0.8364854062803957]

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.



 60%|██████    | 3/5 [03:06<02:01, 60.54s/trial, best loss: -0.8364854062803957]

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.



 80%|████████  | 4/5 [03:38<00:49, 49.38s/trial, best loss: -0.8364854062803957]

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.



100%|██████████| 5/5 [1:33:56<00:00, 1127.25s/trial, best loss: -0.8364854062803957]


In [17]:
# print(best_param)

{'max_depth': 260.0, 'max_features': 0, 'min_samples_leaf': 0.0450100263227729, 'min_samples_split': 0.04862715588473516, 'n_estimators': 470.0}


## 6. Validate Results
To ensure correct training without overfitting and to demonstrate the generalizability of the model, a validation step is performed last. The `val` subset, which was not seen by the neural network during training, serves as the data basis for this. Therefore, the obtained results can be used as a representation of the generalistic predictive ability of the random forest model. Since, depending on the data set, there may be an imbalance in the distribution of the classes, the accuracy is not used as the discriminating metric. 

The `F1-Score`, `G-Mean`, the `AUC of the ROC` both as well as the basic Precision and Recall are calculated in the following section.

In [18]:
# rf_classifier = RandomForestClassifier(
#     n_estimators = int(best_param['n_estimators']),
#     max_depth = int(best_param["max_depth"]),
#     min_samples_split = best_param["min_samples_split"],
#     min_samples_leaf = best_param["min_samples_leaf"],
#     max_features = max_features_values[best_param["max_features"]],
#     random_state = 456,
#     n_jobs = -1,
#     verbose = 2
# )

In [19]:
# rf_classifier.fit(
#     X=X_train,
#     y=np.ravel(y_train)
# )

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 10 concurrent workers.


building tree 1 of 470building tree 2 of 470

building tree 3 of 470
building tree 4 of 470
building tree 5 of 470
building tree 6 of 470
building tree 7 of 470
building tree 8 of 470
building tree 9 of 470
building tree 10 of 470
building tree 11 of 470
building tree 12 of 470
building tree 13 of 470
building tree 14 of 470
building tree 15 of 470
building tree 16 of 470
building tree 17 of 470
building tree 18 of 470
building tree 19 of 470
building tree 20 of 470
building tree 21 of 470building tree 22 of 470

building tree 23 of 470
building tree 24 of 470
building tree 25 of 470
building tree 26 of 470
building tree 27 of 470
building tree 28 of 470
building tree 29 of 470
building tree 30 of 470
building tree 31 of 470
building tree 32 of 470
building tree 33 of 470
building tree 34 of 470
building tree 35 of 470
building tree 36 of 470
building tree 37 of 470
building tree 38 of 470
building tree 39 of 470
building tree 40 of 470


[Parallel(n_jobs=-1)]: Done  21 tasks      | elapsed:    1.3s


building tree 41 of 470
building tree 42 of 470
building tree 43 of 470
building tree 44 of 470
building tree 45 of 470
building tree 46 of 470
building tree 47 of 470
building tree 48 of 470
building tree 49 of 470
building tree 50 of 470
building tree 51 of 470
building tree 52 of 470
building tree 53 of 470
building tree 54 of 470
building tree 55 of 470
building tree 56 of 470
building tree 57 of 470
building tree 58 of 470
building tree 59 of 470
building tree 60 of 470
building tree 61 of 470
building tree 62 of 470
building tree 63 of 470
building tree 64 of 470
building tree 65 of 470
building tree 66 of 470
building tree 67 of 470
building tree 68 of 470
building tree 69 of 470
building tree 70 of 470
building tree 71 of 470
building tree 72 of 470
building tree 73 of 470
building tree 74 of 470
building tree 75 of 470
building tree 76 of 470
building tree 77 of 470
building tree 78 of 470
building tree 79 of 470
building tree 80 of 470
building tree 81 of 470
building tree 82

[Parallel(n_jobs=-1)]: Done 142 tasks      | elapsed:    6.5s


building tree 153 of 470
building tree 154 of 470
building tree 155 of 470
building tree 156 of 470
building tree 157 of 470
building tree 158 of 470
building tree 159 of 470
building tree 160 of 470
building tree 161 of 470
building tree 162 of 470
building tree 163 of 470
building tree 164 of 470
building tree 165 of 470
building tree 166 of 470
building tree 167 of 470
building tree 168 of 470
building tree 169 of 470
building tree 170 of 470
building tree 171 of 470
building tree 172 of 470
building tree 173 of 470
building tree 174 of 470
building tree 175 of 470
building tree 176 of 470
building tree 177 of 470
building tree 178 of 470
building tree 179 of 470
building tree 180 of 470
building tree 181 of 470
building tree 182 of 470
building tree 183 of 470
building tree 184 of 470
building tree 185 of 470
building tree 186 of 470
building tree 187 of 470
building tree 188 of 470
building tree 189 of 470
building tree 190 of 470
building tree 191 of 470
building tree 192 of 470


[Parallel(n_jobs=-1)]: Done 345 tasks      | elapsed:   15.6s


building tree 358 of 470
building tree 359 of 470
building tree 360 of 470
building tree 361 of 470
building tree 362 of 470
building tree 363 of 470
building tree 364 of 470
building tree 365 of 470
building tree 366 of 470
building tree 367 of 470
building tree 368 of 470
building tree 369 of 470
building tree 370 of 470
building tree 371 of 470
building tree 372 of 470
building tree 373 of 470
building tree 374 of 470
building tree 375 of 470
building tree 376 of 470
building tree 377 of 470
building tree 378 of 470
building tree 379 of 470
building tree 380 of 470
building tree 381 of 470
building tree 382 of 470
building tree 383 of 470
building tree 384 of 470
building tree 385 of 470
building tree 386 of 470
building tree 387 of 470
building tree 388 of 470
building tree 389 of 470
building tree 390 of 470
building tree 391 of 470
building tree 392 of 470
building tree 393 of 470
building tree 394 of 470
building tree 395 of 470
building tree 396 of 470
building tree 397 of 470


[Parallel(n_jobs=-1)]: Done 470 out of 470 | elapsed:   21.1s finished


In [7]:
# joblib.dump(rf_classifier, "../99_Assets/01_Saved Models/00_random_forest_classifier.joblib")
rf_classifier = joblib.load("../99_assets/01_Saved Models/00_random_forest_classifier.joblib")

In [8]:
y_val_pred = rf_classifier.predict(X_val) #Predict X_test
y_val_pred_proba = rf_classifier.predict_proba(X_val) #Predict probablities X_test
f1 = f1_score(y_val, y_val_pred, average="macro") #Compute f1-score
auc = roc_auc_score(np.ravel(y_val), y_val_pred_proba[:,1], average="macro", multi_class="ovr") #Compute AUC
gmean = geometric_mean_score(y_val, y_val_pred, average="macro") #Compute G-Mean
precision = precision_score(y_val, y_val_pred)
recall = recall_score(y_val, y_val_pred)

[Parallel(n_jobs=10)]: Using backend ThreadingBackend with 10 concurrent workers.
[Parallel(n_jobs=10)]: Done  21 tasks      | elapsed:    0.0s
[Parallel(n_jobs=10)]: Done 142 tasks      | elapsed:    0.0s
[Parallel(n_jobs=10)]: Done 345 tasks      | elapsed:    0.1s
[Parallel(n_jobs=10)]: Done 470 out of 470 | elapsed:    0.1s finished
[Parallel(n_jobs=10)]: Using backend ThreadingBackend with 10 concurrent workers.
[Parallel(n_jobs=10)]: Done  21 tasks      | elapsed:    0.0s
[Parallel(n_jobs=10)]: Done 142 tasks      | elapsed:    0.0s
[Parallel(n_jobs=10)]: Done 345 tasks      | elapsed:    0.1s
[Parallel(n_jobs=10)]: Done 470 out of 470 | elapsed:    0.1s finished


In [9]:
data = [["F1-Score", "G-Mean", "AUC", "Precision", "Recall"], [f1, gmean, auc, precision, recall]]
table = PrettyTable(data[0])
table.add_rows(data[1:])
print(table)

+--------------------+--------------------+--------------------+--------------------+--------------------+
|      F1-Score      |       G-Mean       |        AUC         |     Precision      |       Recall       |
+--------------------+--------------------+--------------------+--------------------+--------------------+
| 0.8347862562765025 | 0.8275218859220146 | 0.9097902212254627 | 0.8463316002310803 | 0.7402728650833754 |
+--------------------+--------------------+--------------------+--------------------+--------------------+


## 7. Conclusion
The goal of this notebook was to demonstrate a binary classification of time series EEG data for epileptic seizure detection by using a Random Forest classifier. To begin, the preprocessed dataset was loaded, the feature dimension was flattened and split into a training and validation subset. In order to build the best possible classification model, hyperparameter optimization was performed and the optimization space was defined first for this purpose. The optimization was performed using the Hyperopt library, which is based on Bayesian mathematics. The objective function includes the definition, training and evaluation of the random forest classifier. In addition, a 5-fold cross validation was performed. The used hyperparameters as well as the results are passed to the minimization function and stored there. After the successful optimization, the hyperparameters with the best result were extracted and an optimized random forest classifier was created. This was also trained and validated with the help of not yet used validation data.

In general, the approach using a random forest classifier has proven to be usable. However, the model only achieved an F1 score of 0.83, which limits its real-world applicability. A final comparison of all methodologies will be made in Notebok `04_Model Comparison.ipynb`.