In [5]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from lifelines import CoxPHFitter
from lifelines import KaplanMeierFitter

from pycox.evaluation import EvalSurv
from sklearn.preprocessing import StandardScaler

In [6]:
random_seed = 137
torch.manual_seed(random_seed)
np.random.seed(random_seed)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

print(device)

cuda:0


In [7]:
# Early stopping class from https://github.com/Bjarten/early-stopping-pytorch
from SurvNODE.EarlyStopping import EarlyStopping
from SurvNODE.SurvNODE_x_sepnet import *

In [8]:
def measures(odesurv,initial,x,Tstart,Tstop,From,To,trans,status, multiplier=1.,points=500):
    with torch.no_grad():
        time_grid = np.linspace(0, multiplier, points)
        pvec = torch.zeros((points,x.shape[0]))
        surv_ode = odesurv.predict(x,torch.from_numpy(np.linspace(0,multiplier,points)).float().to(x.device))
        pvec = torch.einsum("ilkj,k->ilj",(surv_ode[:,:,:,:],initial))[:,:,0].cpu()
        pvec = np.array(pvec.cpu().detach())
        surv_ode_df = pd.DataFrame(pvec)
        surv_ode_df.loc[:,"time"] = np.linspace(0,multiplier,points)
        surv_ode_df = surv_ode_df.set_index(["time"])
        ev_ode = EvalSurv(surv_ode_df, np.array(Tstop.cpu()), np.array(status.cpu()), censor_surv='km')
        conc = ev_ode.concordance_td('antolini')
        ibs = ev_ode.integrated_brier_score(time_grid)
        inbll = ev_ode.integrated_nbll(time_grid)
    return conc,ibs,inbll

In [16]:
from sklearn_pandas import DataFrameMapper
import pandas as pd

def make_dataloader(df,Tmax,batchsize):
#     cols_standardize = ['x0', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13']
#     cols_leave = ['x1', 'x2', 'x3', 'x4', 'x5', 'x6']

#     standardize = [([col], StandardScaler()) for col in cols_standardize]
#     leave = [(col, None) for col in cols_leave]
#     x_mapper = DataFrameMapper(standardize + leave)
#     X = x_mapper.fit_transform(df).astype('float32')

    X = df[['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8']].values
    
    X = torch.from_numpy(X).float().to(device)
    T = torch.from_numpy(df[["duration"]].values).float().flatten().to(device)
    T = T/Tmax
    T[T==0] = 1e-8
    E = torch.from_numpy(df[["event"]].values).float().flatten().to(device)

    Tstart = torch.from_numpy(np.array([0 for i in range(T.shape[0])])).float().to(device)
    From = torch.tensor([1],device=device).repeat((T.shape))
    To = torch.tensor([2],device=device).repeat((T.shape))
    trans = torch.tensor([1],device=device).repeat((T.shape))

    dataset = TensorDataset(X,Tstart,T,From,To,trans,E)
    loader = DataLoader(dataset, batch_size=batchsize, shuffle=True)
    return loader

In [17]:
from sklearn.model_selection import train_test_split

def odesurv_manual_benchmark(df_train, df_test,config,name):
    torch.cuda.empty_cache()
    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.loc[:,"event"])
    
    Tmax = df_train["duration"].max()
    
    train_loader = make_dataloader(df_train,Tmax/config["multiplier"],int(len(df_train)*config["batch_size"]))
    val_loader = make_dataloader(df_val,Tmax/config["multiplier"],len(df_val))
    test_loader = make_dataloader(df_test,Tmax/config["multiplier"],len(df_test))
    
    num_in = 9
    num_latent = config["num_latent"]
    layers_encoder =  [config["encoder_neurons"]]*config["num_encoder_layers"]
    dropout_encoder = [config["encoder_dropout"]]*config["num_encoder_layers"]
    layers_odefunc1 =  [config["odefunc_neurons1"]]*config["num_odefunc_layers1"]
    layers_odefunc2 =  [config["odefunc_neurons2"]]*config["num_odefunc_layers2"]

    trans_matrix = torch.tensor([[np.nan,1],[np.nan,np.nan]]).to(device)

    encoder = Encoder(num_in,num_latent,layers_encoder, dropout_encoder).to(device)
    odefunc = ODEFunc(trans_matrix,num_in,num_latent,layers_odefunc1,layers_odefunc2,config["softplus_beta"]).to(device)
    block = ODEBlock(odefunc).to(device)
    odesurv = SurvNODE(block,encoder).to(device)

    optimizer = torch.optim.Adam(odesurv.parameters(), weight_decay = config["weight_decay"], lr=config["lr"])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=config["scheduler_gamma"], patience=config["scheduler_epoch"])

    early_stopping = EarlyStopping(name=name,patience=config["patience"], verbose=True)
    for i in tqdm(range(1000)):
        odesurv.train()
        for mini,ds in enumerate(train_loader):
            myloss,_,_ = loss(odesurv,*ds,mu=config["mu"])
            optimizer.zero_grad()
            myloss.backward()    
            optimizer.step()

        
        odesurv.eval()
        with torch.no_grad():
            lossval,conc,ibs,ibnll = 0., 0., 0., 0.
            for _,ds in enumerate(val_loader):
                t1,_,_ = loss(odesurv,*ds,mu=config["mu"])
                lossval += t1.item()
                t1,t2,t3 = measures(odesurv,torch.tensor([1.,0.],device=device),*ds,multiplier=config["multiplier"])
                conc += t1
                ibs += t2
                ibnll += t3
            early_stopping(lossval/len(val_loader), odesurv)
            scheduler.step(lossval/len(val_loader))
            
            conc_test,ibs_test,ibnll_test = 0., 0., 0.
            print("it: "+str(i)+", train loss="+str(myloss.item())+", validation loss="+str(lossval/len(val_loader))+", c="+str(conc/len(val_loader))+", ibs="+str(ibs/len(val_loader))+", ibnll="+str(ibnll/len(val_loader)))

        if early_stopping.early_stop:
            print("Early stopping")
            break

    odesurv.load_state_dict(torch.load(name+'_checkpoint.pt'))

    odesurv.eval()
    with torch.no_grad():
        conc,ibs,ibnll = 0., 0., 0.
        for _,ds in enumerate(test_loader):
            t1,t2,t3 = measures(odesurv,torch.tensor([1.,0.],device=device),*ds,multiplier=config["multiplier"])
            conc += t1
            ibs += t2
            ibnll += t3
    return lossval/len(val_loader), conc/len(test_loader), ibs/len(test_loader), ibnll/len(test_loader)

In [18]:
from sklearn.model_selection import StratifiedKFold
from pycox import datasets

kfold = StratifiedKFold(5, shuffle=True)
df_all = datasets.metabric.read_df()
gen = kfold.split(df_all.iloc[:,df_all.columns.values!="event"],df_all.loc[:,"event"])

config = {
    "lr": 5e-4,
    "weight_decay": 1e-4,
    "num_latent": 200,
    "encoder_neurons": 200,
    "num_encoder_layers": 2,
    "encoder_dropout": 0.,
    "odefunc_neurons1": 1000,
    "num_odefunc_layers1": 3,
    "odefunc_neurons2": 1000,
    "num_odefunc_layers2": 3,
    "batch_size": 1/3,
    "multiplier": 1.,
    "mu": 1e-4,
    "softplus_beta": 1.,
    "scheduler_epoch": 5,
    "scheduler_gamma": 0.1,
    "patience": 20
}

odesurv_bench_vals = []
for g in gen:
    df_train = df_all.iloc[g[0]]
    df_test =  df_all.iloc[g[1]]
    conc, ibs, ibnll = odesurv_manual_benchmark(df_train,df_test,config,"metabric_test")
    odesurv_bench_vals.append([conc,ibs,ibnll])

  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [19]:
print("c="+str(np.mean(np.array(odesurv_bench_vals)[:,0]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,0])))
print("ibs="+str(np.mean(np.array(odesurv_bench_vals)[:,1]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,1])))
print("ibnll="+str(np.mean(np.array(odesurv_bench_vals)[:,2]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,2])))

c=0.6409015344596731+-0.04692738258148167
ibs=0.16260818148631234+-0.010486489849482484
ibnll=0.4873410756000446+-0.02808576493539076


In [19]:
odesurv_bench_vals

[]

In [None]:
from hyperopt import hp
args = {
    "lr": hp.choice("lr", [1e-4, 5e-4]),
    "weight_decay": hp.choice("weight_decay", [1e-3, 1e-4]),
    "num_latent": hp.randint('num_latent', 20, 400),
    "encoder_neurons": hp.randint('encoder_neurons', 100, 800),
    "num_encoder_layers": hp.randint("num_encoder_layers", 2, 5),
    "encoder_dropout": 0.,
    "odefunc_neurons1": hp.randint('odefunc_neurons1', 100, 1500),
    "num_odefunc_layers1": hp.randint("num_encoder_layers1", 2, 5),
    "odefunc_neurons2": hp.randint('odefunc_neurons2', 100, 1500),
    "num_odefunc_layers2": 3,
    "batch_size": 1/3,
    "multiplier": 1.,
    "mu": 1e-4,
    "softplus_beta": 1.,
    "scheduler_epoch": 5,
    "scheduler_gamma": 0.1,
    "patience": 20
}

# define an objective function
def objective(args):
    lossval, conc, ibs, ibnll = odesurv_manual_benchmark(df_train,df_test,args,"metabrick_test")
    return lossval

# define a search space

# minimize the objective over the space
from hyperopt import fmin, tpe, space_eval
best = fmin(objective, args, algo=tpe.suggest, max_evals=100)


  0%|          | 0/100 [00:00<?, ?trial/s, best loss=?]

  0%|          | 0/1000 [00:00<?, ?it/s]

Validation loss decreased (inf --> 0.308389).  Saving model ...
it: 0, train loss=0.3627825975418091, validation loss=0.30838891863822937, c=0.47419365009301007, ibs=0.18561537486190832, ibnll=0.5494290075757851
Validation loss decreased (0.308389 --> 0.290681).  Saving model ...
it: 1, train loss=0.2873215079307556, validation loss=0.29068130254745483, c=0.562379796323738, ibs=0.17089219090813973, ibnll=0.5136373003772895
Validation loss decreased (0.290681 --> 0.286903).  Saving model ...
it: 2, train loss=0.28436478972435, validation loss=0.2869032025337219, c=0.5320175300312135, ibs=0.17096519199344704, ibnll=0.5147577613136745
Validation loss decreased (0.286903 --> 0.284641).  Saving model ...
it: 3, train loss=0.3154808282852173, validation loss=0.2846408784389496, c=0.5595106725100104, ibs=0.17261255066958203, ibnll=0.5198031030113369
EarlyStopping counter: 1 out of 20                     
it: 4, train loss=0.2859744727611542, validation loss=0.2853977382183075, c=0.55175457956

  0%|          | 0/1000 [00:00<?, ?it/s]

Validation loss decreased (inf --> 0.449379).  Saving model ...                        
it: 0, train loss=0.444072961807251, validation loss=0.4493785798549652, c=0.41538414278253927, ibs=0.251535327312559, ibnll=0.6931567216709986
Validation loss decreased (0.449379 --> 0.437089).  Saving model ...                   
it: 1, train loss=0.4375377297401428, validation loss=0.4370894432067871, c=0.45163272202254784, ibs=0.24642078960098, ibnll=0.6816585727867762
Validation loss decreased (0.437089 --> 0.423800).  Saving model ...                   
it: 2, train loss=0.4299686551094055, validation loss=0.42380020022392273, c=0.5038859705710688, ibs=0.2407520541315941, ibnll=0.6690452610944831
Validation loss decreased (0.423800 --> 0.410172).  Saving model ...                   
it: 3, train loss=0.39993271231651306, validation loss=0.41017237305641174, c=0.5127330814364267, ibs=0.2347875304024837, ibnll=0.6559099995152446
Validation loss decreased (0.410172 --> 0.396480).  Saving model ..

  0%|          | 0/1000 [00:00<?, ?it/s]

Validation loss decreased (inf --> 0.463742).  Saving model ...                        
it: 0, train loss=0.43851351737976074, validation loss=0.4637424647808075, c=0.5331164896839823, ibs=0.2406132245347525, ibnll=0.6665694489603302
Validation loss decreased (0.463742 --> 0.454377).  Saving model ...                   
it: 1, train loss=0.45083343982696533, validation loss=0.4543765187263489, c=0.5512293121591888, ibs=0.23701892737913066, ibnll=0.6585858546660246
Validation loss decreased (0.454377 --> 0.444936).  Saving model ...                   
it: 2, train loss=0.43320485949516296, validation loss=0.44493597745895386, c=0.5149080008928856, ibs=0.23333273038368618, ibnll=0.6504667775079991
Validation loss decreased (0.444936 --> 0.435389).  Saving model ...                   
it: 3, train loss=0.4173411428928375, validation loss=0.43538936972618103, c=0.4908957556044517, ibs=0.229558528239417, ibnll=0.642215920684109
Validation loss decreased (0.435389 --> 0.425885).  Saving mode

In [None]:
print(best)
print(space_eval(space, best))