## Demo for CD-CVAE

This is the demo Jupyter Notebook is helpful to repeat experiments in Censor-dependent Variational Inference. It includes

- Script for training and tuning a state-of-the-art survival model.

- Script for training and tuning CD-CVAE model and the variants.


### 1.Load and process dataset via Customized Data Loader

In [None]:
# Packages to import
import os
import sys
import math
import numpy as np
import pandas as pd
import torch
import pandas as pd
import sys
import argparse

import warnings
warnings.filterwarnings('ignore')
# logistics on two-subfolder location
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
sys.path.append(project_dir)

from auton_survival.datasets import load_dataset #https://github.com/autonlab/auton-survival

from utils.override_functions import plot_performance_metrics
from utils.preprocess import pre_process
from utils.data_load import data_loader
from pycox.evaluation import EvalSurv #https://github.com/havakv/pycox/blob/master/pycox/evaluation/eval_surv.py
from auton_survival.estimators import SurvivalModel
from utils.override_functions import survival_regression_metric_modified
from sklearn.model_selection import ParameterGrid


name = 'SUPPORT'
validation_metric = "ctdpycox"
test_metric = "ctd"

outcomes, features = data_loader(name)
print("_______Preprocessing"+str(name)+" dataset started____________")
x_tr,y_tr,x_val,y_val,x_te,y_te= pre_process(features, outcomes,dataset=name,to_numpy=False,log=False)
print("_______Preprocessing"+str(name)+"  finished____________")

# Define the times for model evaluation
times = np.quantile(y_tr['time'][y_tr['event']==1], np.linspace(0.1, 1, 10)).tolist()



### 2. Training and tuning Deep Survival Machine (DSM)

In [None]:
modelname = "dsm"

# Define hyperparameters grid for tuning the model
param_grid = {'k' : [3,  6],
              'distribution' : ['LogNormal', 'Weibull'],
              'learning_rate' : [ 1e-4, 1e-3],
              'layers' : [ [100], [100, 100] ]
             }
params = ParameterGrid(param_grid)

models = []
for param in params:
    model = SurvivalModel(modelname, random_seed=20, hyperparams=param)
    
    # The fit method is called to train the model
    model.fit(x_tr, y_tr)

    # Obtain survival probabilities for validation set and compute the Integrated Brier Score 
    predictions_val = model.predict_survival(x_val, times)

    # Determine the evaluation metric
    metric_val = survival_regression_metric_modified(validation_metric, y_val, predictions_val, times, y_tr)
    models.append([metric_val, model])
    
# Select the best model based on the mean metric value computed for the validation set
metric_vals = [i[0] for i in models]
first_min_idx = metric_vals.index(min(metric_vals))
model = models[first_min_idx][1]

# Obtain survival probabilities for test set
times = np.quantile(y_te['time'][y_te['event']==1], [0.75,0.75,0.75]).tolist()
predictions_te = model.predict_survival(x_te, times)
metric = survival_regression_metric_modified(test_metric,y_te,predictions_te,times,y_tr) #len(times)
print(metric)
print("_______Evaluating the performance on test dataset using "+str(test_metric)+ ", the average value is "+str(round(np.mean(metric),4)))

### 3.Training and tuning Censor-dependent Variational Autoencoders (CD-CVAEs)

In [None]:
# available models
from model.cd_cvae import CDCVAE
from model.cd_diwae import CDDIWAE
from model.cd_iwae import CDIWAE
from model.cvae import CVAE

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
rng = np.random.default_rng(seed=1)

model = CDCVAE(encoder_layer_sizes=[x_tr.shape[-1]+1,x_tr.shape[-1]*12,x_tr.shape[-1]*24,x_tr.shape[-1]*6,x_tr.shape[-1]*2], 
               latent_dim=int(x_tr.shape[-1]/2), px=x_tr.shape[-1],
               decoder_layer_sizes=[x_tr.shape[-1]*4,x_tr.shape[-1]*8,x_tr.shape[-1]*2,int(x_tr.shape[-1]/2),1],
               sigma_learning="joint",primative ="gumbel",dropout=0.95)

# Pre-train CDCVAE with cross validation
model.fit(train_data=[x_tr,y_tr,x_val,y_val],batch_size = 200,num_epochs=5000,learning_rate=0.001,criterion=validation_metric,patience = 1000,temperature=0.9)

# Evaluate CDCVAE on test dataset
times = np.quantile(y_te['time'][y_te['event']==1], 0.75).tolist() # or times = np.unique(y_te['time'][y_te['event']==1]).tolist() 
predictions = model.predict(x_te,times,format="pre",expo=True)
print(predictions)
metric = survival_regression_metric_modified(test_metric,y_te,predictions,times,y_tr) #len(times)
print("_______Evaluating the performance on test dataset using ,"+str(test_metric)+", the average value is "+str(round(np.mean(metric),4)))

If you want to use simulations dataset alone, the example below is helpful.

In [None]:
from utils.metrics import iwae_loss_fn, KL_divergence
kwargs={"clevel":"all_censor"}
outcomes, features,latent_dict = data_loader("SIMULATE",**kwargs)

n = len(features['x'])
x_train = torch.tensor(features["x"].to_numpy(), dtype=torch.float32).reshape(n,1)
y_train = torch.tensor(outcomes["time"].to_numpy(), dtype=torch.float32).reshape(n,1)
e_train = torch.tensor(outcomes["event"].to_numpy(), dtype=torch.float32)

true_mu = torch.tensor(latent_dict["mu"], dtype=torch.float32)
true_log_var = torch.zeros_like(true_mu, dtype=torch.float32,device=true_mu.device)
