Skip to content

Commit

Permalink
updates to AJF model, including fixing #688 (#696)
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Apr 8, 2019
1 parent b92122b commit cca9bc8
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 32 deletions.
75 changes: 44 additions & 31 deletions lifelines/fitters/aalen_johansen_fitter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
from textwrap import dedent
import numpy as np
import pandas as pd
import warnings
Expand Down Expand Up @@ -103,8 +104,10 @@ def fit(

if ties:
warnings.warn(
"""Tied event times were detected. The Aalen-Johansen estimator cannot handle tied event times.
To resolve ties, data is randomly jittered.""",
dedent(
"""Tied event times were detected. The Aalen-Johansen estimator cannot handle tied event times.
To resolve ties, data is randomly jittered."""
),
Warning,
)
durations = self._jitter(
Expand All @@ -114,28 +117,32 @@ def fit(
seed=self._seed,
)

alpha = alpha if alpha else self.alpha

# Creating label for event of interest & indicator for that event
cmprisk_label = "CIF_" + str(int(event_of_interest))
self.label_cmprisk = "observed_" + str(int(event_of_interest))
event_of_interest = int(event_of_interest)
cmprisk_label = "CIF_" + str(event_of_interest)
self.label_cmprisk = "observed_" + str(event_of_interest)

# Fitting Kaplan-Meier for either event of interest OR competing risk
km = KaplanMeierFitter()
km.fit(durations, event_observed=event_observed, timeline=timeline, entry=entry, weights=weights)
km = KaplanMeierFitter().fit(
durations, event_observed=event_observed, timeline=timeline, entry=entry, weights=weights
)
aj = km.event_table
aj["overall_survival"] = km.survival_function_
aj["lagged_overall_survival"] = aj["overall_survival"].shift()

# Setting up table for calculations and to return to user
event_spec = np.where(pd.Series(event_observed) == event_of_interest, 1, 0)
event_spec_proc = _preprocess_inputs(
event_spec = pd.Series(event_observed) == event_of_interest
self.durations, self.event_observed, *_, event_table = _preprocess_inputs(
durations=durations, event_observed=event_spec, timeline=timeline, entry=entry, weights=weights
)
event_spec_times = event_spec_proc[-1]["observed"]
event_spec_times = event_table["observed"]
event_spec_times = event_spec_times.rename(self.label_cmprisk)
aj = pd.concat([aj, event_spec_times], axis=1).reset_index()

# Estimator of Cumulative Incidence (Density) Function
aj[cmprisk_label] = ((aj[self.label_cmprisk]) / (aj["at_risk"]) * aj["lagged_overall_survival"]).cumsum()
aj[cmprisk_label] = (aj[self.label_cmprisk] / aj["at_risk"] * aj["lagged_overall_survival"]).cumsum()
aj.loc[0, cmprisk_label] = 0 # Setting initial CIF to be zero
aj = aj.set_index("event_at")

Expand All @@ -145,52 +152,61 @@ def fit(
self._predict_label = label
self._update_docstrings()

alpha = alpha if alpha else self.alpha
self._label = label
self.cumulative_density_ = pd.DataFrame(aj[cmprisk_label])

# Technically, cumulative incidence, but consistent with KaplanMeierFitter
self.event_table = aj[
["removed", "observed", self.label_cmprisk, "censored", "entrance", "at_risk"]
] # Event table

if self._calc_var:
self.variance, self.confidence_interval_ = self._bounds(
self.variance_, self.confidence_interval_ = self._bounds(
aj["lagged_overall_survival"], alpha=alpha, ci_labels=ci_labels
)
else:
self.variance_, self.confidence_interval_ = None, None

return self

def _jitter(self, durations, event, jitter_level, seed=None):
"""Determine extent to jitter tied event times. Automatically called by fit if tied event times are detected
"""
np.random.seed(seed)

if jitter_level <= 0:
raise ValueError("The jitter level is less than zero, please select a jitter value greater than 0")
if seed is not None:
np.random.seed(seed)

event_time = durations.loc[event != 0].copy()
# Determining whether to randomly shift event times up or down
mark = np.random.choice([-1, 1], size=event_time.shape[0])
event_times = durations[event != 0].copy()
n = event_times.shape[0]

# Determining extent to jitter event times up or down
shift = np.random.uniform(size=event_time.shape[0]) * jitter_level
# Jittering times
event_time += mark * shift
durations_jitter = event_time.align(durations)[0].fillna(durations)
shift = np.random.uniform(low=-1, high=1, size=n) * jitter_level
event_times += shift
durations_jitter = event_times.align(durations)[0].fillna(durations)

# Recursive call if event times are still tied after jitter
if self._check_for_duplicates(durations=durations_jitter, events=event):
return self._jitter(durations=durations_jitter, event=event, jitter_level=jitter_level, seed=seed)
return durations_jitter

def _bounds(self, lagged_survival, alpha, ci_labels):
"""Bounds are based on pg411 of "Modelling Survival Data in Medical Research" David Collett 3rd Edition, which
"""Bounds are based on pg 411 of "Modelling Survival Data in Medical Research" David Collett 3rd Edition, which
is derived from Greenwood's variance estimator. Confidence intervals are obtained using the delta method
transformation of SE(log(-log(F_j))). This ensures that the confidence intervals all lie between 0 and 1.
Formula for the variance follows:
Var(F_j) = sum((F_j(t) - F_j(t_i))**2 * d/(n*(n-d) + S(t_i-1)**2 * ((d*(n-d))/n**3) +
-2 * sum((F_j(t) - F_j(t_i)) * S(t_i-1) * (d/n**2)
.. math::
Var(F_j) = sum((F_j(t) - F_j(t_i))**2 * d/(n*(n-d) + S(t_i-1)**2 * ((d*(n-d))/n**3) +
-2 * sum((F_j(t) - F_j(t_i)) * S(t_i-1) * (d/n**2)
Delta method transformation:
SE(log(-log(F_j) = SE(F_j) / (F_j * absolute(log(F_j)))
.. math::
SE(log(-log(F_j) = SE(F_j) / (F_j * |log(F_j)|)
More information can be found at: https://support.sas.com/documentation/onlinedoc/stat/141/lifetest.pdf
There is also an alternative method (Aalen) but this is not currently implemented
Expand Down Expand Up @@ -239,16 +255,13 @@ def _check_for_duplicates(durations, events):
where the events are of different types
"""
# Setting up DataFrame to detect duplicates
df = pd.DataFrame()
df["t"] = durations
df["e"] = events
df = pd.DataFrame({"t": durations, "e": events})

# Finding duplicated event times
dup_times = pd.Series(df["t"].loc[df["e"] != 0]).duplicated(keep=False)
dup_times = df.loc[df["e"] != 0, "t"].duplicated(keep=False)

# Finding duplicated events and event times
dup_events = df.loc[df["e"] != 0, ["t", "e"]].duplicated(keep=False)

# Detect duplicated times with different event types
ties = np.any(dup_times & (~dup_events))
return ties > 0
return (dup_times & (~dup_events)).any()
2 changes: 1 addition & 1 deletion tests/test_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4022,7 +4022,7 @@ def test_variance_calculation_against_sas(self, fitter, duration, event_observed
variance_from_sas = np.array([0.0, 0.0, 0.0, 0.0, 0.032, 0.048, 0.048])

fitter.fit(duration, event_observed, event_of_interest=2)
npt.assert_allclose(variance_from_sas, np.array(fitter.variance))
npt.assert_allclose(variance_from_sas, np.array(fitter.variance_))

def test_ci_calculation_against_sas(self, fitter, duration, event_observed):
ci_from_sas = np.array(
Expand Down
28 changes: 28 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
LogLogisticFitter,
WeibullAFTFitter,
ExponentialFitter,
AalenJohansenFitter,
)

from tests.test_estimation import known_parametric_univariate_fitters
Expand Down Expand Up @@ -118,6 +119,33 @@ def test_naf_plotting_with_custom_colours(self, block):
self.plt.show(block=block)
return

def test_ajf_plotting(self, block):
E = [0, 1, 1, 2, 2, 0]
T = [1, 2, 3, 4, 5, 6]
ajf = AalenJohansenFitter().fit(T, E, event_of_interest=1)
ajf.plot()
self.plt.title("test_ajf_plotting")
self.plt.show(block=block)
return

def test_ajf_plotting_no_confidence_intervals(self, block):
E = [0, 1, 1, 2, 2, 0]
T = [1, 2, 3, 4, 5, 6]
ajf = AalenJohansenFitter(calculate_variance=False).fit(T, E, event_of_interest=1)
ajf.plot(ci_show=False)
self.plt.title("test_ajf_plotting_no_confidence_intervals")
self.plt.show(block=block)
return

def test_ajf_plotting_with_add_count_at_risk(self, block):
E = [0, 1, 1, 2, 2, 0]
T = [1, 2, 3, 4, 5, 6]
ajf = AalenJohansenFitter().fit(T, E, event_of_interest=1)
ajf.plot(at_risk_counts=True)
self.plt.title("test_ajf_plotting_with_add_count_at_risk")
self.plt.show(block=block)
return

def test_aalen_additive_plot(self, block):
# this is a visual test of the fitting the cumulative
# hazards.
Expand Down

0 comments on commit cca9bc8

Please sign in to comment.