In [1]:
import numpy as np
import os
import pickle
import random
import math
import copy

import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils import data
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn import Parameter
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

from utils import utils
from utils.readers import InHospitalMortalityReader
from utils.preprocessing import Discretizer, Normalizer
from utils import metrics
from utils import common_utils

import sklearn
from sklearn.cluster import KMeans
from sklearn.cluster import SpectralClustering

In [2]:
data_path = './data/'

small_part = False
arg_timestep = 1.0
batch_size = 256

In [3]:
# Build readers, discretizers, normalizers
train_reader = InHospitalMortalityReader(dataset_dir=os.path.join(data_path, 'train'),
                                         listfile=os.path.join(data_path, 'train_listfile.csv'),
                                         period_length=48.0)

val_reader = InHospitalMortalityReader(dataset_dir=os.path.join(data_path, 'train'),
                                       listfile=os.path.join(data_path, 'val_listfile.csv'),
                                       period_length=48.0)

discretizer = Discretizer(timestep=arg_timestep,
                          store_masks=True,
                          impute_strategy='previous',
                          start_time='zero')

In [4]:
discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
normalizer_state = 'ihm_normalizer'
normalizer_state = os.path.join(os.path.dirname(data_path), normalizer_state)
normalizer.load_params(normalizer_state)

In [5]:
n_trained_chunks = 0
train_raw = utils.load_data(train_reader, discretizer, normalizer, small_part, return_names=True)
val_raw = utils.load_data(val_reader, discretizer, normalizer, small_part, return_names=True)

(14681, 48, 76)
(3222, 48, 76)


In [6]:
demographic_data = []
diagnosis_data = []
idx_list = []
ethnicity_types = []
gender_types = []

demo_path = data_path + 'demographic/'
for cur_name in os.listdir(demo_path):
    cur_id, cur_episode = cur_name.split('_', 1)
    cur_episode = cur_episode[:-4]
    cur_file = demo_path + cur_name

    with open(cur_file, "r") as tsfile:
        header = tsfile.readline().strip().split(',')
        if header[0] != "Icustay":
            continue
        # print(header)
        cur_data = tsfile.readline().strip().split(',')
        
    if len(cur_data) == 1:
        cur_demo = np.zeros(10)
        cur_diag = np.zeros(128)
    else:
        if cur_data[1] not in ethnicity_types:
            ethnicity_types.append(cur_data[1])
        if cur_data[2] not in gender_types:
            gender_types.append(cur_data[2])
        if cur_data[3] == '':
            cur_data[3] = 60.0               # age
        if cur_data[4] == '':
            cur_data[4] = 160                # height
        if cur_data[5] == '':
            cur_data[5] = 60                 # weight

        cur_demo = np.zeros(10)
        cur_demo[int(cur_data[1])] = 1           #ethnicity -- 0-4 ('0','1','2','3','4') 
        cur_demo[5 + int(cur_data[2]) - 1] = 1   #gender    -- 5-6 ('1','2')
        cur_demo[7:] = cur_data[3:6]             #7-9: age/height/weight
        cur_diag = np.array(cur_data[8:], dtype=np.int)

    demographic_data.append(cur_demo)
    diagnosis_data.append(cur_diag)
    idx_list.append(cur_id+'_'+cur_episode)

print(ethnicity_types)
print(gender_types)

for each_idx in range(7,10):
    cur_val = []
    for i in range(len(demographic_data)):
        cur_val.append(demographic_data[i][each_idx])
    cur_val = np.array(cur_val)
    _mean = np.mean(cur_val)
    _std = np.std(cur_val)
    _std = _std if _std > 1e-7 else 1e-7
    for i in range(len(demographic_data)):
        demographic_data[i][each_idx] = (demographic_data[i][each_idx] - _mean) / _std

['4', '1', '3', '0', '2']
['2', '1']


In [32]:
device = torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu')
#device = torch.device('cpu')
print("available device: {}".format(device))

available device: cuda:0


### Model & Functions

In [8]:
class Dataset(data.Dataset):
    def __init__(self, x, y, name):
        self.x = x
        self.y = y
        self.name = name

    def __getitem__(self, index):#返回的是tensor
        return self.x[index], self.y[index], self.name[index]

    def __len__(self):
        return len(self.x)
    
train_dataset = Dataset(train_raw['data'][0], train_raw['data'][1], train_raw['names'])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataset = Dataset(val_raw['data'][0], val_raw['data'][1], val_raw['names'])
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [9]:
def get_loss(y_pred, y_true):
    loss = torch.nn.BCELoss()
    return loss(y_pred, y_true)

In [38]:
class FinalAttentionQKV(nn.Module):
    def __init__(self, attention_input_dim, attention_hidden_dim, attention_type='add', dropout=None):
        super(FinalAttentionQKV, self).__init__()
        
        self.attention_type = attention_type
        self.attention_hidden_dim = attention_hidden_dim
        self.attention_input_dim = attention_input_dim


        self.W_q = nn.Linear(attention_input_dim, attention_hidden_dim)
        self.W_k = nn.Linear(attention_input_dim, attention_hidden_dim)
        self.W_v = nn.Linear(attention_input_dim, attention_hidden_dim)

        self.W_out = nn.Linear(attention_hidden_dim, 1)

        self.b_in = nn.Parameter(torch.zeros(1,))
        self.b_out = nn.Parameter(torch.zeros(1,))

        nn.init.kaiming_uniform_(self.W_q.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_k.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_v.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_out.weight, a=math.sqrt(5))

        self.Wh = nn.Parameter(torch.randn(2*attention_input_dim, attention_hidden_dim))
        self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
        self.ba = nn.Parameter(torch.zeros(1,))
        
        nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        
        self.dropout = nn.Dropout(p=dropout)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, input):
 
        batch_size, time_step, input_dim = input.size() # batch_size * input_dim + 1 * hidden_dim(i)
        input_q = self.W_q(input[:,-1,:]) # b h
        input_k = self.W_k(input)# b t h
        input_v = self.W_v(input)# b t h

        if self.attention_type == 'add': #B*T*I  @ H*I

            q = torch.reshape(input_q, (batch_size, 1, self.attention_hidden_dim)) #B*1*H
            h = q + input_k + self.b_in # b t h
            h = self.tanh(h) #B*T*H
            e = self.W_out(h) # b t 1
            e = torch.reshape(e, (batch_size, time_step))# b t

        elif self.attention_type == 'mul':
            q = torch.reshape(input_q, (batch_size, self.attention_hidden_dim, 1)) #B*h 1
            e = torch.matmul(input_k, q).squeeze(-1)#b t
            
        elif self.attention_type == 'concat':
            q = input_q.unsqueeze(1).repeat(1,time_step,1)# b t h
            k = input_k
            c = torch.cat((q, k), dim=-1) #B*T*2I
            h = torch.matmul(c, self.Wh)
            h = self.tanh(h)
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
            e = torch.reshape(e, (batch_size, time_step)) # b t 
        
        a = self.softmax(e) #B*T
        if self.dropout is not None:
            a = self.dropout(a)
        v = torch.matmul(a.unsqueeze(1), input_v).squeeze() #B*I

        return v, a


def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])



class SAFARI(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_clu, output_dim, keep_prob=0.5):
        super(SAFARI, self).__init__()

        # hyperparameters
        self.input_dim = input_dim  
        self.hidden_dim = hidden_dim  # d_model
        self.output_dim = output_dim
        self.keep_prob = keep_prob
        self.n_clu = n_clu
        self.dim_list = [2, 1, 1, 8, 12, 13, 12, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
        
        self.GRUs = nn.ModuleList()
        for i in self.dim_list:
            self.GRUs.append(nn.GRU(i+1, self.hidden_dim, batch_first=True))
        self.feature_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
        
        self.FinalAttentionQKV = FinalAttentionQKV(self.hidden_dim, self.hidden_dim, attention_type='mul',dropout = 1 - self.keep_prob)

        self.GCN_W1 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.GCN_W2 = nn.Linear(self.hidden_dim, self.hidden_dim)

#         self.demo_proj_main = nn.Linear(12, self.hidden_dim)
        self.demo_proj = nn.Linear(10, self.hidden_dim)
        self.output0 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.output1 = nn.Linear(self.hidden_dim, self.output_dim)

        self.dropout = nn.Dropout(p = 1 - self.keep_prob)
        self.tanh=nn.Tanh()
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
        self.relu=nn.ReLU()
        self.elu=nn.ELU()

    def forward(self, input, static, adj_mat):
        
        batch_size = input.size(0)
        time_step = input.size(1)
        feature_dim = len(self.dim_list)
        
        assert(feature_dim == self.input_dim)# input Tensor : 256 * 48 * 76
   
        GRU_embeded_input = None

        # MIMIC features include multi-dimensional one-hot variables, and need to separate them into different channels of encoders
        start_pos = 0
        for i in range(feature_dim):
            mask_pos = 59 + i
            tmp_input = torch.cat((input[:,:,start_pos:start_pos+self.dim_list[i]], input[:,:,mask_pos].unsqueeze(-1)), dim=-1)
            start_pos += self.dim_list[i]
            hidden_0 = Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device)
            embeded_input = self.GRUs[i](tmp_input, hidden_0)[1].squeeze().unsqueeze(1)
            embeded_input = self.feature_proj(embeded_input)
            if GRU_embeded_input is None:
                GRU_embeded_input = embeded_input
            else:
                GRU_embeded_input = torch.cat((GRU_embeded_input, embeded_input), 1)

        static_emb = self.feature_proj(self.relu(self.demo_proj(static))).unsqueeze(1)
        GRU_embeded_input = torch.cat((GRU_embeded_input, static_emb), dim=1)
        posi_input = self.dropout(GRU_embeded_input) # batch_size * d_input * hidden_dim

        contexts = posi_input
        
        clu_context = None
        gcn_hidden = None
        gcn_contexts = None
        #Graph Conv
        if gcn_hidden is None:
            gcn_hidden = self.relu(self.GCN_W1(torch.matmul(adj_mat, contexts)))
        if gcn_contexts is None:
            gcn_contexts = self.relu(self.GCN_W2(torch.matmul(adj_mat, gcn_hidden)))
        

        clu_context = gcn_contexts[:,:,:]

        
        weighted_contexts = self.FinalAttentionQKV(clu_context)[0]
        output = self.relu(self.output0(self.dropout(weighted_contexts)))
        output = self.output1(self.dropout(output))# b 1
        output = self.sigmoid(output)
#         print(weighted_contexts.shape)
          
        return output, weighted_contexts, GRU_embeded_input




In [25]:
# Multi-Relational Graph Update, returns a adjacency matrix and Clustering Info
def GraphUpdate(sim_metric, feature_emb, input_dim, n_clu, feat2clu=None):
    adj_mat = torch.zeros(input_dim+1, input_dim+1).to(device)
    eps = 1e-7
    #print(feature_emb.size())

    if sim_metric == 'euclidean':
        feature_mean_emb = [None for i in range(input_dim)]
        for i in range(input_dim):
            feature_mean_emb[i] = torch.mean(feature_emb[:,i,:].squeeze(), dim=0).cpu().numpy()
        feature_mean_emb = np.array(feature_mean_emb)
        #print(feature_mean_emb.shape)
        
        if feat2clu is None:
            kmeans = KMeans(n_clusters=n_clu, init='random', n_init=2).fit(feature_mean_emb)
            feat2clu = kmeans.labels_
        
        clu2feat = [[] for i in range(n_clu)]
        for i in range(input_dim):
            clu2feat[feat2clu[i]].append(i)

        for clu_id, cur_clu in enumerate(clu2feat):
            for i in cur_clu:
                for j in cur_clu:
                    if i != j:
                        cos_sim = np.dot(feature_mean_emb[i], feature_mean_emb[j])
                        cos_sim = cos_sim / max(eps, float(np.linalg.norm(feature_mean_emb[i]) * np.linalg.norm(feature_mean_emb[j])))
                        adj_mat[i][j] = torch.tensor(cos_sim).to(device)


    elif 'kernel' in sim_metric:
        kernel_mat = torch.zeros((input_dim, input_dim)).to(device)
        sigma = 0
        for i in range(input_dim):
            for j in range(input_dim):
                if sim_metric == 'rbf_kernel':
                    sample_dist = F.pairwise_distance(feature_emb[:,i,:], feature_emb[:,j,:], p=2)
                if sim_metric == 'laplacian_kernel':
                    sample_dist = F.pairwise_distance(feature_emb[:,i,:], feature_emb[:,j,:], p=1)
                sigma += torch.mean(sample_dist)
        
        sigma = sigma / (input_dim * input_dim)
        #sigma = feature_emb.size(-1)
    
        for i in range(input_dim):
            for j in range(input_dim):
                if sim_metric == 'rbf_kernel':
                    sample_dist = F.pairwise_distance(feature_emb[:,i,:], feature_emb[:,j,:], p=2)
                    kernel_mat[i, j] = torch.mean(torch.exp(-(sample_dist * sample_dist) / (2 * (sigma**2))))
                elif sim_metric == 'laplacian_kernel':
                    sample_dist = F.pairwise_distance(feature_emb[:,i,:], feature_emb[:,j,:], p=1)
                    kernel_mat[i, j] = torch.mean(torch.exp(-sample_dist / sigma))
        #print(kernel_mat)
        aff_mat = np.array(kernel_mat.cpu().detach().numpy())
        #print(aff_mat)
        
        if feat2clu is None:
            kmeans = SpectralClustering(n_clusters=n_clu, affinity='precomputed', n_init=20).fit(aff_mat)
            feat2clu = kmeans.labels_
        
        clu2feat = [[] for i in range(n_clu)]
        for i in range(input_dim):
            clu2feat[feat2clu[i]].append(i)

        for clu_id, cur_clu in enumerate(clu2feat):
            for i in cur_clu:
                for j in cur_clu:
                    if i != j:
                        adj_mat[i][j] = torch.tensor(aff_mat[i][j]).to(device)


    for i in range(input_dim + 1):
        adj_mat[i][i] = 1

    for i in range(input_dim):
        adj_mat[i][input_dim] = 1
        adj_mat[input_dim][i] = 1

    
    return adj_mat, feat2clu, clu2feat

### Run for training

In [39]:
RANDOM_SEED = 3407
np.random.seed(RANDOM_SEED) #numpy
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED) # cpu
torch.cuda.manual_seed(RANDOM_SEED) #gpu
torch.backends.cudnn.deterministic=True # cudnn
    
epochs = 180
input_dim = 17
hidden_dim = 32
output_dim = 1
n_clu = 5

model = SAFARI(input_dim = input_dim, hidden_dim = hidden_dim, n_clu=n_clu, output_dim = output_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50], gamma=0.2)

In [None]:
file_name = './model/MIMIC_SAFARI'

batch_size = 256
epochs = 150
cluster_epochs = 0.2 * epochs
pad_token = np.zeros(input_dim)
max_roc = 0
max_prc = 0
min_loss = 999
sim_metric = 'laplacian_kernel'
train_loss = []
train_model_loss = []
valid_loss = []
valid_model_loss = []
history = []
np.set_printoptions(threshold=np.inf)
np.set_printoptions(precision=2)
np.set_printoptions(suppress=True)

feat2clu = np.random.randint(0, n_clu, size=input_dim)
clu2feat = [[] for i in range(n_clu)]
for i in range(input_dim):
    clu2feat[feat2clu[i]].append(i)

#Graph Init
adj_mat = torch.zeros(input_dim+1, input_dim+1).to(device)
for clu_id, cur_clu in enumerate(clu2feat):
    for i in cur_clu:
        for j in cur_clu:
            if i != j:
                adj_mat[i][j] = 1

for i in range(input_dim + 1):
    adj_mat[i][i] = 1
    
for i in range(input_dim):
    adj_mat[i][input_dim] = 1
    adj_mat[input_dim][i] = 1

    

for each_epoch in range(epochs):
    batch_loss = []
    model_batch_loss = []
    epoch_loss = []
    print('Current cluster:', clu2feat)
    
    model.train() 
    
    for step, (batch_x, batch_y, batch_name) in enumerate(train_loader):  
        optimizer.zero_grad()
        batch_x = batch_x.float().to(device)
        batch_y = batch_y.float().to(device)

        batch_demo = []
        for i in range(len(batch_name)):
            cur_id, cur_ep, _ = batch_name[i].split('_', 2)
            cur_idx = cur_id + '_' + cur_ep
            cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)], dtype=torch.float32)
            batch_demo.append(cur_demo)
        
        batch_demo = torch.stack(batch_demo).to(device)

        opt, emb, contexts = model(batch_x, batch_demo, adj_mat)
        
        #FL_Loss = get_FL_loss(opt, batch_y.unsqueeze(-1), alpha=0.25, gamma=2)
        FL_Loss = get_loss(opt, batch_y.unsqueeze(-1))
        model_loss =  100*FL_Loss
        loss = model_loss #+ 10000 * decov_loss
        
        batch_loss.append(loss.cpu().detach().numpy())
        model_batch_loss.append(model_loss.cpu().detach().numpy())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 20)
        optimizer.step()
        
        if step % 20 == 0:
            print('Epoch %d Batch %d: Train Loss = %.4f'%(each_epoch, step, np.mean(np.array(batch_loss))))
    train_loss.append(np.mean(np.array(batch_loss)))
    train_model_loss.append(np.mean(np.array(model_batch_loss)))


    batch_loss = []
    model_batch_loss = []
    
    
    y_true = []
    y_pred = []
    with torch.no_grad():
        model.eval()
        for batch_x, batch_y, batch_name in valid_loader:
            batch_x = batch_x.float().to(device)
            batch_y = batch_y.float().to(device)

            batch_demo = []
            for i in range(len(batch_name)):
                cur_id, cur_ep, _ = batch_name[i].split('_', 2)
                cur_idx = cur_id + '_' + cur_ep
                cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)], dtype=torch.float32)
                batch_demo.append(cur_demo)

            batch_demo = torch.stack(batch_demo).to(device)

            opt, emb, contexts = model(batch_x, batch_demo, adj_mat)

            #FL_Loss = get_FL_loss(opt, batch_y.unsqueeze(-1), alpha=0.25, gamma=2)
            FL_Loss = get_loss(opt, batch_y.unsqueeze(-1))
            model_loss =  100*FL_Loss
            loss = model_loss #+ 10000 * decov_loss
        
            
            batch_loss.append(loss.cpu().detach().numpy())
            model_batch_loss.append(model_loss.cpu().detach().numpy())
            y_pred += list(opt.cpu().detach().numpy().flatten())
            y_true += list(batch_y.cpu().numpy().flatten())
            
            
    valid_loss.append(np.mean(np.array(batch_loss)))
    valid_model_loss.append(np.mean(np.array(model_batch_loss)))
    
    print("\n==>Predicting on validation")
    print('Valid Loss = %.4f'%(valid_loss[-1]))
    print('valid_model Loss = %.4f'%(valid_model_loss[-1]))

    y_pred = np.array(y_pred)
    y_pred = np.stack([1 - y_pred, y_pred], axis=1)
    ret = metrics.print_metrics_binary(y_true, y_pred)
    history.append(ret)
    print('')

    cur_prc = ret['auprc']
    if cur_prc > max_prc:
        max_prc = cur_prc
        state = {
            'net': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': each_epoch,
            'cluster': clu2feat,
            'adj_mat': adj_mat
        }
        torch.save(state, file_name+"prc")
        print('\n------------ Save best-prc model ------------\n')
    
    cur_roc = ret['auroc']
    if cur_roc > max_roc:
        max_roc = cur_roc
        state = {
            'net': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': each_epoch,
            'cluster': clu2feat,
            'adj_mat': adj_mat
        }
        torch.save(state, file_name+"roc")
        print('\n------------ Save best-roc model ------------\n')
    
    if valid_loss[-1] < min_loss:
        min_loss = valid_loss[-1]
        state = {
            'net': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': each_epoch,
            'cluster': clu2feat,
            'adj_mat': adj_mat
        }
        torch.save(state, file_name+"loss")
        print('\n------------ Save best-valloss model ------------\n')
    
    feature_emb = None
    if each_epoch < cluster_epochs:
        with torch.no_grad():
            model.eval()
            for step, (batch_x, batch_y, batch_name) in enumerate(train_loader):  
                batch_x = batch_x.float().to(device)
                batch_y = batch_y.float().to(device)

                batch_demo = []
                for i in range(len(batch_name)):
                    cur_id, cur_ep, _ = batch_name[i].split('_', 2)
                    cur_idx = cur_id + '_' + cur_ep
                    cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)], dtype=torch.float32)
                    batch_demo.append(cur_demo)

                batch_demo = torch.stack(batch_demo).to(device)

                opt, emb, contexts = model(batch_x, batch_demo, adj_mat)

                if step % 30 == 0:
                    print("Generating hidden from train in eval mode. Batch %d..." % step)
                #Sampling
                cur_batch_size = batch_x.size(0)
                sample_size = min(32, cur_batch_size)
                indices = torch.tensor(random.sample(range(cur_batch_size), sample_size)).to(device)
                if feature_emb is None:
                    feature_emb = torch.index_select(contexts[:,:-1,:], dim=0, index=indices)
                else:
                    cur_feature_emb = torch.index_select(contexts[:,:-1,:], dim=0, index=indices)
                    feature_emb = torch.cat((feature_emb, cur_feature_emb), dim=0)
                    
        adj_mat, feat2clu, clu2feat = GraphUpdate(sim_metric, feature_emb, input_dim, n_clu)

    scheduler.step()
        
        
print('==============DONE==============')

In [None]:
####Plot Loss
import matplotlib.pyplot as plt
fig = plt.figure()
ax1 = fig.add_subplot(1, 1, 1)
x_axis = np.arange(1, 151)
ax1.plot(x_axis, train_model_loss, c='red', label='train')
ax1.plot(x_axis, valid_model_loss, c='blue', label='valid')
plt.legend(loc='best')
plt.show()

### Test

In [15]:
test_reader = InHospitalMortalityReader(dataset_dir=os.path.join(data_path, 'test'),
                                            listfile=os.path.join(data_path, 'test_listfile.csv'),
                                            period_length=48.0)
test_raw = utils.load_data(test_reader, discretizer, normalizer, small_part, return_names=True)
test_dataset = Dataset(test_raw['data'][0], test_raw['data'][1], test_raw['names'])
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

(3236, 48, 76)


In [42]:
file_name = './model/MIMIC_SAFARI'

batch_size = 256
pad_token = np.zeros(34)
checkpoint = torch.load(file_name+'roc')#23
save_epoch = checkpoint['epoch']
clu2feat = checkpoint['cluster']
adj_mat = checkpoint['adj_mat']
print("last saved model is in epoch {}".format(save_epoch))
model.eval()


batch_loss = []
y_true = []
y_pred = []
with torch.no_grad():
    model.eval()
    for step, (batch_x, batch_y, batch_name) in enumerate(test_loader):  
        optimizer.zero_grad()
        batch_x = batch_x.float().to(device)
        batch_y = batch_y.float().to(device)

        batch_demo = []
        for i in range(len(batch_name)):
            cur_id, cur_ep, _ = batch_name[i].split('_', 2)
            cur_idx = cur_id + '_' + cur_ep
            cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)], dtype=torch.float32)
            batch_demo.append(cur_demo)

        batch_demo = torch.stack(batch_demo).to(device)
        opt, emb, contexts = model(batch_x, batch_demo, adj_mat)
        BCE_Loss = get_loss(opt, batch_y.unsqueeze(-1))
        model_loss =  BCE_Loss 

        loss = model_loss
        batch_loss.append(loss.cpu().detach().numpy())
        if step % 20 == 0:
            print('Batch %d: Test Loss = %.4f'%(step, loss.cpu().detach().numpy()))
        y_pred += list(opt.cpu().detach().numpy().flatten())
        y_true += list(batch_y.cpu().numpy().flatten())

print("\n==>Predicting on test")
print('Test Loss = %.4f'%(np.mean(np.array(batch_loss))))
y_pred = np.array(y_pred)
y_pred = np.stack([1 - y_pred, y_pred], axis=1)
test_res = metrics.print_metrics_binary(y_true, y_pred)

last saved model is in epoch 31
Batch 0: Test Loss = 0.2869

==>Predicting on test
Test Loss = 0.2544
confusion matrix:
[[2803   59]
 [ 242  132]]
accuracy = 0.9069839119911194
precision class 0 = 0.9205254316329956
precision class 1 = 0.6910994648933411
recall class 0 = 0.9793850183486938
recall class 1 = 0.3529411852359772
AUC of ROC = 0.865084436671562
AUC of PRC = 0.5379030136044421
min(+P, Se) = 0.5080213903743316
f1_score = 0.4672566288726912


In [21]:
# Bootstrap
N = len(y_true)
N_idx = np.arange(N)
K = 1000

auroc = []
auprc = []
minpse = []
f1 = []
for i in range(K):
    boot_idx = np.random.choice(N_idx, N, replace=True)
    boot_true = np.array(y_true)[boot_idx]
    boot_pred = y_pred[boot_idx, :]
    test_ret = metrics.print_metrics_binary(boot_true, boot_pred, verbose=0)
    auroc.append(test_ret['auroc'])
    auprc.append(test_ret['auprc'])
    minpse.append(test_ret['minpse'])
    f1.append(test_ret['f1_score'])
#     print('%d/%d'%(i+1,K))
    
print('auroc %.4f(%.4f)'%(np.mean(auroc), np.std(auroc)))
print('auprc %.4f(%.4f)'%(np.mean(auprc), np.std(auprc)))
print('minpse %.4f(%.4f)'%(np.mean(minpse), np.std(minpse)))
print('f1 %.4f(%.4f)'%(np.mean(f1), np.std(f1)))

auroc 0.8656(0.0094)
auprc 0.5402(0.0276)
minpse 0.5152(0.0229)
f1 0.4681(0.0254)


In [22]:
all = [auroc, auprc, minpse, f1]

### U-Test

In [23]:
class RNN(nn.Module):
    def __init__(self, input_dim=76, hidden_dim=128, output_dim=1, dropout=0.3):
        super(RNN, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.dropout = dropout
        self.NUM_LAYERS = 2
        self.demo_proj = nn.Linear(10, self.hidden_dim)
        
        self.gru_encoder = nn.GRU(1, self.hidden_dim, self.NUM_LAYERS, batch_first=True, dropout=self.dropout)
#         for param in self.timenet.parameters():
#             param.requires_grad = True
    
        self.nn_output = nn.Linear(self.hidden_dim*(self.input_dim*2+1), self.output_dim)
        self.sigmoid = nn.Sigmoid()
        self.nn_dropout = nn.Dropout(p=dropout)

    def forward(self, input, static):
        batch_size = input.size(0)
        time_step = input.size(1)
        feature_dim = input.size(2)

        gru_atten = None
        
        for i in range(feature_dim):
            
#             tmp_input = pack_padded_sequence(input[:,:,i].unsqueeze(-1), lens, batch_first=True)
            timenet_feature = self.gru_encoder(input[:,:,i].unsqueeze(-1))[1].transpose(0, 1)
            timenet_feature = torch.reshape(timenet_feature, (batch_size, self.hidden_dim*2))
            if gru_atten is None:
                gru_atten = timenet_feature
            else:
                gru_atten = torch.cat((gru_atten, timenet_feature), dim=1)
        static_emb = self.demo_proj(static)
        gru_atten = torch.cat((gru_atten, static_emb), dim=1)
        hn = gru_atten

        if self.dropout > 0.0:
            hn = self.nn_dropout(hn)

        rnn_output = self.nn_output(hn)
        rnn_output = self.sigmoid(rnn_output)

        return rnn_output

    
input_dim = 76
hidden_dim = 32
output_dim = 1

model = RNN(input_dim = input_dim, hidden_dim = hidden_dim, output_dim = output_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [24]:
file_name = './model/mimic_timenet'

batch_size = 256
checkpoint = torch.load(file_name+'roc')#23
save_epoch = checkpoint['epoch']
print("last saved model is in epoch {}".format(save_epoch))
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
model.eval()


batch_loss = []
y_true = []
y_pred = []
with torch.no_grad():
    model.eval()
    for step, (batch_x, batch_y, batch_name) in enumerate(test_loader):  
        optimizer.zero_grad()
        batch_x = batch_x.float().to(device)
        batch_y = batch_y.float().to(device)

        batch_demo = []
        for i in range(len(batch_name)):
            cur_id, cur_ep, _ = batch_name[i].split('_', 2)
            cur_idx = cur_id + '_' + cur_ep
            cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)], dtype=torch.float32)
            batch_demo.append(cur_demo)

        batch_demo = torch.stack(batch_demo).to(device)
        opt = model(batch_x, batch_demo)
        BCE_Loss = get_loss(opt, batch_y.unsqueeze(-1))
        model_loss =  BCE_Loss 

        loss = model_loss
        batch_loss.append(loss.cpu().detach().numpy())
        if step % 20 == 0:
            print('Batch %d: Test Loss = %.4f'%(step, loss.cpu().detach().numpy()))
        y_pred += list(opt.cpu().detach().numpy().flatten())
        y_true += list(batch_y.cpu().numpy().flatten())

print("\n==>Predicting on test")
print('Test Loss = %.4f'%(np.mean(np.array(batch_loss))))
y_pred = np.array(y_pred)
y_pred = np.stack([1 - y_pred, y_pred], axis=1)
test_res = metrics.print_metrics_binary(y_true, y_pred)

last saved model is in epoch 40
Batch 0: Test Loss = 0.2718

==>Predicting on test
Test Loss = 0.2526
confusion matrix:
[[2809   53]
 [ 246  128]]
accuracy = 0.9076019525527954
precision class 0 = 0.9194762706756592
precision class 1 = 0.7071823477745056
recall class 0 = 0.9814814925193787
recall class 1 = 0.34224599599838257
AUC of ROC = 0.8643585316726272
AUC of PRC = 0.5245101181936102
min(+P, Se) = 0.506426735218509
f1_score = 0.4612612731545188


In [25]:
# Bootstrap
N = len(y_true)
N_idx = np.arange(N)
K = 1000

cmp_auroc = []
cmp_auprc = []
cmp_minpse = []
cmp_f1 = []
for i in range(K):
    boot_idx = np.random.choice(N_idx, N, replace=True)
    boot_true = np.array(y_true)[boot_idx]
    boot_pred = y_pred[boot_idx, :]
    test_ret = metrics.print_metrics_binary(boot_true, boot_pred, verbose=0)
    cmp_auroc.append(test_ret['auroc'])
    cmp_auprc.append(test_ret['auprc'])
    cmp_minpse.append(test_ret['minpse'])
    cmp_f1.append(test_ret['f1_score'])
#     print('%d/%d'%(i+1,K))

cmp_all = [cmp_auroc, cmp_auprc, cmp_minpse, cmp_f1]
print('auroc %.4f(%.4f)'%(np.mean(cmp_auroc), np.std(cmp_auroc)))
print('auprc %.4f(%.4f)'%(np.mean(cmp_auprc), np.std(cmp_auprc)))
print('minpse %.4f(%.4f)'%(np.mean(cmp_minpse), np.std(cmp_minpse)))
print('f1 %.4f(%.4f)'%(np.mean(cmp_f1), np.std(cmp_f1)))

auroc 0.8642(0.0093)
auprc 0.5245(0.0273)
minpse 0.5083(0.0224)
f1 0.4609(0.0253)


In [31]:
from prettytable import PrettyTable
from scipy.stats import ranksums
table = PrettyTable(['feature','t','Mean in SAFARI','Mean in Best Baseline','p-value','p<0.05','p<0.01'])

col_name = ['AUROC', 'AUPRC', 'minPSE']
for idx in range(3):
    posi_li = list(all[idx])
    nega_li = list(cmp_all[idx])
    
    
    total_li = []
    total_li.extend(posi_li)
    total_li.extend(nega_li)
    total_li.sort()
#     print(total_li)
    s_cut = total_li[int(len(total_li)*0.05)]
    b_cut = total_li[int(len(total_li)*0.95)]
#     print(s_cut)
    t,p = ranksums(posi_li, nega_li)
#     t2,p = mannwhitneyu(nega_li, posi_li)
    flag = ""
    if p < 0.05:
        flag = "Y"
    flag2 = ""
    if p < 0.01:
        flag2 = "Y"
        
#     flag3 = ""
#     if p < 0.1:
#         flag3 = "Y"
        
    rela1 = ''
    rela2 = ''
    if t > 0:
        rela1 = "Y"
    else:
        rela2 = "Y"
    
    table.add_row([col_name[idx],round(t,3), round(np.mean(posi_li),3),round(np.mean(nega_li),3),round(p,3),flag,flag2])
#     break
# table.align["feature"] = "r"  

print(table)

+---------+-------+----------------+-----------------------+---------+--------+--------+
| feature |   t   | Mean in SAFARI | Mean in Best Baseline | p-value | p<0.05 | p<0.01 |
+---------+-------+----------------+-----------------------+---------+--------+--------+
|  AUROC  | 3.276 |     0.866      |         0.864         |  0.001  |   Y    |   Y    |
|  AUPRC  | 12.28 |      0.54      |         0.524         |   0.0   |   Y    |   Y    |
|  minPSE | 6.518 |     0.515      |         0.508         |   0.0   |   Y    |   Y    |
+---------+-------+----------------+-----------------------+---------+--------+--------+
