# Causal Neural Survival Clustering on METABRIC

In this notebook, we will apply Causal Neural Survival Clustering on the METABRIC dataset

In [None]:
import sys
sys.path.append('../')

### Load the Dataset

In [None]:
from cnsc.datasets import load_dataset
import pandas as pd

In [None]:
x, a, t, e, col = load_dataset('METABRIC')

In [None]:
x = pd.DataFrame(x, columns = col)
a, t, e = pd.Series(a), pd.Series(t), pd.Series(e) # Reformate data

### Compute horizons at which we evaluate the performance of CNSC

Survival predictions are issued at certain time horizons. Here we will evaluate the performance
of CNSC to issue predictions at the 25th, 50th and 75th event time quantile as is standard practice in Survival Analysis.

In [None]:
# Fix seeds
import torch
import numpy as np

np.random.seed(42)
torch.random.manual_seed(42)

In [None]:
# Estimate time horizons
horizons = [0.25, 0.5, 0.75]
times = np.quantile(t[e!=0], horizons).tolist()

In [None]:
# Display the percentage of observed event at different time horizon
for treat in np.unique(a):
    selection = (a == treat)
    print('-' * 42)
    for time in times:
        print('At time {:.2f} months'.format(time))
        for risk in np.unique(e):
            print('\t {:.2f} % observed risk {}'.format(100 * ((e[selection] == risk) & (t[selection] < time)).mean(), risk))
    print('Total')
    for risk in np.unique(e):
        print('\t {:.2f} % observed risk {}'.format(100 * ((e[selection] == risk)).mean(), risk))
              
print('-' * 42)
print('Overall')
for risk in np.unique(e):
    print('\t {:.2f} % observed risk {}'.format(100 * ((e == risk)).mean(), risk))

### Splitting the data into train, test and validation sets

We will train CNSC on 80% of the Data (10 % of which is used for stopping criterion and 10% for model Selection) and report performance on the remaining 20% held out test set.

In [None]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(x.index, test_size = 0.2, random_state = 42)
train, val  = train_test_split(train, test_size = 0.2, random_state = 42)
val, dev    = train_test_split(val, test_size = 0.5, random_state = 42)

In [None]:
minmax = lambda x: x / t.loc[train].max() # Enforce to be inferior to 1
t_ddh = minmax(t)
times_ddh = minmax(np.array(times))

### Setting the parameter grid

Lets set up the parameter grid to tune hyper-parameters.

In [None]:
from sklearn.model_selection import ParameterSampler

In [None]:
layers = [[50, 50], [50, 50, 50]]
param_grid = {
            'layers_surv': layers,
            'k': [3],
            'representation': [10],
            'layers' : layers,
            'act': ['Tanh']
            }
params = ParameterSampler(param_grid, 3, random_state = 42)

### Model Training and Selection

In [None]:
from cnsc import CausalNeuralSurvivalClustering

In [None]:
models = []
for param in params:
    print(param)

    # Train model on the same set with same stopping
    model = CausalNeuralSurvivalClustering(**param, correct = True, multihead = False)
    model.fit(x.loc[train].values, t_ddh.loc[train].values, e.loc[train].values, a.loc[train].values, n_iter = 1000, bs = 250,
            lr = 0.001, val_data = (x.loc[dev].values, t_ddh.loc[dev].values, e.loc[dev].values, a.loc[dev].values))
    nll = model.compute_nll(x.loc[val].values, t_ddh.loc[val].values, e.loc[val].values, a.loc[val].values)

    # Save model
    models.append([nll, model])


Discrimination should be 0 when negative gamma as it is possible to predict it given the covariates

In [None]:
best_model = min(models, key = lambda x: x[0])
model = best_model[1]

### Evaluation

We evaluate the performance of CNSC in its discriminative ability (Time Dependent Concordance Index and Cumulative Dynamic AUC) as well as Brier Score on both the **factual** distribution.

In [None]:
from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc

In [None]:
# Factual loss
out_survival = model.predict_survival(x.loc[test].values, times_ddh.tolist(), a.loc[test].values)
out_risk = 1 - out_survival

# Evaluation in the context of competing risks
et_train = np.array([(e.loc[i] == 1, t.loc[i]) for i in train],
                dtype = [('e', bool), ('t', float)])
et_test = np.array([(e.loc[i] == 1, t.loc[i]) for i in test],
                dtype = [('e', bool), ('t', float)])
selection = (t.loc[test] < t.loc[train].max())

cis = []
for i, _ in enumerate(times):
    cis.append(concordance_index_ipcw(et_train, et_test[selection], out_risk[:, i][selection], times[i])[0])
brs = brier_score(et_train, et_test[selection], out_survival[selection], times)[1]
roc_auc = []
for i, _ in enumerate(times):
    roc_auc.append(cumulative_dynamic_auc(et_train, et_test[selection], out_risk[:, i][selection], times[i])[0])
for horizon in enumerate(horizons):
    print(f"For {horizon[1]} quantile,")
    print("TD Concordance Index:", cis[horizon[0]])
    print("Brier Score:", brs[horizon[0]])
    print("ROC AUC ", roc_auc[horizon[0]][0], "\n")

##  Treatment effect evaluation

In this section, we evaluate how good is the treatment estimation. We display the KM estimate and estimate of the model clusters.

In [None]:
# Extract same eval time than saved rmst
eval_times = np.linspace(0, t.max(), 100)
norm_eval_times = minmax(eval_times)
delta = eval_times[1] - eval_times[0]

In [None]:
# Estimate the assignment of each points to the different clusters
alphas = pd.DataFrame(model.predict_alphas(x.loc[test].values), index = test)

In [None]:
# Estimates at the same points than RMST and CIF
estimated_survival = pd.concat({treatment: pd.DataFrame(model.predict_survival(x.loc[test].values, norm_eval_times.tolist(), a = value), columns = eval_times, index = test)
                           for value, treatment in enumerate(['untreated', 'treated'])}, axis = 1, names = ['Treatment'])
estimated_cif = 1 - estimated_survival
estimated_rmse = estimated_cif[('untreated',)] - estimated_cif[('treated',)]

estimated_cluster_treatment = pd.DataFrame(model.treatment_effect_cluster(norm_eval_times.tolist()).T, columns = eval_times)

### Population level

Estimate the population level treatment effect

In [None]:
mean, std = estimated_rmse.mean(0), 1.96 * estimated_rmse.std(0) / np.sqrt(len(estimated_rmse))
ax = mean.rename('Estimate').plot(ls = '-.')
plt.fill_between(mean.index, mean + std, mean - std, alpha = 0.3, color = ax.get_lines()[-1].get_color())

plt.ylabel('Treatment effect')
plt.title('Mean outcome')
plt.grid(alpha = 0.3)
plt.legend()
plt.show()

### Feature importance

Estimate which feature most impact the model assignment through a permutation test.

In [None]:
importance, confidence = model.feature_importance(x.loc[test].values, t.loc[test].values, e.loc[test].values, a.loc[test].values)

In [None]:
(pd.DataFrame({'Value': 100 * np.array(list(importance.values())), 'Conf': confidence.values()}, index = col)).sort_values('Value').plot.bar(yerr = 'Conf')
plt.ylabel('% change in NLL')
plt.xlabel('Covariates')
plt.grid(alpha = 0.3)

### Cluster level

Analyse the group clusters, by displaying treatment effect and their differences.

In [None]:
x, a, t, e, col = load_dataset('METABRIC', path = 'data/', standardisation = False)
x = pd.DataFrame(x, columns = col)

In [None]:
for k in range(model.torch_model.k):
    alphas_max = (alphas.apply(lambda x: x.argmax(), 1) == k)
    ax = estimated_cluster_treatment.loc[k].rename('Cluster {} (n = {}, a = {})'.format(k, alphas_max.sum(), a[test][alphas_max].sum())).plot()
    estimated_rmse[alphas_max].mean(0).rename('Average Effect').plot(ls = '--', color = ax.lines[-1].get_color())
plt.ylabel('Treatment effect')
plt.xlabel('Time (in years)')
plt.grid(alpha = 0.3)
plt.legend()
plt.show()

In [None]:
from scipy.stats import kruskal
results = x.loc[test].groupby(alphas.loc[test].idxmax(1)).apply(lambda x:  pd.Series(["{:.3f} ({:.3f})".format(mean, std) for mean, std in zip(x.mean(), x.std())], index = x.columns)).T
results['P-Value'] = [kruskal(*[x[col].loc[test][alphas.loc[test].idxmax(1) == i] for i in range(2)]).pvalue for col in results.index]
results.sort_values('P-Value')