In [7]:
### Tools and Packages
##Basics
import pandas as pd
import numpy as np
import sys, random
import math
try:
    import cPickle as pickle
except:
    import pickle
import string
import re
import os
import time
from tqdm import tqdm

## ML and Stats 
from sklearn import datasets, linear_model
from sklearn.linear_model import LinearRegression
import sklearn.metrics as m
import sklearn.linear_model  as lm
import lifelines
from sklearn.metrics import roc_auc_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.tree import export_graphviz
import statsmodels.formula.api as sm
import patsy
from scipy import stats
from termcolor import colored


## Visualization
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
%matplotlib inline
import plotly as py
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
init_notebook_mode(connected=True)
import plotly.tools as tls
import plotly.graph_objs as go
from plotly.graph_objs import *
from IPython.display import HTML

## DL Framework
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim

###GPU enabling and device allocation
use_cuda = torch.cuda.is_available()
#torch.cuda.set_device(1) ## uncomment if you need to specify specific GPU

#use_cuda=False ## uncomment if you need explicitly to not use GPU

from importlib import reload

### import pytorch ehr files
#import sys
#sys.path.insert(0, '../ehr_pytorch')

import pytorch_ehr_3.models as model 
from pytorch_ehr_3.EHRDataloader import EHRdataloader
from pytorch_ehr_3.EHRDataloader import EHRdataFromLoadedPickles as EHRDataset
import pytorch_ehr_3.utils as ut 
from pytorch_ehr_3.EHREmb import EHREmbeddings


### Data Prepartion

In [None]:
### Read the header of data_preprocess_v4.py for more information

!python data_preprocess_v5.py data_tab_delimited.txt label_tab_delimited.txt medicalcode_mapping_to_token_file.types output_folder/file_prefix nosplit


### Data Loading

In [8]:
### load Data
test_sl= pickle.load(open('output_folder/file_prefix.combined.all', 'rb'), encoding='bytes')

6724


In [None]:
#### In order to avoid any errors from patients who had new medical codes, that were not used during the pretraining
## you may need to exclude those patients from your test set, using some statement like:

test_sl_n=[]
for x in test_sl:
       if (max(max(x[-1], key=lambda xmb: max(xmb[1]))[1]))<123642 : train_sl_1.append(x)
        
### make sure, that you replace all test_sl below with the new test_sl_n

In [23]:
### Load our models
## Based on the pytorch version, you may face an error loading the model directly using torch.load,
## therefore added the except section to initiate the model and then populate the paramters from the state dictionary
try:
    mort_model = torch.load('CRWD_Pretrained_Models/CovRNN_iMort_v552.pth')
    vent_model = torch.load('CRWD_Pretrained_Models/CovRNN_mVent_v552.pth')
    plos_model = torch.load('CRWD_Pretrained_Models/CovRNN_pLOS_v552.pth')
    mort_surv_model = torch.load('CRWD_Pretrained_Models/CovRNN_iMort_Surv_v552.pth')
    vent_surv_model = torch.load('CRWD_Pretrained_Models/CovRNN_mVent_Surv_v552.pth')
    
except:
    mort_model = model.EHR_RNN([123641], embed_dim=64, hidden_size=64, n_layers=1, dropout_r=0., cell_type='GRU', bii=False , time=True , surv=True)
    mort_model.load_state_dict(torch.load('CRWD_Pretrained_Models/state_dicts/CovRNN_iMort_v552.st'))

    vent_model = model.EHR_RNN([123641], embed_dim=64, hidden_size=64, n_layers=1, dropout_r=0., cell_type='GRU', bii=False , time=True , surv=True)
    vent_model.load_state_dict(torch.load('CRWD_Pretrained_Models/state_dicts/CovRNN_mVent_v552.st'))

    plos_model = model.EHR_RNN([123641], embed_dim=64, hidden_size=64, n_layers=1, dropout_r=0., cell_type='GRU', bii=False , time=True , surv=True)
    plos_model.load_state_dict(torch.load('CRWD_Pretrained_Models/state_dicts/CovRNN_pLOS_v552.st'))

    mort_surv_model = model.EHR_RNN([123641], embed_dim=64, hidden_size=64, n_layers=1, dropout_r=0., cell_type='GRU', bii=False , time=True , surv=True)
    mort_surv_model.load_state_dict(torch.load('CRWD_Pretrained_Models/state_dicts/CovRNN_iMort_Surv_v552.st'))

    vent_surv_model = model.EHR_RNN([123641], embed_dim=64, hidden_size=64, n_layers=1, dropout_r=0., cell_type='GRU', bii=False , time=True , surv=True)
    vent_surv_model.load_state_dict(torch.load('CRWD_Pretrained_Models/state_dicts/CovRNN_mVent_Surv_v552.st'))


if use_cuda:
    mort_model.cuda()
    vent_model.cuda()
    plos_model.cuda()
    vent_surv_model.cuda()
    mort_surv_model.cuda()

mort_model.eval()
vent_model.eval()
plos_model.eval()
vent_surv_model.eval()
mort_surv_model.eval()

def pt_predictions(test_set):
    with torch.no_grad():
        pt_preds=[]
        for pt in test_set:
            #print(pt)
            pt_id=pt[0]
            pt_ds = EHRDataset([pt],sort= True, model='RNN')
            #print(pt_ds)
            pt_m = list(EHRdataloader(pt_ds, batch_size = 1, packPadMode = True,multilbl=True))
            #print(len(pt_m[0]))
            x1, label,seq_len,time_diff = pt_m[0]
            if use_cuda:
                label=label.cpu().squeeze().numpy()          
                mort_score = mort_model(x1,seq_len,time_diff).cpu().numpy()
                mort_surv_score = mort_surv_model(x1,seq_len,time_diff).cpu().numpy()
                vent_score = vent_model(x1,seq_len,time_diff).cpu().numpy()
                vent_surv_score = vent_surv_model(x1,seq_len,time_diff).cpu().numpy()
                plos_score = plos_model(x1,seq_len,time_diff).cpu().numpy()
            else:  
                label=label.squeeze().numpy()
                mort_score = mort_model(x1,seq_len,time_diff).numpy()
                mort_surv_score = mort_surv_model(x1,seq_len,time_diff).numpy()
                vent_score = vent_model(x1,seq_len,time_diff).numpy()
                vent_surv_score = vent_surv_model(x1,seq_len,time_diff).numpy()
                plos_score = plos_model(x1,seq_len,time_diff).numpy()
            pt_preds.append([pt_id,label[0],label[1],mort_score,mort_surv_score,label[2],label[3],vent_score,vent_surv_score,label[5],plos_score])
    
    pt_preds_df= pd.DataFrame(pt_preds)
    pt_preds_df.columns=['pt','mort_label','mort_tte','mort_prob','mort_logHF','vent_label','vent_tte','vent_prob','vent_logHF','plos_label','plos_prob']
    return pt_preds_df


In [None]:
newData_preds=pt_predictions(test_sl)

In [20]:
newData_preds

Unnamed: 0,pt,mort_label,mort_tte,mort_prob,mort_logHF,vent_label,vent_tte,vent_prob,vent_logHF,plos_label,plos_prob
0,PT551222542,0.0,5.0,0.011483809,-0.07667988,0.0,5.0,0.175104,1.8812087,0.0,0.072981015
1,PT560891258,0.0,5.0,0.17363873,1.1077724,0.0,5.0,0.020506425,-1.9877058,0.0,0.7339609
2,PT347917274,0.0,9.0,0.13791549,-0.7829055,1.0,0.0,0.9708481,5.8927784,1.0,0.5697233
3,PT458192575,0.0,2.0,0.006077234,-0.86824286,0.0,2.0,0.004780006,-1.0253104,0.0,0.014261379
4,PT555833422,0.0,11.0,0.0024757017,-1.2200451,0.0,11.0,0.04967887,1.4017644,1.0,0.43321112
...,...,...,...,...,...,...,...,...,...,...,...
6719,PT262597655,0.0,4.0,0.010645734,0.2021003,0.0,4.0,0.11693959,2.4794743,0.0,0.14157677
6720,PT258212367,0.0,14.0,0.004918831,-0.653102,0.0,14.0,0.0039896052,-2.1123278,1.0,0.4037383
6721,PT299030144,0.0,2.0,0.002901735,-1.5422993,0.0,2.0,0.00012922873,-1.9245168,0.0,0.000733583
6722,PT238528023,0.0,3.0,0.02977879,-0.29111862,0.0,3.0,0.068191856,0.16750655,0.0,0.52785385


In [24]:
newData_preds.to_csv('newData_preds_v1.csv',index=False)
