## Mortality Prediction

### Import Packages

In [58]:
import sys
import warnings
sys.path.append("../")
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import pickle
from data_class import embeddings
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn import metrics
from sklearn.ensemble import RandomForestClassifier

### Data Loading

In [59]:
expire_flag_df = pd.read_csv("./embeddings/flag.csv")
print(expire_flag_df)

       ID  flag
0    S000     1
1    S001     1
2    S002     1
3    S003     1
4    S004     1
..    ...   ...
195  S195     0
196  S196     0
197  S197     0
198  S198     0
199  S199     0

[200 rows x 2 columns]


In [60]:
subject_data_list = []
for subject_id in expire_flag_df.ID.values:
    with open("./embeddings/%s.pkl"%(subject_id), "rb") as f:
        data = pickle.load(f)
    subject_data_list.append(data)

In [61]:
subject_data = subject_data_list[0]

print("-"*50)
print("# tabular embeddings:", len(subject_data.tabular_emb))
print(subject_data.tabular_emb.to_frame().T)

print("-"*50)
print("# timeseries(Vitalsigns) embeddings:", len(subject_data.timeseries_emb))
print(subject_data.timeseries_emb.to_frame().T)

print("-"*50)
print("# signal(ECG) embeddings:", len(subject_data.signal_emb))
print(subject_data.signal_emb.to_frame().T)

print("-"*50)
print("# image(Chest X-Ray) embeddings:", len(subject_data.image_emb))
print(subject_data.image_emb.to_frame().T)

print("-"*50)
print("# note(Chest X-Ray) embeddings:", len(subject_data.note_emb))
print(subject_data.note_emb.to_frame().T)

--------------------------------------------------
# tabular embeddings: 36
    age  gender  ...  red_blood_cells  white_blood_cells
0  78.0     1.0  ...             3.28                3.9

[1 rows x 36 columns]
--------------------------------------------------
# timeseries(Vitalsigns) embeddings: 55
   heart_rate_max  heart_rate_min  ...  blood_oxygen_npeaks  blood_oxygen_trend
0           114.0            72.0  ...                  6.0           -0.060769

[1 rows x 55 columns]
--------------------------------------------------
# signal(ECG) embeddings: 326
      1dAVb          RBBB  ...  ecg_dense_318  ecg_dense_319
0  0.000176  8.829986e-07  ...            0.0            0.0

[1 rows x 326 columns]
--------------------------------------------------
# image(Chest X-Ray) embeddings: 1042
   Atelectasis  Consolidation  ...  cxr_dense_1022  cxr_dense_1023
0     0.243355       0.174304  ...        0.002957        0.250387

[1 rows x 1042 columns]
--------------------------------------

### Read data from different modalities

In [62]:
def get_data(subject_data_list, tabular = True, timeseries = True, signal = True, note = True, image = True):
    data = pd.DataFrame()
    for subject_data in subject_data_list:
        emb_list = []
        if tabular:
            emb_list.append(subject_data.tabular_emb)
        if timeseries:
            emb_list.append(subject_data.timeseries_emb)
        if signal:
            emb_list.append(subject_data.signal_emb)
        if note:
            emb_list.append(subject_data.note_emb)
        if image:
            emb_list.append(subject_data.image_emb)
    
        subject_emb = pd.concat(emb_list, axis = 0)
        data = pd.concat([data, subject_emb.to_frame().T] ,ignore_index=True)
    return data

data = get_data(subject_data_list)
print(data)

      age  gender  Septicemia  ...  cxr_dense_1021  cxr_dense_1022  cxr_dense_1023
0    78.0     1.0         0.0  ...        0.003784        0.002957        0.250387
1    66.0     0.0         0.0  ...        0.007224        0.002480        0.006324
2    68.0     1.0         0.0  ...        0.003386        0.003358        0.405208
3    64.0     0.0         0.0  ...        0.014333        0.012056        0.073979
4    61.0     1.0         0.0  ...        0.000000        0.000000        0.058607
..    ...     ...         ...  ...             ...             ...             ...
195  65.0     0.0         0.0  ...        0.000331        0.009345        0.000195
196  43.0     1.0         0.0  ...        0.000000        0.019390        0.000000
197  84.0     1.0         0.0  ...        0.002252        0.005291        0.368608
198  84.0     1.0         0.0  ...        0.000000        0.017738        0.029614
199  83.0     0.0         0.0  ...        0.045845        0.044703        0.000000

[20

### Model Training

In [63]:
def run_model(data):
    train_id, test_id = train_test_split(range(len(data)), test_size=0.2, random_state=0)
    x_train = data.iloc[train_id].values
    y_train = expire_flag_df.iloc[train_id].flag.values
    
    x_test = data.iloc[test_id].values
    y_test = expire_flag_df.iloc[test_id].flag.values

    gs_metric = 'roc_auc'
    param_grid = {'max_depth': [1,2,3],
                  'n_estimators': [10, 20, 30, 40]}
    
    clf = RandomForestClassifier(random_state=0)
    gs = GridSearchCV(estimator = clf, param_grid=param_grid, scoring=gs_metric, cv=5)
    gs.fit(x_train, y_train)
    
    #print(gs.best_params_)
    y_pred_prob_train = gs.predict_proba(x_train)
    y_pred_train = gs.predict(x_train)
    y_pred_prob_test = gs.predict_proba(x_test)
    y_pred_test = gs.predict(x_test)
    
    auc_train =  metrics.roc_auc_score(y_train, y_pred_prob_train[:,1])
    print(f'AUC for Training Set is: {auc_train}')
    
    auc_test =  metrics.roc_auc_score(y_test, y_pred_prob_test[:,1])
    print(f'AUC for Testing Set is: {auc_test}')
    return auc_train, auc_test

In [64]:
print("Multiple Modalities")
auc_train, auc_test = run_model(data)

Multiple Modalities
AUC for Training Set is: 0.8227016885553471
AUC for Testing Set is: 0.75


In [65]:
print("Tabular Data")
tabular_data = get_data(subject_data_list, tabular = True, timeseries = False, signal = False, note = False, image = False)
auc_train, auc_test = run_model(tabular_data)

Tabular Data
AUC for Training Set is: 0.940744215134459
AUC for Testing Set is: 0.6085858585858587


In [66]:
print("TimeSeries Data")
timeseries_data = get_data(subject_data_list, tabular = False, timeseries = True, signal = False, note = False, image = False)
auc_train, auc_test = run_model(timeseries_data)

TimeSeries Data
AUC for Training Set is: 0.9122889305816135
AUC for Testing Set is: 0.7095959595959596


In [67]:
print("Signal Data")
signal_data = get_data(subject_data_list, tabular = False, timeseries = False, signal = True, note = False, image = False)
auc_train, auc_test = run_model(signal_data)

Signal Data
AUC for Training Set is: 0.7547686053783615
AUC for Testing Set is: 0.5959595959595959


In [68]:
print("Note Data")
note_data = get_data(subject_data_list, tabular = False, timeseries = False, signal = False, note = True, image = False)
auc_train, auc_test = run_model(note_data)

Note Data
AUC for Training Set is: 0.8959505941213258
AUC for Testing Set is: 0.601010101010101


In [69]:
print("Image Data")
image_data = get_data(subject_data_list, tabular = False, timeseries = False, signal = False, note = False, image = True)
auc_train, auc_test = run_model(image_data)

Image Data
AUC for Training Set is: 0.9409787367104441
AUC for Testing Set is: 0.6363636363636364
