### Script to convert Transformer outputs to desired results

In [2]:
import pickle
import os
import numpy as np
import torch
import pandas as pd
import glob

In [4]:
fps = '../data_d/'
vocab_fp = '../data_d/pos_vocab_last30_non3digit'
# list        #list    # lst of np              # list
discharge_id, labels, events, importance_score, probability_score = [], [], [], [], []
for fp in glob.glob(os.path.join(fps, 'final_importance*')):
    print(fp)
    print('='*20)
    cur_discharge_id, cur_labels, cur_events, cur_importance_score, cur_probability_score = torch.load(fp)
    
    discharge_id = discharge_id + cur_discharge_id
    labels = labels + cur_labels
    probability_score = probability_score + cur_probability_score
    events = events + cur_events
    importance_score = importance_score + cur_importance_score
    
    print('cur events: ', len(events))
    print('cur score: ', len(importance_score))
    
    
vocab = torch.load(vocab_fp)                                                    

../data_d/final_importance_results_kfold-2
cur events:  312446
cur score:  312446
../data_d/final_importance_results_kfold-4
cur events:  624894
cur score:  624894
../data_d/final_importance_results_kfold-3
cur events:  937342
cur score:  937342
../data_d/final_importance_results_kfold-1
cur events:  1249788
cur score:  1249788
../data_d/final_importance_results_kfold0
cur events:  1562234
cur score:  1562234


In [31]:
# not used, single file
fp = './explain/explainablity'
vocab_fp = './explain/explain_xyz_vocab'
discharge_id, labels, events, importance_score, probability_score = torch.load(fp)
vocab = torch.load(vocab_fp)     

In [5]:
len(events)

1562234

In [6]:
# create df of discharge_id, labels, probability score
probs_df = pd.DataFrame()
probs_df['discharge_id'] = discharge_id
probs_df['target'] = labels
probs_df['predict_probs'] = probability_score

# create dataframe for importance scores & events
imps_df = pd.DataFrame(importance_score)
imps_df['discharge_id'] = discharge_id

events_df = pd.DataFrame(events)
events_df['discharge_id'] = discharge_id

In [7]:
def grab_top_ids(probs_df, upper_q, lower_q, topn):
    print('='*20)
    print('topq:', upper_q)
    print('lowq:', lower_q)
    sub_df = probs_df[(probs_df.predict_probs > probs_df.predict_probs.quantile(lower_q)) &
                      (probs_df.predict_probs <= probs_df.predict_probs.quantile(upper_q))]
    
    print('sub_df:', sub_df.shape)
    
    sub_df.sort_values('predict_probs', ascending=False, inplace=True)
    
    print('returning:', topn)
    return sub_df.discharge_id.values.tolist()[:topn]

In [8]:
def get_top_10_feat(events_df, imps_df, discharge_ids, vocab):
    
    output = pd.DataFrame()
    for did in discharge_ids:
        # create row info
        cur_lst = []
        
        # grab only the top 10 most important features
        cur_imps = imps_df[imps_df.discharge_id == did]
        cur_imps = cur_imps.drop(columns='discharge_id').T
        cur_imps.sort_values(cur_imps.columns.tolist(), ascending=False, inplace=True)
        
        n_feat = 10
        top_idx = cur_imps.index.tolist()[:n_feat]
        
        # find the indxes of the top events
        cur_evnts = events_df.loc[events_df.discharge_id == did, top_idx]
        cur_evnts = cur_evnts.T
        
        # create row
        for idx, evnt, imps in zip(top_idx, cur_evnts.values.tolist(), cur_imps.values.tolist()):
            cur_lst.append(str(vocab.itos[evnt[0]]) + '-' + str(idx))
            cur_lst.append(imps[0])
        
            # missing look up table
        
        output[did] = cur_lst
    output = output.T
    output.index.name = 'discharge_id'
    output.columns = create_colname(n_feat)
    output.reset_index(inplace=True)
        
    return output

In [9]:
def create_colname(n):
    lst = [None] * 2 * n
    for idx in range(n):
        lst[2 * idx] = 'event_'+str(idx+1)
        lst[2 * idx + 1] = 'score_'+str(idx+1)
    return lst

In [10]:
def merge_data(output_df, probs_df):
    
    return pd.merge(output_df, probs_df, how='left', on='discharge_id')

In [11]:
# quantiles that were requested by Merck
buckets = [1.0, 0.995, 0.99, 0.95, 0.9, 0.8]

In [12]:
# create all the files
fdir = './explain/readmit_transf_scores/'    
if not os.path.isdir(fdir):
    os.makedirs(fdir)

for idx in range(len(buckets)-1):
    top_p = grab_top_ids(probs_df, upper_q=buckets[idx], lower_q=buckets[idx+1], topn=20)
    output = get_top_10_feat(events_df, imps_df, top_p, vocab)
    output = merge_data(output, probs_df)
    output['patient_id'] = [x.split('_')[0] for x in output['discharge_id']]
    
    fname = 'output_scores_' + str(buckets[idx+1]).replace('.', 'p') + str('.csv')
      
    output.to_csv(os.path.join(fdir, fname), index=False)

topq: 1.0
lowq: 0.995
sub_df: (7812, 3)


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy


returning: 20
topq: 0.995
lowq: 0.99
sub_df: (7811, 3)


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy


returning: 20
topq: 0.99
lowq: 0.95
sub_df: (62489, 3)


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy


returning: 20
topq: 0.95
lowq: 0.9
sub_df: (78112, 3)


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy


returning: 20
topq: 0.9
lowq: 0.8
sub_df: (156222, 3)


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy


returning: 20


### Save to file

## Manually check the results

In [13]:
imps_df.describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,890,891,892,893,894,895,896,897,898,899
count,1562234.0,1562234.0,1562234.0,1562234.0,1562234.0,1562234.0,1562234.0,1562234.0,1562234.0,1562234.0,...,1562234.0,1562234.0,1562234.0,1562234.0,1562234.0,1562234.0,1562234.0,1562234.0,1562234.0,1562234.0
mean,0.01171669,0.001697569,0.001377525,0.00111548,0.0008069283,0.0005158363,0.0002963281,0.0002274258,0.0002191466,0.0001944944,...,2.145172e-05,2.031677e-05,2.317919e-05,2.370735e-05,1.419533e-05,7.666452e-06,5.301879e-06,4.986066e-06,6.491957e-06,6.991463e-06
std,0.009003308,0.00711011,0.00704792,0.005776651,0.00566147,0.003846884,0.002361862,0.002152922,0.002451117,0.002292114,...,0.0006778493,0.0008249431,0.001412193,0.001977498,0.001035664,0.0004608879,0.0003079596,0.0004378172,0.0007299453,0.0009924845
min,5.1807e-07,8.626043e-14,9.361437e-14,1.064761e-13,1.320664e-13,1.236531e-13,1.039176e-13,9.47408e-14,9.492592e-14,1.278436e-13,...,1.077745e-12,2.844391e-12,2.59829e-10,1.951892e-11,1.370525e-11,1.280979e-11,2.085549e-12,3.327562e-12,2.509326e-10,5.983711e-11
25%,0.005074821,3.115043e-09,3.129346e-09,3.097165e-09,3.080108e-09,2.920022e-09,2.762159e-09,2.764368e-09,2.734919e-09,2.717269e-09,...,2.43352e-09,2.558891e-09,2.732807e-09,2.834597e-09,2.747727e-09,2.526277e-09,2.40439e-09,2.504893e-09,2.753046e-09,2.929693e-09
50%,0.01092988,6.194067e-09,6.661266e-09,6.844282e-09,6.666648e-09,6.04627e-09,5.529307e-09,5.524569e-09,5.574388e-09,5.680786e-09,...,4.901509e-09,5.395248e-09,5.809901e-09,5.822421e-09,5.321044e-09,5.044039e-09,4.86749e-09,5.255375e-09,5.870009e-09,6.128454e-09
75%,0.01656035,1.761015e-08,1.552025e-08,1.602264e-08,1.503957e-08,1.402882e-08,1.331843e-08,1.250975e-08,1.289801e-08,1.310003e-08,...,1.176109e-08,1.258562e-08,1.275678e-08,1.232843e-08,1.214483e-08,1.198075e-08,1.141178e-08,1.192351e-08,1.248436e-08,1.242924e-08
max,0.5095705,0.5330237,0.6004187,0.5764338,0.8949229,0.6368256,0.3572277,0.2566267,0.2492524,0.5565163,...,0.1563981,0.3598569,0.8158737,0.9621294,0.5234054,0.1908427,0.1263719,0.3957086,0.6964054,0.652948


In [18]:
df = pd.read_csv('./explain/readmit_transf_scores/output_scores_0p95.csv')
print(df.shape)
df.head(20)

(20, 24)


Unnamed: 0,discharge_id,event_1,score_1,event_2,score_2,event_3,score_3,event_4,score_4,event_5,...,score_7,event_8,score_8,event_9,score_9,event_10,score_10,target,predict_probs,patient_id
0,496859785_20100402,admission-870,0.014749,h_A0425-694,0.014657,h_84100-571,0.01451,h_93010-722,0.014355,h_93010-692,...,0.013955,nan-360,0.013955,nan-90,0.013955,nan-120,0.013955,1,0.746186,496859785
1,112962701_20090427,d_4280-752,0.010742,d_4280-542,0.010742,h_99214-485,0.010554,p_3491-582,0.010166,h_99253-604,...,0.009867,h_99285-453,0.009859,d_5119-754,0.009822,h_88305-784,0.009757,0,0.746184,112962701
2,124045089_20100723,h_A0425-873,0.013896,d_51883-880,0.013491,d_5119-183,0.013096,d_514-333,0.013021,h_A0428-875,...,0.012668,d_5184-184,0.01259,h_72193-155,0.012545,h_93010-872,0.012525,1,0.746184,124045089
3,176385353_20100819,h_36248-697,0.010036,d_5859-751,0.009951,h_90935-876,0.009924,h_A0425-484,0.009913,h_90935-811,...,0.009833,d_5856-510,0.009833,d_5856-600,0.009833,d_5856-810,0.009833,0,0.746183,176385353
4,162753137_20110302,h_78815-1,0.006707,d_5849-752,0.006681,d_5119-210,0.006615,d_1629-576,0.006613,d_1629-456,...,0.006596,d_42731-602,0.006596,h_A0425-424,0.006595,d_78605-692,0.006583,1,0.746183,162753137
5,190510843_20100902,d_3101-590,0.006342,h_99285-271,0.006337,h_88305-340,0.006286,d_4280-692,0.006282,d_4280-602,...,0.006282,h_A0425-512,0.006255,h_A0425-872,0.006255,h_93010-663,0.006219,0,0.746179,190510843
6,141398025_20110729,admission-572,0.061766,admission-271,0.044031,d_78605-333,0.043099,d_5849-273,0.039772,d_28860-362,...,0.026717,d_28860-423,0.026717,d_28860-303,0.026717,d_28860-453,0.026717,0,0.746177,141398025
7,168770701_20111028,d_496-32,0.009095,d_496-66,0.009,h_A0425-214,0.008996,h_A0425-874,0.008996,d_5849-93,...,0.008908,d_5849-65,0.008883,d_5859-336,0.008848,d_496-185,0.008838,0,0.746175,168770701
8,100583249_20110902,h_A0428-873,0.007516,d_85220-753,0.007305,d_85220-153,0.007305,d_85220-423,0.007305,d_85220-483,...,0.007305,d_85220-573,0.007305,d_85220-604,0.007293,h_99253-513,0.007229,0,0.746174,100583249
9,140861573_20100326,p_D1C-188,0.128438,d_4280-182,0.114556,d_4280-32,0.114556,d_78605-92,0.072922,admission-151,...,0.034055,h_90935-392,0.029706,h_90935-842,0.029706,d_5859-393,0.02591,0,0.746173,140861573


In [19]:
df[df.target == 1]

Unnamed: 0,discharge_id,event_1,score_1,event_2,score_2,event_3,score_3,event_4,score_4,event_5,...,score_7,event_8,score_8,event_9,score_9,event_10,score_10,target,predict_probs,patient_id
0,496859785_20100402,admission-870,0.014749,h_A0425-694,0.014657,h_84100-571,0.01451,h_93010-722,0.014355,h_93010-692,...,0.013955,nan-360,0.013955,nan-90,0.013955,nan-120,0.013955,1,0.746186,496859785
2,124045089_20100723,h_A0425-873,0.013896,d_51883-880,0.013491,d_5119-183,0.013096,d_514-333,0.013021,h_A0428-875,...,0.012668,d_5184-184,0.01259,h_72193-155,0.012545,h_93010-872,0.012525,1,0.746184,124045089
4,162753137_20110302,h_78815-1,0.006707,d_5849-752,0.006681,d_5119-210,0.006615,d_1629-576,0.006613,d_1629-456,...,0.006596,d_42731-602,0.006596,h_A0425-424,0.006595,d_78605-692,0.006583,1,0.746183,162753137
10,495260597_20110908,h_93010-601,0.013457,h_93010-841,0.013457,h_93010-452,0.012628,h_93010-662,0.012628,h_G0164-180,...,0.011979,d_5856-781,0.011566,d_5856-751,0.011566,d_5856-692,0.0112,1,0.746173,495260597
11,153930863_20100817,p_8102-499,0.011209,h_A0425-452,0.010298,d_5119-875,0.010168,h_70450-485,0.010108,d_5119-601,...,0.009884,h_A0427-455,0.009796,h_74000-571,0.009777,d_78550-516,0.009684,1,0.746172,153930863
13,121612201_20110113,d_4280-724,0.011444,d_4280-784,0.011444,h_99284-880,0.010335,h_99285-574,0.009714,discharge-874,...,0.009558,h_99214-421,0.009543,h_83880-899,0.009532,h_83874-653,0.009438,1,0.746169,121612201


In [20]:
probs_df.predict_probs.quantile(0.995)

0.7623846456408501

In [21]:
probs_df.sort_values('predict_probs', ascending=False).head(50)

Unnamed: 0,discharge_id,target,predict_probs
1146073,481806887_20091017,0,0.933195
1101725,475336879_20111128,0,0.926729
1214250,488105847_20111114,0,0.925498
1135606,483667883_20101209,0,0.923785
1086607,470556037_20111001,1,0.920088
1197787,485027101_20101012,0,0.91608
1391207,495047599_20090413,0,0.910623
173485,171128485_20100525,0,0.910209
975226,190168551_20110215,0,0.902141
1366345,493257751_20101115,0,0.900174


In [53]:
events_df.head(30)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,891,892,893,894,895,896,897,898,899,discharge_id
0,2,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,100000559_20111006
1,2,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,100003229_20100510
2,2,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,100005357_20111012
3,23,68,78,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,100007025_20101217
4,2,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,100007801_20111003
5,9,422,809,2140,2402,4131,0,0,0,0,...,0,0,0,0,0,0,0,0,0,100008869_20101116
6,2,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,100009907_20090702
7,2,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,100010875_20100309
8,44,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,100013023_20090926
9,2,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,100013179_20100831


In [52]:
imps_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,891,892,893,894,895,896,897,898,899,discharge_id
0,0.000627,0.000978,0.001131,0.001083,0.000963,0.000871,0.000796,0.000828,0.001016,0.001055,...,0.001264,0.001195,0.001064,0.001015,0.001075,0.001242,0.001385,0.001357,0.001211,100000559_20111006
1,0.00065,0.000992,0.001153,0.001105,0.000989,0.000897,0.000817,0.000846,0.001055,0.001083,...,0.001342,0.001242,0.001093,0.001044,0.001118,0.001321,0.001472,0.001422,0.001246,100003229_20100510
2,0.0007,0.001005,0.001164,0.001168,0.001062,0.000965,0.000874,0.000853,0.001021,0.001112,...,0.001297,0.001258,0.001149,0.001099,0.001136,0.001279,0.001428,0.001423,0.0013,100005357_20111012
3,0.004046,0.00423,0.002707,0.001005,0.000859,0.000773,0.000718,0.000843,0.001063,0.001036,...,0.001253,0.001174,0.001001,0.00095,0.001044,0.001216,0.001343,0.001325,0.001151,100007025_20101217
4,0.000684,0.000992,0.001154,0.001146,0.00104,0.000945,0.000854,0.000844,0.001029,0.001101,...,0.001329,0.001257,0.001131,0.00108,0.00113,0.001309,0.001466,0.001431,0.001284,100007801_20111003


In [17]:
len(probability_score)

312300