Cox's proportional hazards model is widely used in medical research and other fields to identify risk factors associated with an outcome of interest, and to estimate how these risk factors affect the time to the outcome while controlling for other variables. Which also offer great flexibility due to its semi-parametric nature.

The hazard function for the survival time $T$ associated with a $p$-vector of possibly time-varying covariates $\mathbf{Z}(\cdot)$ takes the form 
$$\lambda(t\mid \mathbf{Z})=\lambda_0(t)\exp^{\boldsymbol{\beta}^{\prime}\mathbf{Z}(t)}$$
under the proportional hazards model, where $\boldsymbol{\beta}$ is $p$-vector of regression parameters and $\lambda_0(\cdot)$ is baseline hazard function. 

For many reasons not all survival time are fullly observed, we only consider the right-censoring situation here and let $C$ be the censoring time. Denote the observed time by $X=\min\{T,C\}$ and the censoring indicator by $\delta=I(T\leq C)$. Define the observed-failure counting process $N(t)=I(X\leq t,\delta=1)$ which registers whether an uncensored failure has occurred by time $t$ and the corresponding at-risk indicator $Y(t)=I(X\geq t)$. Without loss of generality, denote 1 as the terminal time of observation and there are $n$ independent data $\{(X_i,\delta_i,\mathbf{Z}_i): i=1,2,\cdots,n\}$.

The negative log partial likelihood function of proportional hazards model is
$$L_2(\boldsymbol{\beta})=-\sum_{i=1}^{n}\int_0^1 \boldsymbol{\beta}^{\prime}\mathbf{Z}_i - \log \left(\sum_{j=1}^n Y_j(t)\exp(\boldsymbol{\beta}^{\prime}\mathbf{Z}_j)\right) \mathrm{d} N_i(t).$$
We can estimate $\boldsymbol{\beta}$ by minimizing the negative log partial likelihood function under sparsity constraint:
$$arg\min_{\beta \in R^p}L_2(\beta) , s.t.  || \beta ||_0 \leq s.$$

Here is Python code for solving sparse proportional hazards model:

In [4]:
from abess.datasets import make_glm_data
import jax.numpy as jnp
from scope import ScopeSolver

n, p, k = 200, 10, 2
data = make_glm_data(n, p, k, family="cox", c = 3)

def phazard_objective(params):
    Xbeta = jnp.matmul(data.x, params)
    logsum = jnp.zeros_like(Xbeta)
    for i in range(0,n):
        logsum = logsum.at[i].set(jnp.log(jnp.dot(data.y[:,0] >= data.y[:,0][i], jnp.exp(Xbeta))))
    return jnp.dot(data.y[:,1],logsum)/n-jnp.dot(data.y[:,1], Xbeta)/n


solver = ScopeSolver(p, k)
solver.solve(phazard_objective, jit=True)

#print("Estimated parameter:", solver.get_result()["params"], "objective:",solver.get_result()["value_of_objective"])
print("Estimated parameter:", solver.get_result()["params"], "objective:",solver.get_result()["value_of_objective"])
print("True parameter:", data.coef_, "objective:",phazard_objective(data.coef_))

censoring rate:0.355
Estimated parameter: [0.         0.         0.         0.         5.25791676 0.
 0.         0.         5.45310829 0.        ] objective: 1.5464740991592407
True parameter: [0.         0.         0.         0.         4.63711743 0.
 0.         0.         4.59908961 0.        ] objective: 1.5630498
