# Feature-Based Multi-Layer Perceptron
Reference Paper: *Müller Jens et al. (2021) Coherent false seizure prediction in epilepsy, coincidence or providence?*

# Imports

In [7]:
import sys

from models.FB_MLP import create_ptnt_mlp_ensemble

sys.path.append('..')

In [8]:
import pandas as pd
from models.load_data import load_features_and_labels

import numpy as np

from utils.io import pickle_path
import os

from feature_extraction.extract_features import Features
from config.paths import PATHS

In [9]:
# make it only use GPU 0
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# If I don't do this, there are warnings
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"

import tensorflow as tf
from tensorflow.keras.metrics import Recall, AUC
from tensorflow.keras.layers import Dense, Input, BatchNormalization
import keras
from keras import layers

# Loading Data

In [4]:
ptnt_dir = PATHS.patient_dirs()[0]
segs = pd.read_pickle(pickle_path(ptnt_dir.segments_table))
split = pd.read_pickle(pickle_path(ptnt_dir.train_test_split))

In [19]:
x_train, y_train, x_test, y_test = load_features_and_labels(segs, split, Features.ORDERED_FEATURE_NAMES)

In [20]:
x_train

array([[2.38408076e-01, 1.30000000e+01, 1.50000000e+01, ...,
        2.16633944e+00, 1.78818256e+00, 4.43084117e-01],
       [3.49040270e-01, 5.00000000e+00, 1.00000000e+00, ...,
        9.62379027e+00, 1.23229121e+02, 1.06329102e+02],
       [2.29366248e-02, 4.00000000e+00, 5.00000000e+00, ...,
        4.06466539e+00, 4.48925093e+01, 3.29175888e+01],
       ...,
       [7.51664197e-02, 1.50000000e+01, 1.50000000e+01, ...,
        2.04105325e+00, 2.02587888e+00, 5.08150506e-01],
       [5.49219011e-01, 1.00000000e+00, 1.00000000e+00, ...,
        2.88138452e+01, 4.61651687e+02, 4.03978461e+02],
       [6.96928014e-02, 2.00000000e+00, 1.00000000e+00, ...,
        1.50485342e+01, 1.94110683e+02, 1.50016758e+02]],
      shape=(30051, 15))

In [21]:
y_train

array([0, 0, 0, ..., 0, 0, 0], shape=(30051,), dtype=int32)

In [22]:
np.unique_counts(y_train)

UniqueCountsResult(values=array([0, 1], dtype=int32), counts=array([28620,  1431]))

In [23]:
def create_mlp(n_features: int, name: str) -> tf.keras.models.Sequential:
    model = tf.keras.models.Sequential([
        Input([n_features]),
        Dense(16, activation='relu', name='dense0'),
        BatchNormalization(name='batch_norm0'),
        Dense(8, activation='relu', name='dense1'),
        BatchNormalization(name='batch_norm1'),
        Dense(4, activation='relu', name='dense2'),
        BatchNormalization(name='batch_norm2'),
        Dense(1, activation='sigmoid', name='output')
    ], name=name)
    return model


mlp_model = create_mlp(Features.N_FEATURES, 'FB-MLP')
mlp_model.summary()

In [24]:
mlp_model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.0001),
                  loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
                  metrics=["accuracy", Recall(name='recall'), AUC(name='AUC')])

In [25]:
# Calculate Class weights
total = len(y_train)
total

30051

In [26]:
counts = np.bincount(y_train)  # number of samples per class
counts

array([28620,  1431])

In [27]:
n_classes = len(counts)  # 2
n_classes

2

In [28]:
class_weights = {
    0: total / (n_classes * counts[0]),
    1: total / (n_classes * counts[1]),
}
class_weights

{0: np.float64(0.525), 1: np.float64(10.5)}

In [29]:
# The average weight is 1
class_weights[0] * counts[0] + class_weights[1] * counts[1] == len(y_train)

np.True_

In [30]:
# mlp_model.fit(x_train, y_train, epochs=500)
mlp_model.fit(x_train, y_train,
              epochs=10,
              batch_size=256,  # larger batch size, so that preictal examples are statistically in every batch
              class_weight=class_weights
              )

Epoch 1/10


2025-12-22 18:00:23.791337: I external/local_xla/xla/service/service.cc:163] XLA service 0x79dde4005e80 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-12-22 18:00:23.791364: I external/local_xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA RTX A5000, Compute Capability 8.6
2025-12-22 18:00:23.855393: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-12-22 18:00:24.148077: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 90800


[1m 23/118[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m0s[0m 7ms/step - AUC: 0.4098 - accuracy: 0.3869 - loss: 1.1723 - recall: 0.5126

I0000 00:00:1766422827.081581 4136902 device_compiler.h:196] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 36ms/step - AUC: 0.4288 - accuracy: 0.3887 - loss: 1.1171 - recall: 0.5087
Epoch 2/10
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 6ms/step - AUC: 0.4256 - accuracy: 0.3820 - loss: 1.0930 - recall: 0.5171
Epoch 3/10
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - AUC: 0.4221 - accuracy: 0.3821 - loss: 1.0776 - recall: 0.5080
Epoch 4/10
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - AUC: 0.4217 - accuracy: 0.3830 - loss: 1.0570 - recall: 0.4934
Epoch 5/10
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - AUC: 0.4156 - accuracy: 0.3839 - loss: 1.0456 - recall: 0.4990
Epoch 6/10
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - AUC: 0.4213 - accuracy: 0.3856 - loss: 1.0032 - recall: 0.4969
Epoch 7/10
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - AUC: 0.4311 - ac

<keras.src.callbacks.history.History at 0x79e1f05c5d50>

In [31]:
mlp_model.evaluate(x_test, y_test)

[1m8803/8803[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 7ms/step - AUC: 0.3923 - accuracy: 0.9678 - loss: 0.4978 - recall: 0.0299


[0.4978410601615906,
 0.967819333076477,
 0.029881862923502922,
 0.3923249840736389]

## See Predictions

In [32]:
samples_idx = [i for i in range(3)]

In [33]:
predictions = mlp_model.predict(x_test[samples_idx])
predictions

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 687ms/step


array([[0.2393611 ],
       [0.23282835],
       [0.18887137]], dtype=float32)

In [34]:
threshold = 0.5
class_preds = np.array([pred > threshold for pred in predictions]).astype(int)
class_preds

array([[0],
       [0],
       [0]])

In [35]:
# See which predictions are correct
class_preds == np.expand_dims(y_test[samples_idx], axis=-1)

array([[ True],
       [ True],
       [ True]])

# Ensemble of Models



In [None]:
models = []
for i in range(5):
    print(f"Creating model {i}")
    model = create_mlp(Features.N_FEATURES, f"FB-MLP_{i:02}")
    model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.0001),
                  loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
                  metrics=["accuracy", Recall(name='recall'), AUC(name='AUC')])
    model.fit(x_train, y_train,
              epochs=5,
              batch_size=256,
              class_weight=class_weights,
              )
    models.append(model)

In [None]:
models

## Predict

In [47]:
samples_idx = [i for i in range(5)]
samples = x_test[samples_idx]
labels = y_test[samples_idx]

In [None]:
all_preds = np.array([m.predict(samples) for m in models])

In [None]:
all_preds.shape

In [None]:
ensemble_probs = all_preds.mean(axis=0)
ensemble_probs

In [None]:
threshold = 0.6
class_preds = np.array([pred > threshold for pred in ensemble_probs]).astype(int)
class_preds

In [None]:
# See which predictions are correct
class_preds == np.expand_dims(labels, axis=-1)

# Ensemble with native Keras

In [37]:
# Example
def get_model():
    inputs = keras.Input(shape=(128,))
    outputs = layers.Dense(1)(inputs)
    return keras.Model(inputs, outputs)


model1 = get_model()
model2 = get_model()
model3 = get_model()

inputs = keras.Input(shape=(128,))
y1 = model1(inputs)
y2 = model2(inputs)
y3 = model3(inputs)
outputs = layers.average([y1, y2, y3])
ensemble_model = keras.Model(inputs=inputs, outputs=outputs)

In [57]:
input_layer = Input([Features.N_FEATURES], name='input')
models = []

for i in range(2):
    model = create_mlp(Features.N_FEATURES, f"FB-MLP_{i:02}")
    model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.0001),
                  loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
                  metrics=["accuracy", Recall(name='recall'), AUC(name='AUC')])
    model.fit(x_train, y_train,
              epochs=2,
              batch_size=256,
              class_weight=class_weights,
              )
    # Give each model the same input layer
    y = model(input_layer)
    models.append(y)

output_layer = layers.average(models, name='average')
ensemble = keras.Model(inputs=input_layer, outputs=output_layer, name="ensemble")

Epoch 1/2
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 33ms/step - AUC: 0.4682 - accuracy: 0.6751 - loss: 0.9093 - recall: 0.2215
Epoch 2/2
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - AUC: 0.4823 - accuracy: 0.6778 - loss: 0.8938 - recall: 0.2418
Epoch 1/2
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 32ms/step - AUC: 0.5261 - accuracy: 0.5718 - loss: 0.7191 - recall: 0.4326
Epoch 2/2
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - AUC: 0.5271 - accuracy: 0.5631 - loss: 0.7184 - recall: 0.4549


In [58]:
ensemble.summary()

In [59]:
samples

array([[ 1.37164764e-01,  1.00000000e+01,  1.00000000e+00,
         8.10792294e+01,  2.41162981e+02,  5.13651929e+01,
         1.10179912e+01,  5.75790071e+00,  8.74588871e+00,
         2.85533892e+00,  3.33635548e+01,  7.25378134e+00,
         4.27025245e+00,  1.07270762e+02,  8.12468918e+01],
       [ 6.10213619e-02,  1.40000000e+01,  1.00000000e+00,
         1.27673102e+02,  3.40723322e+02,  9.43458589e+01,
         9.26661758e+00,  2.34515831e+00,  9.22781980e+00,
         4.66573347e+00,  4.22558812e+01,  7.08614948e+00,
         3.34408978e+00,  1.55402144e+02,  1.23892290e+02],
       [ 3.23471957e-02,  1.50000000e+01,  1.00000000e+00,
         9.88866754e+01,  3.37720015e+02,  6.68866487e+01,
         7.95653514e+00,  4.36694992e+00,  8.29540150e+00,
         3.84531902e+00,  5.78759488e+01,  7.35512597e+00,
         2.86993650e+00,  1.55521075e+02,  1.22925448e+02],
       [ 8.48960710e-02,  1.40000000e+01,  1.00000000e+01,
         9.75289220e+01,  2.39837399e+02,  7.13031796

In [60]:
ensemble.predict(samples)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 966ms/step


array([[0.4874581 ],
       [0.52553487],
       [0.49454844],
       [0.46397263],
       [0.4494729 ]], dtype=float32)

# Testing

In [10]:
ptnt_dir = PATHS.patient_dirs()[0]
ptnt_dir

PatientDir('/data/home/webb/UNEEG_data/20240201_UNEEG_ForMayo/K37N36L4D')

In [11]:
ensemble = create_ptnt_mlp_ensemble(ptnt_dir)

I0000 00:00:1766424626.308344 4152948 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22377 MB memory:  -> device: 0, name: NVIDIA RTX A5000, pci bus id: 0000:31:00.0, compute capability: 8.6


Epoch 1/2


2025-12-22 18:30:28.149447: I external/local_xla/xla/service/service.cc:163] XLA service 0x789d3c003e50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-12-22 18:30:28.149486: I external/local_xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA RTX A5000, Compute Capability 8.6
2025-12-22 18:30:28.245665: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-12-22 18:30:28.520381: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 90800


[1m 23/118[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m0s[0m 7ms/step - AUC: 0.4768 - accuracy: 0.7445 - loss: 0.7512 - recall: 0.0931

I0000 00:00:1766424630.749095 4153385 device_compiler.h:196] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 36ms/step - AUC: 0.4706 - accuracy: 0.7010 - loss: 0.7464 - recall: 0.1670
Epoch 2/2
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - AUC: 0.4764 - accuracy: 0.6821 - loss: 0.7441 - recall: 0.2124
Epoch 1/2
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 30ms/step - AUC: 0.5226 - accuracy: 0.4311 - loss: 0.8277 - recall: 0.6087
Epoch 2/2
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - AUC: 0.5333 - accuracy: 0.4287 - loss: 0.8131 - recall: 0.6122


In [12]:
ensemble.summary()