# Implementing time-varying covariates

In this notebook, we analyse a simulated dataset with time-varying covariates and survival outcomes. `TorchSurv` is used to train a model that predicts relative risk of subjects based on covariates observed over time. We will attempt to thoroughly explain the necessary elements to understand our implementation, but for a detailed read on time-varying survival models refer to Chapter 6 of [Dynamic Regression Models for Survival Data](https://link.springer.com/book/10.1007/0-387-33960-4). For a more brief explanation, please refer to these [slides](https://ms.uky.edu/~mai/sta635/Cox%20model.pdf). Below is a summary of the necessary information.

### Future project ideas:
Future projects can take on various themes: testing edge cases of this implementation, improving the code to become more robust to different data types, include weibull distribution and compare this approate.
Testing edge cases:
- use the simulated data and change different parameters to see how it affects performance, this can help guide appropriate use
- design slightly different simulations known for being difficult or easy in specific scenarios
- use a dataset with known properties

Improving code to be more robust:
- generalising the loss functions and overall defining the formatting required or it to work, generalise for different time scales etc.
- extend the method to deal with multiple types of covariates in one loss function or a combination of multiple losses. This can extend to multiple time-varyin covariates and mixing time-invarint and varying ones.

Weibull:
- Extent the cox loss function to also include the Weibull distribution, this is described for both the log-likelihood and the simulation

Comparison
- One could compare this approach to other loss functions or statistical model to get an idea of what it brings as a benefit and a challenge. Note this comparison can be done via simulation or some dataset.


## Partial log likelihood for time-varying covariates

### Context and statistical set-up

Let $i$ e the index for some subject $i$ with a failute time denoted as $\tau^*_i$ and $C$ be the censoring time. For the moment $C$ remains constant but there are extensions that allow for $C$ to vary over $i$. Let $\tau_i = min(\tau^*_i, C)$. We use $\delta_i$ to denote whether $\tau^*_i$ was observed. 

We will use $Z(t)$ to denote the value of of covariate $Z$ and time $t$. 
We use $Z(t)$ to denote the value of Z at time $t$ and $\overline{Z}(t)$ to denote the set of covariates from the beginning up to time $t$: $ \overline{Z}(t) = \{ Z(s): 0 \leq s \leq t\}$.
Let $t_k$ for $k \in \{1, \dots, K\} denote the time points at which the covariates are observed. For the moment, we assume that all subjects have been observed on the same time grid. $R_k$ is the set of individuals who are at risk at $t_k$. 

The conditional hazard function of $T$ given $\overline{Z}(t)$ is defined as
$$ \lambda(T|\overline{Z}(t)) = Pr(T \in [t, t+ dt)|T \geq t, \overline{Z}(t)), $$
in other words, it is the probability that an event will occur in the next time instance if we have observed covariates up to time $t$ and that a subject has not yet experienced an event.

The typical cox proportional hazards model with constant covariates $Z$ assumes a constant hazard ratio: $\lambda(T|Z)= \lambda_0(t) exp(\beta Z)$, where $\beta$ in an unknown set of regression parameters and $\lambda_0(t)$ is an unspecified baseline hazard function. In this case $\frac{\lambda(T|Z)}{\lambda_0(t)} = exp(\beta Z) $. The cumulative hazard ia defined as $\Lambda(t) = \int_0^t \lambda(s)ds$. 

In a time varying cox model, the hazard ratio is now dependent on time:
$$ \frac{\lambda(t|Z)}{\lambda_0(t)} = exp(\beta Z(t)) $$ 
and the proportional hazard model specifies:
$$ \lambda(t|Z) = \lambda_0(t)exp(\beta Z(t)) $$

Let $i_j$ denote the label or identity of the individual who fails at time $\tau_j$, including the value of their time-varying covariate
during their time in the study $\{ Z_{i_j}(t): t \in [0, \tau_j] \}$. The partial likelihood is:
$$ L (\beta) = \prod_j \Big (\frac{\lambda(\tau_j: Z_i(\tau_j)))}{\sum_{l \in R_i} \lambda(\tau_j: Z_l(\tau_j)))} \Big),$$
in terms of the model form:
$$ L (\beta) = \prod_j \Big (\frac{\exp(\beta Z_i(\tau_j))}{\sum_{j \in R_i} \exp(\beta Z_i(\tau_j))} \Big).$$

Taking the log on both sides, we get the partial log-likelihood:
$$ \log L (\beta) = \sum_j \Big (\beta Z_i(\tau_j)) - \log [\sum_{j \in R_i} \exp(\beta Z_i(\tau_j))]\Big ). $$


### Extension to neural networks

Consider a more genera form, where we have the cox proportional hazards model:
$$\lambda(T|\overline{Z}(t))= \lambda_0(t) \theta(Z(t))$$

Additionally, consider some network that maps the input covariates $Z(t)$ to the log relative hazards: $\log \theta(Z(t))$.

The partial likelihood with respect to $\theta(Z(\tau_j))$ is written as:
$$ \log L(\theta) = \sum_j \Big( \log \theta(Z_i(\tau_j)) - \log [\sum_{j \in R_i} \theta (Z_i(\tau_j))] \Big).$$
It onlu considers the covariate values at the time of event or censoring denoted as $\tau_j$, all prior covariates are not considered.

As the output of the network is set to be $\log \theta(Z(t))$, the code is written to account for this, to show this explicitly, set $\phi(Z(t)) = \log \theta(Z(t))$ and write the log likelihood in terms oh $phi$:

$$ \log L(\theta) = \sum_j \Big( \phi(Z_i(\tau_j)) - \log [\sum_{j \in R_i} \exp \phi(Z_i(\tau_j))] \Big).$$


### Dependencies

To run this notebook, dependencies must be installed. the recommended method is to use our development conda environment (**preferred**). Instruction can be found [here](https://opensource.nibr.com/torchsurv/devnotes.html#set-up-a-development-environment-via-conda) to install all optional dependencies. The other method is to install only required packages using the command line below:

In [1]:
# Install only required packages (optional)
# %pip install lifelines
# %pip install matplotlib
# %pip install sklearn
# %pip install pandas

In [2]:
import warnings

warnings.filterwarnings("ignore")

In [3]:
import numpy as np
import pandas as pd
import torch

# Our package
# from torchsurv.loss.time_varying import neg_partial_log_likelihood2
# PyTorch boilerplate - see https://github.com/Novartis/torchsurv/blob/main/docs/notebooks/helpers_introduction.py

## Simulating realistic data

A good approach for simulating data is described in detail by [Ngwa et al 2020](https://pmc.ncbi.nlm.nih.gov/articles/PMC7731987/). If this is not yet implemented, it would be a good way of starting to ensure that both methods work as expected. There are tow parts in simulating such a dataset. First, simulating the longitudina lobservational data and then the survival data. Below we describe methodologies for both.

### Longitudinal data (covariates)

We use $i \in \{1, \dots, n\}$ to index subjects and $j \in \{1, \dots, m_i\}$ to index time points where $m_i$ is the final time point for subject $i$.
We simulate covariates independently:
- age at baseline $Age_i \sim N(35,5)$
- sex $\sim Bernoulli(p=0.54)$

Generate expected longitudinal trajectories $\varphi_{\beta}(t_{ij})$:

$$ \varphi_{\beta}(t_{ij}) = b_{i1} + b_{i2} \cdot t_{ij} + \alpha Age_i, $$

where $b_{i1}, b_{i2}$ are random effects

We will generate $b_{i1}, b_{i2}$ from multivariate normal distribution with a covariance matrix $G = [[0.29, -0.00465],[-0.00465, 0.000320]]$. Sample from this multivariate normal distribution (with mean zero) to get the random intercept and slope.

The observed longitudinal measures measures $Y_{ij}(t_{ij})$ from a multivariate normal distribution with mean $ \varphi_{\beta}(t_{ij})$ and variance $V$:

$$ V = Z_i GZ_i ^T + R_i, \text{ where }Z_i = [[1,1,1,1,1,1]^T, [0,5,10,15,20,25]^T]$$

and $R_i = diag(\sigma^2)$ and $\sigma^2$ is set to $0.1161$.

Note: Compared to the paper, we slightly adjust steps 3 and 4 from the simulation algorithm section (6.1) to avoid fitting a random effects model which adds more complexity in terms of data formatting. 

In [4]:
import torch.distributions as dist

# Set random seed for reproducibility
torch.manual_seed(123)

n = 100  # Number of subjects
T = torch.tensor(6)  # Number of time points
time_vec = torch.tensor([0, 1, 2, 3, 4, 5])

# Simulation parameters
age_mean = 35
age_std = 5
sex_prob = 0.54
G = torch.tensor([[0.29, -0.00465], [-0.00465, 0.000320]])
Z = torch.tensor([[1, 1, 1, 1, 1, 1], time_vec], dtype=torch.float32).T
sigma = torch.tensor([0.1161])
alpha = 1

# Simulate age at baseline
age_dist = dist.Normal(age_mean, age_std)
age = age_dist.sample((n,))

# Simulate sex
sex_dist = dist.Bernoulli(probs=sex_prob)
sex = sex_dist.sample((n,))

# Simulate random effects
random_effects_dist = dist.MultivariateNormal(torch.zeros(2), G)
random_effects = random_effects_dist.sample((n,))

# sample random error
error_sample = dist.Normal(0, sigma).sample((n,))

# Generate expected longitudinal trajectories
# quite frakly this is useless now - it was based on my bad understanding of the algorithm
trajectories = (
    random_effects[:, 0].unsqueeze(1)
    + random_effects[:, 1].unsqueeze(1) * Z[:, 1]
    + alpha * age.unsqueeze(1)
    + error_sample
)

print(trajectories[1:5, :])

tensor([[34.2016, 34.2186, 34.2356, 34.2526, 34.2696, 34.2866],
        [33.4380, 33.4308, 33.4235, 33.4163, 33.4091, 33.4018],
        [31.5581, 31.5564, 31.5548, 31.5531, 31.5515, 31.5498],
        [35.7813, 35.7953, 35.8093, 35.8233, 35.8373, 35.8513]])


In [5]:
## ANOTHER WAY OF GENERATING DATA

# # Simulate observed longitudinal measures
# R = torch.diag_embed(sigma.repeat(T))
# V = torch.matmul(torch.matmul(Z, G), Z.T) + R

# #get a mean trajectory
# b1 = torch.tensor([4.250])
# b2 = torch.tensor([0.250])
# mean_trajectory =  b1.item() + b2.item() * Z[:,1] + alpha * age_mean

# #define the distribution to sample the trajectories from
# observed_data_dist = dist.MultivariateNormal(trajectories, V)

# #sample from the distribution to get an n x T matrix of observations/covariates
# observed_data = observed_data_dist.sample((1,)).squeeze()

# print(observed_data[1:5, :])

### Survival data (outcomes)

here I will describe how to get the survival and censoring for all the subjects from above. then I will code it up in python.

Specify (varying) values for the parameter estimates for $Age$, $sex$ and the link parameter $\gamma$, which measures the strength of the association between the longitudinal measures $Y_{ij}(t_{ij})$ and the time-to-event $\tau_j$.

Let $Q \sim Unif(0,1)$ be a random variable that determines the hazard of a subject. Then using the time varying cox model it can be expressd as:

$$ Q(t;X,Y) = \exp[-H_0(t)\cdot \exp(X^T\alpha + \gamma (b_{i1} + b_{i2} \cdot t))],$$
$X^T$ is a vector of tine-invariant covariates, $\alpha$ a vector of regression coefficients.

$H_o(t) = \lambda t$ and if $h_0(t)>0$ for all $t$, then $H_0$ can be inverted:
$$-\log(Q) = \lambda t \cdot \exp[X^T \alpha + \gamma (b_{i1} + b_{i2} \cdot t) ] $$
This expression can be rearranged to generate the times-to-event.

Generate the time-to-event $\tau_j$ using the following equations for the Cox Exponential model:
$$ t = \frac{1}{\gamma \cdot b_{i2}} W \Big( \frac{-\gamma(b_{i2}) \log(Q)}{\lambda \exp (X^T \alpha + \gamma(b_{i1}))} \Big). $$

Where $W$ is the Lambert W function (LWF) first proposed by [Corless et al. 1996](https://link.springer.com/article/10.1007/BF02124750) provide a history, theory and applications of the LWF. The LWF is the inverse of the function $f(p) = p \cdot \exp(p) $.

Generate the censoring variable $C \sim Unif⁡(25, 30)$ for censoring to occur later in study. From the survival and censoring times, we obtain the censoring indicator $\delta_i$ which is defined as 1 if $\tau_j < C_i$ and 0 otherwise.


In [6]:
# import lmbert W function

from scipy.special import lambertw

Note: pre-determined parameters such as $\alpha, \gamma, \lambda_0$ have a large effect on the event time outcomes, the values used here are:
- $\alpha_{age} = 0.05$,
- $\alpha_{sex} = -0.1$,
- $\gamma = 1.2$,
- $\lambda_0 = 0.04$


In [7]:
# Specify the values for parameters, generate the random variables and call on relevant variables defined previously

alpha = torch.tensor([0.05, -0.1])  # regression coefficient for time-invariant covariates
gamma = torch.tensor(1.2)  # association strength between longitudinal measures and time-to-event
lambda_0 = torch.tensor(0.04)  # baseline hazard rate

torch.manual_seed(123)

# Generate the random variables for hazard of a subject and censoring
Q = dist.Uniform(0, 1).sample((n,))  # Random variable for hazard (Q)
C = dist.Uniform(3, 5.5).sample((n,))  # Random variable for censoring

# age and sex are the names of variables corresponding to those covariates
# create the X matrix of covariates
XX = torch.stack((age, sex), dim=1)

# get b1 and b2 from the random sample we made before
b1 = random_effects[:, 0]
b2 = random_effects[:, 1]

# Generate time to event T using the equation above
log_Q = torch.log(Q)
lambert_W_nominator = gamma * b2 * log_Q
lambert_W_denominator = torch.exp(alpha @ XX.T + gamma * b1)
# below should give a vector of length sample_size
lambert_W = lambertw(-lambert_W_nominator / (lambda_0 * lambert_W_denominator))
time_to_event = lambert_W / (gamma * b2)

# take the real part of the LBF, the complex part is =0
outcome_LWF = time_to_event.real
outcome_LWF = torch.ceil(outcome_LWF)
print(outcome_LWF)

# implement censoring with some level of intensity below
events = C < 5
events

tensor([ 4.,  3.,  8.,  5., 22.,  1., 12.,  7.,  7.,  2.,  3.,  3., 12.,  3.,
        10.,  4., 17.,  1.,  3.,  2.,  1.,  3.,  1.,  1., 12.,  5.,  9.,  1.,
         2.,  8.,  1.,  2.,  1.,  1.,  1.,  7.,  7., 13., 11.,  3.,  6., 10.,
         5.,  1.,  3.,  3.,  3.,  4.,  1.,  1.,  3.,  2.,  2.,  2.,  5., 38.,
         4.,  1.,  5.,  9.,  3.,  1., 19.,  4.,  5.,  4.,  4., 16., 15.,  1.,
         1.,  2.,  4.,  9.,  2.,  3.,  6., 19.,  7.,  1.,  1.,  4.,  6.,  3.,
         8.,  6.,  8.,  2.,  1.,  2., 32.,  1.,  2.,  1., 13.,  2.,  1.,  2.,
         1.,  3.], dtype=torch.float64)


tensor([False,  True,  True,  True,  True,  True,  True, False,  True,  True,
        False,  True,  True,  True,  True,  True,  True,  True,  True, False,
         True,  True, False,  True, False,  True,  True,  True,  True, False,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True, False,  True,  True,  True,  True,  True,  True,
         True,  True,  True, False,  True, False,  True,  True,  True,  True,
         True, False,  True, False, False, False,  True, False, False, False,
         True,  True,  True,  True,  True,  True,  True,  True, False,  True,
         True,  True, False,  True,  True,  True,  True,  True, False,  True,
         True,  True,  True, False,  True,  True,  True,  True, False,  True])

A simpler method for generating the time-to-event where the covariate is assumed to have a more straightforward relation in time $Z(t) = kt$ for some $k>0$. This approach is suggested by [Peter C. Austin 2012](https://pmc.ncbi.nlm.nih.gov/articles/PMC3546387/pdf/sim0031-3946.pdf) and here 
$$ t = \frac{1}{\gamma k} \log \Big ( 1 + \frac{\gamma k (-log(u))}{\lambda \exp(\alpha X)}\Big). $$
The above equation has been adapted to remain consistent with the parameters defined before. In our case, $k$ could be replaced with $b_{i2}$ if $b_{i2}$ would be sampled such that it is strictly positive. In the above configuration that is not the case.



### Data Format

Here we create a single matrix of data that corresponds to one covariate being observed over time for some dataset.
The time series is padded with zeros so that each subject has the same length vector, the vector contains their covariate $Z_i(t)$ up until failure time $\tau_j$ and then values beyond that are zero.

In general, prior to fitting a survival model or a network, one should consider ohw to handle missing data beforehand. This is most important for covariates that are missing at event time $\tau_j $. Data imputation methods can vary depending on the use case but some to consider are:
- use the most recent value (assumes step function),
- interpolate,
- impute based on some model.

## Training the RNN 

Below we will give an example set up of how to use the partial log likelihood in a loss function. We import the python file containing the loss and set up an RNN to work with our simulated data.

In [8]:
from importlib import reload

import loss_time_covariates

reload(loss_time_covariates)
log_likelihood = loss_time_covariates._partial_likelihood_time_cox
neg_loss_function = loss_time_covariates.neg_partial_time_log_likelihood

In [9]:
# from torchsurv.loss import time_covariates
# from torchsurv.metrics.cindex import ConcordanceIndex

# Parameters
input_size = 1
output_size = 1
num_layers = 2
seq_length = T
batch_size = n

# Create simple RNN model
rnn = torch.nn.RNN(input_size, output_size, num_layers)
inputs = torch.randn(seq_length, batch_size, input_size)
test = trajectories.T.unsqueeze(2)
print(test.shape)
print(inputs.shape)

# initialize hidden state
h0 = torch.randn(num_layers, batch_size, output_size)
print(h0.shape)
# Forward pass time series input
outputs, _ = rnn(test, h0)
print(outputs.shape)

# outcome_LWF is the time someone experiences an event
loss = neg_loss_function(outputs, outcome_LWF, events)
print(f"loss = {loss}, has gradient = {loss.requires_grad}")  # loss = 1.0389232635498047, has gradient = True

torch.Size([6, 100, 1])
torch.Size([6, 100, 1])
torch.Size([2, 100, 1])
torch.Size([6, 100, 1])
loss = 1.0968555212020874, has gradient = True


## Comparison to Lifelines package

Re-format the simulation data to fit a normal time-varying cox model in the lifelines package.

In [10]:
from numpy import sum as array_sum_to_scalar

# as a reminder covars is the matrix of covariates where a row corresponds to a subject and a column corresponds to their observation at some time
# the columns are padded so if a subject experiences an event, the remaining of the column is zero
# Generating example torch matrix
torch_matrix = trajectories
# Convert torch matrix to pandas dataframe
# set time to integer
max_time = max(time_vec.type(torch.int64))
print(max_time)
event_time = outcome_LWF
vars = []
start = []
stop = []
event = []
subjs = []

for i in range(n):
    subj_counter = 0
    subj_event_time = int(event_time[i].item())
    # print(subj_event_time)
    for j in range(subj_event_time):
        if j < max_time:
            vars.append(torch_matrix[i, j].item())
            start.append(j)
            stop.append(j + 1)
            event.append(False)
            subjs.append(i)
            subj_counter += 1
        if j >= max_time:
            vars.append(torch_matrix[i, -1].item())
            start.append(j)
            stop.append(subj_event_time)
            event.append(False)
            subjs.append(i)
            subj_counter += 1
            break
    # set the last value to have an event
    event[-1] = True
    # if you want censoring use below
    # if events[i]==True: event[-1]=True

# for every time point before they experience an event
# record all their variables until event time
df = pd.DataFrame(
    {
        "subj": subjs,
        "start": start,
        "stop": stop,
        "events": event,
        "var": vars,
    }
)

df.head(10)

tensor(5)


Unnamed: 0,subj,start,stop,events,var
0,0,0,1,False,37.006706
1,0,1,2,False,37.01387
2,0,2,3,False,37.021034
3,0,3,4,True,37.028198
4,1,0,1,False,34.201637
5,1,1,2,False,34.218636
6,1,2,3,True,34.235634
7,2,0,1,False,33.438019
8,2,1,2,False,33.430782
9,2,2,3,False,33.42355


We will compute the lgo likelihood using the code from lifelines to compare our method to theirs. This snippet of code is taken from [cox_time_varying_fitter.py](https://github.com/CamDavidsonPilon/lifelines/blob/master/lifelines/fitters/cox_time_varying_fitter.py) on lines 499-550.

In [11]:
stop_times = df["stop"]
event_bool = df["events"]
unique_death_times = np.unique(stop_times[event_bool])
covariates = df["var"]
# the following is an internal column in lifelines, since we do not define it in this simulation it is set to 1.0
# this is also done in the code at lines 182-185
weights = np.ones(len(df))


# below is defined at line 50, unsure what this means
def matrix_axis_0_sum_to_1d_array(m):
    return np.sum(m, 0)


# we will be replacing x*beta from the code with out outputs from the network as written in the beginning of this notebook
# network_out = outputs
# print(network_out.shape)
beta = np.array(
    [1],
)
# print(beta)
# print(beta.shape)
# np.dot(X_at_t, beta)
for t in unique_death_times:
    # returns a boolean vector of length nxT in our case
    ix = (start < t) & (t <= stop)
    # returns a vector of covariates at event time
    X_at_t = covariates[ix]
    weights_at_t = weights[ix]
    stops_events_at_t = stop_times[ix]
    events_at_t = event_bool[ix]

    # changed dot product to multiply cause dot is no longer supported in that way
    phi_i = weights_at_t * np.exp(np.multiply(X_at_t, beta))
    print(phi_i.shape)
    # removed indexing from original code cause we only have 1 dim
    phi_x_i = phi_i * X_at_t
    phi_x_x_i = np.dot(X_at_t.T, phi_x_i)

    # Calculate sums of Risk set
    risk_phi = array_sum_to_scalar(phi_i)
    risk_phi_x = matrix_axis_0_sum_to_1d_array(phi_x_i)
    risk_phi_x_x = phi_x_x_i

    # Calculate the sums of Tie set
    deaths = events_at_t & (stops_events_at_t == t)

    tied_death_counts = array_sum_to_scalar(deaths.astype(int))  # should always at least 1. Why? TODO

    xi_deaths = X_at_t[deaths]

    x_death_sum = matrix_axis_0_sum_to_1d_array(weights_at_t[deaths] * xi_deaths)

    weight_count = array_sum_to_scalar(weights_at_t[deaths])
    weighted_average = weight_count / tied_death_counts

    # no tensors here, but do some casting to make it easier in the converging step next.
    denom = 1.0 / np.array([risk_phi])
    number = risk_phi_x
    a1 = risk_phi_x_x * denom

summand = number * denom[:, None]
a2 = summand.T.dot(summand)
log_lik = np.dot(x_death_sum, beta) + weighted_average * np.log(denom).sum()

log_lik

(100,)
(76,)
(62,)
(47,)
(38,)
(32,)
(28,)
(23,)
(19,)
(16,)
(14,)
(13,)
(10,)
(8,)
(7,)
(6,)
(5,)
(3,)
(2,)
(1,)


array([0.])

Fitting a cox regression model using the lifelines package.


In [None]:
from lifelines import CoxTimeVaryingFitter

ctv = CoxTimeVaryingFitter(penalizer=0)
ctv.fit(df, id_col="subj", event_col="events", start_col="start", stop_col="stop", show_progress=True)
ctv.print_summary()
ctv.plot()

## Real life data: heart transplant survival

This is to demonstrate the method with a neural network, example inspired by the [lifelines example](https://lifelines.readthedocs.io/en/latest/Time%20varying%20survival%20regression.html#).

This is a classic dataset for survival regression with time varying covariates. The original dataset is from J Crowley and M Hu. 'Covariance analysis of heart transplant survival data', and this dataset is from R’s survival library.


In [None]:
import lifelines

df = lifelines.datasets.load_stanford_heart_transplants()
df.head(5)

The dataset contains the following:

- `start`: entry time,
- `stop`: exit time,
- `event`: status for this interval of time,
- `age`: subjetct's age -48 years,
- `year`: tyear of acceptance (in years after 1 Nov 1967)
- `surgery`: prior bypass surgery 1=yes
- `transplant`: received transplant 1=yes
- `id`: patient id