## Demo
Demo notebook for the project.

In [1]:
# Default imports
import os
import random

# External lib imports
import torch
import pandas as pd
import tensorflow as tf
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score

# Local imports
from Model import get_predictions, get_names_of_predefined_feature_sets, PatientDeteriorationDataset, Respiratory_Deteroration
from proposed_algo import self_aware_SGD

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
SEED = 10
random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7ff8e5fdc0d0>

Load initial Respiratory Deterioration model

In [2]:
model = Respiratory_Deteroration()
model.build((None,77))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['AUC'])
W = model.get_weights()
model.load_weights('./2016.h5') # load initial model

Evaluation on initial test data

In [3]:
# list of 77 features
used_features = get_names_of_predefined_feature_sets()
used_features.append('label')

# evaluation on initial test data
test_data = pd.read_csv('./data/2017_test.csv')
test_dataset = PatientDeteriorationDataset(test_data, used_features, 'label')
test_loader = DataLoader(test_dataset, batch_size=2048, shuffle=False)
P, L = get_predictions(model, test_loader)
auc = roc_auc_score(L, P)
print(auc)

0.8814187935419315


Load incremental data

In [4]:
MLP_BATCH_SIZE = 2048

train_data = pd.read_csv('./data/2017_train.csv')
val_data = pd.read_csv('./data/2017_val.csv')

train_data = PatientDeteriorationDataset(train_data, used_features, 'label')
val_data = PatientDeteriorationDataset(val_data, used_features, 'label')
    
train_loader = DataLoader(train_data, batch_size=MLP_BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_data, batch_size=MLP_BATCH_SIZE, shuffle=False)

In [5]:
def loss_fn(pred_y, y):
    return tf.keras.backend.mean(tf.keras.losses.binary_crossentropy(y, pred_y))

In [None]:
name = '2017_inc'
opt = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9, nesterov=True)
model = self_aware_SGD(model, W, train_loader, val_loader, loss_fn, opt, name, epochs=50)

316it [01:33,  3.56it/s]