In [1]:
from typing import Any, List, Optional

import numpy as np
import pdb
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sksurv.datasets import load_whas500
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.preprocessing import OneHotEncoder

import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
from typing import Optional

import torch
from torch import nn

# Fit Cox PH Model with sksurv

In [2]:
X, y = load_whas500()
mask = X.notnull().all(axis=1)
X, y = X.loc[mask], y[mask.values]

Xe = OneHotEncoder().fit_transform(X)
Xt = StandardScaler().fit_transform(Xe)

coxph = CoxPHSurvivalAnalysis().fit(Xt, y)
coxph_coef = pd.Series(coxph.coef_, index=Xe.columns, name="sksurv")
coxph_coef

afb=1       0.012764
age         0.684880
av3=1       0.045116
bmi        -0.254596
chf=1       0.338557
cvd=1      -0.024386
diasbp     -0.272298
gender=1   -0.142440
hr          0.284784
los        -0.032106
miord=1     0.039950
mitype=1   -0.110669
sho=1       0.251575
sysbp       0.049986
Name: sksurv, dtype: float64

## Fit a Cox Model with Pytorch

In [3]:
def make_riskset(time: np.ndarray) -> np.ndarray:
    """Compute mask that represents each sample's risk set.

    Parameters
    ----------
    time : np.ndarray, shape=(n_samples,)
        Observed event time sorted in descending order.

    Returns
    -------
    risk_set : np.ndarray, shape=(n_samples, n_samples)
        Boolean matrix where the `i`-th row denotes the
        risk set of the `i`-th instance, i.e. the indices `j`
        for which the observer time `y_j >= y_i`.
    """
    assert time.ndim == 1, "expected 1D array"

    # sort in descending order
    o = np.argsort(-time, kind="mergesort")
    n_samples = len(time)
    risk_set = np.zeros((n_samples, n_samples), dtype=np.uint8)
    for i_org, i_sort in enumerate(o):
        ti = time[i_sort]
        k = i_org
        while k < n_samples and ti == time[o[k]]:
            k += 1
        risk_set[i_sort, o[:k]] = 1
    return risk_set


def cox_collate_fn(
    batch: List[Any], time_index: Optional[int] = -1, data_collate=default_collate
) -> List[torch.Tensor]:
    """Create risk set from batch."""
    transposed_data = list(zip(*batch))
    y_time = np.array(transposed_data[time_index])

    data = []
    for b in transposed_data:
        bt = data_collate(b)
        data.append(bt)

    data.append(torch.from_numpy(make_riskset(y_time)))

    return data


def safe_normalize(x: torch.Tensor) -> torch.Tensor:
    """Normalize risk scores to avoid exp underflowing.

    Note that only risk scores relative to each other matter.
    If minimum risk score is negative, we shift scores so minimum
    is at zero.
    """
    x_min, _ = torch.min(x, dim=0)
    c = torch.zeros(x_min.shape, device=x.device)
    norm = torch.where(x_min < 0, -x_min, c)
    return x + norm


def logsumexp_masked(
    risk_scores: torch.Tensor, mask: torch.Tensor, dim: int = 0, keepdim: Optional[bool] = None
) -> torch.Tensor:
    """Compute logsumexp across `dim` for entries where `mask` is true."""
    assert risk_scores.dim() == mask.dim(), "risk_scores and mask must have same rank"

    mask_f = mask.type_as(risk_scores)
    risk_scores_masked = risk_scores * mask_f
    # for numerical stability, substract the maximum value
    # before taking the exponential
    amax, _ = torch.max(risk_scores_masked, dim=dim, keepdim=True)
    risk_scores_shift = risk_scores_masked - amax

    exp_masked = risk_scores_shift.exp() * mask_f
    exp_sum = exp_masked.sum(dim, keepdim=True)
    output = exp_sum.log() + amax
    if not keepdim:
        output.squeeze_(dim=dim)
    return output


class CoxphLoss(nn.Module):
    def forward(self, predictions: torch.Tensor, event: torch.Tensor, riskset: torch.Tensor) -> torch.Tensor:
        """Negative partial log-likelihood of Cox's proportional
        hazards model.

        Args:
            predictions (torch.Tensor):
                The predicted outputs. Must be a rank 2 tensor.
            event (torch.Tensor):
                Binary vector where 1 indicates an event 0 censoring.
            riskset (torch.Tensor):
                Boolean matrix where the `i`-th row denotes the
                risk set of the `i`-th instance, i.e. the indices `j`
                for which the observer time `y_j >= y_i`.

        Returns:
            loss (torch.Tensor):
                Scalar loss.

        References:
            .. [1] Faraggi, D., & Simon, R. (1995).
            A neural network model for survival data. Statistics in Medicine,
            14(1), 73–82. https://doi.org/10.1002/sim.4780140108
        """
        if predictions is None or predictions.dim() != 2:
            raise ValueError("predictions must be a 2D tensor.")
        if predictions.size()[1] != 1:
            raise ValueError("last dimension of predictions ({}) must be 1.".format(predictions.size()[1]))
        if event is None:
            raise ValueError("event must not be None.")
        if predictions.dim() != event.dim():
            raise ValueError(
                "Rank of predictions ({}) must equal rank of event ({})".format(predictions.dim(), event.dim())
            )
        if event.size()[1] != 1:
            raise ValueError("last dimension event ({}) must be 1.".format(event.size()[1]))
        if riskset is None:
            raise ValueError("riskset must not be None.")

        event = event.type_as(predictions)
        riskset = riskset.type_as(predictions)
        predictions = safe_normalize(predictions)

        # move batch dimension to the end so predictions get broadcast
        # row-wise when multiplying by riskset
        pred_t = predictions.t()

        # compute log of sum over risk set for each row
        rr = logsumexp_masked(pred_t, riskset, dim=1, keepdim=True)
        assert rr.size() == predictions.size()

        losses = event * (rr - predictions)
        loss = torch.mean(losses)

        return loss


Define the dataset and our model.

In [4]:
class WhasDataset(Dataset):

    def __init__(self):
        X, y = load_whas500()
        mask = X.notnull().all(axis=1)
        X, y = X.loc[mask], y[mask.values]

        Xe = OneHotEncoder().fit_transform(X)
        Xt = StandardScaler().fit_transform(Xe).astype(np.float32)
        y_event = y["fstat"][:, np.newaxis].astype(np.uint8)
        y_time = y["lenfol"].astype(np.float32)

        self.n_features = Xt.shape[1]
        self.data = list(zip(Xt, y_event, y_time))

    def __getitem__(self, index: int):
        return self.data[index]

    def __len__(self) -> int:
        return len(self.data)


class CoxModel(nn.Module):

    def __init__(self, n_inputs):
        super().__init__()

        self.layer = nn.Linear(n_inputs, 1, bias=False)

    def forward(self, x):
        return self.layer(x)

Run training, may take a while.

In [7]:
torch.manual_seed(25)
dev = torch.device("cuda")

train_dataset = WhasDataset()
train_loader = DataLoader(
    train_dataset, collate_fn=cox_collate_fn, batch_size=len(train_dataset)
)
model = CoxModel(train_dataset.n_features).to(dev)
opt = Adam(model.parameters(), lr=5e-4)
loss_fn = CoxphLoss()

model.train()
for i in range(10000):
    for x, y_event, y_time, y_riskset in train_loader:
        pdb.set_trace()
        x = x.to(dev)
        y_event = y_event.to(dev)
        y_riskset = y_riskset.to(dev)

        opt.zero_grad()
        logits = model.forward(x)

        loss = loss_fn(logits, y_event, y_riskset)

        loss.backward()
        opt.step()

    if i % 1000 == 0:
        print(i, loss.detach().cpu().numpy())

0 2.715545
1000 2.2721267
2000 2.236773
3000 2.2314878
4000 2.2305958
5000 2.2305293
6000 2.2305286
7000 2.2305286
8000 2.2305286
9000 2.2305286


Compare Coefficients

In [6]:
torch_coef = pd.Series(
    next(model.parameters()).detach().cpu().numpy().squeeze(),
    index=Xe.columns, name="pytorch"
)
pd.concat((coxph_coef, torch_coef), axis=1)

Unnamed: 0,sksurv,pytorch
afb=1,0.012764,0.012764
age,0.68488,0.684879
av3=1,0.045116,0.045116
bmi,-0.254596,-0.254596
chf=1,0.338557,0.338557
cvd=1,-0.024386,-0.024386
diasbp,-0.272298,-0.272298
gender=1,-0.14244,-0.14244
hr,0.284784,0.284784
los,-0.032106,-0.032106
