# Pycox: DeepSurv Stratified by Batch


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchtuples as tt
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper

os.chdir("../")
from pycox.models.cox import CoxPH, CoxPHStratified, StratifiedDataset
from pycox.evaluation.eval_surv import EvalSurv, EvalSurvStratified
from utils import *
from runDeepSurvModels import *

## *Test: Network*

In [None]:
from torch import Tensor

def cox_ph_loss_sorted(log_h: Tensor, events: Tensor, eps: float = 1e-7) -> Tensor:
    """Requires the input to be sorted by descending duration time.
    See DatasetDurationSorted.

    We calculate the negative log of $(\frac{h_i}{\sum_{j \in R_i} h_j})^d$,
    where h = exp(log_h) are the hazards and R is the risk set, and d is event.

    We just compute a cumulative sum, and not the true Risk sets. This is a
    limitation, but simple and fast.
    """
    if events.dtype is torch.bool:
        events = events.float()
    events = events.view(-1)
    log_h = log_h.view(-1)
    gamma = log_h.max()
    log_cumsum_h = log_h.sub(gamma).exp().cumsum(0).add(eps).log().add(gamma)
    return - log_h.sub(log_cumsum_h).mul(events).sum().div(events.sum())

####### [UPDATE] 1105
def stratified_cox_ph_loss(log_h: Tensor, durations: Tensor, events: Tensor, batch_indices: Tensor, eps: float = 1e-7) -> Tensor:
    """
    Stratified CoxPH loss that computes partial likelihood across batches.

    Arguments:
        log_h {torch.Tensor} -- Log hazard predictions for each instance.
        durations {torch.Tensor} -- Duration times for each instance.
        events {torch.Tensor} -- Event indicators (1 if event, 0 if censored).
        batch_indices {numpy array} -- Batch labels for each instance.
        eps {float} -- Small epsilon for numerical stability.

    Returns:
        torch.Tensor -- The total stratified negative log partial likelihood.
    """
    device = batch_indices.device
    unique_batches = torch.unique(batch_indices)
    losses = torch.zeros(len(unique_batches), device=device)
        
    for i, batch in enumerate(unique_batches):
        # Select data for the current batch
        mask = (batch_indices == batch)
        if mask.sum() == 0:
            continue  # skip empty batch
        
        # Sort by descending durations
        idx = torch.argsort(durations[mask], descending=True)
        # idx = durations[mask].sort(descending=True)[1]
        
        events_batch = events[mask][idx]
        log_h_batch = log_h[mask][idx]
        if events_batch.sum() == 0:
            continue 
        
        losses[i] = cox_ph_loss_sorted(log_h_batch, events_batch, eps)
    
    return losses.sum()


import torchtuples.callbacks as cb

### UPDATE 11/05
class CoxPHLossStratified(torch.nn.Module):
    """Loss for CoxPH model with batch variable.

    We calculate the batch-stratified negative log of $(\frac{h_i}{\sum_{j \in R_i} h_j})^d$,
    where h = exp(log_h) are the hazards and R is the risk set, and d is event.

    We just compute a cumulative sum, and not the true Risk sets. This is a
    limitation, but simple and fast.
    """
    # def forward(self, log_h: Tensor, durations: Tensor, events: Tensor, batch_indices: Tensor) -> Tensor:
        # return stratified_cox_ph_loss(log_h, durations, events, batch_indices)
    def forward(self, log_h: Tensor, durations: Tensor, events: Tensor, batch_indices: Tensor) -> Tensor:
        if torch.isnan(log_h).any():
            print("NaNs detected in log hazards")
        if torch.isnan(durations).any():
            print("NaNs detected in input survival time")
        if torch.isnan(events).any():
            print("NaNs detected in input events")
        if (events.sum() == 0).item():
            print("No observed events in batch (val)")
        return stratified_cox_ph_loss(log_h, durations, events, batch_indices)

In [90]:
import warnings

# Import the entire base module
from pycox.models import base
from pycox.models import loss as Loss
from pycox.models import utils

def search_sorted_idx(array, values):
    '''For sorted array, get index of values.
    If value not in array, give left index of value.
    '''
    n = len(array)
    idx = np.searchsorted(array, values)
    idx[idx == n] = n-1 # We can't have indexes higher than the length-1
    not_exact = values != array[idx]
    idx -= not_exact
    if any(idx < 0):
        warnings.warn('Given value smaller than first value')
        idx[idx < 0] = 0
    return idx


class _CoxBase(base.SurvBase):
    duration_col = 'duration'
    event_col = 'event'

    def fit(self, input, target, batch_size=256, epochs=1, callbacks=None, verbose=True,
            num_workers=0, shuffle=True, metrics=None, val_data=None, val_batch_size=8224,
            **kwargs):
        """Fit  model with inputs and targets. Where 'input' is the covariates, and
        'target' is a tuple with (durations, events).
        
        Arguments:
            input {np.array, tensor or tuple} -- Input x passed to net.
            target {np.array, tensor or tuple} -- Target [durations, events]. 
        
        Keyword Arguments:
            batch_size {int} -- Elements in each batch (default: {256})
            epochs {int} -- Number of epochs (default: {1})
            callbacks {list} -- list of callbacks (default: {None})
            verbose {bool} -- Print progress (default: {True})
            num_workers {int} -- Number of workers used in the dataloader (default: {0})
            shuffle {bool} -- If we should shuffle the order of the dataset (default: {True})
            **kwargs are passed to 'make_dataloader' method.
    
        Returns:
            TrainingLogger -- Training log
        """
        self.training_data = tt.tuplefy(input, target)
        return super().fit(input, target, batch_size, epochs, callbacks, verbose,
                           num_workers, shuffle, metrics, val_data, val_batch_size,
                           **kwargs)

    def _compute_baseline_hazards(self, input, df, max_duration, batch_size, eval_=True, num_workers=0):
        raise NotImplementedError

    def target_to_df(self, target):
        durations, events = tt.tuplefy(target).to_numpy()
        df = pd.DataFrame({self.duration_col: durations, self.event_col: events}) 
        return df

    def compute_baseline_hazards(self, input=None, target=None, max_duration=None, sample=None, batch_size=8224,
                                set_hazards=True, eval_=True, num_workers=0):
        """Computes the Breslow estimates form the data defined by `input` and `target`
        (if `None` use training data).

        Typically call
        model.compute_baseline_hazards() after fitting.
        
        Keyword Arguments:
            input  -- Input data (train input) (default: {None})
            target  -- Target data (train target) (default: {None})
            max_duration {float} -- Don't compute estimates for duration higher (default: {None})
            sample {float or int} -- Compute estimates of subsample of data (default: {None})
            batch_size {int} -- Batch size (default: {8224})
            set_hazards {bool} -- Set hazards in model object, or just return hazards. (default: {True})
        
        Returns:
            pd.Series -- Pandas series with baseline hazards. Index is duration_col.
        """
        if (input is None) and (target is None):
            if not hasattr(self, 'training_data'):
                raise ValueError("Need to give a 'input' and 'target' to this function.")
            input, target = self.training_data
        df = self.target_to_df(target)#.sort_values(self.duration_col)
        if sample is not None:
            if sample >= 1:
                df = df.sample(n=sample)
            else:
                df = df.sample(frac=sample)
        input = tt.tuplefy(input).to_numpy().iloc[df.index.values]
        base_haz = self._compute_baseline_hazards(input, df, max_duration, batch_size,
                                                  eval_=eval_, num_workers=num_workers)
        if set_hazards:
            self.compute_baseline_cumulative_hazards(set_hazards=True, baseline_hazards_=base_haz)
        return base_haz

    def compute_baseline_cumulative_hazards(self, input=None, target=None, max_duration=None, sample=None,
                                            batch_size=8224, set_hazards=True, baseline_hazards_=None,
                                            eval_=True, num_workers=0):
        """See `compute_baseline_hazards. This is the cumulative version."""
        if ((input is not None) or (target is not None)) and (baseline_hazards_ is not None):
            raise ValueError("'input', 'target' and 'baseline_hazards_' can not both be different from 'None'.")
        if baseline_hazards_ is None:
            baseline_hazards_ = self.compute_baseline_hazards(input, target, max_duration, sample, batch_size,
                                                             set_hazards=False, eval_=eval_, num_workers=num_workers)
        assert baseline_hazards_.index.is_monotonic_increasing,\
            'Need index of baseline_hazards_ to be monotonic increasing, as it represents time.'
        bch = (baseline_hazards_
                .cumsum()
                .rename('baseline_cumulative_hazards'))
        if set_hazards:
            self.baseline_hazards_ = baseline_hazards_
            self.baseline_cumulative_hazards_ = bch
        return bch

    def predict_cumulative_hazards(self, input, max_duration=None, batch_size=8224, verbose=False,
                                   baseline_hazards_=None, eval_=True, num_workers=0):
        """See `predict_survival_function`."""
        if type(input) is pd.DataFrame:
            input = self.df_to_input(input)
        if baseline_hazards_ is None:
            if not hasattr(self, 'baseline_hazards_'):
                raise ValueError('Need to compute baseline_hazards_. E.g run `model.compute_baseline_hazards()`')
            baseline_hazards_ = self.baseline_hazards_
        assert baseline_hazards_.index.is_monotonic_increasing,\
            'Need index of baseline_hazards_ to be monotonic increasing, as it represents time.'
        return self._predict_cumulative_hazards(input, max_duration, batch_size, verbose, baseline_hazards_,
                                                eval_, num_workers=num_workers)

    def _predict_cumulative_hazards(self, input, max_duration, batch_size, verbose, baseline_hazards_,
                                    eval_=True, num_workers=0):
        raise NotImplementedError

    def predict_surv_df(self, input, max_duration=None, batch_size=8224, verbose=False, baseline_hazards_=None,
                        eval_=True, num_workers=0):
        """Predict survival function for `input`. S(x, t) = exp(-H(x, t))
        Require computed baseline hazards.

        Arguments:
            input {np.array, tensor or tuple} -- Input x passed to net.

        Keyword Arguments:
            max_duration {float} -- Don't compute estimates for duration higher (default: {None})
            batch_size {int} -- Batch size (default: {8224})
            baseline_hazards_ {pd.Series} -- Baseline hazards. If `None` used `model.baseline_hazards_` (default: {None})
            eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
            num_workers {int} -- Number of workers in created dataloader (default: {0})

        Returns:
            pd.DataFrame -- Survival estimates. One columns for each individual.
        """
        return np.exp(-self.predict_cumulative_hazards(input, max_duration, batch_size, verbose, baseline_hazards_,
                                                       eval_, num_workers))

    def predict_surv(self, input, max_duration=None, batch_size=8224, numpy=None, verbose=False,
                     baseline_hazards_=None, eval_=True, num_workers=0):
        """Predict survival function for `input`. S(x, t) = exp(-H(x, t))
        Require compueted baseline hazards.

        Arguments:
            input {np.array, tensor or tuple} -- Input x passed to net.

        Keyword Arguments:
            max_duration {float} -- Don't compute estimates for duration higher (default: {None})
            batch_size {int} -- Batch size (default: {8224})
            numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input
                (default: {None})
            baseline_hazards_ {pd.Series} -- Baseline hazards. If `None` used `model.baseline_hazards_` (default: {None})
            eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
            num_workers {int} -- Number of workers in created dataloader (default: {0})

        Returns:
            pd.DataFrame -- Survival estimates. One columns for each individual.
        """
        surv = self.predict_surv_df(input, max_duration, batch_size, verbose, baseline_hazards_,
                                    eval_, num_workers)
        surv = torch.from_numpy(surv.values.transpose())
        return tt.utils.array_or_tensor(surv, numpy, input)

    def save_net(self, path, **kwargs):
        """Save self.net and baseline hazards to file.

        Arguments:
            path {str} -- Path to file.
            **kwargs are passed to torch.save

        Returns:
            None
        """
        path, extension = os.path.splitext(path)
        if extension == "":
            extension = '.pt'
        super().save_net(path+extension, **kwargs)
        if hasattr(self, 'baseline_hazards_'):
            self.baseline_hazards_.to_pickle(path+'_blh.pickle')

    def load_net(self, path, **kwargs):
        """Load net and hazards from file.

        Arguments:
            path {str} -- Path to file.
            **kwargs are passed to torch.load

        Returns:
            None
        """
        path, extension = os.path.splitext(path)
        if extension == "":
            extension = '.pt'
        super().load_net(path+extension, **kwargs)
        blh_path = path+'_blh.pickle'
        if os.path.isfile(blh_path):
            self.baseline_hazards_ = pd.read_pickle(blh_path)
            self.baseline_cumulative_hazards_ = self.baseline_hazards_.cumsum()

    def df_to_input(self, df):
        input = df[self.input_cols].values
        return input
    

class _CoxPHBase(_CoxBase):
    def _compute_baseline_hazards(self, input, df_target, max_duration, batch_size, eval_=True, num_workers=0):
        if max_duration is None:
            max_duration = np.inf

        # Here we are computing when expg when there are no events.
        #   Could be made faster, by only computing when there are events.
        return (df_target
                .assign(expg=np.exp(self.predict(input, batch_size, True, eval_, num_workers=num_workers)))
                .groupby(self.duration_col)
                .agg({'expg': 'sum', self.event_col: 'sum'})
                .sort_index(ascending=False)
                .assign(expg=lambda x: x['expg'].cumsum())
                .pipe(lambda x: x[self.event_col]/x['expg'])
                .fillna(0.)
                .iloc[::-1]
                .loc[lambda x: x.index <= max_duration]
                .rename('baseline_hazards'))

    def _predict_cumulative_hazards(self, input, max_duration, batch_size, verbose, baseline_hazards_,
                                    eval_=True, num_workers=0):
        max_duration = np.inf if max_duration is None else max_duration
        if baseline_hazards_ is self.baseline_hazards_:
            bch = self.baseline_cumulative_hazards_
        else:
            bch = self.compute_baseline_cumulative_hazards(set_hazards=False, 
                                                           baseline_hazards_=baseline_hazards_)
        bch = bch.loc[lambda x: x.index <= max_duration]
        expg = np.exp(self.predict(input, batch_size, True, eval_, num_workers=num_workers)).reshape(1, -1)
        return pd.DataFrame(bch.values.reshape(-1, 1).dot(expg), 
                            index=bch.index)

    def partial_log_likelihood(self, input, target, g_preds=None, batch_size=8224, eps=1e-7, eval_=True,
                               num_workers=0):
        '''Calculate the partial log-likelihood for the events in datafram df.
        This likelihood does not sample the controls.
        Note that censored data (non events) does not have a partial log-likelihood.

        Arguments:
            input {tuple, np.ndarray, or torch.tensor} -- Input to net.
            target {tuple, np.ndarray, or torch.tensor} -- Target labels.

        Keyword Arguments:
            g_preds {np.array} -- Predictions from `model.predict` (default: {None})
            batch_size {int} -- Batch size (default: {8224})
            eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
            num_workers {int} -- Number of workers in created dataloader (default: {0})

        Returns:
            Partial log-likelihood.
        '''
        df = self.target_to_df(target)
        if g_preds is None:
            g_preds = self.predict(input, batch_size, True, eval_, num_workers=num_workers)
        return (df
                .assign(_g_preds=g_preds)
                .sort_values(self.duration_col, ascending=False)
                .assign(_cum_exp_g=(lambda x: x['_g_preds']
                                    .pipe(np.exp)
                                    .cumsum()
                                    .groupby(x[self.duration_col])
                                    .transform('max')))
                .loc[lambda x: x[self.event_col] == 1]
                .assign(pll=lambda x: x['_g_preds'] - np.log(x['_cum_exp_g'] + eps))
                ['pll'])


class CoxPH(_CoxPHBase):
    """Cox proportional hazards model parameterized with a neural net.
    This is essentially the DeepSurv method [1].

    The loss function is not quite the partial log-likelihood, but close.    
    The difference is that for tied events, we use a random order instead of 
    including all individuals that had an event at that point in time.

    Arguments:
        net {torch.nn.Module} -- A pytorch net.
    
    Keyword Arguments:
        optimizer {torch or torchtuples optimizer} -- Optimizer (default: {None})
        device {str, int, torch.device} -- Device to compute on. (default: {None})
            Preferably pass a torch.device object.
            If 'None': use default gpu if available, else use cpu.
            If 'int': used that gpu: torch.device('cuda:<device>').
            If 'string': string is passed to torch.device('string').

    [1] Jared L. Katzman, Uri Shaham, Alexander Cloninger, Jonathan Bates, Tingting Jiang, and Yuval Kluger.
        Deepsurv: personalized treatment recommender system using a Cox proportional hazards deep neural network.
        BMC Medical Research Methodology, 18(1), 2018.
        https://bmcmedresmethodol.biomedcentral.com/articles/10.1186/s12874-018-0482-1
    """
    def __init__(self, net, optimizer=None, device=None, loss=None):
        if loss is None:
            loss = Loss.CoxPHLoss()
        super().__init__(net, loss, optimizer, device)

In [None]:
### Update 7/1/2025
class StratifiedDataset(torch.utils.data.Dataset):
    def __init__(self, x, durations, events, batch_ids):
        self.x = x
        self.durations = durations
        self.events = events
        self.batch_ids = batch_ids

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.durations[idx], self.events[idx], self.batch_ids[idx]


### UPDATE 11/05/2024
class CoxPHStratified(_CoxPHBase):
    """Cox proportional hazards model parameterized with a neural net.

    The loss function is not quite the partial log-likelihood, but close.    
    The difference is that for we stratify events by batch (strata) when
    calculating partial log likelihood.

    Arguments:
        net {torch.nn.Module} -- A pytorch net.
    
    Keyword Arguments:
        optimizer {torch or torchtuples optimizer} -- Optimizer (default: {None})
        device {str, int, torch.device} -- Device to compute on. (default: {None})
            Preferably pass a torch.device object.
            If 'None': use default gpu if available, else use cpu.
        loss {function} -- Loss function to use, default is stratified_cox_ph_loss.
    """
    def __init__(self, net, optimizer=None, device=None, loss=None):
        self.batch_ids = None
        if loss is None:
            # loss = Loss.CoxPHLossStratified()
            loss = CoxPHLossStratified()
        super().__init__(net, loss, optimizer, device)
        
    def compute_metrics(self, data, metrics=None):
        if metrics is None:
            metrics = self.metrics
        if self.loss is None and self.loss in metrics.values():
            raise RuntimeError("Need to set `self.loss`.")
        x, durations, events, batch_ids = data
        log_h = self.net(x)
        
        if torch.isnan(log_h).any():
            print("NaNs detected in log_h during compute_metrics()")
            return {name: float('nan') for name in metrics}
        
        return {name: metric(log_h, durations, events, batch_ids) for name, metric in metrics.items()}
    
    def fit_dataloader(self, dataloader, epochs=1, callbacks=None, verbose=True, metrics=None, val_dataloader=None):
        """
        Custom training loop for CoxPHStratified that reads (x, duration, event, batch_id) from DataLoader.
        
        Args:
            dataloader (DataLoader): training data loader returning 4-tuples.
            epochs (int): number of training epochs.
            callbacks (list): optional callbacks.
            verbose (bool): print training progress.
            metrics (dict): optional metrics.
            val_dataloader (DataLoader): optional validation dataloader (can be normal 2-tuple).
        
        Returns:
            TrainingLogger
        """
        self._setup_train_info(dataloader)
        self.metrics = self._setup_metrics(metrics)
        self.log.verbose = verbose
        self.val_metrics.dataloader = val_dataloader
        
        if callbacks is None:
            callbacks = []
        self.callbacks = cb.TrainingCallbackHandler(
            self.optimizer, self.train_metrics, self.log, self.val_metrics, callbacks
        )
        self.callbacks.give_model(self)

        stop = self.callbacks.on_fit_start()
        for _ in range(epochs):
            if stop:
                break
            stop = self.callbacks.on_epoch_start()
            if stop:
                break
            
            for x, durations, events, batch_ids in dataloader:
                stop = self.callbacks.on_batch_start()
                if stop:
                    break
                
                self.optimizer.zero_grad()
                log_h = self.net(x)
                loss = self.loss(log_h, durations, events, batch_ids)
                self.batch_loss = loss
                self.batch_metrics = {"loss": loss}
                
                loss.backward()
                stop = self.callbacks.before_step()
                if stop:
                    break
                self.optimizer.step()
                stop = self.callbacks.on_batch_end()
                if stop:
                    break
            else:
                stop = self.callbacks.on_epoch_end()
        self.callbacks.on_fit_end()
        
        return self.log

## *Test: Debugging*

In [92]:
# Create a PyTorch tensor
batch_indices = torch.tensor([1, 2, 2, 2, 2, 2, 3, 3, 4, 4], dtype=torch.float32)
durations = torch.tensor([169.5, 0.6, 12.3, 1.5, 3.8, 0.1, 0.1, 0.1, 0.6, 0.1], dtype=torch.float32)
events = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32)
log_h = torch.tensor([-4.1238, 2.1188, -1.5863, -1.2239, 0.9088, 5.6637, 1.2920, 4.5356, 1.5392, 5.0004], dtype=torch.float32)

# Test out ufnction
device = batch_indices.device
unique_batches = torch.unique(batch_indices)
losses = torch.zeros(len(unique_batches), device=device)

for i, batch in enumerate(unique_batches):
    # i = 1
    # batch = 1
    # print(i)
    mask = (batch_indices == batch)
    if mask.sum() == 0:
        print(f"batch {batch} is empty")
        continue
    idx = torch.argsort(durations[mask], descending=True)
    # idx = durations[mask].sort(descending=True)[1]
    log_h_batch = log_h[mask][idx]
    events_batch = events[mask][idx]
    
    print(events_batch)
    if events_batch.sum() == 0:
        print(f"batch {int(batch)} has no events")
        continue
    
    losses[i] = cox_ph_loss_sorted(log_h_batch, events_batch, eps=1e-7)
    
losses.sum()

tensor([0.])
batch 1 has no events
tensor([1., 1., 1., 1., 1.])
tensor([1., 1.])
tensor([1., 1.])


tensor(0.5826)

In [94]:
stratified_cox_ph_loss(log_h, durations, events, batch_indices)

tensor(0.5826)

# Full process

In [2]:
batchNormType='BE00Asso00_normNone'
dataType = 'linear-moderate'
keywords = ['061825']
test_size=10000
random_state=42
time_col='time'
status_col='status'
batch_col='batch.id'

train_df, test_df = load_simulate_survival_data(batchNormType=batchNormType,
                                                dataName=dataType,
                                                keywords=keywords, 
                                                keep_batch=True)

print(f"Training data dimensions: {train_df.shape}")
print(f"Testing data dimensions:  {test_df.shape}")

Training data dimensions: (90000, 541)
Testing data dimensions:  (10000, 541)


In [3]:
def _preprocess_data(df, mapper=None, fit_scaler=True):
    survival_cols = [time_col, status_col]
    covariate_cols = [col for col in df.columns if col not in survival_cols]
    
    if fit_scaler or mapper is None:
        standardize = [([col], StandardScaler()) for col in covariate_cols]
        mapper = DataFrameMapper(standardize)
        # Transform features (miRNA expression)
        x = mapper.fit_transform(df[covariate_cols]).astype('float32')
    else:
        x = mapper.transform(df[covariate_cols]).astype('float32')
    
    # Prepare labels (survival data)
    y = (df[time_col].values, df[status_col].values)
    
    return x, y, mapper

train_sub,_ = train_test_split(train_df,
                            train_size=1000, 
                            shuffle=True, random_state=42,
                            stratify=train_df[[status_col, batch_col]])

test_sub, _ = train_test_split(test_df,
                            train_size=1000, 
                            shuffle=True, random_state=42,
                            stratify=test_df[[status_col, batch_col]])

batch_ids_train = train_sub[batch_col].to_numpy().reshape(-1)
batch_ids_test = test_sub[[batch_col]].to_numpy().reshape(-1)

train_sub = train_sub.drop(columns=[batch_col])
test_sub = test_sub.drop(columns=[batch_col])

x_train, y_train, mapper = _preprocess_data(train_sub, fit_scaler=True)
x_test, y_test, _ = _preprocess_data(test_sub, mapper=mapper, fit_scaler=False)

durations_train, events_train = y_train[0], y_train[1]
durations_test, events_test = y_test[0], y_test[1]

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x_train = torch.from_numpy(x_train).to(device)
x_test = torch.from_numpy(x_test).to(device)
y_train = torch.tensor(y_train).transpose(0,1).to(device)
y_test = torch.tensor(y_test).transpose(0,1).to(device)

durations_train = torch.from_numpy(durations_train).float().to(device)
durations_test = torch.from_numpy(durations_test).float().to(device)
events_train = torch.from_numpy(events_train).float().to(device)
events_test = torch.from_numpy(events_test).float().to(device)

batch_ids_train = torch.from_numpy(batch_ids_train).long().to(device)
batch_ids_test = torch.from_numpy(batch_ids_test).long().to(device)

print(x_test.shape)           # Should be [n_samples, n_features]
print(y_test.shape)           # Should be [n_samples, n_features]
print(durations_test.shape)   # Should be [n_samples]
print(events_test.shape)      # Should be [n_samples]
print(batch_ids_test.shape)   # Should be [n_samples]

torch.Size([1000, 538])
torch.Size([1000, 2])
torch.Size([1000])
torch.Size([1000])
torch.Size([1000])


  y_train = torch.tensor(y_train).transpose(0,1).to(device)


In [5]:
# x_train.shape
input_size = x_train.shape[1]
output_size = 1
num_nodes = [32]            # Default # layers & nodes
dropout = 0.3                    # Default dropout rate
learning_rate = 1e-4      # Default learning rate
batch_size = 128               # Default batch size
epochs = 100                      # Default number of epochs
batch_norm = True             # Default batch normalization
output_bias = True           # Default output bias
weight_decay = 1e-4         # Default weight decay
activation = torch.nn.ReLU

net = tt.practical.MLPVanilla(
    in_features=input_size,
    out_features=output_size,
    num_nodes=num_nodes,
    dropout=dropout, 
    batch_norm=batch_norm,
    activation=activation,
    output_bias=output_bias
).to(device)
optimizer = tt.optim.Adam(weight_decay=weight_decay, lr=learning_rate)

# Get default early stopping settings if not defined 
patience = 40
min_delta = 1e-4
callbacks = [tt.callbacks.EarlyStopping(patience=patience, min_delta=min_delta)]

### CoxPH

In [13]:
# CoxPH model
model = CoxPH(net, optimizer=optimizer)
log = model.fit(
    x_train, y_train,
    batch_size=batch_size,
    epochs=epochs,
    callbacks=callbacks, 
    verbose=True,
    val_data=(x_test, y_test),
    val_batch_size=batch_size
)

0:	[0s / 0s],		train_loss: 3.8831,	val_loss: 3.5626
1:	[0s / 0s],		train_loss: 3.5485,	val_loss: 3.4992
2:	[0s / 0s],		train_loss: 3.4668,	val_loss: 3.4378
3:	[0s / 0s],		train_loss: 3.3699,	val_loss: 3.4001
4:	[0s / 0s],		train_loss: 3.2844,	val_loss: 3.3809
5:	[0s / 0s],		train_loss: 3.2342,	val_loss: 3.3483
6:	[0s / 0s],		train_loss: 3.1337,	val_loss: 3.3225
7:	[0s / 0s],		train_loss: 3.0652,	val_loss: 3.3020
8:	[0s / 0s],		train_loss: 3.0529,	val_loss: 3.2784
9:	[0s / 0s],		train_loss: 2.9835,	val_loss: 3.2556
10:	[0s / 0s],		train_loss: 2.9500,	val_loss: 3.2556
11:	[0s / 0s],		train_loss: 2.9056,	val_loss: 3.2324
12:	[0s / 0s],		train_loss: 2.8455,	val_loss: 3.2226
13:	[0s / 0s],		train_loss: 2.8475,	val_loss: 3.2211
14:	[0s / 0s],		train_loss: 2.8370,	val_loss: 3.2256
15:	[0s / 0s],		train_loss: 2.7594,	val_loss: 3.1857
16:	[0s / 0s],		train_loss: 2.7236,	val_loss: 3.1786
17:	[0s / 0s],		train_loss: 2.7048,	val_loss: 3.1772
18:	[0s / 0s],		train_loss: 2.6561,	val_loss: 3.1447
19:

In [14]:
# ==================== Evaluation ====================
_ = model.compute_baseline_hazards()

# # Convert torch tensors back to numpy objects for evaluation
# x_train_np = x_train.detach().cpu().numpy()
# x_test_np = x_test.detach().cpu().numpy()

# durations_train_np = durations_train.detach().cpu().numpy()
# events_train_np    = events_train.detach().cpu().numpy()
# durations_val_np   = durations_test.detach().cpu().numpy()
# events_val_np      = events_test.detach().cpu().numpy()

# Initialize EvalSurv objects 
tr_surv  = model.predict_surv_df(x_train)
te_surv = model.predict_surv_df(x_test)

tr_ev = EvalSurv(tr_surv, durations_train, events_train, censor_surv='km')
te_ev = EvalSurv(te_surv, durations_test, events_test, censor_surv='km')

# Concordance index ----------------
tr_c_index = tr_ev.concordance_td() 
te_c_index = te_ev.concordance_td() 
tr_c_index, te_c_index

(0.9502699643045979, 0.8175973395539272)

### Stratified CoxPH

In [56]:
# train_dataset = StratifiedDataset(x_train, durations_train, events_train, batch_ids_train)
# train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# test_dataset = torch.utils.data.TensorDataset(x_test, durations_test, events_test, batch_ids_test)
# test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# # # Access batches
# # for idx, (inputs, durations, events, batch_ids) in enumerate(train_loader):
# #     print(f"Batch {idx + 1}:")
# #     print("batch ids:", batch_ids)
# #     # print("Time:", durations)
# #     print("Events:", events)
# #     print()
    
# ## Test
# for idx, (inputs, durations, events, batch_ids) in enumerate(test_loader):
#     print(f"Batch {idx + 1}:")
#     print("batch ids:", batch_ids)
#     # print("Time:", durations)
#     print("Events:", events)
#     print()

In [6]:
# Get default early stopping settings if not defined 
# patience = 20
# min_delta = 0
# callbacks = [tt.callbacks.EarlyStopping(patience=patience, min_delta=min_delta)]

train_dataset = StratifiedDataset(x_train, durations_train, events_train, batch_ids_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# test_dataset = torch.utils.data.TensorDataset(x_test, durations_test, events_test, batch_ids_test)
test_dataset = StratifiedDataset(x_test, durations_test, events_test, batch_ids_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Stratified CoxPH model
model = CoxPHStratified(net, optimizer=optimizer)
model.metrics = {'val_loss': model.loss}
start = time.time() # Record iteration start time
log = model.fit_dataloader(
    train_loader,
    epochs=epochs,
    callbacks=callbacks,
    verbose=True,
    val_dataloader=test_loader  # optional for now
)
stop = time.time() # Record time when training finished
duration = round(stop - start, 2)
print(f"Training time: {duration}")

NameError: name 'StratifiedDataset' is not defined

In [99]:
# ==================== Evaluation ====================
_ = model.compute_baseline_hazards(input=x_train, target=(durations_train, events_train))

# Convert torch tensors back to numpy objects for evaluation
# x_train = x_train.detach().cpu().numpy()
# x_test  = x_test.detach().cpu().numpy()
# durations_train = durations_train.detach().cpu().numpy()
# durations_test  = durations_test.detach().cpu().numpy()
# events_train    = events_train.detach().cpu().numpy()
# events_test     = events_test.detach().cpu().numpy()

# Initialize EvalSurv objects 
tr_surv  = model.predict_surv_df(x_train)
te_surv = model.predict_surv_df(x_test)
tr_ev = EvalSurv(tr_surv, durations_train, events_train, censor_surv='km')
te_ev = EvalSurv(te_surv, durations_test, events_test, censor_surv='km')

# Concordance index ----------------
tr_c_index  = tr_ev.concordance_td() 
te_c_index = te_ev.concordance_td() 

tr_c_index, te_c_index

(0.8718146735088339, 0.7258435801802039)

# ==== Archive ====

In [5]:
# prepare data
folder = 'linear'
keywords = ['moderate', "latest", 'RW']

train_df, test_df = load_simulate_survival_data(folder=folder, keywords=keywords, test_size=0.2)

train_df.head()

Unnamed: 0,hsa.let.7a.2..1,hsa.let.7a.3..1,hsa.let.7a..2..1,hsa.let.7b.1,hsa.let.7b..1,hsa.let.7c.1,hsa.let.7c..1,hsa.let.7d.1,hsa.let.7d..1,hsa.let.7e.1,...,hsa.miR.96.1,hsa.miR.96..1,hsa.miR.98.1,hsa.miR.98..1,hsa.miR.99a.1,hsa.miR.99a..1,hsa.miR.99b.1,hsa.miR.99b..1,time,status
0,0.161142,16.139914,9.440727,12.377491,5.524234,11.493406,3.315916,13.211118,4.830156,10.275969,...,4.311983,0.000551,11.930738,6.893816,9.471225,1.244932,1.820734,0.499726,17.195448,0
1,0.737665,15.622746,7.350143,12.387646,4.14715,12.739532,3.693167,11.653065,6.348644,10.511325,...,5.211017,0.010937,10.01897,4.504616,12.851019,4.45857,9.536961,4.829052,1.190342,1
2,1.137077,15.838051,7.474055,14.062855,5.124704,13.175972,3.549938,12.224586,5.85129,10.720255,...,6.218995,0.023063,11.03204,5.361192,12.062099,3.112622,9.023745,4.720672,6.049298,1
3,2.152887,18.113934,7.656181,12.87939,4.768,13.147153,3.518407,12.940474,6.602133,11.160047,...,3.165273,0.015713,11.539787,5.979588,13.096312,4.459175,9.323495,5.122238,8.483713,1
4,3.068358,14.510637,7.261556,11.505635,5.917256,12.665515,4.810049,11.739157,6.542291,11.056379,...,17.231062,0.005042,10.274817,5.15452,13.246133,5.570715,11.376455,6.263673,23.514074,1


## Feature transforms


In [3]:
survival_cols = ['time', 'status']

In [4]:
tr_df, val_df = train_test_split(train_df, 
                                test_size=0.2,
                                shuffle=True, random_state=42,
                                stratify=train_df['status'])

# Transform data
covariate_cols = [col for col in train_df.columns if col not in survival_cols]
standardize = [([col], StandardScaler()) for col in covariate_cols]
leave = [(col, None) for col in survival_cols]
x_mapper = DataFrameMapper(standardize)

# gene expression data
x_train = x_mapper.fit_transform(tr_df[covariate_cols]).astype('float32')
x_val = x_mapper.fit_transform(val_df[covariate_cols]).astype('float32')
x_test = x_mapper.transform(test_df[covariate_cols]).astype('float32')

# prepare labels
get_target = lambda df: (df['time'].values, df['status'].values)
y_train = get_target(tr_df)
y_val = get_target(val_df)
t_test, e_test = get_target(test_df)
val = x_val, y_val

## Neural net

We create a simple MLP with two hidden layers, ReLU activations, batch norm and dropout. 
Here, we just use the `torchtuples.practical.MLPVanilla` net to do this.


In [None]:
in_features = x_train.shape[1]
num_nodes = [32, 16]
out_features = 1
batch_norm = True
dropout = 0.2
output_bias = True

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm,
                            dropout, output_bias=output_bias)

## Training the model

To train the model we need to define a `torch.optim` optimizer; here we instead use one from `tt.optim` as it has some added functionality.
We use the `Adam` optimizer and set the desired learning rate with `model.lr_finder`.

In [None]:
optimizer = tt.optim.Adam(weight_decay=0.01)

be_model = CoxPHStratified(net, optimizer)

# we  set it manually to 0.001
be_model.optimizer.set_lr(1e-3)

We include the `EarlyStopping` callback to stop training when the validation loss stops improving. After training, this callback will also load the best performing model in terms of validation loss.

In [None]:
%%time
batch_size = 64
epochs = 500
callbacks = [tt.callbacks.EarlyStopping(patience=20, min_delta=5e-2)]
verbose = True

batch_indices = np.ones(len(y_train[1]))
log = be_model.fit(x_train, y_train,
                batch_indices,
                batch_size,
                epochs,
                callbacks, 
                verbose=verbose,
                val_data=val, val_batch_size=batch_size
                )

TypeError: forward() missing 1 required positional argument: 'batch_indices'