<a href="https://colab.research.google.com/github/vitaldb/examples/blob/master/eeg_mac.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prediction of anesthetic concentration from EEG
In this example, we will build a model to predict anesthetic concentration (age-related MAC) from EEG during Sevoflurane anesthesia.

> Note that <b>all users who use Vital DB, an open biosignal dataset, must agree to the Data Use Agreement below.
</b> If you do not agree, please close this window.
Click here: [Data Use Agreement](https://vitaldb.net/dataset/?query=overview&documentId=13qqajnNZzkN7NZ9aXnaQ-47NWy7kx-a6gbrcEsi-gak&sectionId=h.vcpgs1yemdb5)

## Required libraries

In [None]:
!pip install vitaldb
import vitaldb
import random
import numpy as np
import pandas as pd
import scipy.signal
import matplotlib.pyplot as plt

Defaulting to user installation because normal site-packages is not writeable


## Preprocessing

In [None]:
SRATE = 128  # in hz
SEGLEN = 4 * SRATE  # segment samples
MAX_CASES = 100

df_trks = pd.read_csv("https://api.vitaldb.net/trks")  # track information
df_cases = pd.read_csv("https://api.vitaldb.net/cases")  # patient information

# track names and column order when loading data
track_names = ['SNUADC/ECG_II', 'SNUADC/ECG_V5', 'BIS/BIS']
ECG_II = 0
ECG_V5 = 1
BIS = 2

# Inclusion & Exclusion criteria
caseids = set(df_cases.loc[df_cases['age'] > 18, 'caseid'])
caseids &= set(df_trks.loc[df_trks['tname'] == 'SNUADC/ECG_II', 'caseid'])
caseids &= set(df_trks.loc[df_trks['tname'] == 'BIS/BIS', 'caseid'])
caseids &= set(df_trks.loc[df_trks['tname'] == 'SNUADC/ECG_V5', 'caseid'])
caseids -= set(df_trks.loc[df_trks['tname'] == 'Primus/EXP_DES', 'caseid'])
caseids -= set(df_trks.loc[df_trks['tname'] == 'Orchestra/PPF20_CE', 'caseid'])
caseids -= set(df_trks.loc[df_trks['tname'] == 'Orchestra/RFTN50_CE', 'caseid'])
caseids = list(caseids)
print(f'Total {len(caseids)} cases found')

# Initialize data lists
x_ecg_ii = []  # ECG II
x_ecg_v5 = []  # ECG V5
b = []  # BIS
c = []  # caseids
icase = 0  # number of loaded cases

for caseid in caseids:
    print(f'loading caseid={caseid} ({icase + 1}/{MAX_CASES})', end='...', flush=True)

    # extract data
    vals = vitaldb.load_case(caseid, track_names, 1 / SRATE)

    # exclude cases without ECG data
    if np.isnan(np.nanmax(vals[:, ECG_II])):
        print('no ECG II data')
        continue

    if np.isnan(np.nanmax(vals[:, ECG_V5])):
        print('no ECG V5 data')
        continue

    # exclude cases without BIS
    if not np.any(vals[:, BIS] > 0):
        print('all bis <= 0')
        continue

    # trim cases with BIS
    valid_idx = np.where(vals[:, BIS] > 0)[0]
    first_idx = valid_idx[0]
    last_idx = valid_idx[-1]
    vals = vals[first_idx:last_idx + 1, :]

    # exclude cases shorter than 30 min
    if len(vals) < 1800 * SRATE:
        print('len bis < 30 min')
        continue

    # forward fill in BIS value up to 7 sec
    vals[:, BIS:] = pd.DataFrame(vals[:, BIS:]).ffill(limit=7*SRATE).values

    # extract segments
    oldlen = len(b)
    for isamp in range(SEGLEN, len(vals), 2 * SRATE):
        bis = vals[isamp, BIS]
        if np.isnan(bis) or bis == 0:
            continue
        # add to dataset
        ecg_ii = vals[isamp - SEGLEN:isamp, ECG_II]
        ecg_v5 = vals[isamp - SEGLEN:isamp, ECG_V5]
        x_ecg_ii.append(ecg_ii)
        x_ecg_v5.append(ecg_v5)
        b.append(bis)
        c.append(caseid)

    # print results
    icase += 1
    print(f'{len(b) - oldlen} segments read -> total {len(b)} segments ({icase}/{MAX_CASES})')
    if icase >= MAX_CASES:
        break

# Change the input datasets to numpy arrays
x_ecg_ii = np.array(x_ecg_ii)
x_ecg_v5 = np.array(x_ecg_v5)
b = np.array(b)
c = np.array(c)


Total 547 cases found
loading caseid=2 (1/100)...7530 segments read -> total 7530 segments (1/100)
loading caseid=4104 (2/100)...3176 segments read -> total 10706 segments (2/100)
loading caseid=12 (3/100)...15222 segments read -> total 25928 segments (3/100)
loading caseid=4112 (4/100)...10929 segments read -> total 36857 segments (4/100)
loading caseid=18 (5/100)...all bis <= 0
loading caseid=4116 (5/100)...4984 segments read -> total 41841 segments (5/100)
loading caseid=21 (6/100)...5859 segments read -> total 47700 segments (6/100)
loading caseid=4119 (7/100)...1562 segments read -> total 49262 segments (7/100)
loading caseid=25 (8/100)...7072 segments read -> total 56334 segments (8/100)
loading caseid=4126 (9/100)...all bis <= 0
loading caseid=39 (9/100)...all bis <= 0
loading caseid=4135 (9/100)...5881 segments read -> total 62215 segments (9/100)
loading caseid=4141 (10/100)...all bis <= 0
loading caseid=4144 (10/100)...9372 segments read -> total 71587 segments (10/100)
loadi

loading caseid=619 (85/100)...all bis <= 0
loading caseid=4731 (85/100)...10479 segments read -> total 459957 segments (85/100)
loading caseid=643 (86/100)...2801 segments read -> total 462758 segments (86/100)
loading caseid=4740 (87/100)...all bis <= 0
loading caseid=649 (87/100)...6889 segments read -> total 469647 segments (87/100)
loading caseid=653 (88/100)...4290 segments read -> total 473937 segments (88/100)
loading caseid=656 (89/100)...1557 segments read -> total 475494 segments (89/100)
loading caseid=4759 (90/100)...11880 segments read -> total 487374 segments (90/100)
loading caseid=4774 (91/100)...3139 segments read -> total 490513 segments (91/100)
loading caseid=688 (92/100)...all bis <= 0
loading caseid=689 (92/100)...1805 segments read -> total 492318 segments (92/100)
loading caseid=695 (93/100)...2117 segments read -> total 494435 segments (93/100)
loading caseid=703 (94/100)...4840 segments read -> total 499275 segments (94/100)
loading caseid=4800 (95/100)...1287

## Filtering input data

In [None]:
# exclude segments
print('invalid samples...', end='', flush=True)

# If there is NaN in any ECG II or ECG V5 segment
valid_mask_nan = ~(np.max(np.isnan(x_ecg_ii), axis=1) > 0) & ~(np.max(np.isnan(x_ecg_v5), axis=1) > 0)

# BIS impedance check (difference between max and min BIS values > 12)
valid_mask_bis = (np.nanmax(b) - np.nanmin(b) > 12)

# Noisy sample check (max absolute value of ECG II and ECG V5 < 100)
valid_mask_noisy = (np.nanmax(np.abs(x_ecg_ii), axis=1) < 100) & (np.nanmax(np.abs(x_ecg_v5), axis=1) < 100)

# Combine all conditions
valid_mask = valid_mask_nan & valid_mask_bis & valid_mask_noisy

# Apply the masks
x_ecg_ii = x_ecg_ii[valid_mask]
x_ecg_v5 = x_ecg_v5[valid_mask]
b = b[valid_mask]
c = c[valid_mask]

print(f'{100*(1-np.mean(valid_mask)):.1f}% removed')


invalid samples...0.0% removed


## Splitting samples into training and testing dataset

In [None]:
# Case ID
caseids = list(np.unique(c))
# random.shuffle(caseids)

# Split dataset into training and testing data
ntest = max(1, int(len(caseids) * 0.2))
caseids_train = caseids[ntest:]
caseids_test = caseids[:ntest]

train_mask = np.isin(c, caseids_train)
test_mask = np.isin(c, caseids_test)
x_train_ecg_ii = x_ecg_ii[train_mask]  # Use x_ecg_ii
x_train_ecg_v5 = x_ecg_v5[train_mask]  # Use x_ecg_v5
y_train = b[train_mask]  # BIS as the target variable

x_test_ecg_ii = x_ecg_ii[test_mask]  # Use x_ecg_ii
x_test_ecg_v5 = x_ecg_v5[test_mask]  # Use x_ecg_v5
y_test = b[test_mask]  # BIS as the target variable
c_test = c[test_mask]

print('====================================================')
print(f'total: {len(caseids)} cases {len(b)} samples')
print(f'train: {len(np.unique(c[train_mask]))} cases {len(y_train)} samples')
print(f'test {len(np.unique(c_test))} cases {len(y_test)} samples')
print('====================================================')


total: 100 cases 554817 samples
train: 80 cases 439118 samples
test 20 cases 115699 samples


## Modeling and Evaluation

In [None]:



import numpy as np
import pandas as pd
from keras.models import Model
from keras.layers import LSTM, Dense, Dropout, Input, concatenate, GlobalMaxPooling1D
from keras.callbacks import EarlyStopping, ModelCheckpoint

# Define the input shapes for ECG2 and ECG5
inp_ecg2 = Input(shape=(x_train_ecg_ii.shape[1], 1))
inp_ecg5 = Input(shape=(x_train_ecg_v5.shape[1], 1))

# LSTM layer for ECG2
out_ecg2 = LSTM(units=64, return_sequences=True)(inp_ecg2)
out_ecg2 = Dropout(0.2)(out_ecg2)

# LSTM layer for ECG5
out_ecg5 = LSTM(units=64, return_sequences=True)(inp_ecg5)
out_ecg5 = Dropout(0.2)(out_ecg5)

# Concatenate the output of LSTM layers
merged = concatenate([out_ecg2, out_ecg5])

# Global max pooling
out = GlobalMaxPooling1D()(merged)

# Dense layers
out = Dense(128, activation='relu')(out)
out = Dropout(0.2)(out)
out = Dense(1)(out)

# Create the model
model_lstm = Model(inputs=[inp_ecg2, inp_ecg5], outputs=[out])

# Compile the model
model_lstm.compile(loss='mean_absolute_error', optimizer='adam', metrics=['mean_absolute_error'])

# Model summary
model_lstm.summary()

# Fit the model



#______________
#pyHRV

import biosppy
import numpy as np
import pyhrv.tools as tools
from opensignalsreader import OpenSignalsReader



signal, rpeaks = biosppy.signals.ecg.ecg(x_train_ecg_ii, show=False)[1:3]
nni_ecg_ii = tools.nn_intervals(rpeaks)

signal, rpeaks = biosppy.signals.ecg.ecg(x_train_ecg_v5, show=False)[1:3]
nni_ecg_v5 = tools.nn_intervals(rpeaks)

#______________
hist_lstm = model_lstm.fit([nni_ecg_ii, nni_ecg_v5], y_train,
                           validation_split=0.2, epochs=10, batch_size=4096,
                           callbacks=[ModelCheckpoint(monitor='val_loss', filepath='model_lstm_2.h5', verbose=1, save_best_only=True)])



Model: "model_4"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_15 (InputLayer)          [(None, 512, 1)]     0           []                               
                                                                                                  
 input_16 (InputLayer)          [(None, 512, 1)]     0           []                               
                                                                                                  
 lstm_14 (LSTM)                 (None, 512, 64)      16896       ['input_15[0][0]']               
                                                                                                  
 lstm_15 (LSTM)                 (None, 512, 64)      16896       ['input_16[0][0]']               
                                                                                            

In [None]:
# Load the trained model
from keras.models import load_model

# Load the saved model (change the filename if necessary)
model_lstm = load_model('model_lstm_2.h5')

# Perform predictions on the test set
pred_test_lstm = model_lstm.predict([x_test_ecg_ii, x_test_ecg_v5]).flatten()

# Optionally, apply post-processing (e.g., median filtering)
for caseid in np.unique(c_test):
    case_mask = (c_test == caseid)
    pred_test_lstm[case_mask] = scipy.signal.medfilt(pred_test_lstm[case_mask], 15)

# Calculate Mean Absolute Error (MAE) for accuracy
test_mae_lstm = np.mean(np.abs(y_test - pred_test_lstm))
print(f"Test MAE (LSTM): {test_mae_lstm}")

# Optionally, calculate accuracy within a threshold
threshold = 5  # Set your desired threshold
accuracy_percentage = np.mean(np.abs(pred_test_lstm - y_test) < threshold) * 100
print(f"Accuracy within {threshold} units: {accuracy_percentage:.2f}%")



Test MAE (LSTM): 14.09238393588695
Accuracy within 5 units: 18.24%
