In [1]:
# Libraries
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', 40)
pd.set_option('display.width', 2000)
import random

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch import optim
from sklearn.metrics import roc_auc_score

import optuna
from optuna.trial import TrialState
import torch.optim as optim

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# Remove printing error
pd.options.mode.chained_assignment = None

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set the random seeds for deterministic results.
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

<torch._C.Generator at 0x7f832a69a190>

In [3]:
# Import
path = r'switch_data/icare_switch_data_preprocessed.csv'
icare_df_preprocessed = pd.read_csv(path)

model_data = icare_df_preprocessed.drop(columns=['date', 'ROUTE', '24_hour_flag', '48_hour_flag', 'iv_treatment_length'])
# Rename
model_data.rename(columns={'SPELL_IDENTIFIER': 'stay_id'}, inplace=True)

# Split into 'preprocessing dataset' and a 'train/vaid/test dataset'
stays = model_data['stay_id'].unique()
random.Random(0).shuffle(stays)
model_data2 = model_data.set_index("stay_id").loc[stays].reset_index()
n = round(0.5 * len(stays))
preprocessing_stays = stays[:n]
hold_out_stays = stays[n:]
preprocessing_data = model_data2[model_data2['stay_id'].isin(preprocessing_stays)]
hold_out_data = model_data2[model_data2['stay_id'].isin(hold_out_stays)]

In [4]:
columns_to_drop = [
 'Diastolic Blood Pressure2',
 'Diastolic Blood Pressure3',
 'Diastolic Blood Pressure4',
 'Diastolic Blood Pressure5',
 'Diastolic Blood Pressure6',
 'Diastolic Blood Pressure7',
 'Diastolic Blood Pressure8',
 'Diastolic Blood Pressure9',
 'Diastolic Blood Pressure11',
 'Diastolic Blood Pressure12',
 'Diastolic Blood Pressure13',
 'Diastolic Blood Pressure14',
 'Diastolic Blood Pressure15',
 'Diastolic Blood Pressure16',
 'Diastolic Blood Pressure18',
 'Diastolic Blood Pressure19',
 'Diastolic Blood Pressure20',
 'Diastolic Blood Pressure21',
 'Diastolic Blood Pressure2_current_stay',
 'Diastolic Blood Pressure3_current_stay',
 'Diastolic Blood Pressure4_current_stay',
 'Diastolic Blood Pressure5_current_stay',
 'Diastolic Blood Pressure6_current_stay',
 'Diastolic Blood Pressure8_current_stay',
 'Diastolic Blood Pressure12_current_stay',
 'Diastolic Blood Pressure13_current_stay',
 'Diastolic Blood Pressure14_current_stay',
 'Diastolic Blood Pressure16_current_stay',
 'Diastolic Blood Pressure18_current_stay',
 'Diastolic Blood Pressure19_current_stay',
 'Diastolic Blood Pressure20_current_stay',
 'Diastolic Blood Pressure21_current_stay',
 'Glasgow Coma Score0',
 'Glasgow Coma Score1',
 'Glasgow Coma Score2',
 'Glasgow Coma Score3',
 'Glasgow Coma Score4',
 'Glasgow Coma Score5',
 'Glasgow Coma Score6',
 'Glasgow Coma Score7',
 'Glasgow Coma Score8',
 'Glasgow Coma Score9',
 'Glasgow Coma Score10',
 'Glasgow Coma Score11',
 'Glasgow Coma Score12',
 'Glasgow Coma Score13',
 'Glasgow Coma Score14',
 'Glasgow Coma Score15',
 'Glasgow Coma Score16',
 'Glasgow Coma Score17',
 'Glasgow Coma Score18',
 'Glasgow Coma Score19',
 'Glasgow Coma Score20',
 'Glasgow Coma Score21',
 'Glasgow Coma Score0_current_stay',
 'Glasgow Coma Score1_current_stay',
 'Glasgow Coma Score2_current_stay',
 'Glasgow Coma Score3_current_stay',
 'Glasgow Coma Score4_current_stay',
 'Glasgow Coma Score5_current_stay',
 'Glasgow Coma Score6_current_stay',
 'Glasgow Coma Score7_current_stay',
 'Glasgow Coma Score8_current_stay',
 'Glasgow Coma Score10_current_stay',
 'Glasgow Coma Score11_current_stay',
 'Glasgow Coma Score12_current_stay',
 'Glasgow Coma Score13_current_stay',
 'Glasgow Coma Score14_current_stay',
 'Glasgow Coma Score15_current_stay',
 'Glasgow Coma Score16_current_stay',
 'Glasgow Coma Score17_current_stay',
 'Glasgow Coma Score18_current_stay',
 'Glasgow Coma Score19_current_stay',
 'Glasgow Coma Score20_current_stay',
 'Glasgow Coma Score21_current_stay',
 'Heart Rate2',
 'Heart Rate3',
 'Heart Rate4',
 'Heart Rate5',
 'Heart Rate6',
 'Heart Rate7',
 'Heart Rate8',
 'Heart Rate9',
 'Heart Rate11',
 'Heart Rate12',
 'Heart Rate13',
 'Heart Rate14',
 'Heart Rate15',
 'Heart Rate16',
 'Heart Rate18',
 'Heart Rate19',
 'Heart Rate20',
 'Heart Rate21',
 'Heart Rate2_current_stay',
 'Heart Rate3_current_stay',
 'Heart Rate4_current_stay',
 'Heart Rate5_current_stay',
 'Heart Rate6_current_stay',
 'Heart Rate8_current_stay',
 'Heart Rate12_current_stay',
 'Heart Rate13_current_stay',
 'Heart Rate14_current_stay',
 'Heart Rate16_current_stay',
 'Heart Rate18_current_stay',
 'Heart Rate19_current_stay',
 'Heart Rate20_current_stay',
 'Heart Rate21_current_stay',
 'Mean Arterial Pressure2',
 'Mean Arterial Pressure3',
 'Mean Arterial Pressure4',
 'Mean Arterial Pressure5',
 'Mean Arterial Pressure6',
 'Mean Arterial Pressure7',
 'Mean Arterial Pressure8',
 'Mean Arterial Pressure9',
 'Mean Arterial Pressure11',
 'Mean Arterial Pressure12',
 'Mean Arterial Pressure13',
 'Mean Arterial Pressure14',
 'Mean Arterial Pressure15',
 'Mean Arterial Pressure16',
 'Mean Arterial Pressure18',
 'Mean Arterial Pressure19',
 'Mean Arterial Pressure20',
 'Mean Arterial Pressure21',
 'Mean Arterial Pressure2_current_stay',
 'Mean Arterial Pressure3_current_stay',
 'Mean Arterial Pressure4_current_stay',
 'Mean Arterial Pressure5_current_stay',
 'Mean Arterial Pressure6_current_stay',
 'Mean Arterial Pressure8_current_stay',
 'Mean Arterial Pressure12_current_stay',
 'Mean Arterial Pressure13_current_stay',
 'Mean Arterial Pressure14_current_stay',
 'Mean Arterial Pressure16_current_stay',
 'Mean Arterial Pressure18_current_stay',
 'Mean Arterial Pressure19_current_stay',
 'Mean Arterial Pressure20_current_stay',
 'Mean Arterial Pressure21_current_stay',
 'NEWS Conscious Level Score0',
 'NEWS Conscious Level Score1',
 'NEWS Conscious Level Score2',
 'NEWS Conscious Level Score3',
 'NEWS Conscious Level Score4',
 'NEWS Conscious Level Score5',
 'NEWS Conscious Level Score6',
 'NEWS Conscious Level Score7',
 'NEWS Conscious Level Score8',
 'NEWS Conscious Level Score9',
 'NEWS Conscious Level Score10',
 'NEWS Conscious Level Score11',
 'NEWS Conscious Level Score12',
 'NEWS Conscious Level Score13',
 'NEWS Conscious Level Score14',
 'NEWS Conscious Level Score15',
 'NEWS Conscious Level Score16',
 'NEWS Conscious Level Score17',
 'NEWS Conscious Level Score18',
 'NEWS Conscious Level Score19',
 'NEWS Conscious Level Score20',
 'NEWS Conscious Level Score21',
 'NEWS Conscious Level Score0_current_stay',
 'NEWS Conscious Level Score1_current_stay',
 'NEWS Conscious Level Score2_current_stay',
 'NEWS Conscious Level Score3_current_stay',
 'NEWS Conscious Level Score4_current_stay',
 'NEWS Conscious Level Score5_current_stay',
 'NEWS Conscious Level Score6_current_stay',
 'NEWS Conscious Level Score7_current_stay',
 'NEWS Conscious Level Score8_current_stay',
 'NEWS Conscious Level Score9_current_stay',
 'NEWS Conscious Level Score10_current_stay',
 'NEWS Conscious Level Score11_current_stay',
 'NEWS Conscious Level Score12_current_stay',
 'NEWS Conscious Level Score13_current_stay',
 'NEWS Conscious Level Score14_current_stay',
 'NEWS Conscious Level Score15_current_stay',
 'NEWS Conscious Level Score16_current_stay',
 'NEWS Conscious Level Score17_current_stay',
 'NEWS Conscious Level Score18_current_stay',
 'NEWS Conscious Level Score19_current_stay',
 'NEWS Conscious Level Score20_current_stay',
 'NEWS Conscious Level Score21_current_stay',
 'NEWS Supplemental Oxygen Calc0',
 'NEWS Supplemental Oxygen Calc1',
 'NEWS Supplemental Oxygen Calc2',
 'NEWS Supplemental Oxygen Calc3',
 'NEWS Supplemental Oxygen Calc4',
 'NEWS Supplemental Oxygen Calc5',
 'NEWS Supplemental Oxygen Calc6',
 'NEWS Supplemental Oxygen Calc7',
 'NEWS Supplemental Oxygen Calc8',
 'NEWS Supplemental Oxygen Calc9',
 'NEWS Supplemental Oxygen Calc10',
 'NEWS Supplemental Oxygen Calc11',
 'NEWS Supplemental Oxygen Calc12',
 'NEWS Supplemental Oxygen Calc13',
 'NEWS Supplemental Oxygen Calc14',
 'NEWS Supplemental Oxygen Calc15',
 'NEWS Supplemental Oxygen Calc16',
 'NEWS Supplemental Oxygen Calc17',
 'NEWS Supplemental Oxygen Calc18',
 'NEWS Supplemental Oxygen Calc19',
 'NEWS Supplemental Oxygen Calc20',
 'NEWS Supplemental Oxygen Calc21',
 'NEWS Supplemental Oxygen Calc0_current_stay',
 'NEWS Supplemental Oxygen Calc1_current_stay',
 'NEWS Supplemental Oxygen Calc2_current_stay',
 'NEWS Supplemental Oxygen Calc3_current_stay',
 'NEWS Supplemental Oxygen Calc4_current_stay',
 'NEWS Supplemental Oxygen Calc5_current_stay',
 'NEWS Supplemental Oxygen Calc6_current_stay',
 'NEWS Supplemental Oxygen Calc7_current_stay',
 'NEWS Supplemental Oxygen Calc8_current_stay',
 'NEWS Supplemental Oxygen Calc9_current_stay',
 'NEWS Supplemental Oxygen Calc10_current_stay',
 'NEWS Supplemental Oxygen Calc11_current_stay',
 'NEWS Supplemental Oxygen Calc12_current_stay',
 'NEWS Supplemental Oxygen Calc13_current_stay',
 'NEWS Supplemental Oxygen Calc14_current_stay',
 'NEWS Supplemental Oxygen Calc15_current_stay',
 'NEWS Supplemental Oxygen Calc16_current_stay',
 'NEWS Supplemental Oxygen Calc17_current_stay',
 'NEWS Supplemental Oxygen Calc18_current_stay',
 'NEWS Supplemental Oxygen Calc19_current_stay',
 'NEWS Supplemental Oxygen Calc20_current_stay',
 'NEWS Supplemental Oxygen Calc21_current_stay',
 'Respiratory Rate0',
 'Respiratory Rate2',
 'Respiratory Rate3',
 'Respiratory Rate4',
 'Respiratory Rate5',
 'Respiratory Rate6',
 'Respiratory Rate7',
 'Respiratory Rate8',
 'Respiratory Rate9',
 'Respiratory Rate11',
 'Respiratory Rate12',
 'Respiratory Rate13',
 'Respiratory Rate14',
 'Respiratory Rate15',
 'Respiratory Rate16',
 'Respiratory Rate18',
 'Respiratory Rate19',
 'Respiratory Rate20',
 'Respiratory Rate21',
 'Respiratory Rate0_current_stay',
 'Respiratory Rate2_current_stay',
 'Respiratory Rate3_current_stay',
 'Respiratory Rate4_current_stay',
 'Respiratory Rate5_current_stay',
 'Respiratory Rate6_current_stay',
 'Respiratory Rate8_current_stay',
 'Respiratory Rate12_current_stay',
 'Respiratory Rate13_current_stay',
 'Respiratory Rate14_current_stay',
 'Respiratory Rate16_current_stay',
 'Respiratory Rate18_current_stay',
 'Respiratory Rate19_current_stay',
 'Respiratory Rate20_current_stay',
 'Respiratory Rate21_current_stay',
 'SpO20',
 'SpO22',
 'SpO23',
 'SpO24',
 'SpO25',
 'SpO26',
 'SpO27',
 'SpO28',
 'SpO29',
 'SpO211',
 'SpO212',
 'SpO213',
 'SpO214',
 'SpO215',
 'SpO216',
 'SpO218',
 'SpO219',
 'SpO220',
 'SpO221',
 'SpO20_current_stay',
 'SpO22_current_stay',
 'SpO23_current_stay',
 'SpO24_current_stay',
 'SpO25_current_stay',
 'SpO26_current_stay',
 'SpO28_current_stay',
 'SpO212_current_stay',
 'SpO213_current_stay',
 'SpO214_current_stay',
 'SpO216_current_stay',
 'SpO218_current_stay',
 'SpO219_current_stay',
 'SpO220_current_stay',
 'SpO221_current_stay',
 'Systolic Blood Pressure2',
 'Systolic Blood Pressure3',
 'Systolic Blood Pressure4',
 'Systolic Blood Pressure5',
 'Systolic Blood Pressure6',
 'Systolic Blood Pressure7',
 'Systolic Blood Pressure8',
 'Systolic Blood Pressure9',
 'Systolic Blood Pressure11',
 'Systolic Blood Pressure12',
 'Systolic Blood Pressure13',
 'Systolic Blood Pressure14',
 'Systolic Blood Pressure15',
 'Systolic Blood Pressure16',
 'Systolic Blood Pressure18',
 'Systolic Blood Pressure19',
 'Systolic Blood Pressure20',
 'Systolic Blood Pressure21',
 'Systolic Blood Pressure2_current_stay',
 'Systolic Blood Pressure3_current_stay',
 'Systolic Blood Pressure4_current_stay',
 'Systolic Blood Pressure5_current_stay',
 'Systolic Blood Pressure6_current_stay',
 'Systolic Blood Pressure8_current_stay',
 'Systolic Blood Pressure12_current_stay',
 'Systolic Blood Pressure13_current_stay',
 'Systolic Blood Pressure14_current_stay',
 'Systolic Blood Pressure16_current_stay',
 'Systolic Blood Pressure18_current_stay',
 'Systolic Blood Pressure19_current_stay',
 'Systolic Blood Pressure20_current_stay',
 'Systolic Blood Pressure21_current_stay',
 'Temperature0',
 'Temperature2',
 'Temperature3',
 'Temperature4',
 'Temperature5',
 'Temperature6',
 'Temperature7',
 'Temperature8',
 'Temperature9',
 'Temperature11',
 'Temperature12',
 'Temperature13',
 'Temperature14',
 'Temperature15',
 'Temperature16',
 'Temperature18',
 'Temperature19',
 'Temperature20',
 'Temperature21',
 'Temperature0_current_stay',
 'Temperature2_current_stay',
 'Temperature3_current_stay',
 'Temperature4_current_stay',
 'Temperature5_current_stay',
 'Temperature6_current_stay',
 'Temperature8_current_stay',
 'Temperature12_current_stay',
 'Temperature13_current_stay',
 'Temperature14_current_stay',
 'Temperature16_current_stay',
 'Temperature18_current_stay',
 'Temperature19_current_stay',
 'Temperature20_current_stay',
 'Temperature21_current_stay',
 'Diastolic Blood Pressure2_difference',
 'Diastolic Blood Pressure3_difference',
 'Diastolic Blood Pressure4_difference',
 'Diastolic Blood Pressure5_difference',
 'Diastolic Blood Pressure6_difference',
 'Diastolic Blood Pressure7_difference',
 'Diastolic Blood Pressure8_difference',
 'Diastolic Blood Pressure9_difference',
 'Diastolic Blood Pressure11_difference',
 'Diastolic Blood Pressure12_difference',
 'Diastolic Blood Pressure13_difference',
 'Diastolic Blood Pressure14_difference',
 'Diastolic Blood Pressure15_difference',
 'Diastolic Blood Pressure16_difference',
 'Diastolic Blood Pressure18_difference',
 'Diastolic Blood Pressure19_difference',
 'Diastolic Blood Pressure20_difference',
 'Diastolic Blood Pressure21_difference',
 'Diastolic Blood Pressure2_current_stay_difference',
 'Diastolic Blood Pressure3_current_stay_difference',
 'Diastolic Blood Pressure4_current_stay_difference',
 'Diastolic Blood Pressure5_current_stay_difference',
 'Diastolic Blood Pressure6_current_stay_difference',
 'Diastolic Blood Pressure8_current_stay_difference',
 'Diastolic Blood Pressure12_current_stay_difference',
 'Diastolic Blood Pressure13_current_stay_difference',
 'Diastolic Blood Pressure14_current_stay_difference',
 'Diastolic Blood Pressure16_current_stay_difference',
 'Diastolic Blood Pressure18_current_stay_difference',
 'Diastolic Blood Pressure19_current_stay_difference',
 'Diastolic Blood Pressure21_current_stay_difference',
 'Glasgow Coma Score0_difference',
 'Glasgow Coma Score1_difference',
 'Glasgow Coma Score2_difference',
 'Glasgow Coma Score3_difference',
 'Glasgow Coma Score4_difference',
 'Glasgow Coma Score5_difference',
 'Glasgow Coma Score6_difference',
 'Glasgow Coma Score7_difference',
 'Glasgow Coma Score8_difference',
 'Glasgow Coma Score9_difference',
 'Glasgow Coma Score10_difference',
 'Glasgow Coma Score11_difference',
 'Glasgow Coma Score12_difference',
 'Glasgow Coma Score13_difference',
 'Glasgow Coma Score14_difference',
 'Glasgow Coma Score15_difference',
 'Glasgow Coma Score16_difference',
 'Glasgow Coma Score17_difference',
 'Glasgow Coma Score18_difference',
 'Glasgow Coma Score19_difference',
 'Glasgow Coma Score20_difference',
 'Glasgow Coma Score21_difference',
 'Glasgow Coma Score0_current_stay_difference',
 'Glasgow Coma Score1_current_stay_difference',
 'Glasgow Coma Score2_current_stay_difference',
 'Glasgow Coma Score3_current_stay_difference',
 'Glasgow Coma Score4_current_stay_difference',
 'Glasgow Coma Score5_current_stay_difference',
 'Glasgow Coma Score6_current_stay_difference',
 'Glasgow Coma Score7_current_stay_difference',
 'Glasgow Coma Score8_current_stay_difference',
 'Glasgow Coma Score10_current_stay_difference',
 'Glasgow Coma Score11_current_stay_difference',
 'Glasgow Coma Score12_current_stay_difference',
 'Glasgow Coma Score13_current_stay_difference',
 'Glasgow Coma Score14_current_stay_difference',
 'Glasgow Coma Score15_current_stay_difference',
 'Glasgow Coma Score16_current_stay_difference',
 'Glasgow Coma Score17_current_stay_difference',
 'Glasgow Coma Score18_current_stay_difference',
 'Glasgow Coma Score19_current_stay_difference',
 'Glasgow Coma Score20_current_stay_difference',
 'Glasgow Coma Score21_current_stay_difference',
 'Heart Rate2_difference',
 'Heart Rate3_difference',
 'Heart Rate4_difference',
 'Heart Rate5_difference',
 'Heart Rate6_difference',
 'Heart Rate7_difference',
 'Heart Rate8_difference',
 'Heart Rate9_difference',
 'Heart Rate11_difference',
 'Heart Rate12_difference',
 'Heart Rate13_difference',
 'Heart Rate14_difference',
 'Heart Rate15_difference',
 'Heart Rate16_difference',
 'Heart Rate18_difference',
 'Heart Rate19_difference',
 'Heart Rate20_difference',
 'Heart Rate21_difference',
 'Heart Rate2_current_stay_difference',
 'Heart Rate3_current_stay_difference',
 'Heart Rate4_current_stay_difference',
 'Heart Rate5_current_stay_difference',
 'Heart Rate6_current_stay_difference',
 'Heart Rate8_current_stay_difference',
 'Heart Rate12_current_stay_difference',
 'Heart Rate13_current_stay_difference',
 'Heart Rate14_current_stay_difference',
 'Heart Rate16_current_stay_difference',
 'Heart Rate18_current_stay_difference',
 'Heart Rate19_current_stay_difference',
 'Heart Rate21_current_stay_difference',
 'Mean Arterial Pressure2_difference',
 'Mean Arterial Pressure3_difference',
 'Mean Arterial Pressure4_difference',
 'Mean Arterial Pressure5_difference',
 'Mean Arterial Pressure6_difference',
 'Mean Arterial Pressure7_difference',
 'Mean Arterial Pressure8_difference',
 'Mean Arterial Pressure9_difference',
 'Mean Arterial Pressure11_difference',
 'Mean Arterial Pressure12_difference',
 'Mean Arterial Pressure13_difference',
 'Mean Arterial Pressure14_difference',
 'Mean Arterial Pressure15_difference',
 'Mean Arterial Pressure16_difference',
 'Mean Arterial Pressure18_difference',
 'Mean Arterial Pressure19_difference',
 'Mean Arterial Pressure20_difference',
 'Mean Arterial Pressure21_difference',
 'Mean Arterial Pressure2_current_stay_difference',
 'Mean Arterial Pressure3_current_stay_difference',
 'Mean Arterial Pressure4_current_stay_difference',
 'Mean Arterial Pressure5_current_stay_difference',
 'Mean Arterial Pressure6_current_stay_difference',
 'Mean Arterial Pressure8_current_stay_difference',
 'Mean Arterial Pressure12_current_stay_difference',
 'Mean Arterial Pressure13_current_stay_difference',
 'Mean Arterial Pressure14_current_stay_difference',
 'Mean Arterial Pressure16_current_stay_difference',
 'Mean Arterial Pressure18_current_stay_difference',
 'Mean Arterial Pressure19_current_stay_difference',
 'Mean Arterial Pressure21_current_stay_difference',
 'NEWS Conscious Level Score0_difference',
 'NEWS Conscious Level Score1_difference',
 'NEWS Conscious Level Score2_difference',
 'NEWS Conscious Level Score3_difference',
 'NEWS Conscious Level Score4_difference',
 'NEWS Conscious Level Score5_difference',
 'NEWS Conscious Level Score6_difference',
 'NEWS Conscious Level Score7_difference',
 'NEWS Conscious Level Score8_difference',
 'NEWS Conscious Level Score9_difference',
 'NEWS Conscious Level Score10_difference',
 'NEWS Conscious Level Score11_difference',
 'NEWS Conscious Level Score12_difference',
 'NEWS Conscious Level Score13_difference',
 'NEWS Conscious Level Score14_difference',
 'NEWS Conscious Level Score15_difference',
 'NEWS Conscious Level Score16_difference',
 'NEWS Conscious Level Score17_difference',
 'NEWS Conscious Level Score18_difference',
 'NEWS Conscious Level Score19_difference',
 'NEWS Conscious Level Score20_difference',
 'NEWS Conscious Level Score21_difference',
 'NEWS Conscious Level Score0_current_stay_difference',
 'NEWS Conscious Level Score1_current_stay_difference',
 'NEWS Conscious Level Score2_current_stay_difference',
 'NEWS Conscious Level Score3_current_stay_difference',
 'NEWS Conscious Level Score4_current_stay_difference',
 'NEWS Conscious Level Score5_current_stay_difference',
 'NEWS Conscious Level Score6_current_stay_difference',
 'NEWS Conscious Level Score7_current_stay_difference',
 'NEWS Conscious Level Score8_current_stay_difference',
 'NEWS Conscious Level Score9_current_stay_difference',
 'NEWS Conscious Level Score10_current_stay_difference',
 'NEWS Conscious Level Score11_current_stay_difference',
 'NEWS Conscious Level Score12_current_stay_difference',
 'NEWS Conscious Level Score13_current_stay_difference',
 'NEWS Conscious Level Score14_current_stay_difference',
 'NEWS Conscious Level Score15_current_stay_difference',
 'NEWS Conscious Level Score16_current_stay_difference',
 'NEWS Conscious Level Score17_current_stay_difference',
 'NEWS Conscious Level Score18_current_stay_difference',
 'NEWS Conscious Level Score19_current_stay_difference',
 'NEWS Conscious Level Score20_current_stay_difference',
 'NEWS Conscious Level Score21_current_stay_difference',
 'NEWS Supplemental Oxygen Calc0_difference',
 'NEWS Supplemental Oxygen Calc1_difference',
 'NEWS Supplemental Oxygen Calc2_difference',
 'NEWS Supplemental Oxygen Calc3_difference',
 'NEWS Supplemental Oxygen Calc4_difference',
 'NEWS Supplemental Oxygen Calc5_difference',
 'NEWS Supplemental Oxygen Calc6_difference',
 'NEWS Supplemental Oxygen Calc7_difference',
 'NEWS Supplemental Oxygen Calc8_difference',
 'NEWS Supplemental Oxygen Calc9_difference',
 'NEWS Supplemental Oxygen Calc10_difference',
 'NEWS Supplemental Oxygen Calc11_difference',
 'NEWS Supplemental Oxygen Calc12_difference',
 'NEWS Supplemental Oxygen Calc13_difference',
 'NEWS Supplemental Oxygen Calc14_difference',
 'NEWS Supplemental Oxygen Calc15_difference',
 'NEWS Supplemental Oxygen Calc16_difference',
 'NEWS Supplemental Oxygen Calc17_difference',
 'NEWS Supplemental Oxygen Calc18_difference',
 'NEWS Supplemental Oxygen Calc19_difference',
 'NEWS Supplemental Oxygen Calc20_difference',
 'NEWS Supplemental Oxygen Calc21_difference',
 'NEWS Supplemental Oxygen Calc0_current_stay_difference',
 'NEWS Supplemental Oxygen Calc1_current_stay_difference',
 'NEWS Supplemental Oxygen Calc2_current_stay_difference',
 'NEWS Supplemental Oxygen Calc3_current_stay_difference',
 'NEWS Supplemental Oxygen Calc4_current_stay_difference',
 'NEWS Supplemental Oxygen Calc5_current_stay_difference',
 'NEWS Supplemental Oxygen Calc6_current_stay_difference',
 'NEWS Supplemental Oxygen Calc7_current_stay_difference',
 'NEWS Supplemental Oxygen Calc8_current_stay_difference',
 'NEWS Supplemental Oxygen Calc10_current_stay_difference',
 'NEWS Supplemental Oxygen Calc12_current_stay_difference',
 'NEWS Supplemental Oxygen Calc13_current_stay_difference',
 'NEWS Supplemental Oxygen Calc14_current_stay_difference',
 'NEWS Supplemental Oxygen Calc15_current_stay_difference',
 'NEWS Supplemental Oxygen Calc16_current_stay_difference',
 'NEWS Supplemental Oxygen Calc18_current_stay_difference',
 'NEWS Supplemental Oxygen Calc19_current_stay_difference',
 'NEWS Supplemental Oxygen Calc20_current_stay_difference',
 'NEWS Supplemental Oxygen Calc21_current_stay_difference',
 'Respiratory Rate2_difference',
 'Respiratory Rate3_difference',
 'Respiratory Rate4_difference',
 'Respiratory Rate5_difference',
 'Respiratory Rate6_difference',
 'Respiratory Rate7_difference',
 'Respiratory Rate8_difference',
 'Respiratory Rate9_difference',
 'Respiratory Rate11_difference',
 'Respiratory Rate12_difference',
 'Respiratory Rate13_difference',
 'Respiratory Rate14_difference',
 'Respiratory Rate15_difference',
 'Respiratory Rate16_difference',
 'Respiratory Rate18_difference',
 'Respiratory Rate19_difference',
 'Respiratory Rate20_difference',
 'Respiratory Rate21_difference',
 'Respiratory Rate0_current_stay_difference',
 'Respiratory Rate2_current_stay_difference',
 'Respiratory Rate3_current_stay_difference',
 'Respiratory Rate4_current_stay_difference',
 'Respiratory Rate5_current_stay_difference',
 'Respiratory Rate6_current_stay_difference',
 'Respiratory Rate8_current_stay_difference',
 'Respiratory Rate12_current_stay_difference',
 'Respiratory Rate13_current_stay_difference',
 'Respiratory Rate14_current_stay_difference',
 'Respiratory Rate16_current_stay_difference',
 'Respiratory Rate18_current_stay_difference',
 'Respiratory Rate19_current_stay_difference',
 'Respiratory Rate21_current_stay_difference',
 'SpO22_difference',
 'SpO23_difference',
 'SpO24_difference',
 'SpO25_difference',
 'SpO26_difference',
 'SpO27_difference',
 'SpO28_difference',
 'SpO29_difference',
 'SpO211_difference',
 'SpO212_difference',
 'SpO213_difference',
 'SpO214_difference',
 'SpO215_difference',
 'SpO216_difference',
 'SpO218_difference',
 'SpO219_difference',
 'SpO220_difference',
 'SpO221_difference',
 'SpO22_current_stay_difference',
 'SpO23_current_stay_difference',
 'SpO24_current_stay_difference',
 'SpO25_current_stay_difference',
 'SpO26_current_stay_difference',
 'SpO28_current_stay_difference',
 'SpO212_current_stay_difference',
 'SpO213_current_stay_difference',
 'SpO214_current_stay_difference',
 'SpO216_current_stay_difference',
 'SpO218_current_stay_difference',
 'SpO219_current_stay_difference',
 'SpO221_current_stay_difference',
 'Systolic Blood Pressure2_difference',
 'Systolic Blood Pressure3_difference',
 'Systolic Blood Pressure4_difference',
 'Systolic Blood Pressure5_difference',
 'Systolic Blood Pressure6_difference',
 'Systolic Blood Pressure7_difference',
 'Systolic Blood Pressure8_difference',
 'Systolic Blood Pressure9_difference',
 'Systolic Blood Pressure11_difference',
 'Systolic Blood Pressure12_difference',
 'Systolic Blood Pressure13_difference',
 'Systolic Blood Pressure14_difference',
 'Systolic Blood Pressure15_difference',
 'Systolic Blood Pressure16_difference',
 'Systolic Blood Pressure18_difference',
 'Systolic Blood Pressure19_difference',
 'Systolic Blood Pressure20_difference',
 'Systolic Blood Pressure21_difference',
 'Systolic Blood Pressure2_current_stay_difference',
 'Systolic Blood Pressure3_current_stay_difference',
 'Systolic Blood Pressure4_current_stay_difference',
 'Systolic Blood Pressure5_current_stay_difference',
 'Systolic Blood Pressure6_current_stay_difference',
 'Systolic Blood Pressure8_current_stay_difference',
 'Systolic Blood Pressure12_current_stay_difference',
 'Systolic Blood Pressure13_current_stay_difference',
 'Systolic Blood Pressure14_current_stay_difference',
 'Systolic Blood Pressure16_current_stay_difference',
 'Systolic Blood Pressure18_current_stay_difference',
 'Systolic Blood Pressure19_current_stay_difference',
 'Systolic Blood Pressure21_current_stay_difference',
 'Temperature2_difference',
 'Temperature3_difference',
 'Temperature4_difference',
 'Temperature5_difference',
 'Temperature6_difference',
 'Temperature7_difference',
 'Temperature8_difference',
 'Temperature9_difference',
 'Temperature11_difference',
 'Temperature12_difference',
 'Temperature13_difference',
 'Temperature14_difference',
 'Temperature15_difference',
 'Temperature16_difference',
 'Temperature18_difference',
 'Temperature19_difference',
 'Temperature20_difference',
 'Temperature21_difference',
 'Temperature2_current_stay_difference',
 'Temperature3_current_stay_difference',
 'Temperature4_current_stay_difference',
 'Temperature5_current_stay_difference',
 'Temperature6_current_stay_difference',
 'Temperature8_current_stay_difference',
 'Temperature12_current_stay_difference',
 'Temperature13_current_stay_difference',
 'Temperature14_current_stay_difference',
 'Temperature16_current_stay_difference',
 'Temperature18_current_stay_difference',
 'Temperature19_current_stay_difference',
 'Temperature21_current_stay_difference']

In [5]:
stays = preprocessing_data['stay_id'].unique()
# Random shuffle
random.Random(5).shuffle(stays)
preprocessing_data2 = preprocessing_data.set_index("stay_id").loc[stays].reset_index()

# Filter for features
X_data = preprocessing_data2.drop(columns=['stay_id', 'po_flag'])
X_data = X_data.drop(columns=columns_to_drop)
model_data = pd.concat([preprocessing_data2[['stay_id', 'po_flag']], X_data], axis=1)

# Split into training, validation and testing
n = round(0.7 * len(stays))
n2 = round(0.85 * len(stays))
train_stays = stays[:n]
validation_stays = stays[n:n2]
test_stays = stays[n2:]
train_data = model_data[model_data['stay_id'].isin(train_stays)]
valid_data = model_data[model_data['stay_id'].isin(validation_stays)]
test_data = model_data[model_data['stay_id'].isin(test_stays)]

In [6]:
# Oversampling train set
from typing import Counter
import imblearn
from imblearn.over_sampling import SMOTE
# Split X y
train_data_X = train_data.drop(columns=['stay_id', 'po_flag'])
train_data_y = train_data['po_flag']
Counter(train_data_y)
oversample = SMOTE()
train_data_X, train_data_y = oversample.fit_resample(train_data_X, train_data_y)
Counter(train_data_y)
train_data_y = pd.DataFrame(train_data_y, columns=['po_flag'])
train_data_y['stay_id'] = 'x'
train_data = pd.concat([train_data_y, train_data_X], axis=1)
train_data = train_data.sample(frac=1, random_state=0).reset_index(drop=True)

Counter({1: 11725, 0: 8334})

Counter({1: 11725, 0: 11725})

In [7]:
def define_model(trial):
    # We optimize the number of layers, hidden units and dropout ratio in each layer.
    n_layers = trial.suggest_int("n_layers", 1, 5)
    layers = []
    in_features = 253
    layers.append(nn.BatchNorm1d(in_features))
    p = trial.suggest_categorical("dropout", [0, 0.1, 0.2, 0.3, 0.4, 0.5])

    for i in range(n_layers):
        out_features = trial.suggest_categorical("n_units_l{}".format(i), [2, 4, 8, 16, 32, 64, 128, 256, 512])
        #layers.append(nn.BatchNorm1d(in_features))
        layers.append(nn.Linear(in_features, out_features))
        layers.append(nn.ReLU())
        #p = trial.suggest_float("dropout_l{}".format(i), 0, 0.5)
        layers.append(nn.Dropout(p))

        in_features = out_features
    layers.append(nn.Linear(in_features, 1))

    return nn.Sequential(*layers)

In [8]:
def objective(trial):

    # Generate the model.
    model = define_model(trial)

    # Generate the optimizers.
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
    #lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    lr = trial.suggest_categorical("lr", [0.01, 0.001, 0.0001])
    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
    batch_size = trial.suggest_categorical("batch_size", [128, 256, 512])

    # Get the dataset.
    train_dataset = MIMICDataset(train_data)
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, collate_fn=train_dataset.collate_fn_padd)

    valid_dataset = MIMICDataset(valid_data)
    valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=batch_size, collate_fn=valid_dataset.collate_fn_padd)

    # Define loss
    criterion = nn.BCEWithLogitsLoss()

    # Training of the model.
    for epoch in range(EPOCHS):
        model.train()
        epoch_loss = 0
        batch_prediction_list = []
        batch_label_list = []
        for batch_idx, (labels, features) in enumerate(train_dataloader):
            features = features.float()
            labels = labels.float()

            optimizer.zero_grad()
            features2 = torch.permute(features, (1, 0))
            labels2 = torch.permute(labels, (1, 0))
            output = model(features2)
            loss = criterion(output, labels2)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            sig = torch.nn.Sigmoid()
            output = sig(output)      
            np_predictions = output.cpu().detach().numpy()
            np_labels = labels2.cpu().detach().numpy()
            np_predictions = np_predictions.squeeze()
            np_labels = np_labels.squeeze()
            np_predictions = np_predictions.flatten()
            np_labels = np_labels.flatten()
            for x in np_predictions:
                batch_prediction_list.append(x)
            for x in np_labels:
                batch_label_list.append(x)
        final_predictions = np.array(batch_prediction_list)
        final_labels = np.array(batch_label_list)
        train_auroc = roc_auc_score(final_labels, final_predictions)
        train_loss = epoch_loss / len(train_dataloader)
        print('epoch:', epoch)
        print('train_auroc:', train_auroc)
        print('train_loss:', train_loss)

        # Validation of the model.
        model.eval()
        epoch_loss = 0
        batch_prediction_list = []
        batch_label_list = []
        with torch.no_grad():
            for batch_idx, (labels, features) in enumerate(valid_dataloader):
                features = features.float()
                labels = labels.float()

                features2 = torch.permute(features, (1, 0))
                labels2 = torch.permute(labels, (1, 0))
                output = model(features2)
                loss = criterion(output, labels2)
                epoch_loss += loss.item()
                sig = torch.nn.Sigmoid()
                output = sig(output)  
                np_predictions = output.cpu().detach().numpy()
                np_labels = labels2.cpu().detach().numpy()
                np_predictions = np_predictions.squeeze()
                np_labels = np_labels.squeeze()
                np_predictions = np_predictions.flatten()
                np_labels = np_labels.flatten()
                # Create list
                for x in np_predictions:
                    batch_prediction_list.append(x)
                for x in np_labels:
                    batch_label_list.append(x)
        final_predictions = np.array(batch_prediction_list)
        final_labels = np.array(batch_label_list)
        valid_auroc = roc_auc_score(final_labels, final_predictions)
        valid_loss = epoch_loss / len(train_dataloader)
        print('valid_auroc:', valid_auroc)
        print('valid_loss:', valid_loss)
        #print(type(valid_auroc))

        trial.report(valid_auroc, epoch)

        if valid_auroc == 0.5:
            raise optuna.exceptions.TrialPruned()

        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return valid_auroc

In [10]:
class MIMICDataset(Dataset):

    def __init__(self, path):
        self.mimic_df = path

    def __len__(self):
        return len(self.mimic_df)

    def __getitem__(self, idx):
        selected_rows = self.mimic_df.iloc[[idx]] # Double brackets to return df
        selected_rows.drop(['stay_id'], axis=1, inplace=True)
        labels = selected_rows[['po_flag']].to_numpy()
        features = selected_rows.drop(['po_flag'], axis=1).to_numpy()
        sample = {"labels": torch.from_numpy(labels).squeeze(0), "features": torch.from_numpy(features).squeeze(0)}
        return sample

    def collate_fn_padd(self, batch):
        # Extract contexts and actions from data list
        labels = [sample["labels"] for sample in batch]
        features = [sample["features"] for sample in batch]

        pad_key = 123456  # Key to identify padded values
        padded_features = torch.nn.utils.rnn.pad_sequence(features, batch_first=False, padding_value=pad_key) # Keep to convert to tensor
        padded_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=False, padding_value=pad_key)
        
        return padded_labels, padded_features


In [11]:
EPOCHS = 10
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=10)

pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

print("Study statistics:")
print("Number of finished trials:", len(study.trials))
print("Number of pruned trials:", len(pruned_trials))
print("Number of complete trials:", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print("Value:", trial.value)

print("Params:")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

[I 2024-01-22 13:31:24,334] A new study created in memory with name: no-name-dd812fe9-fb4a-4340-a20e-3f3cc2a73fd6
[I 2024-01-22 13:38:12,826] Trial 0 finished with value: 0.7837282667910515 and parameters: {'n_layers': 3, 'dropout': 0, 'n_units_l0': 128, 'n_units_l1': 16, 'n_units_l2': 512, 'optimizer': 'Adam', 'lr': 0.01, 'batch_size': 512}. Best is trial 0 with value: 0.7837282667910515.
[I 2024-01-22 13:44:48,651] Trial 1 finished with value: 0.7583156785318802 and parameters: {'n_layers': 4, 'dropout': 0.2, 'n_units_l0': 4, 'n_units_l1': 256, 'n_units_l2': 8, 'n_units_l3': 2, 'optimizer': 'Adam', 'lr': 0.001, 'batch_size': 256}. Best is trial 0 with value: 0.7837282667910515.
[I 2024-01-22 13:51:28,495] Trial 2 finished with value: 0.7509335384257496 and parameters: {'n_layers': 5, 'dropout': 0.3, 'n_units_l0': 4, 'n_units_l1': 32, 'n_units_l2': 8, 'n_units_l3': 2, 'n_units_l4': 512, 'optimizer': 'Adam', 'lr': 0.0001, 'batch_size': 128}. Best is trial 0 with value: 0.78372826679105

epoch: 0
train_auroc: 0.7943011606602988
train_loss: 0.548647693965746
valid_auroc: 0.8150157795351476
valid_loss: 0.10335239444089972
epoch: 1
train_auroc: 0.8273842653925014
train_loss: 0.502729295388512
valid_auroc: 0.8227163588500108
valid_loss: 0.0992954820394516
epoch: 2
train_auroc: 0.838936575120135
train_loss: 0.48660575112570886
valid_auroc: 0.814843563623157
valid_loss: 0.10260871182317319
epoch: 3
train_auroc: 0.8520216038297698
train_loss: 0.4681334974973098
valid_auroc: 0.8135904584551366
valid_loss: 0.10402940926344498
epoch: 4
train_auroc: 0.8597516577938817
train_loss: 0.45513190717800805
valid_auroc: 0.798871279584403
valid_loss: 0.10719509941080342
epoch: 5
train_auroc: 0.8730952341551457
train_loss: 0.4342272793469222
valid_auroc: 0.8009891436323434
valid_loss: 0.11348708233107692
epoch: 6
train_auroc: 0.8853940980446534
train_loss: 0.41288857680300006
valid_auroc: 0.8044544696022893
valid_loss: 0.1128316651219907
epoch: 7
train_auroc: 0.8960811307459049
train_loss:

In [12]:
print("Study statistics:")
print("Number of finished trials:", len(study.trials))
print("Number of pruned trials:", len(pruned_trials))
print("Number of complete trials:", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print("AUROC Value:", trial.value)

print("Params:")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

Study statistics:
Number of finished trials: 10
Number of pruned trials: 3
Number of complete trials: 7
Best trial:
AUROC Value: 0.8230529572830951
Params:
    n_layers: 1
    dropout: 0
    n_units_l0: 256
    optimizer: RMSprop
    lr: 0.0001
    batch_size: 256
