In [None]:
import os
import pandas as pd
import torch
import tqdm
import copy
import random
import logging
from torch.utils.data import TensorDataset, DataLoader
from survnam.nam import metrics
from survnam.nam import data_utils
from survnam.nam import *
import sklearn.metrics
from sksurv.util import Surv
from sksurv.ensemble import RandomSurvivalForest
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.nonparametric import nelson_aalen_estimator

In [None]:
learning_rate = 1e-6  # "Hyper-parameter: learning rate."
l2_regularization = 0.00  # "Hyper-parameter: l2 weight decay"
dropout = 0.0  # "Hyper-parameter: Dropout rate"
feature_dropout = 0.0  # "Hyper-parameter: Prob. with which features are dropped"

training_epochs = 20 # "The number of epochs to run training for."
batch_size = 1  # "Hyper-parameter: batch size."
seed = 42  # "Seed used for reproducibility."
n_basis_functions = 1000  # "Number of basis functions to use in a FeatureNN for a real-valued feature."
units_multiplier = 2  # "Number of basis functions for a categorical feature"

hidden_units = []  # "Amounts of neurons for additional hidden layers, e.g. 64,32,32"
log_file = "survnam.log"  # "File where to store summaries."
shallow_layer = "exu"  # "Activation function used for the first layer: (1) relu, (2) exu"
hidden_layer = "relu"  # "Activation function used for the hidden layers: (1) relu, (2) exu"

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
def train_one_epoch(model, criterion, optimizer, data_loader, device, times, nelson_est):
    pbar = tqdm.tqdm(enumerate(data_loader, start=1), total=len(data_loader))
    total_loss = 0
    for i, (xk, chf_k, weight_k) in pbar:
        xk = torch.unsqueeze(xk, dim=0)
        xk = xk
        chf_k = chf_k
        weight_k = weight_k
        logits, _ = model.forward(xk)
        loss = criterion(logits, chf_k, times, nelson_est, weight_k)
        loss.backward(retain_graph=True)
        x_loss = loss.item()
        optimizer.step()
        model.zero_grad()
        total_loss += x_loss
    pbar.set_description(f"train | loss = {total_loss:.5f}")
    return total_loss

In [None]:
def train_model(x_train, chfs, weights, device, times, nelson_est):
    times = torch.tensor(times, device = device)
    nelson_est = torch.tensor(nelson_est, device = device)
    model = NeuralAdditiveModel(
        input_size=x_train.shape[-1],
        shallow_units=data_utils.calculate_n_units(x_train, n_basis_functions, units_multiplier),
        hidden_units=list(map(int, hidden_units)),
        shallow_layer=ExULayer if shallow_layer == "exu" else ReLULayer,
        hidden_layer=ExULayer if hidden_layer == "exu" else ReLULayer,
        hidden_dropout=dropout,
        feature_dropout=feature_dropout).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=l2_regularization)

    criterion = metrics.survnam_loss
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.995, step_size=1)

    train_dataset = TensorDataset(torch.tensor(x_train, device=device), 
                                  torch.tensor(chfs, device=device), 
                                  torch.tensor(weights, device=device))
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    best_weights = None  # to store the optimal performance

    for epoch in range(training_epochs):
        model = model.train()  # training the base
        total_loss = train_one_epoch(model, criterion, optimizer, train_loader, device, times, nelson_est)
        # record the log of training (training loss)
        logging.info(f"epoch {epoch} | train | {total_loss}")
        scheduler.step()  # update the learning rate
        best_weights = copy.deepcopy(model.state_dict())  # update the optimal base
    model.load_state_dict(best_weights)  # continue training from the optimal base

    return model

In [None]:
seed_everything(seed)  # random seed
handlers = [logging.StreamHandler()]
if log_file:
    handlers.append(logging.FileHandler(log_file))
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", handlers=handlers)
device = torch.device("cpu")
print("device:", device)
logging.info("load data")

In [None]:
def get_explanation(x, survnam):
    return survnam(torch.transpose(torch.atleast_3d(torch.tensor(x.astype("float32"))), 1, 2).to(device))[1].detach().numpy().squeeze()

# `dataset0`

In [None]:
dataset0_train = pd.read_csv("../data/exp2_dataset0_train.csv")
dataset0_test = pd.read_csv("../data/exp2_dataset0_test.csv")

In [None]:
X_train = dataset0_train.iloc[:, :5]
X_test = dataset0_test.iloc[:, :5]
y_train = Surv.from_dataframe("event", "time", dataset0_train)
y_test = Surv.from_dataframe("event", "time", dataset0_test)

In [None]:
rsf = RandomSurvivalForest(n_estimators=150, max_depth=12, max_features=3, min_samples_leaf=6, min_samples_split=10, random_state=123)
rsf.fit(X_train, y_train)

In [None]:
# from SurvNAM article
sds = np.array(0.05 * (X_test.describe().loc["max"] - X_test.describe().loc["min"]))
def generate_neighbours(ind):
    x = X_test.iloc[ind].values
    neighbours = np.random.multivariate_normal(x, np.diag(sds**2), 1000)
    neighbours[0, ] = x
    return neighbours
    
def get_weights(neighbourhood, distance_metric="euclidean"):
    distances = sklearn.metrics.pairwise_distances(
            neighbourhood,
            neighbourhood[0].reshape(1, -1),
            metric=distance_metric,
        ).ravel()
    weights = np.exp(-(distances**2) / 0.5).squeeze()
    return weights

In [None]:
def data_to_explanation(model, data, y):
    event_field, time_field = y.dtype.names
    nelson_est = nelson_aalen_estimator(y[event_field], y[time_field])
    times = nelson_est[0][0:-1]
    nelson_est = nelson_est[1][0:-1]
    preds = model.predict_cumulative_hazard_function(data)
    chfs = (
            np.array([chf(times) for chf in preds]) + 1e-32 
        )
    times_to_provide = np.hstack((times, [times[-1] + 1e-10]))
    return chfs, times_to_provide, nelson_est

## RSF

In [None]:
n_obs = len(X_test)
survnam_explanations = [None for i in range(n_obs)]
N = 1000
distance_metric="euclidean"
for i, obs in tqdm.tqdm(enumerate(X_test.values)):
    neighbourhood = generate_neighbours(i)
    weights = get_weights(neighbourhood, distance_metric=distance_metric)
    chfs, times_to_provide, nelson_est = data_to_explanation(rsf, neighbourhood, y_train)
    survnam_model = train_model(neighbourhood.astype("float32"), chfs, 
                weights, device, times_to_provide, nelson_est)
    survnam_explanations[i] = get_explanation(obs, survnam_model)

In [None]:
pd.DataFrame(survnam_explanations).to_csv("exp2_survnam_explanations_dataset0_rsf.csv", index=False)

## CPH

In [None]:
cph = CoxPHSurvivalAnalysis()
cph.fit(X_train, y_train)

In [None]:
cph.score(X_test, y_test)

In [None]:
n_obs = len(X_test)
survnam_explanations_dataset0_cph = [None for i in range(n_obs)]
N = 1000
distance_metric="euclidean"
for i, obs in tqdm.tqdm(enumerate(X_test.values)):
    neighbourhood = generate_neighbours(i)
    weights = get_weights(neighbourhood, distance_metric=distance_metric)
    chfs, times_to_provide, nelson_est = data_to_explanation(cph, neighbourhood, y_train)
    survnam_model = train_model(neighbourhood.astype("float32"), chfs, 
                weights, device, times_to_provide, nelson_est)
    survnam_explanations_dataset0_cph[i] = get_explanation(obs, survnam_model)

In [None]:
pd.DataFrame(survnam_explanations_dataset0_cph).to_csv("exp2_survnam_explanations_dataset0_cph.csv", index=False)

# `dataset1`

In [None]:
dataset1_train = pd.read_csv("../data/exp2_dataset1_train.csv")
dataset1_test = pd.read_csv("../data/exp2_dataset1_test.csv")

In [None]:
X_train = dataset1_train.iloc[:, :5]
X_test = dataset1_test.iloc[:, :5]
y_train = Surv.from_dataframe("event", "time", dataset1_train)
y_test = Surv.from_dataframe("event", "time", dataset1_test)

In [None]:
rsf = RandomSurvivalForest(n_estimators=150, max_depth=12, max_features=3, min_samples_leaf=6, min_samples_split=10, random_state=123)
rsf.fit(X_train, y_train)

In [None]:
# from SurvNAM article
sds = np.array(0.05 * (X_test.describe().loc["max"] - X_test.describe().loc["min"]))
def generate_neighbours(ind):
    x = X_test.iloc[ind].values
    neighbours = np.random.multivariate_normal(x, np.diag(sds**2), 1000)
    neighbours[0, ] = x
    return neighbours

def get_weights(neighbourhood, distance_metric="euclidean"):
    distances = sklearn.metrics.pairwise_distances(
            neighbourhood,
            neighbourhood[0].reshape(1, -1),
            metric=distance_metric,
        ).ravel()
    weights = np.exp(-(distances**2) / 0.5).squeeze()
    return weights

## RSF

In [None]:
n_obs = len(X_test)
survnam_explanations_dataset1_rsf = [None for i in range(n_obs)]
N = 1000
distance_metric="euclidean"
for i, obs in tqdm.tqdm(enumerate(X_test.values)):
    neighbourhood = generate_neighbours(i)
    weights = get_weights(neighbourhood, distance_metric=distance_metric)
    chfs, times_to_provide, nelson_est = data_to_explanation(rsf, neighbourhood, y_train)
    survnam_model = train_model(neighbourhood.astype("float32"), chfs, 
                weights, device, times_to_provide, nelson_est)
    survnam_explanations_dataset1_rsf[i] = get_explanation(obs, survnam_model)

In [None]:
pd.DataFrame(survnam_explanations_dataset1_rsf).to_csv("exp2_survnam_explanations_dataset1_rsf.csv",  index=False)

## CPH

In [None]:
cph = CoxPHSurvivalAnalysis()
cph.fit(X_train, y_train)

In [None]:
n_obs = len(X_test)
survnam_explanations_dataset1_cph = [None for i in range(n_obs)]
N = 1000
distance_metric="euclidean"
for i, obs in tqdm.tqdm(enumerate(X_test.values)):
    neighbourhood = generate_neighbours(i)
    weights = get_weights(neighbourhood, distance_metric=distance_metric)
    chfs, times_to_provide, nelson_est = data_to_explanation(cph, neighbourhood, y_train)
    survnam_model = train_model(neighbourhood.astype("float32"), chfs, 
                weights, device, times_to_provide, nelson_est)
    survnam_explanations_dataset1_cph[i] = get_explanation(obs, survnam_model)

In [None]:
pd.DataFrame(survnam_explanations_dataset1_cph).to_csv("exp2_survnam_explanations_dataset1_cph.csv",  index=False)