In [None]:
import pandas as pd
from itertools import islice
import numpy as np
import xarray
import json
import os
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import StratifiedShuffleSplit
from skmultilearn.model_selection import iterative_train_test_split
from sklearn.model_selection import ParameterGrid
from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc
from sklearn.preprocessing import MinMaxScaler, QuantileTransformer

import numpy as np
import matplotlib.pyplot as plt

# For preprocessing
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper 
from pycox.models import DeepHitSingle

import torch # For building the networks 
from torch import nn
import torch.nn.functional as F
import torchtuples as tt # Some useful functions

from pycox.datasets import nwtco
from pycox.models import LogisticHazard
from pycox.models import CoxPH
from pycox.models.loss import NLLLogistiHazardLoss, NLLMTLRLoss, BCESurvLoss
from pycox.evaluation import EvalSurv

import seaborn as sn
sn.set_theme(style="white", palette="rocket_r")

np.random.seed(1234)
_ = torch.manual_seed(123)

In [None]:
from pycox.models import PMF

In [None]:
df_train = nwtco.read_df()
df_test = df_train.sample(frac=0.2)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(frac=0.2)
df_train = df_train.drop(df_val.index)

In [None]:
df_train

In [None]:
cols_standardize = ['age']
cols_leave = ['stage', 'in.subcohort', 'instit_2', 'histol_2', 'study_4']

standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]

x_mapper = DataFrameMapper(standardize + leave)

In [None]:
x_train = x_mapper.fit_transform(df_train).astype('float32')
x_val = x_mapper.transform(df_val).astype('float32')
x_test = x_mapper.transform(df_test).astype('float32')

# PMF

In [None]:
num_durations = 10
labtrans = PMF.label_transform(num_durations)
get_target = lambda df: (df['edrel'].values, df['rel'].values)
y_train = labtrans.fit_transform(*get_target(df_train))
y_val = labtrans.transform(*get_target(df_val))

train = (x_train, y_train)
val = (x_val, y_val)

# We don't need to transform the test labels
durations_test, events_test = get_target(df_test)

In [None]:
in_features = x_train.shape[1]
num_nodes = [32, 32]
out_features = labtrans.out_features
batch_norm = True
dropout = 0.5

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout)

In [None]:
model = PMF(net, tt.optim.Adam, duration_index=labtrans.cuts)

In [None]:
batch_size = 256
lr_finder = model.lr_finder(x_train, y_train, batch_size, tolerance=4)
_ = lr_finder.plot()

In [None]:
lr_finder.get_best_lr()

In [None]:
model.optimizer.set_lr(lr_finder.get_best_lr())

In [None]:
epochs = 1000
log = model.fit(x_train, y_train, batch_size, epochs, val_data=val)

In [None]:
_ = log.plot()

In [None]:
surv = model.predict_surv_df(x_test)

In [None]:
surv = model.interpolate(10).predict_surv_df(x_test)
ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

In [None]:
ev.concordance_td('antolini')

In [None]:
time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)

In [None]:
ev.integrated_brier_score(time_grid) 

In [None]:
ev.integrated_nbll(time_grid) 

# MTLR

In [None]:
from pycox.models import MTLR

In [None]:
num_durations = 10
labtrans = MTLR.label_transform(num_durations)
get_target = lambda df: (df['edrel'].values, df['rel'].values)
y_train = labtrans.fit_transform(*get_target(df_train))
y_val = labtrans.transform(*get_target(df_val))

train = (x_train, y_train)
val = (x_val, y_val)

# We don't need to transform the test labels
durations_test, events_test = get_target(df_test)

In [None]:
in_features = x_train.shape[1]
num_nodes = [32, 32]
out_features = labtrans.out_features
batch_norm = True
dropout = 0.1

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout)

In [None]:
model = MTLR(net, tt.optim.Adam, duration_index=labtrans.cuts)

In [None]:
batch_size = 128
lr_finder = model.lr_finder(x_train, y_train, batch_size, tolerance=6)
_ = lr_finder.plot()

In [None]:
lr_finder.get_best_lr()

In [None]:
model.optimizer.set_lr(lr_finder.get_best_lr())

In [None]:
epochs = 1000
log = model.fit(x_train, y_train, batch_size, epochs, val_data=val)

In [None]:
_ = log.plot()

In [None]:
surv = model.predict_surv_df(x_test)

In [None]:
ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

In [None]:
ev.concordance_td('antolini')

In [None]:
time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)

In [None]:
ev.integrated_brier_score(time_grid) 

In [None]:
ev.integrated_nbll(time_grid) 

# BCESurv

In [None]:
from pycox.models import LogisticHazard, BCESurv

In [None]:
labtrans = LogisticHazard.label_transform(10)
get_dur_ev = lambda df: (df['edrel'].values, df['rel'].values)

y_train = labtrans.fit_transform(*get_dur_ev(df_train))
y_val = labtrans.transform(*get_dur_ev(df_val))
y_test = labtrans.transform(*get_dur_ev(df_test))

train = tt.tuplefy(x_train, y_train)
val = tt.tuplefy(x_val, y_val)
test = tt.tuplefy(x_test, y_test)

In [None]:
in_features = x_train.shape[1]
num_nodes = [32, 32]
out_features = labtrans.out_features
batch_norm = True
dropout = 0.4

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout)

In [None]:
lr=0.01
n_itp=20

model = BCESurv(net, tt.optim.AdamWR(lr, cycle_eta_multiplier=0.8), duration_index=labtrans.cuts)
log = model.fit(*train, 256, 256, verbose=False, val_data=val,
                    callbacks=[tt.cb.EarlyStoppingCycle()])
surv = model.interpolate(n_itp).predict_surv_df(test[0])

In [None]:
_ = model.log.to_pandas().iloc[1:].plot()

In [None]:
ev_bce_true = EvalSurv(surv, durations_test, events_test, 'km')

In [None]:
ev.concordance_td('antolini')

In [None]:
time_grid = np.linspace(0, 100, 100)

In [None]:
ev.integrated_brier_score(time_grid) 

In [None]:
ev.integrated_nbll(time_grid)

# DeepHit

In [None]:
from pycox.models import DeepHitSingle

In [None]:
num_durations = 10
labtrans = DeepHitSingle.label_transform(num_durations)
get_target = lambda df: (df['edrel'].values, df['rel'].values)
y_train = labtrans.fit_transform(*get_target(df_train))
y_val = labtrans.transform(*get_target(df_val))

train = (x_train, y_train)
val = (x_val, y_val)

# We don't need to transform the test labels
durations_test, events_test = get_target(df_test)

In [None]:
in_features = x_train.shape[1]
num_nodes = [32,32]
out_features = labtrans.out_features
batch_norm = True
dropout = 0.4

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout)

In [None]:
model = DeepHitSingle(net, tt.optim.Adam, alpha=0.2, sigma=0.1, duration_index=labtrans.cuts)

In [None]:
batch_size = 128
lr_finder = model.lr_finder(x_train, y_train, batch_size, tolerance=3)
_ = lr_finder.plot()

In [None]:
lr_finder.get_best_lr()

In [None]:
model.optimizer.set_lr(lr_finder.get_best_lr())

In [None]:
epochs = 1000
log = model.fit(x_train, y_train, batch_size, epochs, val_data=val)

In [None]:
_ = log.plot()

In [None]:
surv = model.predict_surv_df(x_test)

In [None]:
ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

In [None]:
ev.concordance_td('antolini')

In [None]:
ev.integrated_brier_score(time_grid) 

In [None]:
ev.integrated_nbll(time_grid) 

# CoxTime

In [None]:
from pycox.models import CoxTime
from pycox.models.cox_time import MLPVanillaCoxTime

In [None]:
labtrans = CoxTime.label_transform()
get_target = lambda df: (df['edrel'].values, df['rel'].values)
y_train = labtrans.fit_transform(*get_target(df_train))
y_val = labtrans.transform(*get_target(df_val))
durations_test, events_test = get_target(df_test)
val = tt.tuplefy(x_val, y_val)

In [None]:
in_features = x_train.shape[1]
num_nodes = [32, 32]
batch_norm = True
dropout = 0.1
net = MLPVanillaCoxTime(in_features, num_nodes, batch_norm, dropout)

In [None]:
model = CoxTime(net, tt.optim.Adam, labtrans=labtrans)

In [None]:
batch_size = 256
lrfinder = model.lr_finder(x_train, y_train, batch_size, tolerance=2)
_ = lrfinder.plot()

In [None]:
lrfinder.get_best_lr()

In [None]:
model.optimizer.set_lr(lrfinder.get_best_lr())

In [None]:
epochs = 1000
verbose = True

In [None]:
log = model.fit(x_train, y_train, batch_size, epochs, verbose,
                val_data=val.repeat(10).cat())

In [None]:
_ = log.plot()

In [None]:
_ = model.compute_baseline_hazards()

In [None]:
surv = model.predict_surv_df(x_test)

In [None]:
ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

In [None]:
ev.concordance_td()

In [None]:
ev.integrated_brier_score(time_grid)

In [None]:
ev.integrated_nbll(time_grid)

# Cox-CC

In [None]:
from pycox.models import CoxCC

In [None]:
get_target = lambda df: (df['edrel'].values, df['rel'].values)
y_train = get_target(df_train)
y_val = get_target(df_val)
durations_test, events_test = get_target(df_test)
val = tt.tuplefy(x_val, y_val)

In [None]:
in_features = x_train.shape[1]
num_nodes = [32, 32]
out_features = 1
batch_norm = True
dropout = 0.1
output_bias = False

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm,
                              dropout, output_bias=output_bias)

In [None]:
model = CoxCC(net, tt.optim.Adam)

In [None]:
batch_size = 256
lrfinder = model.lr_finder(x_train, y_train, batch_size, tolerance=2)
_ = lrfinder.plot()

In [None]:
lrfinder.get_best_lr()

In [None]:
model.optimizer.set_lr(lrfinder.get_best_lr())

In [None]:
epochs = 1000
verbose = True

In [None]:
log = model.fit(x_train, y_train, batch_size, epochs, verbose,
                val_data=val.repeat(10).cat())

In [None]:
_ = log.plot()

In [None]:
_ = model.compute_baseline_hazards()

In [None]:
surv = model.predict_surv_df(x_test)

In [None]:
ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

In [None]:
time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)

In [None]:
ev.concordance_td()

In [None]:
ev.integrated_brier_score(time_grid)

In [None]:
ev.integrated_nbll(time_grid)

# DeepSurv

In [None]:
from pycox.models import CoxPH

In [None]:
get_target = lambda df: (df['edrel'].values, df['rel'].values)
y_train = get_target(df_train)
y_val = get_target(df_val)
durations_test, events_test = get_target(df_test)
val = x_val, y_val

In [None]:
in_features = x_train.shape[1]
num_nodes = [32, 32, 32, 32, 32]
out_features = 1
batch_norm = True
dropout = 0.0
output_bias = False

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm,
                              dropout, output_bias=output_bias)

In [None]:
model = CoxPH(net, tt.optim.Adam)

In [None]:
batch_size = 256
lrfinder = model.lr_finder(x_train, y_train, batch_size, tolerance=10)
_ = lrfinder.plot()

In [None]:
lrfinder.get_best_lr()

In [None]:
model.optimizer.set_lr(lrfinder.get_best_lr())

In [None]:
epochs = 1000
verbose = True

In [None]:
log = model.fit(x_train, y_train, batch_size, epochs, verbose,
                val_data=val, val_batch_size=batch_size)

In [None]:
_ = log.plot()

In [None]:
_ = model.compute_baseline_hazards()

In [None]:
surv = model.predict_surv_df(x_test)

In [None]:
ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

In [None]:
ev.concordance_td()

In [None]:
ev.integrated_brier_score(time_grid)

In [None]:
ev.integrated_nbll(time_grid)

# PCHazard

In [None]:
from pycox.models import PCHazard

In [None]:
num_durations = 10
labtrans = PCHazard.label_transform(num_durations)
get_target = lambda df: (df['edrel'].values.astype(float), df['rel'].values.astype(float))
y_train = labtrans.fit_transform(*get_target(df_train))
y_val = labtrans.transform(*get_target(df_val))

train = (x_train, y_train)
val = (x_val, y_val)

# We don't need to transform the test labels
durations_test, events_test = get_target(df_test)

In [None]:
in_features = x_train.shape[1]
num_nodes = [32, 32, 32]
out_features = labtrans.out_features
batch_norm = True
dropout = 0.1

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout)

In [None]:
model = PCHazard(net, tt.optim.Adam, duration_index=labtrans.cuts)

In [None]:
batch_size = 256
lr_finder = model.lr_finder(x_train, y_train, batch_size, tolerance=8)
_ = lr_finder.plot()

In [None]:
lr_finder.get_best_lr()

In [None]:
model.optimizer.set_lr(lr_finder.get_best_lr())

In [None]:
epochs = 1000
log = model.fit(x_train, y_train, batch_size, epochs, val_data=val)

In [None]:
_ = log.plot()

In [None]:
surv = model.predict_surv_df(x_test)

In [None]:
ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

In [None]:
ev.concordance_td('antolini')

In [None]:
ev.integrated_brier_score(time_grid) 

In [None]:
ev.integrated_nbll(time_grid) 

# Logistic Hazard

In [None]:
from pycox.models import LogisticHazard

In [None]:
num_durations = 10

labtrans = LogisticHazard.label_transform(num_durations)
# labtrans = PMF.label_transform(num_durations)
# labtrans = DeepHitSingle.label_transform(num_durations)

get_target = lambda df: (df['edrel'].values, df['rel'].values)
y_train = labtrans.fit_transform(*get_target(df_train))
y_val = labtrans.transform(*get_target(df_val))

train = (x_train, y_train)
val = (x_val, y_val)

# We don't need to transform the test labels
durations_test, events_test = get_target(df_test)

In [None]:
in_features = x_train.shape[1]
num_nodes = [32, 32]
out_features = labtrans.out_features
batch_norm = True
dropout = 0.1

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout)

In [None]:
model = LogisticHazard(net, tt.optim.Adam(0.001), duration_index=labtrans.cuts)

In [None]:
batch_size = 256
epochs = 1000

In [None]:
log = model.fit(x_train, y_train, batch_size, epochs, val_data=val)

In [None]:
_ = log.plot()

In [None]:
surv = model.predict_surv_df(x_test)

In [None]:
ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

In [None]:
ev.concordance_td('antolini')

In [None]:
ev.integrated_brier_score(time_grid) 

In [None]:
ev.integrated_nbll(time_grid) 

# DySurv

In [None]:
num_durations = 10

labtrans = LogisticHazard.label_transform(num_durations)
# labtrans = PMF.label_transform(num_durations)
# labtrans = DeepHitSingle.label_transform(num_durations)

get_target = lambda df: (df['edrel'].values, df['rel'].values)
y_train = labtrans.fit_transform(*get_target(df_train))
y_val = labtrans.transform(*get_target(df_val))

train = tt.tuplefy(x_train, (y_train, x_train))
val = tt.tuplefy(x_val, (y_val, x_val))

# We don't need to transform the test labels
durations_test, events_test = get_target(df_test)

In [None]:
class Decoder(nn.Module):
    def __init__(self, no_features, output_size):
        super().__init__()

        self.no_features = no_features
        self.hidden_size = no_features
        self.output_size = output_size
        
        self.fc1 = nn.Linear(self.hidden_size, 3*self.hidden_size)
        self.fc2 = nn.Linear(3*self.hidden_size, 5*self.hidden_size)
        self.fc3 = nn.Linear(5*self.hidden_size, 3*self.hidden_size)
        self.fc4 = nn.Linear(3*self.hidden_size, output_size)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        out = self.fc4(x)
        return out

In [None]:
class DySurv(nn.Module):
    def __init__(self, in_features, encoded_features, out_features):
        super().__init__()
        self.fc11 = nn.Linear(in_features, 3*in_features)
        self.fc12 = nn.Linear(3*in_features, 5*in_features)
        self.fc13 = nn.Linear(5*in_features, 3*in_features)
        self.fc14 = nn.Linear(3*in_features, encoded_features)

        self.fc24 = nn.Linear(3*in_features, encoded_features)
        
        self.relu = nn.ReLU()

        self.surv_net = nn.Sequential(
            nn.Linear(encoded_features, 3*in_features), nn.ReLU(),
            nn.Linear(3*in_features, 5*in_features), nn.ReLU(),
            nn.Linear(5*in_features, 3*in_features), nn.ReLU(),
            nn.Linear(3*in_features, out_features),
        )
        
        self.decoder2 = Decoder(encoded_features, in_features)
        
    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = std.data.new(std.size()).normal_()
        sample_z = eps.mul(std).add_(mu)

        return sample_z
    
    def encoder(self, x):
        x = self.relu(self.fc11(x)
        x = self.relu(self.fc12(x))
        x = self.relu(self.fc13(x))
        mu_z = self.fc14(x)
        logvar_z = self.fc24(x)

        return mu_z, logvar_z

    def forward(self, input):
                      
        mu, logvar = self.encoder(input.float())
        z = self.reparameterize(mu, logvar)
        return self.decoder2(z), self.surv_net(z), mu, logvar

    def predict(self, input):
        # Will be used by model.predict later.
        # As this only has the survival output, 
        # we don't have to change LogisticHazard.
        mu, logvar = self.encoder(input)
        encoded = self.reparameterize(mu, logvar)
        return self.surv_net(encoded)

In [None]:
in_features = x_train.shape[1]
encoded_features = 20
out_features = labtrans.out_features
net = DySurv(in_features, encoded_features, out_features)

In [None]:
class _Loss(torch.nn.Module):
    def __init__(self, reduction: str = 'mean') -> None:
        super().__init__()
        self.reduction = reduction

In [None]:
def nll_logistic_hazard(phi: Tensor, idx_durations: Tensor, events: Tensor,
                        reduction: str = 'mean') -> Tensor:
    """
    References:
    [1] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction
        with Neural Networks. arXiv preprint arXiv:1910.06724, 2019.
        https://arxiv.org/pdf/1910.06724.pdf
    """
    if phi.shape[1] <= idx_durations.max():
        raise ValueError(f"Network output `phi` is too small for `idx_durations`."+
                         f" Need at least `phi.shape[1] = {idx_durations.max().item()+1}`,"+
                         f" but got `phi.shape[1] = {phi.shape[1]}`")
    if events.dtype is torch.bool:
        events = events.float()
    events = events.view(-1, 1)
    idx_durations = idx_durations.view(-1, 1)
    y_bce = torch.zeros_like(phi).scatter(1, idx_durations, events)
    bce = F.binary_cross_entropy_with_logits(phi, y_bce, reduction='none')
    loss = bce.cumsum(1).gather(1, idx_durations).view(-1)
    return _reduction(loss, reduction)

In [None]:
class NLLLogistiHazardLoss(_Loss):
    def forward(self, phi: Tensor, idx_durations: Tensor, events: Tensor) -> Tensor:
        return nll_logistic_hazard(phi, idx_durations, events, self.reduction)

In [None]:
class Loss(nn.Module):
    def __init__(self, alpha):
        super().__init__()
        assert (alpha >= 0) and (alpha <= 1), 'Need `alpha` in [0, 1].'
        self.alpha = alpha
        self.loss_surv = NLLLogistiHazardLoss()
        self.loss_ae = nn.MSELoss()
        
    def forward(self, decoded, phi, mu, logvar, target_loghaz, target_ae):
        idx_durations, events = target_loghaz
        loss_surv = self.loss_surv(phi, idx_durations, events)/10
        loss_ae = self.loss_ae(decoded, target_ae)/1
        loss_kd = (-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()))/10
        return self.alpha[0] * loss_surv + self.alpha[1] * loss_ae + self.alpha[2] * loss_kd

In [None]:
loss = Loss(0.5)

In [None]:
model = LogisticHazard(net, tt.optim.Adam(0.001), duration_index=labtrans.cuts, loss=loss)

In [None]:
metrics = dict(
    loss_surv = LossAELogHaz(1),
    loss_ae   = LossAELogHaz(0)
)

In [None]:
batch_size = 256
epochs = 1000
log = model.fit(*train, batch_size, epochs, False, val_data=val, metrics=metrics)

In [None]:
res = model.log.to_pandas()

In [None]:
res.head()

In [None]:
_ = res[['train_loss', 'val_loss']].plot()

In [None]:
surv = model.interpolate(10).predict_surv_df(x_test)

In [None]:
surv.iloc[:, 0:5].plot(drawstyle='steps-post')
plt.ylabel('S(t | x)')
_ = plt.xlabel('Time')

In [None]:
ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

In [None]:
ev.concordance_td('adj_antolini')

In [None]:
time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)

In [None]:
ev.integrated_brier_score(time_grid) 

In [None]:
ev.integrated_nbll(time_grid) 