# Pycox: DeepSurv Stratified by Batch


In [1]:
import os
os.getcwd()

'/home/nfs/dengy/dl-survival-miRNA'

In [2]:
import os
import numpy as np
import torch
import torchtuples as tt
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn_pandas import DataFrameMapper

# os.chdir("dl-survival-miRNA") 
# os.chdir("../")
from pss.pycox.models import CoxPH, CoxPHStratified, StratifiedDataset
from pss.pycox.evaluation.eval_surv import EvalSurv
from pss.utils import load_simulate_survival_data
from pss.run_models import DeepSurvPipeline, train_over_subsets

## *Test: Debugging*

### *Network code*

In [None]:
import numpy as np
import torch
from torch import Tensor

def cox_ph_loss_sorted(log_h: Tensor, events: Tensor, eps: float = 1e-7) -> Tensor:
    """Requires the input to be sorted by descending duration time.
    See DatasetDurationSorted.

    We calculate the negative log of $(\frac{h_i}{\sum_{j \in R_i} h_j})^d$,
    where h = exp(log_h) are the hazards and R is the risk set, and d is event.

    We just compute a cumulative sum, and not the true Risk sets. This is a
    limitati`on, but simple and fast.
    """
    if events.dtype is torch.bool:
        events = events.float()
    events = events.view(-1)
    log_h = log_h.view(-1)
    if events.sum() == 0:
        return log_h.sum() * 0.0  # update 08/11/25: safe dummy loss
    
    gamma = log_h.max()
    log_cumsum_h = log_h.sub(gamma).exp().cumsum(0).add(eps).log().add(gamma)
    return - log_h.sub(log_cumsum_h).mul(events).sum().div(events.sum())


def cox_ph_loss(log_h: Tensor, durations: Tensor, events: Tensor, eps: float = 1e-7) -> Tensor:
    """Loss for CoxPH model. If data is sorted by descending duration, see `cox_ph_loss_sorted`.

    We calculate the negative log of $(\frac{h_i}{\sum_{j \in R_i} h_j})^d$,
    where h = exp(log_h) are the hazards and R is the risk set, and d is event.

    We just compute a cumulative sum, and not the true Risk sets. This is a
    limitation, but simple and fast.
    """
    idx = durations.sort(descending=True)[1]
    events = events[idx]
    log_h = log_h[idx]
    return cox_ph_loss_sorted(log_h, events, eps)


####### [UPDATE] 07/07/2025
def stratified_cox_ph_loss(log_h: Tensor, durations: Tensor, events: Tensor, batch_indices: Tensor, eps: float = 1e-7) -> Tensor:
    """
    Stratified CoxPH loss that computes partial likelihood across batches.

    Arguments:
        log_h {torch.Tensor} -- Log hazard predictions for each instance.
        durations {torch.Tensor} -- Duration times for each instance.
        events {torch.Tensor} -- Event indicators (1 if event, 0 if censored).
        batch_indices {numpy array} -- Batch labels for each instance.
        eps {float} -- Small epsilon for numerical stability.

    Returns:
        torch.Tensor -- The total stratified negative log partial likelihood.
    """
    device = batch_indices.device
    unique_batches = torch.unique(batch_indices)
    losses = torch.zeros(len(unique_batches), device=device)
    n_valid_batch = 0
        
    for i, batch in enumerate(unique_batches):
        # Select data for the current batch
        mask = (batch_indices == batch)
        if mask.sum() == 0 or events[mask].sum() == 0:
            continue  # skip empty batch (added 08/11/25) or batch with no events
        
        # Sort by descending durations
        idx = torch.argsort(durations[mask], descending=True)
        
        events_batch = events[mask][idx]
        log_h_batch = log_h[mask][idx]
        if events_batch.sum() == 0:
            continue 
        
        losses[i] = cox_ph_loss_sorted(log_h_batch, events_batch, eps)
        n_valid_batch += 1
        
    if n_valid_batch == 0:
        return log_h.sum() * 0.0
    # print(n_valid_batch)
    return losses.sum()

In [None]:
# Create a PyTorch tensor
batch_indices = torch.tensor([1, 2, 2, 2, 2, 2, 3, 3, 4, 4], dtype=torch.float32)
durations = torch.tensor([169.5, 0.6, 12.3, 1.5, 3.8, 0.1, 0.1, 0.1, 0.6, 0.1], dtype=torch.float32)
events = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32)
log_h = torch.tensor([-4.1238, 2.1188, -1.5863, -1.2239, 0.9088, 5.6637, 1.2920, 4.5356, 1.5392, 5.0004], dtype=torch.float32)

# Test out ufnction
device = batch_indices.device
unique_batches = torch.unique(batch_indices)
losses = torch.zeros(len(unique_batches), device=device)

for i, batch in enumerate(unique_batches):
    # i = 1
    # batch = 1
    # print(i)
    mask = (batch_indices == batch)
    if mask.sum() == 0:
        print(f"batch {batch} is empty")
        continue
    idx = torch.argsort(durations[mask], descending=True)
    # idx = durations[mask].sort(descending=True)[1]
    log_h_batch = log_h[mask][idx]
    events_batch = events[mask][idx]
    
    print(events_batch)
    if events_batch.sum() == 0:
        print(f"batch {int(batch)} has no events")
        continue
    
    losses[i] = cox_ph_loss_sorted(log_h_batch, events_batch, eps=1e-7)
    
losses.sum()

In [None]:
stratified_cox_ph_loss(log_h, durations, events, batch_indices)

# Full process

In [3]:
batchNormType='BE00Asso00_normTMM'
dataType = 'linear-moderate'
keywords = ['061825']
test_size=10000
random_state=42
time_col='time'
status_col='status'
batch_col='batch.id'

train_df, test_df = load_simulate_survival_data(batchNormType=batchNormType,
                                                dataName=dataType,
                                                keywords=keywords, 
                                                keep_batch=True)

print(f"Training data dimensions: {train_df.shape}")
print(f"Testing data dimensions:  {test_df.shape}")

Training data dimensions: (90000, 541)
Testing data dimensions:  (10000, 541)


In [4]:
def _preprocess_data(df, mapper=None, fit_scaler=True):
    survival_cols = [time_col, status_col]
    covariate_cols = [col for col in df.columns if col not in survival_cols]
    # Transform features (miRNA expression)
    if fit_scaler or mapper is None:
        standardize = [([col], StandardScaler()) for col in covariate_cols]
        mapper = DataFrameMapper(standardize)
        x = mapper.fit_transform(df[covariate_cols]).astype('float32')
    else:
        x = mapper.transform(df[covariate_cols]).astype('float32')
    # Prepare labels (survival data)
    y = (df[time_col].values, df[status_col].values)
    
    return x, y, mapper

train_sub,_ = train_test_split(train_df,
                            train_size=2000, 
                            shuffle=True, random_state=42,
                            stratify=train_df[[status_col, batch_col]])
test_sub, _ = train_test_split(test_df,
                            train_size=1000, 
                            shuffle=True, random_state=42,
                            stratify=test_df[[status_col, batch_col]])

batch_ids_train = train_sub[batch_col].to_numpy().reshape(-1)
batch_ids_test = test_sub[[batch_col]].to_numpy().reshape(-1)

train_sub = train_sub.drop(columns=[batch_col])
test_sub = test_sub.drop(columns=[batch_col])

x_train, y_train, mapper = _preprocess_data(train_sub, fit_scaler=True)
x_test, y_test, _ = _preprocess_data(test_sub, mapper=mapper, fit_scaler=False)

durations_train, events_train = y_train[0], y_train[1]
durations_test, events_test = y_test[0], y_test[1]

# Prepare data 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x_train = torch.from_numpy(x_train).to(device)
x_test = torch.from_numpy(x_test).to(device)

durations_train = torch.from_numpy(durations_train).float().to(device)
durations_test = torch.from_numpy(durations_test).float().to(device)
events_train = torch.from_numpy(events_train).float().to(device)
events_test = torch.from_numpy(events_test).float().to(device)
y_train = (durations_train, events_train)
y_test = (durations_test, events_test)
        
batch_ids_train = torch.from_numpy(batch_ids_train).long().to(device)
batch_ids_test = torch.from_numpy(batch_ids_test).long().to(device)

print(device)
print(x_test.shape)           # Should be [n_samples, n_features]
# print(y_test.shape)           # Should be [n_samples, n_features]
print(durations_test.shape)   # Should be [n_samples]
print(events_test.shape)      # Should be [n_samples]
print(batch_ids_test.shape)   # Should be [n_samples]

cpu
torch.Size([1000, 538])
torch.Size([1000])
torch.Size([1000])
torch.Size([1000])


In [5]:
# x_train.shape
input_size = x_train.shape[1]
output_size = 1
num_nodes = [32,16]            # Default # layers & nodes
dropout = 0.2                    # Default dropout rate
learning_rate = 1e-3      # Default learning rate
batch_size = 128               # Default batch size
epochs = 500                      # Default number of epochs
batch_norm = True             # Default batch normalization
output_bias = True           # Default output bias
weight_decay = 1e-4         # Default weight decay
activation = torch.nn.ReLU

net = tt.practical.MLPVanilla(
    in_features=input_size,
    out_features=output_size,
    num_nodes=num_nodes,
    dropout=dropout, 
    batch_norm=batch_norm,
    activation=activation,
    output_bias=output_bias
).to(device)
optimizer = tt.optim.Adam(weight_decay=weight_decay, lr=learning_rate)

# Get default early stopping settings if not defined 
patience = 30
min_delta = 1e-3
callbacks = [tt.callbacks.EarlyStopping(patience=patience, min_delta=min_delta)]

### CoxPH

In [None]:
# CoxPH model
model = CoxPH(net, optimizer=optimizer)
log = model.fit(
    x_train, y_train,
    batch_size=batch_size,
    epochs=epochs,
    callbacks=callbacks, 
    verbose=True,
    val_data=(x_test, y_test),
    val_batch_size=batch_size
)

In [None]:
# ==================== Evaluation ====================
_ = model.compute_baseline_hazards(input=x_train, target=(durations_train, events_train))

# Convert torch tensors back to numpy objects for evaluation
x_train_np = x_train.detach().cpu().numpy()
x_test  = x_test.detach().cpu().numpy()
durations_train = durations_train.detach().cpu().numpy()
durations_test  = durations_test.detach().cpu().numpy()
events_train    = events_train.detach().cpu().numpy()
events_test     = events_test.detach().cpu().numpy()

# Initialize EvalSurv objects 
tr_surv  = model.predict_surv_df(x_train)
te_surv = model.predict_surv_df(x_test)
tr_ev = EvalSurv(tr_surv, durations_train, events_train, censor_surv='km')
te_ev = EvalSurv(te_surv, durations_test, events_test, censor_surv='km')

# Concordance index ----------------
tr_c_index  = tr_ev.concordance_td() 
te_c_index = te_ev.concordance_td() 

tr_c_index, te_c_index

### Stratified CoxPH

In [None]:
# train_dataset = StratifiedDataset(x_train, durations_train, events_train, batch_ids_train)
# train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# test_dataset = torch.utils.data.TensorDataset(x_test, durations_test, events_test, batch_ids_test)
# test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# # # Access batches
# # for idx, (inputs, durations, events, batch_ids) in enumerate(train_loader):
# #     print(f"Batch {idx + 1}:")
# #     print("batch ids:", batch_ids)
# #     # print("Time:", durations)
# #     print("Events:", events)
# #     print()
    
# ## Test
# for idx, (inputs, durations, events, batch_ids) in enumerate(test_loader):
#     print(f"Batch {idx + 1}:")
#     print("batch ids:", batch_ids)
#     # print("Time:", durations)
#     print("Events:", events)
#     print()

In [6]:
# Get default early stopping settings if not defined 
patience = 30
min_delta = 1e-3
callbacks = [tt.callbacks.EarlyStopping(patience=patience, min_delta=min_delta)]


# train_dataset = torch.utils.data.TensorDataset(x_train, durations_train, events_train, batch_ids_train)
# test_dataset   = torch.utils.data.TensorDataset(x_test, durations_test, events_test, batch_ids_test)
train_dataset = StratifiedDataset(x_train, durations_train, events_train, batch_ids_train)
test_dataset = StratifiedDataset(x_test, durations_test, events_test, batch_ids_test)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

for xb, db, eb, bb in test_loader:
    print("VAL batch total events:", int(eb.sum().item()),
          "| per-stratum:", {int(s): int(eb[bb==s].sum().item()) for s in bb.unique().tolist()})

VAL batch total events: 95 | per-stratum: {1: 11, 2: 10, 3: 10, 4: 11, 5: 6, 6: 7, 7: 7, 8: 9, 9: 7, 10: 17}
VAL batch total events: 96 | per-stratum: {1: 13, 2: 9, 3: 11, 4: 8, 5: 7, 6: 6, 7: 9, 8: 11, 9: 10, 10: 12}
VAL batch total events: 103 | per-stratum: {1: 8, 2: 10, 3: 6, 4: 12, 5: 14, 6: 10, 7: 11, 8: 11, 9: 15, 10: 6}
VAL batch total events: 93 | per-stratum: {1: 13, 2: 11, 3: 11, 4: 7, 5: 4, 6: 10, 7: 11, 8: 6, 9: 8, 10: 12}
VAL batch total events: 94 | per-stratum: {1: 13, 2: 15, 3: 12, 4: 10, 5: 10, 6: 12, 7: 4, 8: 6, 9: 7, 10: 5}
VAL batch total events: 96 | per-stratum: {1: 8, 2: 9, 3: 7, 4: 13, 5: 15, 6: 8, 7: 11, 8: 5, 9: 12, 10: 8}
VAL batch total events: 94 | per-stratum: {1: 8, 2: 13, 3: 4, 4: 13, 5: 6, 6: 7, 7: 16, 8: 10, 9: 9, 10: 8}
VAL batch total events: 77 | per-stratum: {1: 11, 2: 3, 3: 6, 4: 10, 5: 8, 6: 12, 7: 6, 8: 9, 9: 5, 10: 7}


In [7]:
import time

# Stratified CoxPH model
model = CoxPHStratified(net, optimizer=optimizer)
model.metrics = {'val_loss': model.loss}
start = time.time() # Record iteration start time
log = model.fit_dataloader(
    train_loader,
    epochs=epochs,
    callbacks=callbacks,
    verbose=True,
    val_dataloader=test_loader  # optional for now
)
stop = time.time() # Record time when training finished
duration = round(stop - start, 2)
print(f"Training time: {duration}")

0:	[0s / 0s],		train_loss: 20.6112,	val_loss: 19.5040
1:	[0s / 0s],		train_loss: 18.8728,	val_loss: 18.8106
2:	[0s / 0s],		train_loss: 18.3032,	val_loss: 18.4668
3:	[0s / 0s],		train_loss: 17.5007,	val_loss: 18.1772
4:	[0s / 0s],		train_loss: 17.1657,	val_loss: 17.9118
5:	[0s / 0s],		train_loss: 16.7562,	val_loss: 17.6523
6:	[0s / 0s],		train_loss: 16.2896,	val_loss: 17.4696
7:	[0s / 0s],		train_loss: 15.7985,	val_loss: 17.2974
8:	[0s / 0s],		train_loss: 15.4151,	val_loss: 17.2267
9:	[0s / 0s],		train_loss: 14.9886,	val_loss: 16.8337
10:	[0s / 0s],		train_loss: 14.4737,	val_loss: 16.5885
11:	[0s / 0s],		train_loss: 13.8276,	val_loss: 16.5437
12:	[0s / 0s],		train_loss: 13.7400,	val_loss: 16.2938
13:	[0s / 1s],		train_loss: 13.0749,	val_loss: 15.9844
14:	[0s / 1s],		train_loss: 12.8311,	val_loss: 15.8871
15:	[0s / 1s],		train_loss: 12.5133,	val_loss: 15.6785
16:	[1s / 2s],		train_loss: 12.2128,	val_loss: 15.5450
17:	[0s / 3s],		train_loss: 11.9292,	val_loss: 15.4642
18:	[0s / 3s],		trai

### Evaluation

#### *Stratified C-index* 

In [8]:
# ==================== Evaluation ====================
# Convert torch tensors back to numpy objects for evaluation
durations_train_np = durations_train.detach().cpu().numpy()
durations_test_np  = durations_test.detach().cpu().numpy()
events_train_np    = events_train.detach().cpu().numpy()
events_test_np     = events_test.detach().cpu().numpy()

# Compute baseline hazards (per-batch)
baseline_hazards_strata = model.compute_baseline_hazards(input=x_train, target=(durations_train, events_train), batch_ids=batch_ids_train)

# Initialize EvalSurv objects 
tr_surv  = model.predict_surv_df(x_train, batch_ids = batch_ids_train)
te_surv = model.predict_surv_df(x_test, batch_ids = batch_ids_test)
tr_ev = EvalSurv(tr_surv, durations_train_np, events_train_np, censor_surv='km')
te_ev = EvalSurv(te_surv, durations_test_np, events_test_np, censor_surv='km')

# Concordance index ----------------
tr_strat_c_index  = tr_ev.stratified_concordance_td(batch_indices=batch_ids_train) 
te_strat_c_index = te_ev.stratified_concordance_td(batch_indices=batch_ids_test) 

print(tr_strat_c_index, te_strat_c_index)

0.9417363122787157 0.8178225734837048


#### *One-batch C-index* 

In [9]:
baseline_hazards_1batch = model.compute_baseline_hazards(input=x_train, target=(durations_train, events_train))

# Initialize EvalSurv objects 
tr_surv  = model.predict_surv_df(x_train, baseline_hazards_=baseline_hazards_1batch)
te_surv = model.predict_surv_df(x_test, baseline_hazards_=baseline_hazards_1batch)
tr_ev = EvalSurv(tr_surv, durations_train_np, events_train_np, censor_surv='km')
te_ev = EvalSurv(te_surv, durations_test_np, events_test_np, censor_surv='km')

# Concordance index (non-stratified) ----------------
tr_c_index, _  = tr_ev.concordance_td() 
te_c_index, _ = te_ev.concordance_td() 
print(tr_c_index, te_c_index)

0.9387092920643851 0.8268118138566364


In [None]:
# Manual test
from pycox.evaluation.concordance import concordance_td
from pycox.evaluation import ipcw

batch_indices = batch_ids_train.detach().cpu().numpy() if not isinstance(batch_ids_train, np.ndarray) else batch_ids_train
batches = np.unique(batch_indices)
c_index_ls , n_pairs_ls = np.zeros(len(batches)), np.zeros(len(batches))

for i, batch in enumerate(batches):
    # Filter data by batch
    mask = (batch_indices == batch)
    if mask.sum() == 0:
        continue  # skip empty batch
    batch_durations = durations_train[mask]
    batch_events = events_train[mask]
    batch_surv = tr_ev.surv.iloc[:, mask]
    if batch_events.sum() == 0:
        continue
    
    # Compute concordance for the current batch
    c_index_batch, n_pairs_batch = concordance_td(
        batch_durations, batch_events, batch_surv.values,
        tr_ev.idx_at_times(batch_durations), method='adj_antolini'
    )
    print(n_pairs_batch)
    n_pairs_ls[i] = n_pairs_batch
    # n_events_ls[i] = batch_events.sum()
    c_index_ls[i] = c_index_batch
    
print("Final score: %f" % (np.sum(c_index_ls*n_pairs_ls) / np.sum(n_pairs_ls) if np.sum(n_pairs_ls) > 0 else float('nan')))

for e, c in zip(n_pairs_ls, c_index_ls):
    print(f"{int(e)} comparable pairs: {round(c,3)}")

#### *Test: Stratified Integrated Brier score*

In [None]:
# Integrated Brier score -----------
min_surv = np.ceil(max(np.min(durations_train), np.min(durations_test)))
max_surv = np.floor(min(np.max(durations_train), np.max(durations_test)))
times = np.linspace(min_surv, max_surv, 20)

tr_brier  = tr_ev.integrated_brier_score(time_grid=times) 
te_brier =  te_ev.integrated_brier_score(time_grid=times)

print(tr_brier, te_brier)

In [None]:
tr_strat_brier  = tr_ev.stratified_integrated_brier_score(time_grid=times, batch_indices=batch_ids_train) 
te_strat_brier =  te_ev.stratified_integrated_brier_score(time_grid=times, batch_indices=batch_ids_test)
print(tr_strat_brier, te_strat_brier)

In [None]:
# print(tr_ev.surv.values.shape) 
# print(tr_ev.censor_surv.surv.values.shape)
# print(tr_ev.index_surv.shape)
# print(tr_ev.censor_surv.index_surv.shape) 
# print(tr_ev.steps)
# print(tr_ev.censor_surv.steps)

In [None]:
batch_indices = batch_ids_train.detach().numpy() if not isinstance(batch_ids_train, np.ndarray) else batch_ids_train
batches = np.unique(batch_indices)
brier_ls, n_events_ls = np.zeros(len(batches)), np.zeros(len(batches))

for i, batch in enumerate(batches):
    # Filter data by batch
    mask = (batch_indices == batch)
    if mask.sum() == 0:
        continue  # skip empty batch
    batch_durations = durations_train[mask]
    batch_events = events_train[mask]
    batch_surv = tr_ev.surv.iloc[:, mask]
    if batch_events.sum() == 0:
        continue
    batch_surv_values = tr_ev.surv.values[:, mask]
    batch_censor_surv_values = tr_ev.censor_surv.surv.values[:, mask] 
    # batch_index_surv = tr_ev.index_surv[mask]
    # batch_censor_index_surv = tr_ev.censor_surv.index_surv[mask]
    
    # Compute integrated brier score for the current batch
    brier_batch = ipcw.integrated_brier_score(times, batch_durations, batch_events, 
                                    batch_surv_values, batch_censor_surv_values, 
                                    tr_ev.index_surv, tr_ev.censor_surv.index_surv, np.inf, 
                                    tr_ev.steps, tr_ev.censor_surv.steps)
    n_events_ls[i] = batch_events.sum()
    brier_ls[i] = brier_batch
    
print("Final score: %f\n" % (np.sum(brier_ls*n_events_ls) / np.sum(n_events_ls) if np.sum(n_events_ls) > 0 else float('nan')))

for e, c in zip(n_events_ls, brier_ls):
    print(f"{int(e)} events: {round(c,3)}")

# Pipeline Test

In [None]:
# Load data
batchNormType='BE00Asso00_normNone'
dataName='nl-shiftquad'
keywords = ['061825']
test_size=10000
random_state=42
time_col='time'
status_col='status'
batch_col='batch.id'

train_df, test_df = load_simulate_survival_data(batchNormType=batchNormType,
                                                dataName=dataName,
                                                keywords=keywords, 
                                                keep_batch=True)

print(f"Training data dimensions: {train_df.shape}")
print(f"Testing data dimensions:  {test_df.shape}")
plot_simulation_data(train_df, test_df)

In [None]:
hyperparameters = {
    "num_nodes": {"type": "categorical", "choices": [[64,64], [32,32], [16,16]]},
    "dropout": {"type": "float", "low": 0.1, "high": 0.5},
    "weight_decay": {"type": "float", "low": 1e-5, "high": 1e-2, "log": True},
    "learning_rate": {"type": "float", "low": 1e-4, "high": 1e-2, "log": True},
    "batch_size": {"type": "categorical", "choices": [128, 64, 32, 16]}
}

ds = DeepSurvPipeline(
    train_df, test_df, 
    batchNormType=batchNormType, 
    dataName=dataName,
    hyperparameters=hyperparameters,
    is_stratified=True
)

# optuna.logging.disable_default_handler()
stratified_results = ds.train_over_subsets(subset_sizes=[2000],#subset_sizes, 
                                runs_per_size=[5],#runs_per_size, 
                                splits_per_size=[10],#splits_per_size,
                                trials_per_size=[5],#trails_per_size,
                                is_tune=True, 
                                is_save=False, 
                                n_jobs=-1,
                                trial_threshold=5                               
)
stratified_results

# ==== Archive ====

In [None]:
# prepare data
folder = 'linear'
keywords = ['moderate', "latest", 'RW']

train_df, test_df = load_simulate_survival_data(folder=folder, keywords=keywords, test_size=0.2)

train_df.head()

## Feature transforms


In [None]:
survival_cols = ['time', 'status']

In [None]:
tr_df, val_df = train_test_split(train_df, 
                                test_size=0.2,
                                shuffle=True, random_state=42,
                                stratify=train_df['status'])

# Transform data
covariate_cols = [col for col in train_df.columns if col not in survival_cols]
standardize = [([col], StandardScaler()) for col in covariate_cols]
leave = [(col, None) for col in survival_cols]
x_mapper = DataFrameMapper(standardize)

# gene expression data
x_train = x_mapper.fit_transform(tr_df[covariate_cols]).astype('float32')
x_val = x_mapper.fit_transform(val_df[covariate_cols]).astype('float32')
x_test = x_mapper.transform(test_df[covariate_cols]).astype('float32')

# prepare labels
get_target = lambda df: (df['time'].values, df['status'].values)
y_train = get_target(tr_df)
y_val = get_target(val_df)
t_test, e_test = get_target(test_df)
val = x_val, y_val

## Neural net

We create a simple MLP with two hidden layers, ReLU activations, batch norm and dropout. 
Here, we just use the `torchtuples.practical.MLPVanilla` net to do this.


In [None]:
in_features = x_train.shape[1]
num_nodes = [32, 16]
out_features = 1
batch_norm = True
dropout = 0.2
output_bias = True

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm,
                            dropout, output_bias=output_bias)

## Training the model

To train the model we need to define a `torch.optim` optimizer; here we instead use one from `tt.optim` as it has some added functionality.
We use the `Adam` optimizer and set the desired learning rate with `model.lr_finder`.

In [None]:
optimizer = tt.optim.Adam(weight_decay=0.01)

be_model = CoxPHStratified(net, optimizer)

# we  set it manually to 0.001
be_model.optimizer.set_lr(1e-3)

We include the `EarlyStopping` callback to stop training when the validation loss stops improving. After training, this callback will also load the best performing model in terms of validation loss.

In [None]:
%%time
batch_size = 64
epochs = 500
callbacks = [tt.callbacks.EarlyStopping(patience=20, min_delta=5e-2)]
verbose = True

batch_indices = np.ones(len(y_train[1]))
log = be_model.fit(x_train, y_train,
                batch_indices,
                batch_size,
                epochs,
                callbacks, 
                verbose=verbose,
                val_data=val, val_batch_size=batch_size
                )