In [1]:
## Import basic requirements
import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
from tqdm import tqdm
import pickle as pk
import pandas as pd
import random


In [2]:
### Load our EHR model related files
import models as model 
from EHRDataloader import EHRdataFromPickles, EHRdataloader  
import utils as ut 
from EHREmb import EHREmbeddings

In [3]:
## Just using original code for now
def sigmoid(x):
    return (1 / (1 + np.exp(-x)))


In [4]:
## Load our Pretrained model

best_model = torch.load('../models/hf.trainEHRmodel1.pth')


In [5]:
best_model.eval()

EHR_RNN(
  (embed): Embedding(30000, 128, padding_idx=0)
  (rnn_c): GRU(128, 64, batch_first=True, dropout=0.1)
  (out): Linear(in_features=64, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [6]:
best_model.state_dict()

OrderedDict([('embed.weight',
              tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
                        0.0000e+00,  0.0000e+00],
                      [ 3.8867e-01, -5.0994e-01, -4.2380e-01,  ...,  1.4341e-01,
                       -2.1141e-02,  1.8216e-01],
                      [ 2.1339e-01,  1.0675e-01,  6.4853e-01,  ...,  1.6804e-01,
                       -3.0365e-01,  3.8755e-01],
                      ...,
                      [ 2.1458e-10, -6.2775e-02, -6.7921e-42,  ...,  3.4711e-04,
                        7.6896e-01, -7.1290e-39],
                      [ 4.8890e-03, -5.9261e-42, -1.5059e-04,  ..., -3.4045e-04,
                       -2.1071e-03, -1.3142e-36],
                      [ 4.1841e-02, -1.9407e-26, -1.2908e-24,  ..., -2.0760e-02,
                       -7.4471e-02,  1.7137e-03]], device='cuda:0')),
             ('rnn_c.weight_ih_l0',
              tensor([[-0.0331, -0.0380, -0.0410,  ...,  0.1258, -0.0227,  0.0188],
                 

In [7]:
ds = EHRdataFromPickles(root_dir = '../data/', 
                              file = '/dhf_test_60Kb_cid_cscl1.combined.valid', 
                              sort= True,
                              model='RNN')
mbs_list = list(tqdm(EHRdataloader(ds, batch_size=1, packPadMode=True)))


100%|██████████| 12000/12000 [00:19<00:00, 619.12it/s] 


In [8]:
# load the token dict
tk_dc= pk.load(open('../data/dhf_test_60Kb_cid_cscl1.types','rb'))
diag_t= pd.DataFrame.from_dict(tk_dc,orient='index').reset_index()
diag_t.columns=['DIAGNOSIS_ID','TOKEN_ID']
diag=pd.read_csv('../data/HF_D_DIAGNOSIS', sep='|')
diag_tk=pd.merge(diag_t, diag, how='left', on='DIAGNOSIS_ID',sort=True)


In [9]:
diag_tk[diag_tk['DIAGNOSIS_ID']==3]

Unnamed: 0,DIAGNOSIS_ID,TOKEN_ID,DIAGNOSIS_TYPE,DIAGNOSIS_CODE,DIAGNOSIS_DESCRIPTION
0,3,19281,ICD9,1.1,CHOLERA DUE TO VIBRIO CHOLERAE EL TOR


In [13]:
# attribution for the prediction
weights = best_model.rnn_c.state_dict()
_, W_iz, _ = np.split(weights['weight_ih_l0'].cpu().numpy(), 3, 0) ## to get the weights of the middle gate
_, W_hz, _ = np.split(weights['weight_hh_l0'].cpu().numpy(), 3, 0)
_, b_z, _ = np.split(weights['bias_ih_l0'].cpu().numpy() + weights['bias_hh_l0'].cpu().numpy(), 3)
weights_linear = best_model.out.state_dict()
W = weights_linear['weight'].cpu().numpy()
b = weights_linear['bias'].cpu().numpy()


In [14]:
W[0]

array([ 0.26667246, -0.1561053 ,  0.21302247,  0.22679123, -0.21923877,
        0.29989025,  0.10841151, -0.20886146,  0.03460821,  0.2597521 ,
        0.10061447, -0.24237964,  0.15679884,  0.2215598 ,  0.21295518,
       -0.3469017 ,  0.03846961,  0.00176884, -0.13798143, -0.14447302,
        0.2666596 , -0.14771777,  0.2685269 ,  0.13044381, -0.25845376,
        0.2722527 , -0.17379832,  0.25891566,  0.21417032,  0.31167227,
       -0.23342112,  0.2018602 , -0.20814937,  0.20248793, -0.13926002,
        0.29849377,  0.22623911,  0.23675893, -0.1887609 , -0.10474794,
       -0.19818154, -0.30080616, -0.21992539,  0.19540717, -0.20844392,
       -0.22251134,  0.21512419, -0.25633875,  0.18334119,  0.21913089,
        0.24736837, -0.19827954,  0.21285443,  0.21771058, -0.14350338,
        0.08306476,  0.21354948,  0.23521651, -0.2028091 , -0.25210553,
       -0.23640904,  0.1114354 ,  0.07720751,  0.13997132], dtype=float32)

In [18]:
len(mbs_list[122:125])

3

In [30]:
### predictions - Explanation for visit level
use_cuda=True
random.shuffle(mbs_list)
for c,mb in enumerate(mbs_list[:6]):
    print('mb',c)
    x1, label,seq_len,time_diff = mb 
    x_emb4D =best_model.embed(x1)
    print (x_emb4D.shape)
    x_in  = best_model.EmbedPatient_MB(x1,time_diff) 
    print(x_in.shape)
    output, hidden = best_model.rnn_c(x_in) ## hn is the model output
    pred = best_model.sigmoid(best_model.out(hidden[-1])).squeeze()
    #print ('label', label.view(1,-1),'pred',pred)
    print('outshape',output.shape)
    
    x = x_in.cpu().data.numpy()
    hn = output.cpu().data.numpy()
    x_c=x_emb4D.cpu().data.numpy()
    print(hn.shape)
    
    z_pdict=[]
    mb_score_dict=[]
    mb_code_score_dict=[]
    zc_pts_dict=[]
    for pt in range(output.shape[0]):
        z_dict = []
        zc_gdict=[]
        #z_dict.append(np.ones(150)) #### need to ask we have this
        for i in range(output.shape[1]): ### seq lenghth
            zc_dict=[]
            for cd in range(x_c[pt,i,:,:].shape[0]):
                #i = i + 1
                if i==0:
                    #print(type(W_iz), type(x[pt,i,:]),b_z)
                    #print(W_iz.shape,x[pt,i,:].shape,x_emb4D[pt,i,:,:].shape)
                    zc = np.matmul(W_in, x_c[pt,i,c,:])
                    z = sigmoid(np.matmul(W_iz, x[pt,i,:])+ b_z)
                else:
                    zc = np.matmul(W_in, x_c[pt,i,c,:])
                    zc_n = zc.data.norm(dim=-1)### need to use L1 norm or simple sum
                    #zc_nn= zc_n.normalize(dim=0)
                    zc_nn=zc_n/zc_n.sum(dim=0)
                    #zc = sigmoid(np.matmul(W_in, x_c[pt,i,c,:])+ np.matmul(W_hz, hn[pt,i-1,:])+ b_z)
                    z = sigmoid(np.matmul(W_iz, x[pt,i,:]) + np.matmul(W_hz, hn[pt,i-1,:]) + b_z)
                
                zc_dict.append(zc)
            z_dict.append(z)
            zc_gdict.append(zc_dict)
                                 
        alpha_dict = z_dict
        Beta_dict = zc_gdict
        #print(len(alpha_dict))
        #print(alpha_dict[0].shape)
            
        score_dict = []
        code_score_dict=[]
        
        for i in range(len(alpha_dict)):
            if i == 0:
                updating = hn[pt,i,:]
            else:
                updating = hn[pt,i,:] - alpha_dict[i] * hn[pt,i-1,:]
            forgetting = alpha_dict[0]
            for j in range(i+1, len(alpha_dict)):
                forgetting = forgetting*alpha_dict[j]
            score = np.matmul( W[0], updating * forgetting) #+ b[target_class]
            #contribution of each code is
            code_cont=zc_nn*score
            #print(score)
            score_dict.append(score)
            
        #print(len(score_dict))    
        
        
            
        z_pdict.append(z_dict)
        mb_score_dict.append(score_dict) 
        
        for ic,cl in enumerate(Beta_dict):
            vc_score_dict=[]
            for iic in range(len(cl)):
                if ic == 0:
                    updatingc = hn[pt,ic,:]
                else:
                    updatingc = hn[pt,ic,:] - Beta_dict[ic][iic] * hn[pt,i-1,:]
                forgettingc = Beta_dict[ic][0]
                for jc in range(iic+1, len(cl)):
                    forgettingc = forgettingc*cl[jc]
                
                c_score = np.matmul( W[0], updating * forgetting) #+ b[target_class]
                #print(score)
                vc_score_dict.append(c_score)
            code_score_dict.append(vc_score_dict)
        #print(len(score_dict))    
            
        z_pdict.append(z_dict)
        mb_score_dict.append(score_dict) 
        mb_code_score_dict.append(code_score_dict)
        
        

    #print(len(z_pdict), len(z_pdict[-1]))
    #print(z_pdict[0][0].shape)
    #print (len(mb_score_dict),x1.shape,x1[-1],x_in[-1],mb_score_dict[-1])
    print('Patient 1','label: ',label.squeeze().cpu(), '\npred_score :' , pred.cpu().data)
          #,'\npatient visits: ' ,(x1[1].cpu().data.numpy()).tolist(),
    print('\nExplanation score per visit: ',mb_score_dict[0])
    print('\npatient visits: ')
    for g,l in enumerate((x1[0].cpu().data.numpy()).tolist()):
        dls=[diag_tk[diag_tk['TOKEN_ID'].isin(l)]['TOKEN_ID'].tolist(), diag_tk[diag_tk['TOKEN_ID'].isin(l)]['DIAGNOSIS_DESCRIPTION'].tolist()]
        print('Visit:',g+1,' with Explanation Score:',mb_score_dict[0][g],l,dls,'Explanation scores for codes:',mb_code_score_dict[0][g],'\n')
        

mb 0
torch.Size([1, 13, 32, 128])
torch.Size([1, 13, 128])
outshape torch.Size([1, 13, 64])
(1, 13, 64)
Patient 1 label:  tensor(0.) 
pred_score : tensor(0.3950)

Explanation score per visit:  [0.010629646, 0.0042384034, 0.009414, 0.002506926, -0.00996822, -0.0021500927, -0.028929096, 0.020672511, 0.061526824, -0.18433881, 0.20614833, 0.21537498, 0.08338158]

patient visits: 
Visit: 1  with Explanation Score: 0.010629646 [335, 84, 86, 8, 1880, 233, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [[7, 8, 335, 84, 86, 233, 1880], ['BACKACHE, UNSPECIFIED', 'OTHER AND UNSPECIFIED HYPERLIPIDEMIA', 'Panic disorder without agoraphobia', 'PAIN IN JOINT INVOLVING PELVIC REGION AND THIGH', 'PAIN IN LIMB', 'ABDOMINAL PAIN, UNSPECIFIED SITE', 'PERSONAL HISTORY OF UNSPECIFIED CIRCULATORY DISEASE']] Explanation scores for codes: [0.08338158, 0.08338158, 0.08338158, 0.08338158, 0.08338158, 0.08338158, 0.08338158, 0.08338158, 0.08338158, 0.08338158, 0.08338158, 0.08338158

Visit: 15  with Explanation Score: 3.1881696e-05 [5, 17, 185, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [[5, 17, 185], ['Diabetes mellitus without mention of complication, type II or unspecified type, not stated as uncontrolled', 'BENIGN ESSENTIAL HYPERTENSION', 'Long-Term (Current) Use of Anticoagulants']] Explanation scores for codes: [0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935] 

Visit: 16  with Explanation Score: -2.321725e-05 [2231, 540, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [[2231, 540], ['UNSPECIFIED OSTEOMYELITIS, SITE UNSPECIFIED', 'Solitary Pulmonary Nodule']] Explanation scores for codes: [0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.11376935, 0.1137693