In [1]:
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from lifelines.datasets import load_rossi
from lifelines.utils import concordance_index

In [2]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

In [3]:
rossi=load_rossi()

In [4]:
X=rossi.drop(columns=['week','arrest']).copy()
E=rossi['arrest'].copy().values
Y=rossi['week'].copy().values
X_train, X_test, E_train, E_test, Y_train, Y_test = train_test_split(X, E, Y, test_size=0.2)

scaler=StandardScaler().fit(X_train)
X_train=scaler.transform(X_train).astype(np.float32)
X_test=scaler.transform(X_test).astype(np.float32)

scaler=StandardScaler().fit(Y_train.reshape(-1,1))
Y_train=scaler.transform(Y_train.reshape(-1,1)).flatten()
Y_test=scaler.transform(Y_test.reshape(-1,1)).flatten()

In [5]:
print(X_train.shape)
print(E_train.shape)
print(Y_train.shape)
print(X_test.shape)
print(E_test.shape)
print(Y_test.shape)

(345, 7)
(345,)
(345,)
(87, 7)
(87,)
(87,)


In [6]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [7]:
# Define survival module
class DeepSurv(nn.Module):
    def __init__(self, input_shape):
        super().__init__()
        self.surv_layers = nn.Sequential(
            nn.Linear(input_shape, 200),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.BatchNorm1d(200),
            nn.Linear(200, 30),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.BatchNorm1d(30),
            nn.Linear(30, 1)
            )

    def forward(self, inputs):
        logits = self.surv_layers(inputs)
        return logits

In [8]:
def negative_log_likelihood(E):
    global device
    E = torch.Tensor(E).to(device)
    def loss(output, target):
        hazard_ratio = torch.exp(output)
        log_risk = torch.log(torch.cumsum(hazard_ratio, 0))
        uncensored_likelihood = torch.transpose(output, 0, 1) - log_risk
        censored_likelihood = uncensored_likelihood * E
        neg_likelihood_ = -torch.sum(censored_likelihood)

        # TODO
        # For some reason, adding num_observed_events does not work.
        # Therefore, for now we will use it as a simple factor of 1.
        # Is it really needed? Isn't it just a scaling factor?
        # num_observed_events = tf.math.cumsum(E)
        # num_observed_events = tf.cast(num_observed_events, dtype=tf.float32)
        # num_observed_events = torch.cumsum(E, 0)
        num_observed_events = torch.ones([1,1],dtype=torch.float32).to(device)
        
        neg_likelihood = neg_likelihood_ / num_observed_events        
        
        return neg_likelihood
    
    return loss

In [9]:
sort_idx = np.argsort(Y_train)[::-1]

In [10]:
X_train = X_train[sort_idx]
E_train = E_train[sort_idx]
Y_train = Y_train[sort_idx]

In [11]:
class CustomSurvData(Dataset):
    def __init__(self, features, events, times):
        assert len(features)==len(events)==len(times)
        self.features = features
        self.events = events
        self.times = times

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.events[idx], self.times[idx]

In [12]:
dataset_train=CustomSurvData(features=X_train, events=E_train, times=Y_train)
dataset_test=CustomSurvData(features=X_test, events=E_test, times=Y_test)

In [13]:
dataloader_train_surv=DataLoader(dataset_train, batch_size=len(dataset_train), shuffle=False)
dataloader_test_surv=DataLoader(dataset_test, batch_size=len(dataset_test), shuffle=False)

In [14]:
model = DeepSurv(input_shape=X_train.shape[1]).to(device)
print(model)

DeepSurv(
  (surv_layers): Sequential(
    (0): Linear(in_features=7, out_features=200, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): Linear(in_features=200, out_features=30, bias=True)
    (5): ReLU()
    (6): Dropout(p=0.2, inplace=False)
    (7): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): Linear(in_features=30, out_features=1, bias=True)
  )
)


In [15]:
# define loss function and optimizer
optimizer = torch.optim.NAdam(model.parameters(),weight_decay=16)
loss_train = negative_log_likelihood(E=E_train)
loss_test = negative_log_likelihood(E=E_test)

In [16]:
def train(dataloader, model, loss_fn, optimizer):
    global device
    model.train()
    for X, event, time in dataloader:
        X = X.to(device)
        event = event.to(device)

        pred = model(X)
        loss = loss_fn(pred, event)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # raw predictions are risks
        # the computation of lifelines.utils.concordance_index is based on the predicted scores
        # therefore, the risks need to be converted
        pred_score = np.exp(-pred.detach().to('cpu'))
        c_index=concordance_index(event_times=time, predicted_scores=pred_score, event_observed=event.to('cpu'))

        loss = loss.item()
        print(f"Training: \tloss {loss:>7f},\t C index {(c_index):>0.4f}")

In [17]:
def test(dataloader, model: nn.Module, loss_fn):
    global device
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, event, time in dataloader:
            X = X.to(device)
            event = event.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, event).item()

            # raw predictions are risks
            # the computation of lifelines.utils.concordance_index is based on the predicted scores
            # therefore, the risks need to be converted
            pred_score = np.exp(-pred.to('cpu'))
            c_index=concordance_index(event_times=time, predicted_scores=pred_score, event_observed=event.to('cpu'))
    test_loss /= num_batches
    correct /= size
    print(f"Testing: \tAvg loss {test_loss:>8f},\t C index {(c_index):>0.4f}\n")
    return c_index

In [18]:
#training neural network
best_C_index=0.5
epochs = 500
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(dataloader_train_surv, model, loss_train, optimizer)
    temp_c_index=test(dataloader_test_surv, model, loss_test)
    if temp_c_index>best_C_index:
        torch.save(model.state_dict(),"surv_sub%.4f.pt"%(temp_c_index))
        best_C_index=temp_c_index
print("Done!")

Epoch 1
-------------------------------
Training: 	loss 149764.468750,	 C index 0.4616
Testing: 	Avg loss 8511.131836,	 C index 0.5929

Epoch 2
-------------------------------
Training: 	loss 141713.515625,	 C index 0.5866
Testing: 	Avg loss 8516.221680,	 C index 0.5796

Epoch 3
-------------------------------
Training: 	loss 137438.187500,	 C index 0.6330
Testing: 	Avg loss 8522.566406,	 C index 0.5702

Epoch 4
-------------------------------
Training: 	loss 137412.703125,	 C index 0.6290
Testing: 	Avg loss 8536.962891,	 C index 0.5603

Epoch 5
-------------------------------
Training: 	loss 135093.781250,	 C index 0.6558
Testing: 	Avg loss 8553.620117,	 C index 0.5479

Epoch 6
-------------------------------
Training: 	loss 136962.625000,	 C index 0.6522
Testing: 	Avg loss 8566.005859,	 C index 0.5341

Epoch 7
-------------------------------
Training: 	loss 132249.343750,	 C index 0.6745
Testing: 	Avg loss 8581.777344,	 C index 0.5302

Epoch 8
-------------------------------
Training