[<font color = 'coral'>References Code</font>](https://github.com/vitaldb/examples/blob/master/eeg_mac.ipynb)

# Prediction of anesthetic concentration from EEG
In this example, we will build a model to predict anesthetic concentration (age-related MAC) from EEG during Sevoflurane anesthesia.

> Note that <b>all users who use Vital DB, an open biosignal dataset, must agree to the Data Use Agreement below.
</b> If you do not agree, please close this window.
Click here: [Data Use Agreement](https://vitaldb.net/dataset/?query=overview&documentId=13qqajnNZzkN7NZ9aXnaQ-47NWy7kx-a6gbrcEsi-gak&sectionId=h.vcpgs1yemdb5)

## Required libraries

In [9]:
!pip install vitaldb
import vitaldb
import random
import numpy as np
import pandas as pd
import scipy.signal
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm



## Preprocessing

In [10]:
SRATE = 128  # in hz
SEGLEN = 4 * SRATE  # segment samples
MAX_CASES = 100

df_trks = pd.read_csv("https://api.vitaldb.net/trks")  # track information
df_cases = pd.read_csv("https://api.vitaldb.net/cases")  # patient information

# track names and column order when loading data
track_names = ['BIS/EEG1_WAV', 'Primus/EXP_SEVO', 'BIS/BIS']
EEG = 0
SEVO = 1
BIS = 2

# Inclusion & Exclusion criteria
caseids = set(df_cases.loc[df_cases['age'] > 18, 'caseid'])
caseids &= set(df_trks.loc[df_trks['tname'] == 'BIS/EEG1_WAV', 'caseid'])
caseids &= set(df_trks.loc[df_trks['tname'] == 'BIS/BIS', 'caseid'])
caseids &= set(df_trks.loc[df_trks['tname'] == 'Primus/EXP_SEVO', 'caseid'])
caseids -= set(df_trks.loc[df_trks['tname'] == 'Primus/EXP_DES', 'caseid'])
caseids -= set(df_trks.loc[df_trks['tname'] == 'Orchestra/PPF20_CE', 'caseid'])
caseids -= set(df_trks.loc[df_trks['tname'] == 'Orchestra/RFTN50_CE', 'caseid'])
caseids = list(caseids)
print(f'Total {len(caseids)} cases found')
#random.shuffle(caseids)

x = []  # eeg
y = []  # sevo
b = []  # bis
c = []  # caseids
icase = 0  # number of loaded cases
for caseid in tqdm(caseids):
    print(f'loading caseid={caseid} ({icase + 1}/{MAX_CASES})', end='...', flush=True)

    # extract data
    vals = vitaldb.load_case(caseid, track_names, 1 / SRATE)
    if np.nanmax(vals[:, SEVO]) < 1:
        print('all sevo < 1%')
        continue

    # convert etsevo to the age related mac
    age = df_cases.loc[df_cases['caseid'] == caseid, 'age'].values[0]
    vals[:, SEVO] /= 1.80 * 10 ** (-0.00269 * (age - 40))

    # exclude cases without bis
    if not np.any(vals[:, BIS] > 0):
        print('all bis <= 0')
        continue

    # exclude cases with use of SEVO shorter than 30 min
    valid_idx = np.where(vals[:, SEVO] > 0.1)[0]
    first_idx = valid_idx[0]
    last_idx = valid_idx[-1]
    if last_idx - first_idx < 1800 * SRATE:
        print('len sevo < 30 min')
        continue
    vals[:first_idx, SEVO] = 0

    # trim cases with BIS
    valid_idx = np.where(vals[:, BIS] > 0)[0]
    first_idx = valid_idx[0]
    last_idx = valid_idx[-1]
    vals = vals[first_idx:last_idx + 1, :]

    # exclude cases shorter than 30 min
    if len(vals) < 1800 * SRATE:
        print('len bis < 30 min')
        continue

    # forward fill in MAC value and BIS value up to 7 sec
    vals[:, SEVO:] = pd.DataFrame(vals[:, SEVO:]).ffill(limit=7*SRATE).values

    # extract segments
    oldlen = len(y)
    for isamp in range(SEGLEN, len(vals), 2 * SRATE):
        bis = vals[isamp, BIS]
        mac = vals[isamp, SEVO]
        if np.isnan(bis) or np.isnan(mac) or bis == 0:
            continue
        # add to dataset
        eeg = vals[isamp - SEGLEN:isamp, EEG]
        x.append(eeg)
        y.append(mac)
        b.append(bis)
        c.append(caseid)

    # print results
    icase += 1
    print(f'{len(y) - oldlen} segments read -> total {len(y)} segments ({icase}/{MAX_CASES})')
    if icase >= MAX_CASES:
        break

# Change the input dataset to a numpy array
x = np.array(x)
y = np.array(y)
b = np.array(b)
c = np.array(c)

Total 1518 cases found


  0%|          | 0/1518 [00:00<?, ?it/s]

loading caseid=2 (1/100)...7277 segments read -> total 7277 segments (1/100)
loading caseid=4 (2/100)...9860 segments read -> total 17137 segments (2/100)
loading caseid=10 (3/100)...10049 segments read -> total 27186 segments (3/100)
loading caseid=12 (4/100)...14683 segments read -> total 41869 segments (4/100)
loading caseid=18 (5/100)...all bis <= 0
loading caseid=21 (5/100)...5637 segments read -> total 47506 segments (5/100)
loading caseid=24 (6/100)...2496 segments read -> total 50002 segments (6/100)
loading caseid=25 (7/100)...6741 segments read -> total 56743 segments (7/100)
loading caseid=27 (8/100)...8015 segments read -> total 64758 segments (8/100)
loading caseid=33 (9/100)...1927 segments read -> total 66685 segments (9/100)
loading caseid=43 (10/100)...6864 segments read -> total 73549 segments (10/100)
loading caseid=49 (11/100)...len sevo < 30 min
loading caseid=56 (11/100)...13725 segments read -> total 87274 segments (11/100)
loading caseid=58 (12/100)...6840 segme

## Filtering input data

In [11]:
# exclude segments
print('invalid samples...', end='', flush=True)
valid_mask = ~(np.max(np.isnan(x), axis=1) > 0) # if there is nan
valid_mask &= (np.nanmax(x, axis=1) - np.nanmin(x, axis=1) > 12)  # bis impedence check
valid_mask &= (np.nanmax(np.abs(x), axis=1) < 100)  # noisy sample
x = x[valid_mask]
y = y[valid_mask]
b = b[valid_mask]
c = c[valid_mask]
print(f'{100*(1-np.mean(valid_mask)):.1f}% removed')

invalid samples...

  valid_mask &= (np.nanmax(x, axis=1) - np.nanmin(x, axis=1) > 12)  # bis impedence check
  valid_mask &= (np.nanmax(np.abs(x), axis=1) < 100)  # noisy sample


12.7% removed


## Splitting samples into training and testing dataset

In [12]:
# caseid
caseids = list(np.unique(c))
#random.shuffle(caseids)

# Split dataset into training and testing data
ntest = max(1, int(len(caseids) * 0.2))
caseids_train = caseids[ntest:]
caseids_test = caseids[:ntest]

train_mask = np.isin(c, caseids_train)
test_mask = np.isin(c, caseids_test)
x_train = x[train_mask]
y_train = y[train_mask]
x_test = x[test_mask]
y_test = y[test_mask]
b_test = b[test_mask]
c_test = c[test_mask]

print('====================================================')
print(f'total: {len(caseids)} cases {len(y)} samples')
print(f'train: {len(np.unique(c[train_mask]))} cases {len(y_train)} samples')
print(f'test {len(np.unique(c_test))} cases {len(y_test)} samples')
print('====================================================')

total: 100 cases 496346 samples
train: 80 cases 384419 samples
test 20 cases 111927 samples


In [13]:
x_train.shape, y_train.shape

((384419, 512), (384419,))

## Modeling and Evaluation

## LSTM

In [14]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from kerastuner.tuners import RandomSearch
from kerastuner.engine.hyperparameters import HyperParameters

# Define the model building function
def build_model(hp):
    model = Sequential()
    model.add(LSTM(units=hp.Int('units', min_value=16, max_value=64, step=16),
                   input_shape=(x_train.shape[1], 1)))
    model.add(Dropout(rate=0.2))  # Adjust the dropout rate as needed
    model.add(Dense(1, activation='sigmoid', kernel_regularizer='l2'))  # Adjust regularization strength as needed

    model.compile(
        optimizer=Adam(learning_rate=hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])),
        loss='mean_absolute_error',
        metrics=['accuracy']
    )

    return model

# Instantiate the tuner
tuner = RandomSearch(
    build_model,
    objective='val_accuracy',
    max_trials=5,
    directory='hyperparameter_tuning',
    project_name='lstm_tuning'
)

# Perform the hyperparameter search
tuner.search(x_train, y_train, epochs=50, batch_size=32, validation_split=0.15, callbacks=[EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)])

# Get the best hyperparameters
best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]
print(f"Best Hyperparameters: {best_hps}")

# Build and compile the model with the best hyperparameters
best_model = tuner.hypermodel.build(best_hps)
best_model.compile(
    optimizer=Adam(learning_rate=best_hps.get('learning_rate')),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

# Train the model with the best hyperparameters
best_model.fit(x_train, y_train, epochs=50, batch_size=16, validation_split=0.15)

Trial 3 Complete [00h 16m 42s]
val_accuracy: 0.0

Best val_accuracy So Far: 0.0
Total elapsed time: 00h 53m 27s

Search: Running Trial #4

Value             |Best Value So Far |Hyperparameter
32                |16                |units
0.01              |0.01              |learning_rate

Epoch 1/50
Epoch 2/50
Epoch 3/50

KeyboardInterrupt: 

In [15]:
import tensorflow as tf
from keras.layers import Conv2D, Flatten, Dense
# Define the LSTM model
model = tf.keras.Sequential()
model.add(tf.keras.layers.LSTM(64, input_shape=(512, 1)))
#model.add(Flatten())
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

# Compile the model
model.compile(loss='mean_absolute_error', optimizer='adam', metrics=['accuracy'])

# Train the model
model.fit(x_train, y_train, epochs=10, batch_size=32)
#model.fit(x_train, y_train, epochs=10, batch_size=32, validation_split=0.2)

Epoch 1/10
 1173/12014 [=>............................] - ETA: 3:37 - loss: 0.3040 - accuracy: 3.7298e-04

KeyboardInterrupt: 

In [None]:
# Evaluate the model
loss, accuracy = model.evaluate(X_test, y_test, verbose=0)

# Print the loss and accuracy
print(f"Loss: {loss:.2f}")
print(f"Accuracy: {accuracy:.2f}")


# Default

In [None]:
import keras.models
import tensorflow as tf
from keras.models import Model
from keras.layers import Dense, Dropout, Conv1D, MaxPooling1D, GlobalMaxPooling1D, Input, Activation
from keras.callbacks import EarlyStopping, ModelCheckpoint




out = inp = Input(shape=(x_train.shape[1], 1))
for i in range(4):
    out = Conv1D(filters=32, kernel_size=7, padding='same')(out)
    out = Activation('relu')(out)
    out = MaxPooling1D(2, padding='same')(out)
out = GlobalMaxPooling1D()(out)
out = Dense(128)(out)
out = Dropout(0.2)(out)
out = Dense(1)(out)

model = Model(inputs=[inp], outputs=[out])
model.summary()
model.compile(loss='mean_absolute_error', optimizer='adam', metrics=['mean_absolute_error'])
# fit model. the last 20% of the segments will be used for early stopping
hist = model.fit(x_train, y_train, validation_split=0.2, epochs=100, batch_size=4096,
                callbacks=[ModelCheckpoint(monitor='val_loss', filepath='model.x', verbose=1, save_best_only=True),
                            EarlyStopping(monitor='val_loss', patience=2, verbose=1, mode='auto')])

# prediction
pred_test = model.predict(x_test).flatten()
for caseid in np.unique(c_test):
    case_mask = (c_test == caseid)
    pred_test[case_mask] = scipy.signal.medfilt(pred_test[case_mask], 15)

# calculate the performance
test_mae = np.mean(np.abs(y_test - pred_test))

In [None]:
# pearson correlation coefficient
bis_corr = np.corrcoef(y_test, b_test)[0, 1]
our_corr = np.corrcoef(y_test, pred_test)[0, 1]

# scatter plot
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.scatter(y_test, b_test, s=1, alpha=0.5, c='violet', label=f'BIS ({bis_corr:4f})')
plt.xlim([0, 2])
plt.legend(loc="upper right")
plt.subplot(1, 2, 2)
plt.scatter(y_test, pred_test, s=1, alpha=0.5, c='orange', label=f'Ours ({our_corr:.4f})')
plt.xlim([0, 2])
plt.legend(loc="upper left")
plt.tight_layout()
plt.show()

In [None]:
# plot for each case
for caseid in np.unique(c_test):
    case_mask = (c_test == caseid)
    case_len = np.sum(case_mask)
    if case_len == 0:
        continue

    our_mae = np.mean(np.abs(y_test[case_mask] - pred_test[case_mask]))
    print(f'Total MAE={test_mae:.4f}, CaseID {caseid}, MAE={our_mae:.4f}')

    t = np.arange(0, case_len)
    plt.figure(figsize=(20, 5))
    plt.plot(t, y_test[case_mask], label='MAC')
    plt.plot(t, pred_test[case_mask], label=f'Ours ({our_mae:.4f})')
    plt.legend(loc="upper left")
    plt.tight_layout()
    plt.xlim([0, case_len])
    plt.ylim([0, 2])
    plt.show()