In [2]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

import pandas as pd

# For preprocessing
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper 

from pycox.datasets import metabric
from pycox.evaluation import EvalSurv
from pycox.preprocessing.label_transforms import LabTransDiscreteTime



# Preprocessing

In [2]:
n_discrete_times = 10

np.random.seed(1234)
_ = torch.manual_seed(123)
processor = LabTransDiscreteTime(n_discrete_times)


df_train = metabric.read_df()
df_test = df_train.sample(frac=0.2)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(frac=0.2)
df_train = df_train.drop(df_val.index)

# features scaling
cols_standardize = ['x0', 'x1', 'x2', 'x3', 'x8']
cols_leave = ['x4', 'x5', 'x6', 'x7']
standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]
x_mapper = DataFrameMapper(standardize + leave)
x_train = x_mapper.fit_transform(df_train).astype('float32')
x_val = x_mapper.transform(df_val).astype('float32')
x_test = x_mapper.transform(df_test).astype('float32')

# preprocessing time and event
get_target = lambda df: (df['duration'].values, df['event'].values)
y_train = processor.fit_transform(*get_target(df_train))
y_val = processor.transform(*get_target(df_val))


# We don't need to transform the test labels
durations_test, events_test = get_target(df_test)

# Custom deephit

In [3]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        super().__init__()
        self.x = x
        self.event = y[1]
        self.duration = y[0]
    def __len__(self):
        return self.x.shape[0]
    def __getitem__(self, idx):
        return self.x[idx], self.duration[idx], self.event[idx]

# implimentation of pair rank matrix for torch
def pair_rank_mat_torch( idx_durations, events, dtype='float32'):
    idx_durations = idx_durations.reshape(-1)
    events = events.reshape(-1)
    n = len(idx_durations)
    mat = idx_durations.repeat(n, 1)

    mat = (mat.T<mat)  | ((mat.T==mat) & (events.repeat(n, 1)==0))
    mat = mat * events.repeat(n, 1).T

    return mat.float()
    

In [54]:
train_dataset = MyDataset(x_train, y_train)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=False, drop_last=True)

val_dataset = MyDataset(x_val, y_val)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=256)

test_dataset = MyDataset(x_test, (durations_test, events_test))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=256)

In [55]:
class MyModel(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, 32)
        self.fc2 = nn.Linear(32, 32)
        self.fc3 = nn.Linear(32, output_size)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x 
    
def nll_pmf(phi: torch.Tensor, idx_durations: torch.Tensor, 
            events: torch.Tensor, epsilon: float = 1e-7) -> torch.Tensor:

    if phi.shape[1] <= idx_durations.max():
        raise ValueError(f"Network output `phi` is too small for `idx_durations`."+
                         f" Need at least `phi.shape[1] = {idx_durations.max().item()+1}`,"+
                         f" but got `phi.shape[1] = {phi.shape[1]}`")
    if events.dtype is torch.bool:
        events = events.float()
    events = events.view(-1)
 
    idx_durations = idx_durations.view(-1, 1)

    # pad for cumsum
    pad = torch.zeros_like(phi[:, :1])
    phi = torch.cat([pad, phi], dim=1)

    # gamma for log-exp trick, not related to thoeretical derivation
    gamma = phi.max(dim = 1)[0]
    cumsum = phi.sub(gamma.view(-1, 1)).exp().cumsum(1)
    sum_ = cumsum[:, -1]
    
    part1 = phi.gather(1, idx_durations).view(-1).sub(gamma).mul(events)
    part2 = - sum_.relu().add(epsilon).log()
    part3 = sum_.sub(cumsum.gather(1, idx_durations).view(-1)).relu().add(epsilon).log().mul(1. - events)
    # need relu() in part3 (and possibly part2) because cumsum on gpu has some bugs and we risk getting negative numbers.
    loss = - part1.add(part2).add(part3)
    return loss.mean()


def _diff_cdf_at_time_i(pmf: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

    n = pmf.shape[0]
    ones = torch.ones((n, 1), device=pmf.device)

    r = pmf.cumsum(1).matmul(y.transpose(0, 1))
    diag_r = r.diag().view(1, -1)
    r = ones.matmul(diag_r) - r
    return r.transpose(0, 1)

def rank_loss_deephit_single(phi: torch.Tensor, idx_durations: torch.Tensor, events: torch.Tensor, rank_mat: torch.Tensor,
                             sigma: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:

    idx_durations = idx_durations.view(-1, 1)
 
    pad = torch.zeros_like(phi[:, :1])
    pmf = torch.cat([pad, phi], dim=1).softmax(1)

    # pmf = utils.pad_col(phi).softmax(1)
    # hit at the time point
    y = torch.zeros_like(pmf).scatter(1, idx_durations, 1.) # one-hot

    r = _diff_cdf_at_time_i(pmf, y)

    rank_loss = rank_mat * torch.exp(-r/sigma)
    rank_loss = rank_loss.mean(1, keepdim=True)
    return rank_loss.mean()

In [80]:
epoch = 30
sigma=0.1
alpha = 0.2
model = MyModel(x_train.shape[1], n_discrete_times)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)



In [81]:
for i in range(epoch):
    losses = []
    for x, duration, event in train_dataloader:
        optimizer.zero_grad()
        output = model(x)
        nll = nll_pmf(output, duration, event)
        rank_mat = pair_rank_mat_torch(duration, event)
        rank_loss = rank_loss_deephit_single(output, duration, event, rank_mat, 2)
        loss = alpha * nll + (1. - alpha) * rank_loss
        
        loss.backward()
        optimizer.step()
        losses.append(loss.cpu().detach().numpy())
    print('epoch: {}, loss: {}'.format(i, np.mean(losses)))

epoch: 0, loss: 0.5758070349693298
epoch: 1, loss: 0.5584962964057922
epoch: 2, loss: 0.5462806820869446
epoch: 3, loss: 0.5346958041191101
epoch: 4, loss: 0.5252451300621033
epoch: 5, loss: 0.5175309181213379
epoch: 6, loss: 0.5105422735214233
epoch: 7, loss: 0.5053091645240784
epoch: 8, loss: 0.501338005065918
epoch: 9, loss: 0.4980289041996002
epoch: 10, loss: 0.4949811100959778
epoch: 11, loss: 0.49251818656921387
epoch: 12, loss: 0.49013611674308777
epoch: 13, loss: 0.4879540205001831
epoch: 14, loss: 0.4858100116252899
epoch: 15, loss: 0.4837271571159363
epoch: 16, loss: 0.4817905128002167
epoch: 17, loss: 0.47983813285827637
epoch: 18, loss: 0.4781160056591034
epoch: 19, loss: 0.4762911796569824
epoch: 20, loss: 0.4744130074977875
epoch: 21, loss: 0.4725854992866516
epoch: 22, loss: 0.47066834568977356
epoch: 23, loss: 0.46893563866615295
epoch: 24, loss: 0.46720772981643677
epoch: 25, loss: 0.46553540229797363
epoch: 26, loss: 0.46393322944641113
epoch: 27, loss: 0.462224543094

In [82]:
outputs = []
durations = []
events = []
for x, duration, event in test_dataloader:
    output = model(x)
    outputs.append(output)
    durations.append(duration)
    events.append(event)

outputs = torch.cat(outputs, dim=0)
surv = outputs.cpu().detach().numpy()
durations = torch.cat(durations, dim=0).cpu().detach().numpy()
events = torch.cat(events, dim=0).cpu().detach().numpy()
surv_df = pd.DataFrame(-surv, columns=processor.cuts).T

In [None]:
final = EvalSurv(surv_df, durations, events, censor_surv='km')
final.concordance_td('antolini')

### Try interpolation (not use)

In [83]:
from scipy.interpolate import CubicSpline

all_risk = []
for i in range(len(surv)):
    spl = CubicSpline(processor.cuts, surv[i])
    risks = []
    for j in range(381):
        risks.append(spl(j).item())
    all_risk.append(risks)
all_risk = np.array(all_risk).cumsum(axis= 0).T
all_risk = pd.DataFrame(all_risk, columns=range(381))

# Deephit from github

In [50]:
from pycox.models import DeepHitSingle
import torchtuples as tt

model_deephit = DeepHitSingle(model, torch.optim.Adam, alpha=0.2, sigma=0.1, duration_index=processor.cuts)

model_deephit.fit(x_train, y_train, 256,250, val_data = (x_val, y_val), callbacks = [tt.callbacks.EarlyStopping()])
# model_deephit.interpolate(10).predict_surv_df(x_test)

0:	[4s / 4s],		train_loss: 2.3483,	val_loss: 7.2094
1:	[0s / 4s],		train_loss: 1.9862,	val_loss: 6.6771
2:	[0s / 4s],		train_loss: 1.8407,	val_loss: 6.2297
3:	[0s / 4s],		train_loss: 1.6534,	val_loss: 5.7992
4:	[0s / 4s],		train_loss: 1.5680,	val_loss: 5.3871
5:	[0s / 4s],		train_loss: 1.4264,	val_loss: 5.0353
6:	[0s / 4s],		train_loss: 1.2725,	val_loss: 4.7366
7:	[0s / 5s],		train_loss: 1.2169,	val_loss: 4.4896
8:	[0s / 5s],		train_loss: 1.1433,	val_loss: 4.2741
9:	[0s / 5s],		train_loss: 1.0449,	val_loss: 4.0920
10:	[0s / 5s],		train_loss: 1.0155,	val_loss: 3.9349
11:	[0s / 5s],		train_loss: 0.9111,	val_loss: 3.7965
12:	[0s / 5s],		train_loss: 0.9212,	val_loss: 3.6782
13:	[0s / 5s],		train_loss: 0.9229,	val_loss: 3.5734
14:	[0s / 5s],		train_loss: 0.8663,	val_loss: 3.4832
15:	[0s / 5s],		train_loss: 0.8197,	val_loss: 3.4005
16:	[0s / 5s],		train_loss: 0.7995,	val_loss: 3.3249
17:	[0s / 5s],		train_loss: 0.8022,	val_loss: 3.2530
18:	[0s / 5s],		train_loss: 0.7919,	val_loss: 3.1921
19:

<torchtuples.callbacks.TrainingLogger at 0x1e97f729820>

In [51]:
all_risk = model_deephit.interpolate(10).predict_surv_df(x_test)

In [52]:
final = EvalSurv(all_risk, durations_test, events_test, censor_surv='km')

In [53]:
final.concordance_td('antolini')

0.655792954343126