In [None]:
import numpy as np
import pandas as pd
import os
import math

from sas7bdat import SAS7BDAT

pd.set_option('display.max_rows', 300, 'display.max_columns', 300)
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import adabound
from pycox.evaluation.concordance import concordance_td
from sklearn.metrics import accuracy_score, roc_auc_score
from lifelines import KaplanMeierFitter, NelsonAalenFitter, AalenJohansenFitter
torch.cuda.is_available()

In [None]:
class dynamicdeephit(nn.Module):
    
    def __init__(self, fixed_size, timevar_size, ntime, state_size, num_layers=1, atten_size=128, cs_size=128):
        super(dynamicdeephit, self).__init__()
        
        self.fixed_size = fixed_size
        self.timevar_size = timevar_size
        self.ntime = ntime
        self.state_size = state_size
        self.atten_size = atten_size
        self.num_layers = num_layers
        
        self.cs_size = cs_size
        
        self.gru = nn.GRU(input_size = self.timevar_size, 
                          hidden_size = self.state_size,
                          num_layers = self.num_layers,
                          batch_first = True,
                         )
        self.linear = nn.Linear(self.state_size, self.timevar_size)
        self.attention1 = nn.Linear(self.timevar_size + self.state_size, self.atten_size)
        self.attention2 = nn.Linear(self.atten_size, 1)
        
        self.csnet1_1 = nn.Linear(self.timevar_size + self.state_size + self.fixed_size, self.cs_size)
        self.csnet1_2 = nn.Linear(self.cs_size, self.ntime + 2)
        
        self.csnet2_1 = nn.Linear(self.timevar_size + self.state_size + self.fixed_size, self.cs_size)
        self.csnet2_2 = nn.Linear(self.cs_size, self.ntime + 2)
        
        self.activation = nn.LeakyReLU(inplace=True)
        
    def forward(self, x, input_length):
        ## x: (fixed, timevar, obs_mask)
        ## input_length: (batch)
        ## hc: (batch, state_size)
        
        fixed, timevar = x
        ## fixed: (batch, fixed_size)
        ## timevar: (batch, ntime, timevar_size)

        batch_size = timevar.shape[0]
        
        ## SHARED NET
        
        timevar_multiple = timevar[input_length > 1]
        timevar_single = timevar[input_length == 1][:,0,:]
        
        timevar_multiple_batch_size = timevar_multiple.shape[0]
        
        timevar_multiple_last = torch.zeros(timevar_multiple_batch_size, self.timevar_size).to(device)
        for i in range(timevar_multiple_batch_size):
            timevar_multiple_last[i, :] = timevar_multiple[i, (input_length[input_length>1]-1)[i], :]
        timevar_last = torch.cat([timevar_multiple_last, timevar_single], axis=0)
        
        packed_timevar = nn.utils.rnn.pack_padded_sequence(timevar_multiple, input_length[input_length>1]-1, batch_first=True)
        
        h0 = torch.zeros(self.num_layers, timevar_multiple_batch_size, self.state_size).to(device)
        
        output, hn = self.gru(packed_timevar, h0)
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(output, batch_first=True, total_length=ntime-1)
        output_timevar = self.linear(output)
        output_timevar = torch.sigmoid(output_timevar)
            
        decoder_input_with_context_array = []
        
        for i in range(timevar_multiple_batch_size):
        
            e_array = []
            for t_in in range((input_length[input_length>1]-1)[i]):
                decoder_input_with_encoder = torch.cat([timevar_multiple_last[None,i,:], output[None,i,t_in,:]], 1) ## (batch, output_size+state_size)
                e = self.attention1(decoder_input_with_encoder) ## (batch, 1)
                e = self.attention2(self.activation(e)) ## (batch, 1)
                e_array.append(self.activation(e))
            e_array = torch.stack(e_array, 1) ## (batch, input_sequence_len, 1)
            e_array = torch.softmax(e_array, dim=1) ## (batch, input_sequence_len, 1)
            context = torch.sum(e_array * output[None,i,:(input_length[input_length>1]-1)[i],:], dim=1) ## (batch, state_size)

            decoder_input_with_context = torch.cat([timevar_multiple_last[None,i,:], context], 1) ## (batch, output_size+state_size)
            decoder_input_with_context_array.append(decoder_input_with_context)
            
        shared_output = torch.cat(decoder_input_with_context_array, 0)
        
        if timevar_single.shape[0] > 0:
            empty_context = torch.zeros(timevar_single.shape[0], self.state_size).to(device)
            single_with_empty_context = torch.cat([timevar_single, empty_context], 1)
            
            shared_output = torch.cat([shared_output, single_with_empty_context], 0)
            
        shared_output = torch.cat([shared_output, fixed], 1)
        
        ## CAUSE-SPECIFIC NET
        
        pred1 = self.csnet1_1(shared_output)
        pred1 = self.csnet1_2(self.activation(pred1))
        
        pred2 = self.csnet2_1(shared_output)
        pred2 = self.csnet2_2(self.activation(pred2))
        
        pred = torch.cat([pred1, pred2], 1)
        pred = torch.softmax(pred, dim=1)
        
        return pred, output_timevar

In [None]:
def f_get_fc_mask1(meas_time, num_Event, num_Category):
    '''
        mask3 is required to get the contional probability (to calculate the denominator part)
        mask3 size is [N, num_Event, num_Category]. 1's until the last measurement time
    '''
    mask = torch.zeros(np.shape(meas_time)[0], num_Event, num_Category) # for denominator
    for i in range(np.shape(meas_time)[0]):
        mask[i, :, :int(meas_time[i, 0]+1)] = 1 # last measurement time

    return mask


def f_get_fc_mask2(time, label, num_Event, num_Category):
    '''
        mask4 is required to get the log-likelihood loss 
        mask4 size is [N, num_Event, num_Category]
            if not censored : one element = 1 (0 elsewhere)
            if censored     : fill elements with 1 after the censoring time (for all events)
    '''
    mask = torch.zeros(np.shape(time)[0], num_Event, num_Category) # for the first loss function
    for i in range(np.shape(time)[0]):
        if label[i,0] != 0:  #not censored
            mask[i,int(label[i,0]-1),int(time[i,0])] = 1
        else: #label[i,2]==0: censored
            mask[i,:,int(time[i,0]+1):] =  1 #fill 1 until from the censoring time (to get 1 - \sum F)
    return mask


def f_get_fc_mask3(time, meas_time, num_Category):
    '''
        mask5 is required calculate the ranking loss (for pair-wise comparision)
        mask5 size is [N, num_Category]. 
        - For longitudinal measurements:
             1's from the last measurement to the event time (exclusive and inclusive, respectively)
             denom is not needed since comparing is done over the same denom
        - For single measurement:
             1's from start to the event time(inclusive)
    '''
    mask = torch.zeros(np.shape(time)[0], num_Category) # for the first loss function
    if np.shape(meas_time):  #lonogitudinal measurements 
        for i in range(np.shape(time)[0]):
            t1 = int(meas_time[i, 0]) # last measurement time
            t2 = int(time[i, 0]) # censoring/event time
            mask[i,(t1+1):(t2+1)] = 1  #this excludes the last measurement time and includes the event time
    else:                    #single measurement
        for i in range(np.shape(time)[0]):
            t = int(time[i, 0]) # censoring/event time
            mask[i,:(t+1)] = 1  #this excludes the last measurement time and includes the event time
    return mask

In [None]:
def custom_loss(pmf, timevar, input_length, obs_mask, long_pred, event, time, mask1, mask2, mask3, eps=1.e-9, alpha=0.1, beta=0.1, num_event=2):
  
    ## loss 1
    I_1 = torch.sign(event)
    
    denom = 1 - torch.sum(torch.sum(mask1 * pmf, dim=2), dim=1, keepdims=True)
    denom.clamp_(eps, 1-eps)

    logpdf = I_1 * torch.log(torch.sum(torch.sum(mask2 * pmf, dim=2), dim=1, keepdims=True) / denom + eps)
    logsurv = (1. - I_1) * torch.log(torch.sum(torch.sum(mask2 * pmf, dim=2), dim=1, keepdims=True) / denom + eps)

    loss_1 = - torch.mean(logpdf + 1.0*logsurv)

    ## loss 2
    sigma1 = 0.1

    eta = []
    mask3s = torch.stack(num_event * [mask3], 1)
    for e in range(num_event):
        one_vector = torch.ones_like(event)
        I_2 = (event == e+1).float()
        I_2 = torch.diag(torch.squeeze(I_2))
        
        R = torch.matmul(pmf[:,e,:], torch.t(mask3s[:,e,:]))
        ## R_{ij} = risk of i-th pat based on j-th time-condition (last meas. time ~ event time) , i.e. R_i(T_{j})

        diag_R = torch.reshape(torch.diag(R), [-1, 1])
        R = torch.matmul(one_vector, torch.t(diag_R)) - R
        R = torch.t(R)

        T = torch.relu(torch.sign(torch.matmul(one_vector, torch.t(time)) - torch.matmul(time, torch.t(one_vector))))
        ## T_{ij}=1 if t_i < t_j  and T_{ij}=0 if t_i >= t_j

        T = torch.matmul(I_2, T)

        tmp_eta = torch.mean(T * torch.exp(-R/sigma1), dim=1, keepdims=True)

        eta.append(tmp_eta)

    eta = torch.stack(eta, dim=1)
    eta = torch.mean(torch.reshape(eta, [-1, num_event]), dim=1, keepdims=True)

    loss_2 = torch.sum(eta)
    
    ## loss 3

    timevar_multiple = timevar[input_length > 1,1:,:]
    
    # Cross entropy loss (for categorical)
    loss_3 = -torch.sum((timevar_multiple[:,:,:19]*torch.log(long_pred[:,:,:19] + eps) + (1 - timevar_multiple[:,:,:19])*torch.log(1 - long_pred[:,:,:19] + eps))*obs_mask[input_length > 1,1:,:19])/obs_mask[input_length > 1,1:,:19].sum()
    # Mean square loss (for continuous)
    loss_3 += torch.sum((timevar_multiple[:,:,19:] - long_pred[:,:,19:])**2*obs_mask[input_length > 1,1:,19:])/obs_mask[input_length > 1,1:,19:].sum()
    
    return loss_1, alpha * loss_2, beta * loss_3

In [None]:
for weight_decay in [1e-3]:
    for state_size in [128, 256]:
        for num_layers in [1, 2]:
            for atten_size in [128, 256]:
                for cs_size in [128, 256]:

                    path = './DynamicDeepHit_{}statesize_{}numlayers_{}attensize_{}cssize_{:.0e}'.format(state_size, num_layers, atten_size, cs_size, weight_decay)
                    #if os.path.isfile(path):
                    #    continue

                    model = dynamicdeephit(fixed_train.shape[-1], timevar_compact_train.shape[-1], ntime, state_size, num_layers, atten_size, cs_size).to(device)
                    #if os.path.isfile(path):
                    #    model.load_state_dict(torch.load(path, map_location = device))

                    lr = 1e-2
                    optimizer = adabound.AdaBound(model.parameters(), lr=lr, final_lr=0.1, weight_decay=0)

                    loss_array = []
                    loss_array1 = []
                    loss_array2 = []
                    loss_array3 = []
                    patience = 0
                    min_loss = np.inf
                    for e in range(int(1e6)):

                        loss1_array_tmp = []
                        loss2_array_tmp = []
                        loss3_array_tmp = []

                        for fixed_batch, timevar_compact_batch, event_batch, time_batch, obs_mask_batch in train_loader:

                            input_lengths_batch = ((timevar_compact_batch[:,:,-1] * train_max_timevar[0,0,-1]).sum(1) + 1).int()
                            input_lengths_batch, sorted_idx = input_lengths_batch.sort(0, descending=True)

                            fixed_batch_sort = fixed_batch[sorted_idx].float()
                            timevar_compact_batch_sort = timevar_compact_batch[sorted_idx].float()
                            event_batch_sort = event_batch[sorted_idx].float()
                            time_batch_sort = time_batch[sorted_idx].float()
                            obs_mask_batch_sort = obs_mask_batch[sorted_idx].float()

                            mask1_batch_sort = f_get_fc_mask1(input_lengths_batch.reshape(-1, 1) - 1, num_Event=2, num_Category=ntime+2)
                            mask2_batch_sort = f_get_fc_mask2(time_batch_sort.reshape(-1, 1), event_batch_sort.reshape(-1, 1), num_Event=2, num_Category=ntime+2)
                            mask3_batch_sort = f_get_fc_mask3(time_batch_sort.reshape(-1, 1), input_lengths_batch.reshape(-1, 1) - 1, ntime+2)

                            y_pred, long_pred = model((fixed_batch_sort.to(device), timevar_compact_batch_sort.to(device)), input_lengths_batch)

                            norm = 0.
                            for parameter in model.parameters():
                                norm += torch.norm(parameter, p=1)

                            loss1, loss2, loss3 = custom_loss(pmf = y_pred.reshape(-1, 2, ntime+2), 
                                                              timevar = timevar_compact_batch_sort.to(device),
                                                              input_length = input_lengths_batch.to(device),
                                                              obs_mask = obs_mask_batch_sort.to(device),
                                                              long_pred = long_pred.to(device), 
                                                              event = event_batch_sort.reshape(-1, 1).to(device), 
                                                              time = time_batch_sort.reshape(-1, 1).to(device),
                                                              mask1 = mask1_batch_sort.to(device), 
                                                              mask2 = mask2_batch_sort.to(device), 
                                                              mask3 = mask3_batch_sort.to(device))

                            loss = loss1 + loss2 + loss3 + weight_decay*norm
                            loss1_array_tmp.append(loss1.item())
                            loss2_array_tmp.append(loss2.item())
                            loss3_array_tmp.append(loss3.item())

                            model.zero_grad()

                            loss.backward()

                            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
                            optimizer.step()

                        loss_array.append(np.mean(loss1_array_tmp) + np.mean(loss2_array_tmp) + np.mean(loss3_array_tmp))
                        loss_array1.append(np.mean(loss1_array_tmp))
                        loss_array2.append(np.mean(loss2_array_tmp))
                        loss_array3.append(np.mean(loss3_array_tmp))
                        if e % 100 == 0:
                            print('Epoch: ' + str(e) + 
                                  ', TotalLoss: '+ f'{loss_array[-1]:.4e}' +
                                  ', Loss1: '+ f'{loss_array1[-1]:.4e}' + 
                                  ', Loss2: '+ f'{loss_array2[-1]:.4e}' + 
                                  ', Loss3: '+ f'{loss_array3[-1]:.4e}')
                        if min_loss > loss_array[-1]:
                            patience = 0
                            min_loss = loss_array[-1]
                            torch.save(model.state_dict(), path)
                        else:
                            patience += 1

                        torch.cuda.empty_cache()

                        if patience > 1000:
                            break

                    plt.plot(loss_array, label='Total Loss')
                    plt.plot(loss_array1, label='Loss 1')
                    plt.plot(loss_array2, label='Loss 2')
                    plt.plot(loss_array3, label='Loss 3')
                    plt.ylabel('loss')
                    plt.xlabel('epoch')
                    #plt.yscale('log')
                    plt.title(path[2:])
                    plt.legend()
                    plt.show()

                    model.load_state_dict(torch.load(path, map_location = device))
                    
                    print(path[9:])
                    
                    for delt in [1, 3, 6, 9, 12]:

                        print('****************************************')
                        print('Delta t = {}'.format(delt))
                        print('****************************************')

                        input_lengths_train = torch.IntTensor((timevar_compact_train[:,:,-1] * train_max_timevar[0,0,-1]).sum(1) + 1)
                        input_lengths_train, sorted_idx = input_lengths_train.sort(0, descending=True)

                        fixed_train_sort = torch.tensor(fixed_train)[sorted_idx].float()
                        timevar_compact_train_sort = torch.tensor(timevar_compact_train)[sorted_idx].float()
                        event_train_sort = torch.tensor(event_train)[sorted_idx].float()
                        time_train_sort = torch.tensor(time_train)[sorted_idx].float()
                        obs_mask_compact_train_sort = torch.tensor(obs_mask_compact_train)[sorted_idx].float()

                        input_lengths_test = torch.IntTensor((timevar_compact_test[:,:,-1] * train_max_timevar[0,0,-1]).sum(1) + 1)
                        input_lengths_test, sorted_idx = input_lengths_test.sort(0, descending=True)

                        fixed_test_sort = torch.tensor(fixed_test)[sorted_idx].float()
                        timevar_compact_test_sort = torch.tensor(timevar_compact_test)[sorted_idx].float()
                        event_test_sort = torch.tensor(event_test)[sorted_idx].float()
                        time_test_sort = torch.tensor(time_test)[sorted_idx].float()
                        obs_mask_compact_test_sort = torch.tensor(obs_mask_compact_test)[sorted_idx].float()

                        y_train, _ = model((fixed_train_sort.to(device), timevar_compact_train_sort.to(device)), input_lengths_train)
                        y_test, _ = model((fixed_test_sort.to(device), timevar_compact_test_sort.to(device)), input_lengths_test)

                        mask1_train_sort = f_get_fc_mask1(input_lengths_train.reshape(-1, 1) - 1, num_Event=2, num_Category=ntime+2).to(device)
                        mask1_test_sort = f_get_fc_mask1(input_lengths_test.reshape(-1, 1) - 1, num_Event=2, num_Category=ntime+2).to(device)

                        CIF_train = torch.cumsum(y_train.reshape(-1,2,ntime+2) * (1-mask1_train_sort), 2)
                        CIF_test = torch.cumsum(y_test.reshape(-1,2,ntime+2) * (1-mask1_test_sort), 2)

                        time_train_trunc = np.int32(np.where(time_train_sort.detach().cpu().numpy() > delt-1, delt-1, time_train_sort.detach().cpu().numpy()))
                        time_test_trunc = np.int32(np.where(time_test_sort.detach().cpu().numpy() > delt-1, delt-1, time_test_sort.detach().cpu().numpy()))

                        event_train_trunc = np.int32(np.where(time_train_sort.detach().cpu().numpy() > delt-1, 0, event_train_sort.detach().cpu().numpy()))
                        event_test_trunc = np.int32(np.where(time_test_sort.detach().cpu().numpy() > delt-1, 0, event_test_sort.detach().cpu().numpy()))

                        print('train ctd for CVD: {:.4f}'.format(concordance_td(time_train_trunc, event_train_trunc==1, 1.-CIF_train[:,0,:].detach().cpu().numpy().T, time_train_trunc)))
                        print('train ctd for Death: {:.4f}'.format(concordance_td(time_train_trunc, event_train_trunc==2, 1.-CIF_train[:,1,:].detach().cpu().numpy().T, time_train_trunc)))

                        print('test ctd for CVD: {:.4f}'.format(concordance_td(time_test_trunc, event_test_trunc==1, 1.-CIF_test[:,0,:].detach().cpu().numpy().T, time_test_trunc)))
                        print('test ctd for Death: {:.4f}'.format(concordance_td(time_test_trunc, event_test_trunc==2, 1.-CIF_test[:,1,:].detach().cpu().numpy().T, time_test_trunc)) + '\n')

                        E_CVD_train = np.array([CIF_train[i,0,j].item() for (i,j) in zip(range(len(time_train_trunc)), time_train_trunc)])
                        O_CVD_train = event_train_trunc==1
                        E_Death_train = np.array([CIF_train[i,1,j].item() for (i,j) in zip(range(len(time_train_trunc)), time_train_trunc)])
                        O_Death_train = event_train_trunc==2

                        E_CVD_test = np.array([CIF_test[i,0,j].item() for (i,j) in zip(range(len(time_test_trunc)), time_test_trunc)])
                        O_CVD_test = event_test_trunc==1
                        E_Death_test = np.array([CIF_test[i,1,j].item() for (i,j) in zip(range(len(time_test_trunc)), time_test_trunc)])
                        O_Death_test = event_test_trunc==2

                        print('train Brier for CVD: {:.4f}'.format(((E_CVD_train - O_CVD_train)**2).mean()))
                        print('train Brier for Death: {:.4f}'.format(((E_Death_train - O_Death_train)**2).mean()))

                        print('test Brier for CVD: {:.4f}'.format(((E_CVD_test - O_CVD_test)**2).mean()))
                        print('test Brier for Death: {:.4f}'.format(((E_Death_test - O_Death_test)**2).mean()) + '\n')