# Implementing time-varying covariates

In this notebook, we analyse a simulated dataset with time-varying covariates and survival outcomes. `TorchSurv` is used to train a model that predicts relative risk of subjects based on covariates observed over time. We will attempt to thoroughly explain the necessary elements to understand our implementation, but for a detailed read on time-varying survival models refer to Chapter 6 of [Dynamic Regression Models for Survival Data](https://link.springer.com/book/10.1007/0-387-33960-4). 

### Dependencies

To run this notebook, dependencies must be installed. the recommended method is to use our developpment conda environment (**preffered**). Instruction can be found [here](https://opensource.nibr.com/torchsurv/devnotes.html#set-up-a-development-environment-via-conda) to install all optional dependancies. The other method is to install only required packages using the command line below:

In [None]:
# Install only required packages (optional)
# %pip install lifelines
# %pip install matplotlib
# %pip install sklearn
# %pip install pandas

In [1]:
import warnings

warnings.filterwarnings("ignore")

In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# Our package
#from torchsurv.loss.time_varying import neg_partial_log_likelihood2

# PyTorch boilerplate - see https://github.com/Novartis/torchsurv/blob/main/docs/notebooks/helpers_introduction.py
from helpers_introduction import Custom_dataset, plot_losses

## Simulating a dataset

We will simulate a dataset of 100 subjects with 6 follow up times where a covariate is observed. The covariates will change over time slightly but will be generated from one random variable per subject so that 

In [None]:
# defining parameters
sample_size = 100  #number of subjects to generate
obs_time = 6 #number of observations over time for each subject

# create random variables following a normal distribution N(1,1) for each subject 
mean = 1
standard_dev = 1
random_vars = torch.randn(sample_size)*standard_dev + mean

# using the random variables from above, we create a set of covariates for each subject 
t = torch.linspace(0, 2*math.pi, 6)  # Generating 6 equidistant time points from 0 to 2*pi

# Creating the matrix
sample_size = 100  #number of subjects to generate
matrix = torch.zeros(sample_size, 6)

# Filling the matrix with sin values
for i in range(6):
    matrix[:, i] = torch.sin(t[i])

# Multiplying with a vector of random variables
sample_size = 100  #number of subjects to generate
random_vars = torch.randn(sample_size)
result = torch.matmul(matrix.T, random_vars.unsqueeze(1))

In [None]:
# make random boolean events
events = random_vars > 0.5
print(events)  # tensor([ True, False,  True,  True, False, False,  True, False])

# make random positive time to event
time =  random_vars * 100
print(time)  # tensor([32.8563, 38.3207, 24.6015, 72.2986, 19.9004, 65.2180, 73.2083, 21.2663])

## Implementing partial log likelihood for time-varying covariates

Let $T*_i$ be the be the failure time of interest for subject $i$ and $C$ be the censoring time. Let $T_i = min(T*, C)$. We use $\delta_i$ to denote whether $T*_i$ was observed. We will use $Z(t)$ to denote the value of of covariate $Z$ and time $t$. Let $t_k$ for $k \in \{1, \dots, K\} denote the time points at which the covariates are observed. For the moment, we assume that all subjects have been observed on the same time grid. $R_k$ is the set of individuals who are at risk at $t_k$.


Consider a network that outputs a vector $\theta$ for each observed covariate $Z(t_k)$, which can be denoted as $\theta(t_k)$. The vector of these values can be written to be $\theta_K$. Similarly, $Z_K$ can be the vector of the covariate history up until time K. 

The log likelihood in terms of $\theta(t_k)$ can be written as follows.

$$ l(\theta) = \sum_{i=1}^n \delta_i \Big ( \frac{\sum_{j \in R_i} exp(\theta_K)Z_K Z_K^T}{\sum_{j \in R_i} exp(\theta_K)}-\frac{[\sum_{j \in R_i} exp(\theta_K)Z_K][\sum_{j \in R_i} exp(\theta_K)Z_K]^T}{\sum_{j \in R_i} exp(\theta_K)}\Big)$$



In [None]:

def time_partial_log_likelihood(
    log_hz: torch.Tensor, #nx1 vector
    event: torch.Tensor, #n vector (i think)
    time: torch.Tensor, #n vector (i think)
    covariates: torch.Tensor, #nxp vector, p number of params
) -> torch.Tensor:

    # sort data by time-to-event or censoring
    time_sorted, idx = torch.sort(time)
    log_hz_sorted = log_hz[idx]
    event_sorted = event[idx]

    exp_log_hz = torch.exp(log_hz_sorted)
    #need to sort the covariate here as well 
    #sort covariates so that the rows match the ordering
    covariates_sorted = covariates[idx, :]

    #the left hand side (HS) of the equation
    #below is Z_k Z_k^T - i think it should be a vector matrix dim nxn
    covariate_inner_product = torch.matmul(covariates_sorted, covariates_sorted.T)
    
    #pointwise multiplication of vectors to get the nominator of left HS
    #outcome in a vector of length n
    # Ends up being (1, n)
    log_nominator_left = torch.matmul(exp_log_hz.T, covariate_inner_product)

    #right hand size of the equation
    #formulate the brackets \sum exp(theta)Z_k
    bracket = torch.mul(exp_log_hz, covariates_sorted)
    nominator_right = torch.matmul(bracket, bracket.T) #nxn matrix
    ###not sure if the next line is this
    #log_nominator_right = torch.sum(nominator_right, dim=0).unsqueeze(0)
    ### or this
    log_nominator_right = nominator_right[0,].unsqueeze(0)
    #the denominator is the same on both sides
    log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0) #dim=0 sums over the oth dimension
    partial_log_likelihood = torch.div(log_nominator_left - log_nominator_right, log_denominator) # (n, n)
    return (partial_log_likelihood)[event_sorted]


## Testing on the old dataset for dimensions sake

using the data from the introduction notebook just to make sure dimensions work, this is not correct implementation

In [37]:
import lifelines

In [None]:
# Load GBSG2 dataset
df = lifelines.datasets.load_gbsg2()
df.head(5)

In [None]:
# Constant parameters accross models
# Detect available accelerator; Downgrade batch size if only CPU available
if any([torch.cuda.is_available(), torch.backends.mps.is_available()]):
    print("CUDA-enabled GPU/TPU is available.")
    BATCH_SIZE = 128  # batch size for training
else:
    print("No CUDA-enabled GPU found, using CPU.")
    BATCH_SIZE = 32  # batch size for training

EPOCHS = 100
LEARNING_RATE = 1e-2

In [None]:
df_onehot = pd.get_dummies(df, columns=["horTh", "menostat", "tgrade"]).astype("float")
df_onehot.drop(
    ["horTh_no", "menostat_Post", "tgrade_I"],
    axis=1,
    inplace=True,
)
df_onehot.head(5)

In [None]:
df_train, df_test = train_test_split(df_onehot, test_size=0.3)
df_train, df_val = train_test_split(df_train, test_size=0.3)
print(
    f"(Sample size) Training:{len(df_train)} | Validation:{len(df_val)} |Testing:{len(df_test)}"
)

In [None]:
# Dataloader
dataloader_train = DataLoader(
    Custom_dataset(df_train), batch_size=BATCH_SIZE, shuffle=True
)
dataloader_val = DataLoader(
    Custom_dataset(df_val), batch_size=len(df_val), shuffle=False
)
dataloader_test = DataLoader(
    Custom_dataset(df_test), batch_size=len(df_test), shuffle=False
)

In [None]:
cox_model = torch.nn.Sequential(
    torch.nn.BatchNorm1d(num_features),  # Batch normalization
    torch.nn.Linear(num_features, 32),
    torch.nn.ReLU(),
    torch.nn.Dropout(),
    torch.nn.Linear(32, 64),
    torch.nn.ReLU(),
    torch.nn.Dropout(),
    torch.nn.Linear(64, 1),  # Estimating log hazards for Cox models
)

In [None]:
# This is for testing the loss function
x_test, (test_event, test_time) = next(iter(dataloader_train))

log_hz = cox_model(x_test)

In [None]:
print('x_test', x_test.shape)
print('events', test_event.shape)
print('times', test_time.shape)

time_sorted, idx = torch.sort(time)
log_hz_sorted = log_hz[idx]
event_sorted = event[idx]
time_unique = torch.unique(time_sorted)
print('')
print("time_sorted", time_sorted.shape)
print('log_hz_sorted', log_hz_sorted.shape)
print('event_sorted', event_sorted.shape)
print("time_unique", time_unique.shape)

print('-'*30)
cov_fake = torch.clone(x_test)
print('covariates', cov_fake.shape)
covariates_sorted = cov_fake[idx, :]
covariate_inner_product = torch.matmul(covariates_sorted, covariates_sorted.T)
print('cov_inner', covariate_inner_product.shape)
log_nominator_left = torch.matmul(log_hz_sorted.T, covariate_inner_product)
print('log_nom_left', log_nominator_left.shape)
bracket = torch.mul(log_hz_sorted, covariates_sorted)
print('bracket', bracket.shape)
log_nominator_right = torch.matmul(bracket, bracket.T)
print('log_nom_right', log_nominator_right.shape)
sum_nominator_right = log_nominator_right[0,].unsqueeze(0)
print('sum_nom', sum_nominator_right.shape)
log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0).T
print('log_denom', log_denominator.shape)
last_bit = torch.div(log_nominator_left - sum_nominator_right, log_denominator)
print('last_bit', last_bit.shape)
last_bit


tensor([1.3927, 1.5773, 0.0192, 0.1983])

## RNN Example from Github

In [None]:
import torch
from torchsurv.loss import cox
from torchsurv.metrics.cindex import ConcordanceIndex

# Parameters
input_size = 10
output_size = 1
num_layers = 2
seq_length = 5
batch_size = 8

# make random boolean events
events = torch.rand(batch_size) > 0.5
print(events)  # tensor([ True, False,  True,  True, False, False,  True, False])

# make random positive time to event
time = torch.rand(batch_size) * 100
print(time)  # tensor([32.8563, 38.3207, 24.6015, 72.2986, 19.9004, 65.2180, 73.2083, 21.2663])

# Create simple RNN model
rnn = torch.nn.RNN(input_size, output_size, num_layers)
inputs = torch.randn(seq_length, batch_size, input_size)
h0 = torch.randn(num_layers, batch_size, output_size)

# Forward pass time series input
outputs, _ = rnn(inputs, h0)
estimates = outputs[-1]  # Keep only last predictions, many to one approach
print(estimates.size())  # torch.Size([8, 1])
print(f"Estimate shape for {batch_size} samples = {estimates.size()}")  # Estimate shape for 8 samples = torch.Size([8, 1])


loss = cox.neg_partial_log_likelihood(estimates, events, time)
print(f"loss = {loss}, has gradient = {loss.requires_grad}")  # loss = 1.0389232635498047, has gradient = True

cindex = ConcordanceIndex()
print(f"c-index = {cindex(estimates, events, time)}")  # c-index = 0.20000000298023224