# [Survival Analysis](https://en.wikipedia.org/wiki/Survival_analysis)

from : https://www.pymc.io/projects/examples/en/latest/survival_analysis/survival_analysis.html

## Theory

If the random variable $T$ is the time to the event we are studying, survival analysis is primarily concerned with the survival function: 
$$
S(t) = Pr(T > t) = 1 - F(t)
$$

Where $F(t)$ is the cumulative distribution function of $T$.

The survival function is the probability that the event will occur by time $t$ and is express in terms of [hazard rate](https://en.wikipedia.org/wiki/Survival_analysis#Hazard_function_and_cumulative_hazard_function), $\lambda(t)$, which is the instantaneous probability that the event will occur by time $t$ given that it has not already occurred by time $t$.

$$
\lambda(t) = \lim_{\Delta t \to 0} \frac{Pr(t < T > t + \Delta t | T > t)}{\Delta t}
$$

Which can be simplified to
$$
\lambda(t) = \frac{S'(t)}{S(t)}
$$

Solving this differential equation for the survival function shows that:

$$
S(t) =  exp(-\int_0^t \lambda(s) ds)
$$

This representation of the survival function shows that the cumulative hazard function is:

$$
\Lambda(t) = \int_0^t \lambda(s) ds
$$

An important, but subtle, point in survival analysis is [censoring](https://en.wikipedia.org/wiki/Survival_analysis#Censoring). 

## Bayesian proportional hazards model
Perhaps the most commonly used risk regression model is [Cox’s proportional hazards model(https://en.wikipedia.org/wiki/Proportional_hazards_model)]. In this model, if we have covariates $x$ and regression coefficients $\beta$, the hazard rate is modeled as :

$$
\lambda(t) = \lambda_0(t)exp(x\beta)
$$

Here $\lambda_0(t)$ is the baseline hazard rate which is independent of the covariates $x$.
In order to perform Bayesian inference with the Cox model, we must specify priors on $\beta$ and $\lambda_0(t)$:
$$\beta \sim N(\mu_\beta, \sigma^2_\beta)$$
$$\mu_\beta \sim Nomral(0, 10^2)$$ 
$$\sigma^2_\beta \sim Uniform(0,10)$$

$\lambda_0(t)$ is set as a semiparametric prior, a constant function, which require to partition the time range into intervals. This requires to choose priors for the $N-1$ intervals $\lambda_j$. $\lambda_j$ is usually a [Weibull distribution](https://en.wikipedia.org/wiki/Weibull_distribution) or a [Gamma distribution](https://en.wikipedia.org/wiki/Gamma_distribution).


The key observation is that the piecewise-constant proportional hazard model is closely related to a Poisson regression model. (The models are not identical, but their likelihoods differ by a factor that depends only on the observed data and not the parameters $\beta$ and $\lambda_j$.
We define indicator variables based on whether the $i$-th subject died in the $j$-th interval:

$$
d_{ij} = \begin{cases} 
1 & \text{if subject } i \text{ died in interval } j, \\
0 & \text{otherwise.}
\end{cases}
$$

We also define $t_{i,j}$ to be the amount of time the $i$-th subject was at risk in the $j$-th interval.


## Code

In [2]:
import pandas as pd


from main import *

df = pd.read_csv(pm.get_data("mastectomy.csv"))
df.metastasized = (df.metastasized == "yes").astype(np.int64)
m = bi()

intervals, death, exposure = m.to_discrete_time(df)

m.data_on_model = {}
m.data_on_model['intervals'] = jnp.array(intervals)
m.data_on_model['death'] = jnp.array(death)
m.data_on_model['metastasized'] = jnp.array(df.metastasized.values)

jax.local_device_count 32


In [3]:
def model(intervals, death, metastasized):
    lambda0 = bi.dist.gamma(0.01, 0.01, shape= intervals.shape, name = 'lambda0')
    beta = bi.dist.normal(0, 1000, name='beta')

    lambda_ = numpyro.deterministic('lambda_', jnp.outer(jnp.exp(beta * metastasized), lambda0))
    mu = numpyro.deterministic('mu', exposure * lambda_)
    y = lk('obs', Poisson(mu + jnp.finfo(mu.dtype).tiny), obs = death)

m.run(model, num_samples=500) 
m.summary()

sample: 100%|██████████| 1000/1000 [00:12<00:00, 81.80it/s, 511 steps of size 5.37e-03. acc. prob=0.96]


Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%
beta,0.75,0.43,0.08,1.40
lambda0[0],0.00,0.00,0.00,0.00
lambda0[1],0.00,0.00,0.00,0.01
lambda0[2],0.00,0.01,0.00,0.01
lambda0[3],0.00,0.01,0.00,0.01
...,...,...,...,...
"mu[43, 71]",0.00,0.05,0.00,0.00
"mu[43, 72]",0.01,0.12,0.00,0.00
"mu[43, 73]",0.01,0.05,0.00,0.00
"mu[43, 74]",0.01,0.05,0.00,0.00


In [5]:
m.plot_surv(m)

NameError: name 'plt' is not defined