In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')
import os
from pycox.evaluation.eval_surv import EvalSurv

from nfm.nfm.eps_config import ParetoEps
from nfm.nfm.base import FullyNeuralNLL
from nfm.nfm.datasets import SurvivalDataset
import torch.nn as nn
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm 




os.environ['CUDA_VISIBLE_DEVICES'] = '0'


class Net(nn.Module):

    def __init__(self, num_features):
        super(Net, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_features=1 + num_features,
                      out_features=128, bias=False),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=1, bias=False)
        )

    def forward(self, y, z):
        inputs = torch.cat([z, y], dim=1)
        return torch.exp(self.mlp(inputs))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import numpy as np


class mimic_dataset(Dataset):
    def __init__(self, x_path, t_path, e_path):
        self.x = np.load(x_path)
        self.t = np.load(t_path)
        self.e = np.load(e_path)

    def __getitem__(self, index):
        x_i = self.x[index]
        t_i = self.t[index]
        e_i = self.e[index]
        return x_i, t_i, e_i

    def __len__(self):
        return len(self.x)


In [3]:
def nfm(dataset, n_iter, learning_rate):
    if dataset == 'mimic':
        trainset = mimic_dataset(x_path='x_train.npy' , t_path='t_train.npy', e_path='e_train.npy')
        testset = mimic_dataset(x_path='x_test.npy' , t_path='t_test.npy', e_path='e_test.npy')
        loader = DataLoader(trainset, batch_size=128)
    nll = FullyNeuralNLL(eps_conf=ParetoEps(learnable=True), encoder=Net(num_features = trainset.x.shape[-1])).cuda()
    optimizer = torch.optim.Adam(lr=learning_rate, weight_decay=1e-3, params=nll.parameters())
    for epoch in tqdm(range(n_iter)):
        for i, (x, t, e) in enumerate(loader):
            if i == 4:
                x = x.to(torch.float32)
                t = torch.unsqueeze(t.to(torch.float32),1)
                e = e.to(torch.float32)
                nll.train()
                # print(torch.isnan(x.mean(dim=1).cuda()).any())
                # print(torch.isnan(t.cuda()).any())
                # print(torch.isnan(e.cuda()).any())
                loss = nll(z=x.mean(dim=1).cuda(), y=t.cuda() / 24., delta=e.cuda())
                if epoch % 500 == 0 :
                    print(loss)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    # nll.eval()
    with torch.no_grad():
        # y_valid, delta_valid, z_valid = valid_folds[i].sort()
        # y_test, delta_test, z_test = test_folds[i].sort()
        # y_valid, y_test = normalize(y_valid), normalize(y_test)
        # valid_loss = nll(z_valid, y_valid, delta_valid)
        # print(z_valid, y_valid, delta_valid)
        # valid_losses.append(valid_loss.item())
        # tg_test = np.linspace(y_test.cpu().numpy().min(), y_test.cpu().numpy().max(), 100)

        x_test, t_test, e_test = testset[:]
        x_train, t_train, e_train = trainset[:]

        index = np.where(t_test >= t_train.max())
        t_test = np.delete(t_test, index)
        e_test = np.delete(e_test, index)
        x_test = np.delete(x_test, index, axis=0)
        
        horizons = [0.25, 0.5, 0.75, 0.9]
        
        x = np.concatenate((np.array(x_train), np.array(x_test)), axis=0)
        t = np.concatenate((np.array(t_train), np.array(t_test)), axis=0)
        e = np.concatenate((np.array(e_train), np.array(e_test)), axis=0)

        tg_test = np.quantile(t[e==1], horizons)

        out_survival = nll.get_survival_prediction(
            z_test=torch.tensor(x_test.mean(axis=1), dtype=torch.float).cuda(), y_test=torch.tensor(tg_test, dtype=torch.float).view(-1, 1).cuda()).cpu().numpy()
        out_survival = out_survival.T
        out_risk = 1 - out_survival

        x_train = np.array(x_train)
        t_train = np.array(t_train)
        e_train = np.array(e_train)
        x_test = np.array(x_test)
        t_test = np.array(t_test)
        e_test = np.array(e_test)


        from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc

        cis = []
        brs = []

        et_train = np.array([(e_train[i], t_train[i]) for i in range(len(e_train))],
                        dtype = [('e', bool), ('t', float)])
        
        et_test = np.array([(e_test[i], t_test[i]) for i in range(len(e_test))],
                        dtype = [('e', bool), ('t', float)])
        for i, _ in enumerate(tg_test):
            cis.append(concordance_index_ipcw(et_train, et_test, out_risk[:, i], tg_test[i])[0])
        brs.append(brier_score(et_train, et_test, out_survival, tg_test)[1])
        roc_auc = []
        for i, _ in enumerate(tg_test):
            roc_auc.append(cumulative_dynamic_auc(et_train, et_test, out_risk[:, i], tg_test[i])[0])
        for horizon in enumerate(horizons):
            print(f"For {horizon[1]} quantile")
            print("TD Concordance Index:", cis[horizon[0]])
            print("Brier Score:", brs[0][horizon[0]])
            print("ROC AUC ", roc_auc[horizon[0]][0], "\n")

In [4]:
nfm('mimic', 1000, 1e-5)

  0%|          | 1/1000 [00:01<20:20,  1.22s/it]

tensor(2.3659, device='cuda:0', grad_fn=<DivBackward0>)


 50%|█████     | 501/1000 [00:57<00:55,  9.02it/s]

tensor(2.2980, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 1000/1000 [01:52<00:00,  8.86it/s]


For 0.25 quantile
TD Concordance Index: 0.6426854396985576
Brier Score: 0.6648968789897121
ROC AUC  0.6579923652203764 

For 0.5 quantile
TD Concordance Index: 0.6480737784401933
Brier Score: 0.5243192330965464
ROC AUC  0.6918982645036194 

For 0.75 quantile
TD Concordance Index: 0.6485842765429347
Brier Score: 0.32609769832958657
ROC AUC  0.737697061443326 

For 0.9 quantile
TD Concordance Index: 0.6434166123994957
Brier Score: 0.16846010817179827
ROC AUC  0.7588884773294169 



In [6]:
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_features=1 + data_train.num_features, out_features=128, bias=False),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=1, bias=False)
        )

    def forward(self, y, z):
        inputs = torch.cat([z, y], dim=1)
        return torch.exp(self.mlp(inputs))


data_train = SurvivalDataset.kkbox('train')
fold_c_indices = []
fold_ibs = []
fold_nbll = []
normalizing_factor = 366.25


def normalize(y):
    return (y + 1) / normalizing_factor

In [7]:
kkbox_loader = DataLoader(data_train, batch_size=128)

In [8]:
for x, t, e in kkbox_loader:
    print(x.shape)

torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size([128, 58])
torch.Size