In [None]:
import numpy as np
import pandas as pd
import os
import math

from sas7bdat import SAS7BDAT
import datetime

pd.set_option('display.max_rows', 300, 'display.max_columns', 300)
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import adabound
from pycox.evaluation.concordance import concordance_td
from sklearn.metrics import accuracy_score, roc_auc_score
from lifelines import KaplanMeierFitter, NelsonAalenFitter, AalenJohansenFitter

torch.cuda.is_available()

# Cause-specific prediction (single event)

In [None]:
def f_get_fc_mask1(meas_time, num_Event, num_Category):
    '''
        mask3 is required to get the contional probability (to calculate the denominator part)
        mask3 size is [N, num_Event, num_Category]. 1's until the last measurement time
    '''
    mask = torch.zeros(np.shape(meas_time)[0], num_Event, num_Category) # for denominator
    for i in range(np.shape(meas_time)[0]):
        mask[i, :, :int(meas_time[i, 0]+1)] = 1 # last measurement time

    return mask


def f_get_fc_mask2(time, label, num_Event, num_Category):
    '''
        mask4 is required to get the log-likelihood loss 
        mask4 size is [N, num_Event, num_Category]
            if not censored : one element = 1 (0 elsewhere)
            if censored     : fill elements with 1 after the censoring time (for all events)
    '''
    mask = torch.zeros(np.shape(time)[0], num_Event, num_Category) # for the first loss function
    for i in range(np.shape(time)[0]):
        if label[i,0] != 0:  #not censored
            mask[i,int(label[i,0]-1),int(time[i,0])] = 1
        else: #label[i,2]==0: censored
            mask[i,:,int(time[i,0]+1):] =  1 #fill 1 until from the censoring time (to get 1 - \sum F)
    return mask


def f_get_fc_mask3(time, meas_time, num_Category):
    '''
        mask5 is required calculate the ranking loss (for pair-wise comparision)
        mask5 size is [N, num_Category]. 
        - For longitudinal measurements:
             1's from the last measurement to the event time (exclusive and inclusive, respectively)
             denom is not needed since comparing is done over the same denom
        - For single measurement:
             1's from start to the event time(inclusive)
    '''
    mask = torch.zeros(np.shape(time)[0], num_Category) # for the first loss function
    if np.shape(meas_time):  #lonogitudinal measurements 
        for i in range(np.shape(time)[0]):
            t1 = int(meas_time[i, 0]) # last measurement time
            t2 = int(time[i, 0]) # censoring/event time
            mask[i,(t1+1):(t2+1)] = 1  #this excludes the last measurement time and includes the event time
    else:                    #single measurement
        for i in range(np.shape(time)[0]):
            t = int(time[i, 0]) # censoring/event time
            mask[i,:(t+1)] = 1  #this excludes the last measurement time and includes the event time
    return mask

In [None]:
def custom_loss(pmf, event, time, mask1, mask2, mask3, eps=1.e-9, alpha=0.1, num_event=1):
       
    ## loss 1
    I_1 = torch.sign(event)
    
    denom = 1 - torch.sum(torch.sum(mask1 * pmf, dim=2), dim=1, keepdims=True)
    denom.clamp_(eps, 1-eps)

    logpdf = I_1 * torch.log(torch.sum(torch.sum(mask2 * pmf, dim=2), dim=1, keepdims=True) / denom + eps)
    logsurv = (1. - I_1) * torch.log(torch.sum(torch.sum(mask2 * pmf, dim=2), dim=1, keepdims=True) / denom + eps)

    loss_1 = - torch.mean(logpdf + 1.0*logsurv)

    ## loss 2
    sigma1 = 0.1

    eta = []
    mask3s = torch.stack(num_event * [mask3], 1)
    for e in range(num_event):
        one_vector = torch.ones_like(event)
        I_2 = (event == e+1).float()
        I_2 = torch.diag(torch.squeeze(I_2))
        
        R = torch.matmul(pmf[:,e,:], torch.t(mask3s[:,e,:]))
        ## R_{ij} = risk of i-th pat based on j-th time-condition (last meas. time ~ event time) , i.e. R_i(T_{j})

        diag_R = torch.reshape(torch.diag(R), [-1, 1])
        R = torch.matmul(one_vector, torch.t(diag_R)) - R
        R = torch.t(R)

        T = torch.relu(torch.sign(torch.matmul(one_vector, torch.t(time)) - torch.matmul(time, torch.t(one_vector))))
        ## T_{ij}=1 if t_i < t_j  and T_{ij}=0 if t_i >= t_j

        T = torch.matmul(I_2, T)

        tmp_eta = torch.mean(T * torch.exp(-R/sigma1), dim=1, keepdims=True)

        eta.append(tmp_eta)

    eta = torch.stack(eta, dim=1)
    eta = torch.mean(torch.reshape(eta, [-1, num_event]), dim=1, keepdims=True)

    loss_2 = torch.sum(eta)

    return loss_1, alpha * loss_2

In [None]:
class deephitCS(nn.Module):
    
    def __init__(self, input_size, ntime, shared_size=128):
        super(deephitCS, self).__init__()
        
        self.input_size = input_size
        self.ntime = ntime
        
        self.shared_size = shared_size
        
        self.shared = nn.Sequential(
            nn.Linear(self.input_size, self.shared_size),
            nn.LeakyReLU(inplace=True),
            nn.Linear(self.shared_size, self.ntime+2)
        )
        #self.cs1 = nn.Sequential(
        #    nn.Linear(self.shared_size + self.input_size, self.cs_size),
        #    Mish(),
        #    nn.Linear(self.cs_size, self.ntime+1)
        #)
                
    def forward(self, x0):
        ## x: (batch, input_size)

        batch_size = x0.shape[0]
        
        ## SHARED NET
        
        x = self.shared(x0)
        
        return torch.softmax(x, 1)

In [None]:
for weight_decay in [1e-3]:
    for shared_size in [16, 32, 64, 128, 256]:

        path = './DeepHit_cs_{}statesize_{:.0e}'.format(shared_size, weight_decay)
        #if os.path.isfile(path):
        #    continue

        model = deephitCS(input_size=total_train.shape[-1], ntime=ntime, shared_size=shared_size).to(device)
        #if os.path.isfile(path):
        #    model.load_state_dict(torch.load(path, map_location = device))

        lr = 1e-3
        optimizer = adabound.AdaBound(model.parameters(), lr=lr, weight_decay=0)

        loss_array = []
        loss_array1 = []
        loss_array2 = []
        patience = 0
        min_loss = np.inf
        for e in range(int(1e6)):

            loss1_array_tmp = []
            loss2_array_tmp = []

            for total_batch, timevar_compact_batch, event_batch, time_batch, obs_mask_batch in train_loader:

                
                input_lengths_batch = (timevar_compact_batch[:,:,-1].sum(1) + 1).int()

                total_batch = total_batch.float()
                event_batch = event_batch.float()
                time_batch = time_batch.float()
                obs_mask_batch = obs_mask_batch.float()        
                
                mask1_batch = f_get_fc_mask1(input_lengths_batch.reshape(-1, 1) - 1, num_Event=1, num_Category=ntime+2)
                mask2_batch = f_get_fc_mask2(time_batch.reshape(-1, 1), event_batch.reshape(-1, 1), num_Event=1, num_Category=ntime+2)
                mask3_batch = f_get_fc_mask3(time_batch.reshape(-1, 1), input_lengths_batch.reshape(-1, 1) - 1, ntime+2)

                y_pred = model(total_batch.to(device))

                norm = 0.
                for parameter in model.parameters():
                    norm += torch.norm(parameter, p=1)

                loss1, loss2 = custom_loss(pmf = y_pred.reshape(-1, num_event, ntime+2), 
                                           event = event_batch.reshape(-1, 1).to(device), 
                                           time = time_batch.reshape(-1, 1).to(device),
                                           mask1 = mask1_batch.to(device), 
                                           mask2 = mask2_batch.to(device),  
                                           mask3 = mask3_batch.to(device), 
                                           num_event = num_event)

                loss = loss1 + loss2 + weight_decay*norm
                loss1_array_tmp.append(loss1.item())
                loss2_array_tmp.append(loss2.item())

                model.zero_grad()

                loss.backward()

                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
                optimizer.step()

            loss_array.append(np.mean(loss1_array_tmp) + np.mean(loss2_array_tmp))
            loss_array1.append(np.mean(loss1_array_tmp))
            loss_array2.append(np.mean(loss2_array_tmp))
            if e % 100 == 0:
                print('Epoch: ' + str(e) + 
                      ', TotalLoss: '+ f'{loss_array[-1]:.4e}' +
                      ', Loss1: '+ f'{loss_array1[-1]:.4e}' + 
                      ', Loss2: '+ f'{loss_array2[-1]:.4e}')
            if min_loss > loss_array[-1]:
                patience = 0
                min_loss = loss_array[-1]
                torch.save(model.state_dict(), path)
            else:
                patience += 1

            torch.cuda.empty_cache()

            if patience > 1000:
                break

        plt.plot(loss_array, label='Total Loss')
        plt.plot(loss_array1, label='Loss 1')
        plt.plot(loss_array2, label='Loss 2')
        plt.ylabel('loss')
        plt.xlabel('epoch')
        plt.yscale('log')
        plt.title(path[2:])
        plt.legend()
        plt.show()

        model.load_state_dict(torch.load(path, map_location = device))
        
        print(path[9:])
        
        for delt in [1,3,6,9,12]:

            print('****************************************')
            print('Delta t = {}'.format(delt))
            print('****************************************')

            input_lengths_train = torch.IntTensor((timevar_compact_train[:,:,-1]).sum(1) + 1)

            total_train_sort = torch.FloatTensor(total_train)
            event_train_sort = torch.FloatTensor(event_train)
            time_train_sort = torch.FloatTensor(time_train)
            obs_mask_compact_train_sort = torch.FloatTensor(obs_mask_compact_train)

            input_lengths_test = torch.IntTensor((timevar_compact_test[:,:,-1] ).sum(1) + 1)

            total_test_sort = torch.FloatTensor(total_test)
            event_test_sort = torch.FloatTensor(event_test)
            time_test_sort = torch.FloatTensor(time_test)
            obs_mask_compact_test_sort = torch.FloatTensor(obs_mask_compact_test)

            y_train = model(total_train_sort.to(device))
            y_test = model(total_test_sort.to(device))

            mask1_train_sort = f_get_fc_mask1(input_lengths_train.reshape(-1, 1) - 1, num_Event=1, num_Category=ntime+2).to(device)
            mask1_test_sort = f_get_fc_mask1(input_lengths_test.reshape(-1, 1) - 1, num_Event=1, num_Category=ntime+2).to(device)

            CIF_train = torch.cumsum(y_train.reshape(-1,1,ntime+2) * (1-mask1_train_sort), 2)
            CIF_test = torch.cumsum(y_test.reshape(-1,1,ntime+2) * (1-mask1_test_sort), 2)

            print('train ctd for CVD: {:.4f}'.format(concordance_td(time_train_sort.detach().cpu().numpy(), event_train_sort.detach().cpu().numpy()==1, 1.-CIF_train[:,0,:].detach().cpu().numpy().T, np.int32(time_train_sort.detach().cpu().numpy()-1))))
            #print('train ctd for Death: {:.4f}'.format(concordance_td(time_train_sort.detach().cpu().numpy(), event_train_sort.detach().cpu().numpy()==2, 1.-CIF_train[:,1,:].detach().cpu().numpy().T, np.int32(time_train_sort.detach().cpu().numpy()-1))))

            print('test ctd for CVD: {:.4f}'.format(concordance_td(time_test_sort.detach().cpu().numpy(), event_test_sort.detach().cpu().numpy()==1, 1.-CIF_test[:,0,:].detach().cpu().numpy().T, np.int32(time_test_sort.detach().cpu().numpy()-1))) + '\n')
            #print('test ctd for Death: {:.4f}'.format(concordance_td(time_test_sort.detach().cpu().numpy(), event_test_sort.detach().cpu().numpy()==2, 1.-CIF_test[:,1,:].detach().cpu().numpy().T, np.int32(time_test_sort.detach().cpu().numpy()-1))) + '\n')

            E_CVD_train = np.array([CIF_train[i,0,int(j)].item() for (i,j) in zip(range(len(time_train_sort)), time_train_sort.detach().cpu().numpy()-1)])
            O_CVD_train = event_train_sort.detach().cpu().numpy()==1
            #E_Death_train = np.array([CIF_train[i,1,int(j)].item() for (i,j) in zip(range(len(time_train_sort)), time_train_sort.detach().cpu().numpy()-1)])
            #O_Death_train = event_train_sort.detach().cpu().numpy()==2

            E_CVD_test = np.array([CIF_test[i,0,int(j)].item() for (i,j) in zip(range(len(time_test_sort)), time_test_sort.detach().cpu().numpy()-1)])
            O_CVD_test = event_test_sort.detach().cpu().numpy()==1
            #E_Death_test = np.array([CIF_test[i,1,int(j)].item() for (i,j) in zip(range(len(time_test_sort)), time_test_sort.detach().cpu().numpy()-1)])
            #O_Death_test = event_test_sort.detach().cpu().numpy()==2

            print('train Brier for CVD: {:.4f}'.format(((E_CVD_train - O_CVD_train)**2).mean()))
            #print('train Brier for Death: {:.4f}'.format(((E_Death_train - O_Death_train)**2).mean()))

            print('test Brier for CVD: {:.4f}'.format(((E_CVD_test - O_CVD_test)**2).mean()) + '\n')
            #print('test Brier for Death: {:.4f}'.format(((E_Death_test - O_Death_test)**2).mean()) + '\n')

# Competing risks models (multiple events)

In [None]:
class deephitCR(nn.Module):
    
    def __init__(self, input_size, ntime, shared_size=128, cs_size=128, num_event=2):
        super(deephitCR, self).__init__()
        
        self.input_size = input_size
        self.ntime = ntime
        
        self.shared_size = shared_size
        self.cs_size = cs_size
        self.num_event = num_event
        
        self.shared = nn.Sequential(
            nn.Linear(self.input_size, self.shared_size),
            nn.LeakyReLU(inplace=True)
        )
        self.cs1 = nn.Sequential(
            nn.Linear(self.shared_size + self.input_size, self.cs_size),
            nn.LeakyReLU(inplace=True),
            nn.Linear(self.cs_size, self.ntime+2)
        )
        self.cs2 = nn.Sequential(
            nn.Linear(self.shared_size + self.input_size, self.cs_size),
            nn.LeakyReLU(inplace=True),
            nn.Linear(self.cs_size, self.ntime+2)
        )
                
    def forward(self, x0):
        ## x: (batch, input_size)

        batch_size = x0.shape[0]
        
        ## SHARED NET
        
        x = self.shared(x0)
        
        ## CAUSE-SPECIFIC NET
        
        x1 = self.cs1(torch.cat([x, x0], 1))
        x2 = self.cs2(torch.cat([x, x0], 1))
        x = torch.cat([x1, x2], 1)
        
        return torch.softmax(x, 1)

In [None]:
for weight_decay in [1e-3]:
    for shared_size in [16, 32, 64, 128, 256]:
        for cs_size in [16, 32, 64, 128, 256]:

            path = './DeepHit_cr_{}statesize_{}cssize_{:.0e}'.format(shared_size, cs_size, weight_decay)
            #if os.path.isfile(path):
            #    continue

            model = deephitCR(input_size=total_train.shape[-1], ntime=ntime, shared_size=shared_size, cs_size=cs_size, num_event=num_event).to(device)
            #if os.path.isfile(path):
            #    model.load_state_dict(torch.load(path, map_location = device))

            lr = 1e-2
            optimizer = adabound.AdaBound(model.parameters(), lr=lr, weight_decay=0)

            loss_array = []
            loss_array1 = []
            loss_array2 = []
            patience = 0
            min_loss = np.inf
            for e in range(int(1e6)):

                loss1_array_tmp = []
                loss2_array_tmp = []

                for total_batch, timevar_compact_batch, event_batch, time_batch, obs_mask_batch in train_loader:

                    input_lengths_batch = ((timevar_compact_batch[:,:,-1] ).sum(1) + 1).int()

                    total_batch = total_batch.float()
                    event_batch = event_batch.float()
                    time_batch = time_batch.float()
                    obs_mask_batch = obs_mask_batch.float()

                    mask1_batch = f_get_fc_mask1(input_lengths_batch.reshape(-1, 1) - 1, num_Event=2, num_Category=ntime+2)
                    mask2_batch = f_get_fc_mask2(time_batch.reshape(-1, 1), event_batch.reshape(-1, 1), num_Event=2, num_Category=ntime+2)
                    mask3_batch = f_get_fc_mask3(time_batch.reshape(-1, 1), input_lengths_batch.reshape(-1, 1) - 1, ntime+2)

                    y_pred = model(total_batch.to(device))

                    norm = 0.
                    for parameter in model.parameters():
                        norm += torch.norm(parameter, p=1)

                    loss1, loss2 = custom_loss(pmf = y_pred.reshape(-1, num_event, ntime+2), 
                                               event = event_batch.reshape(-1, 1).to(device), 
                                               time = time_batch.reshape(-1, 1).to(device),
                                               mask1 = mask1_batch.to(device), 
                                               mask2 = mask2_batch.to(device), 
                                               mask3 = mask3_batch.to(device), 
                                               num_event = num_event)

                    loss = loss1 + loss2 + weight_decay*norm
                    loss1_array_tmp.append(loss1.item())
                    loss2_array_tmp.append(loss2.item())

                    model.zero_grad()

                    loss.backward()

                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
                    optimizer.step()

                loss_array.append(np.mean(loss1_array_tmp) + np.mean(loss2_array_tmp))
                loss_array1.append(np.mean(loss1_array_tmp))
                loss_array2.append(np.mean(loss2_array_tmp))
                if e % 100 == 0:
                    print('Epoch: ' + str(e) + 
                          ', TotalLoss: '+ f'{loss_array[-1]:.4e}' +
                          ', Loss1: '+ f'{loss_array1[-1]:.4e}' + 
                          ', Loss2: '+ f'{loss_array2[-1]:.4e}')
                if min_loss > loss_array[-1]:
                    patience = 0
                    min_loss = loss_array[-1]
                    torch.save(model.state_dict(), path)
                else:
                    patience += 1

                torch.cuda.empty_cache()

                if patience > 1000:
                    break

            plt.plot(loss_array, label='Total Loss')
            plt.plot(loss_array1, label='Loss 1')
            plt.plot(loss_array2, label='Loss 2')
            plt.ylabel('loss')
            plt.xlabel('epoch')
            plt.yscale('log')
            plt.title(path[2:])
            plt.legend()
            plt.show()

            model.load_state_dict(torch.load(path, map_location = device))
            
            print(path[9:])
            
            for delt in [1,3,6,9,12]:

                print('****************************************')
                print('Delta t = {}'.format(delt))
                print('****************************************')

                input_lengths_train = torch.IntTensor((timevar_compact_train[:,:,-1] ).sum(1) + 1)

                total_train_sort = torch.FloatTensor(total_train)
                event_train_sort = torch.FloatTensor(event_train)
                time_train_sort = torch.FloatTensor(time_train)
                obs_mask_compact_train_sort = torch.FloatTensor(obs_mask_compact_train)

                input_lengths_test = torch.IntTensor((timevar_compact_test[:,:,-1] ).sum(1) + 1)

                total_test_sort = torch.FloatTensor(total_test)
                event_test_sort = torch.FloatTensor(event_test)
                time_test_sort = torch.FloatTensor(time_test)
                obs_mask_compact_test_sort = torch.FloatTensor(obs_mask_compact_test)

                y_train = model(total_train_sort.to(device))
                y_test = model(total_test_sort.to(device))

                mask1_train_sort = f_get_fc_mask1(input_lengths_train.reshape(-1, 1) - 1, num_Event=2, num_Category=ntime+2).to(device)
                mask1_test_sort = f_get_fc_mask1(input_lengths_test.reshape(-1, 1) - 1, num_Event=2, num_Category=ntime+2).to(device)

                CIF_train = torch.cumsum(y_train.reshape(-1,2,ntime+2) * (1-mask1_train_sort), 2)
                CIF_test = torch.cumsum(y_test.reshape(-1,2,ntime+2) * (1-mask1_test_sort), 2)

                print('test ctd for CVD: {:.4f}'.format(concordance_td(time_test_sort.detach().cpu().numpy(), event_test_sort.detach().cpu().numpy()==1, 1.-CIF_test[:,0,:].detach().cpu().numpy().T, np.int32(time_test_sort.detach().cpu().numpy()-1))))
                print('test ctd for Death: {:.4f}'.format(concordance_td(time_test_sort.detach().cpu().numpy(), event_test_sort.detach().cpu().numpy()==2, 1.-CIF_test[:,1,:].detach().cpu().numpy().T, np.int32(time_test_sort.detach().cpu().numpy()-1))) + '\n')

                E_CVD_train = np.array([CIF_train[i,0,int(j)].item() for (i,j) in zip(range(len(time_train_sort)), time_train_sort.detach().cpu().numpy()-1)])
                O_CVD_train = event_train_sort.detach().cpu().numpy()==1
                E_Death_train = np.array([CIF_train[i,1,int(j)].item() for (i,j) in zip(range(len(time_train_sort)), time_train_sort.detach().cpu().numpy()-1)])
                O_Death_train = event_train_sort.detach().cpu().numpy()==2

                E_CVD_test = np.array([CIF_test[i,0,int(j)].item() for (i,j) in zip(range(len(time_test_sort)), time_test_sort.detach().cpu().numpy()-1)])
                O_CVD_test = event_test_sort.detach().cpu().numpy()==1
                E_Death_test = np.array([CIF_test[i,1,int(j)].item() for (i,j) in zip(range(len(time_test_sort)), time_test_sort.detach().cpu().numpy()-1)])
                O_Death_test = event_test_sort.detach().cpu().numpy()==2

                print('test Brier for CVD: {:.4f}'.format(((E_CVD_test - O_CVD_test)**2).mean()))
                print('test Brier for Death: {:.4f}'.format(((E_Death_test - O_Death_test)**2).mean()) + '\n')