# Pycox: DeepSurv Stratified by Batch


In [1]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler

from sklearn_pandas import DataFrameMapper
import torch
import torchtuples as tt


os.chdir("../")
from pycox.models.cox import CoxPH, CoxPHStratified
from pycox.evaluation.eval_surv import EvalSurv, EvalSurvStratified
from scr.utils import *
from scr.runDeepSurvModels import *

# Full process

In [5]:
# prepare data
folder = 'linear'
keywords = ['moderate', "latest", 'RW']

train_df, test_df = load_simulate_survival_data(folder=folder, keywords=keywords, test_size=0.2)

train_df.head()

Unnamed: 0,hsa.let.7a.2..1,hsa.let.7a.3..1,hsa.let.7a..2..1,hsa.let.7b.1,hsa.let.7b..1,hsa.let.7c.1,hsa.let.7c..1,hsa.let.7d.1,hsa.let.7d..1,hsa.let.7e.1,...,hsa.miR.96.1,hsa.miR.96..1,hsa.miR.98.1,hsa.miR.98..1,hsa.miR.99a.1,hsa.miR.99a..1,hsa.miR.99b.1,hsa.miR.99b..1,time,status
0,0.161142,16.139914,9.440727,12.377491,5.524234,11.493406,3.315916,13.211118,4.830156,10.275969,...,4.311983,0.000551,11.930738,6.893816,9.471225,1.244932,1.820734,0.499726,17.195448,0
1,0.737665,15.622746,7.350143,12.387646,4.14715,12.739532,3.693167,11.653065,6.348644,10.511325,...,5.211017,0.010937,10.01897,4.504616,12.851019,4.45857,9.536961,4.829052,1.190342,1
2,1.137077,15.838051,7.474055,14.062855,5.124704,13.175972,3.549938,12.224586,5.85129,10.720255,...,6.218995,0.023063,11.03204,5.361192,12.062099,3.112622,9.023745,4.720672,6.049298,1
3,2.152887,18.113934,7.656181,12.87939,4.768,13.147153,3.518407,12.940474,6.602133,11.160047,...,3.165273,0.015713,11.539787,5.979588,13.096312,4.459175,9.323495,5.122238,8.483713,1
4,3.068358,14.510637,7.261556,11.505635,5.917256,12.665515,4.810049,11.739157,6.542291,11.056379,...,17.231062,0.005042,10.274817,5.15452,13.246133,5.570715,11.376455,6.263673,23.514074,1


## Feature transforms


In [3]:
survival_cols = ['time', 'status']

In [4]:
tr_df, val_df = train_test_split(train_df, 
                                test_size=0.2,
                                shuffle=True, random_state=42,
                                stratify=train_df['status'])

# Transform data
covariate_cols = [col for col in train_df.columns if col not in survival_cols]
standardize = [([col], StandardScaler()) for col in covariate_cols]
leave = [(col, None) for col in survival_cols]
x_mapper = DataFrameMapper(standardize)

# gene expression data
x_train = x_mapper.fit_transform(tr_df[covariate_cols]).astype('float32')
x_val = x_mapper.fit_transform(val_df[covariate_cols]).astype('float32')
x_test = x_mapper.transform(test_df[covariate_cols]).astype('float32')

# prepare labels
get_target = lambda df: (df['time'].values, df['status'].values)
y_train = get_target(tr_df)
y_val = get_target(val_df)
t_test, e_test = get_target(test_df)
val = x_val, y_val

## Neural net

We create a simple MLP with two hidden layers, ReLU activations, batch norm and dropout. 
Here, we just use the `torchtuples.practical.MLPVanilla` net to do this.


In [None]:
in_features = x_train.shape[1]
num_nodes = [32, 16]
out_features = 1
batch_norm = True
dropout = 0.2
output_bias = True

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm,
                            dropout, output_bias=output_bias)

## Training the model

To train the model we need to define a `torch.optim` optimizer; here we instead use one from `tt.optim` as it has some added functionality.
We use the `Adam` optimizer and set the desired learning rate with `model.lr_finder`.

In [None]:
optimizer = tt.optim.Adam(weight_decay=0.01)

be_model = CoxPHStratified(net, optimizer)

# we  set it manually to 0.001
be_model.optimizer.set_lr(1e-3)

We include the `EarlyStopping` callback to stop training when the validation loss stops improving. After training, this callback will also load the best performing model in terms of validation loss.

In [None]:
%%time
batch_size = 64
epochs = 500
callbacks = [tt.callbacks.EarlyStopping(patience=20, min_delta=5e-2)]
verbose = True

batch_indices = np.ones(len(y_train[1]))
log = be_model.fit(x_train, y_train,
                batch_indices,
                batch_size,
                epochs,
                callbacks, 
                verbose=verbose,
                val_data=val, val_batch_size=batch_size
                )

TypeError: forward() missing 1 required positional argument: 'batch_indices'