<a href="https://colab.research.google.com/github/ZiyuWang1121/Deep-machine-learning-meets-survival-analysis/blob/main/DeepSurv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DeepSurv

In this notebook we will train the [Cox-PH method](http://jmlr.org/papers/volume20/18-424/18-424.pdf), also known as [DeepSurv](https://bmcmedresmethodol.biomedcentral.com/articles/10.1186/s12874-018-0482-1).

In [None]:
!pip install torchtuples
!pip install pycox

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper

import torch
import torchtuples as tt

from pycox.models import CoxPH
from pycox.evaluation import EvalSurv

from sklearn.model_selection import ParameterSampler

In [None]:
np.random.seed(1234)
_ = torch.manual_seed(123)

## 1. Dataset

In [None]:
#pip install scikit-survival

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Read the preprocessed file
brca = pd.read_csv('/content/drive/My Drive/3799/brca.csv')

In [None]:
from sklearn.decomposition import PCA
# Apply PCA to reduce dimensionality
pca = PCA(n_components=0.99)  # Retain 95% of the variance
x_pca = pca.fit_transform(brca.iloc[:,2:])
x_pca.shape

In [None]:
# Visualize dimension reduction results
plt.figure(figsize=(12, 6))
# component versus survival time
plt.scatter(x_pca[:,0], x_pca[:,1],c=brca.iloc[:,2], cmap='viridis', marker='o')
plt.colorbar()
plt.title('PCA visualization of the data')
plt.show()

In [None]:
# Concatenating the dataframes
brca_pca = pd.concat([brca.iloc[:, :2], pd.DataFrame(x_pca)], axis=1)

###t-SNE

In [None]:
from sklearn.manifold import TSNE

# Nonlinear dimensionality reduction with t-SNE
#tsne = TSNE(n_components=3, random_state=0)
#x_tsne = tsne.fit_transform(brca.iloc[:,2:])

In [None]:
# Visualize dimension reduction results
#plt.figure(figsize=(12, 6))
# component versus survival time
#plt.scatter(x_tsne[:,0], x_tsne[:,1],c=brca.iloc[:,2], cmap='viridis', marker='o')
#plt.colorbar()
#plt.title('t-SNE visualization of the data')
#plt.show()

In [None]:
#x_tsne.shape

In [None]:
# Concatenating the dataframes
#brca_tsne = pd.concat([brca.iloc[:, :2], pd.DataFrame(x_tsne)], axis=1)

## 3. Modelling

### Train, valid, and test set split

In [None]:
data = brca_pca

In [None]:
#df_train = data
#df_test = df_train.sample(frac=0.2)
#df_train = df_train.drop(df_test.index)

#df_val = df_train.sample(frac=0.2)
#df_train = df_train.drop(df_val.index)

In [None]:
from sklearn.model_selection import train_test_split

# Splitting the dataset into train, validation, and test sets while maintaining class balance
df_train_val, df_test = train_test_split(data, test_size=0.2, random_state=42, stratify=data['status'])
df_train, df_val = train_test_split(df_train_val, test_size=0.2, random_state=42, stratify=df_train_val['status'])

#df_train_val, df_test = train_test_split(data, test_size=0.2, random_state=42, stratify=data['status'])
#df_train, df_val = train_test_split(df_train_val, test_size=0.2, random_state=42, stratify=df_train_val['status'])

### Feature transforms
All variables needs to be of type `'float32'`, as this is required by pytorch.

In [None]:
from sklearn_pandas import DataFrameMapper
# Get a list of all column names in the DataFrame
all_cols = data.iloc[:,2:].columns.tolist()

# Create a list of tuples for all columns
all_cols_tuples = [(col, None) for col in all_cols]

# Create the DataFrameMapper object with all columns
x_mapper = DataFrameMapper(all_cols_tuples)

In [None]:
x_train = x_mapper.fit_transform(df_train).astype('float32')
x_val = x_mapper.transform(df_val).astype('float32')
x_test = x_mapper.transform(df_test).astype('float32')

In [None]:
# Extract target
get_target = lambda df: (df['time'].values, df['status'].values)
y_train = get_target(df_train)
y_val = get_target(df_val)
y_test=get_target(df_test)
durations_test, events_test = get_target(df_test)
val = x_val, y_val

### Neural net

We create a simple MLP with two hidden layers, ReLU activations, batch norm and dropout.
Here, we just use the `torchtuples.practical.MLPVanilla` net to do this.

Note that we set `out_features` to 1, and that we have not `output_bias`.

In [None]:
from sklearn.base import BaseEstimator

class DeepSURVSklearnAdapter(BaseEstimator):
    def __init__(
        self,
        learning_rate=1e-4,
        batch_norm=True,
        dropout=0.0,
        num_nodes=[32, 32],
        batch_size=128,
        epochs=10,
    ):
        self.learning_rate = learning_rate
        self.batch_norm = batch_norm
        self.dropout = dropout
        self.num_nodes = num_nodes
        self.batch_size = batch_size
        self.epochs = epochs

    def fit(self, X, y):
        self.net_ = tt.practical.MLPVanilla(
            X.shape[1],
            self.num_nodes,
            1,
            self.batch_norm,
            self.dropout,
            output_bias = False, # prevent overfitting/simplify the model
        )
        self.deepsurv_ = CoxPH(self.net_, tt.optim.Adam)
        self.deepsurv_.optimizer.set_lr(self.learning_rate)

        # Sklearn needs the y inputs to be arranged as a matrix with each row
        # corresponding to an example but CoxPH needs a tuple with two arrays?
        y_ = (y[0], y[1])

        callbacks = [tt.callbacks.EarlyStopping()]
        log = self.deepsurv_.fit(
            X,
            y_,
            self.batch_size,
            self.epochs,
            verbose=False,
        )

        return self

    def score(self, X, y):
        _ = self.deepsurv_.compute_baseline_hazards()
        surv = self.deepsurv_.predict_surv_df(X)

        ev = EvalSurv(
            surv,
            y[0],  # time to event
            y[1],  # event
            censor_surv="km",
        )

        return ev.concordance_td()

    def brier_score(self, X, y, time_grid):
        _ = self.deepsurv_.compute_baseline_hazards()
        surv = self.deepsurv_.predict_surv_df(X)

        ev = EvalSurv(
            surv,
            y[0],  # time to event
            y[1],  # event
            censor_surv="km",
        )

        return ev.brier_score(time_grid)

    def integrated_brier_score(self, X, y, time_grid):
        _ = self.deepsurv_.compute_baseline_hazards()
        surv = self.deepsurv_.predict_surv_df(X)

        ev = EvalSurv(
            surv,
            y[0],  # time to event
            y[1],  # event
            censor_surv="km",
        )

        return ev.integrated_brier_score(time_grid)

    def predict(self, X):
        # Predict survival probabilities for X
        return self.deepsurv_.predict_surv_df(X)

In [None]:
# Define the parameter grid for random search
param_grid = {
    'learning_rate': [1e-1, 1e-2, 1e-3, 1e-4],
    'dropout': [0.0, 0.1, 0.2],
    'num_nodes': [[32, 32], [64, 64], [32, 64]],
    'batch_size': [128, 256, 512],
    'epochs': [512]
}

# Define the number of parameter settings that will be sampled
n_iter = 10

# Perform random search
random_search = ParameterSampler(param_grid, n_iter=n_iter, random_state=123)

best_score = float('-inf')
best_params = None

for params in random_search:
    print("Current parameters:", params)
    model = DeepSURVSklearnAdapter(**params)
    model.fit(x_train, y_train)
    c_index = model.score(x_val, y_val)
    print("C-index:", round(c_index,3))
    if c_index > best_score:
        best_score = c_index
        best_params = params
    print("=" * 40)

print("Best parameters:", best_params)
print("Best C-index:", best_score)

## Evaluation and Prediction

In [None]:
# Create and train the best model with the best parameters
best_model = DeepSURVSklearnAdapter(**best_params)
best_model.fit(x_train, y_train)

In [None]:
train_c_index = best_model.score(x_train, y_train)
print("C-index on train data:", round(train_c_index,3))

In [None]:
# Evaluate the best model on the test data
test_c_index = best_model.score(x_test, y_test)
print("C-index on test data:", round(test_c_index,3))

In [None]:
# Make predictions using the best model
predictions = best_model.predict(x_test)

In [None]:
# prediction: predicted survival probability
predictions.iloc[:, :5].plot()
plt.ylabel('S(t | x)')
_ = plt.xlabel('Time')

In [None]:
# Calculate the hazard function
hazard_function = -np.log(predictions).diff(axis=1)

In [None]:
# Plot hazard function
hazard_function.iloc[:, :5].plot(style='--', title='Hazard Functions for the First 5 Patients')
plt.ylabel('Hazard Function')
plt.xlabel('Time')
plt.show()

In [None]:
time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
_ = best_model.brier_score(x_test, y_test, time_grid).plot()

In [None]:
print("Integrated brier score:",best_model.integrated_brier_score(x_test, y_test, time_grid))