In [27]:
import os
import numpy as np
import pandas as pd
from joblib import dump, load
import sys
import seaborn as sns
import json
sys.path.append(os.path.join(os.path.abspath('../'), 'predictions_collapsed'))
sys.path.append(os.path.join(os.path.abspath('../'), 'src'))
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import torch
import skorch
import glob
from config_loader import (
    D_CONFIG,
    DATASET_STD_PATH, DATASET_SPLIT_PATH,
    DATASET_PERTSTEP_SPLIT_PATH, PROJECT_REPO_DIR, PROJECT_CONDA_ENV_YAML,
    RESULTS_PATH, RESULTS_PERTSTEP_PATH, CLF_TRAIN_TEST_SPLIT_PATH)

sys.path.append(os.path.join(PROJECT_REPO_DIR, 'src'))
sys.path.append(os.path.join(PROJECT_REPO_DIR, 'src', 'rnn'))
from feature_transformation import *
from filter_admissions_by_tslice import get_preprocessed_data
from merge_features_all_tslices import merge_data_dicts, get_all_features_data
import matplotlib.pyplot as plt
from sklearn.metrics import (accuracy_score, balanced_accuracy_score, f1_score,
                             average_precision_score, confusion_matrix, log_loss,
                             roc_auc_score, roc_curve, precision_recall_curve)
from utils import load_data_dict_json
from dataset_loader import TidySequentialDataCSVLoader
from RNNBinaryClassifier import RNNBinaryClassifier
from sklearn.model_selection import GridSearchCV



## Get the labs, vitals and static features from all hospital admissions
- Merge the labs, vitals and demographics into a giant feature dataframe

## Split them into train and test
- Use split_dataset.py

## Read the features into the RNN Binary Classifier 
- Use the dataset_loader.py in time-series-prediction/src/rnn/ to load the data with different time-lengths for rnn train and test

## Train classifier with RNN


In [16]:
labs_df, labs_data_dict, vitals_df, vitals_data_dict, \
    demographics_df, demographics_data_dict, outcomes_df, outcomes_data_dict = get_preprocessed_data(DATASET_STD_PATH)
features_df,features_data_dict = get_all_features_data(labs_df, labs_data_dict, 
                                                        vitals_df, vitals_data_dict, 
                                                        demographics_df, demographics_data_dict)


feature_cols = parse_feature_cols(features_data_dict)
id_cols = parse_id_cols(features_data_dict)
time_col = parse_time_col(features_data_dict)
## impute values
features_df = features_df.groupby(id_cols).apply(lambda x: x.fillna(method='pad')).copy()
for feature_col in feature_cols:
    features_df[feature_col].fillna(features_df[feature_col].mean(), inplace=True)

# sort by ids and timestamp
features_df.sort_values(by=id_cols+[time_col], inplace=True)
outcomes_df.sort_values(by=id_cols, inplace=True)

x_csv_filename = os.path.join(DATASET_SPLIT_PATH, 'x_train.csv')
x_data_dict_filename = os.path.join(DATASET_SPLIT_PATH, 'x_dict.json')
features_df.to_csv(x_csv_filename, index=False)    

# with open(x_data_dict_filename, 'w') as f:
#     json.dump(features_data_dict, f, indent=4)


y_csv_filename = os.path.join(DATASET_SPLIT_PATH, 'y_train.csv')
y_data_dict_filename = os.path.join(DATASET_SPLIT_PATH, 'y_dict.json')
outcomes_df.to_csv(y_csv_filename, index=False)    

# with open(y_data_dict_filename, 'w') as f:
#     json.dump(outcomes_data_dict, f, indent=4)

In [21]:
train_vitals = TidySequentialDataCSVLoader(
    x_csv_path=x_csv_filename,
    y_csv_path=y_csv_filename,
    x_col_names=feature_cols,
    idx_col_names=id_cols,
    y_col_name='clinical_deterioration_outcome',
    y_label_type='per_sequence'
)

X_train, y_train = train_vitals.get_batch_data(batch_id=0)

In [30]:
learning_rate = [0.0001, 0.001, 0.01, 0.1, 1]
hyperparameters = dict(lr=learning_rate)

# grid search
rnn = RNNBinaryClassifier(
    max_epochs=10,
    batch_size=1024,
    device='cpu',
    lr=0.001,
    callbacks=[
            skorch.callbacks.EpochScoring('roc_auc', lower_is_better=False, on_train=True, name='aucroc_score_train'),
            skorch.callbacks.EpochScoring('roc_auc', lower_is_better=False, on_train=False, name='aucroc_score_valid'),
    ],
    module__rnn_type='LSTM',
    module__n_inputs=X_train.shape[-1],
    module__n_hiddens=32,
    module__n_layers=2,
    optimizer=torch.optim.Adam)

# classifier = GridSearchCV(rnn, hyperparameters, n_jobs=-1, cv=5, verbose=10)
clf = rnn.fit(X_train, y_train)

  epoch    aucroc_score_train    aucroc_score_valid    train_loss    valid_loss       dur
-------  --------------------  --------------------  ------------  ------------  --------
      1                [36m0.4935[0m                [32m0.5772[0m        [35m0.5168[0m        [31m0.3613[0m  296.4492
      2                [36m0.5046[0m                [32m0.7123[0m        [35m0.3408[0m        [31m0.3292[0m  293.0061
      3                [36m0.5300[0m                0.7024        [35m0.3315[0m        [31m0.3271[0m  292.9192
      4                [36m0.5479[0m                0.7016        [35m0.3303[0m        [31m0.3264[0m  291.6375
      5                [36m0.5562[0m                0.7001        [35m0.3298[0m        [31m0.3260[0m  293.5312
      6                [36m0.5619[0m                0.7002        [35m0.3294[0m        [31m0.3257[0m  295.8792
      7                [36m0.5667[0m                0.6994        [35m0.3292[0m        [31m0.32

In [15]:
features_df

Unnamed: 0,patient_id,hospital_admission_id,facility_code,hours_since_admission,CO2_venous_blood,alanine_aminotransferase,albumin_in_serum,alkaline_phosphatase,aspartate_aminotransferase,basophils,...,sodium_in_serum,systolic_blood_pressure,triglyceride_in_serum,urate_in_serum,weight,birth_date,admission_timestamp,age_at_admission,gender_is_male,gender_is_unknown
0,124,16817248,1,1.607,44.063773,51.378171,3.204043,118.535469,61.058401,0.032526,...,138.086294,126.0,128.89276,5.580641,75.343354,1946-11-21,2019-08-17 22:58:00,72.786301,1.0,0.0
1,124,16817248,1,8.379,44.063773,51.378171,3.204043,118.535469,61.058401,0.032526,...,138.086294,126.0,128.89276,5.580641,75.343354,1946-11-21,2019-08-17 22:58:00,72.786301,1.0,0.0
2,124,16817248,1,15.459,44.063773,51.378171,3.204043,118.535469,61.058401,0.032526,...,138.086294,131.0,128.89276,5.580641,75.343354,1946-11-21,2019-08-17 22:58:00,72.786301,1.0,0.0
3,124,16817248,1,18.656,44.063773,51.378171,3.204043,118.535469,61.058401,0.032526,...,138.086294,146.0,128.89276,5.580641,75.343354,1946-11-21,2019-08-17 22:58:00,72.786301,1.0,0.0
4,124,16817248,1,19.072,44.063773,51.378171,3.204043,118.535469,61.058401,0.032526,...,138.086294,146.0,128.89276,5.580641,75.343354,1946-11-21,2019-08-17 22:58:00,72.786301,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1540940,1999981,16226113,1,251.133,55.000000,32.000000,3.800000,67.000000,30.000000,0.090000,...,138.700000,122.0,136.00000,5.400000,80.000000,1987-02-15,2023-01-05 21:36:00,35.912329,0.0,0.0
1540941,1999981,16226113,1,251.400,55.000000,32.000000,3.700000,67.000000,30.000000,0.090000,...,138.700000,122.0,136.00000,5.400000,80.000000,1987-02-15,2023-01-05 21:36:00,35.912329,0.0,0.0
1540942,1999981,16226113,1,251.500,55.000000,32.000000,3.700000,67.000000,30.000000,0.090000,...,138.700000,122.0,136.00000,5.400000,80.000000,1987-02-15,2023-01-05 21:36:00,35.912329,0.0,0.0
1540943,1999981,16226113,1,251.967,55.000000,32.000000,3.700000,67.000000,30.000000,0.090000,...,139.300000,122.0,136.00000,5.400000,80.000000,1987-02-15,2023-01-05 21:36:00,35.912329,0.0,0.0


In [26]:
ii

41