-
-
Notifications
You must be signed in to change notification settings - Fork 550
/
aalen_johansen_fitter.py
271 lines (225 loc) · 12.2 KB
/
aalen_johansen_fitter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
# -*- coding: utf-8 -*-
from textwrap import dedent
import numpy as np
import pandas as pd
import warnings
from lifelines.fitters import UnivariateFitter
from lifelines.utils import _preprocess_inputs, inv_normal_cdf, CensoringType
from lifelines import KaplanMeierFitter
class AalenJohansenFitter(UnivariateFitter):
"""Class for fitting the Aalen-Johansen estimate for the cumulative incidence function in a competing risks framework.
Treating competing risks as censoring can result in over-estimated cumulative density functions. Using the Kaplan
Meier estimator with competing risks as censored is akin to estimating the cumulative density if all competing risks
had been prevented.
Aalen-Johansen cannot deal with tied times. We can get around this by randomly jittering the event times
slightly. This will be done automatically and generates a warning.
Parameters
----------
alpha: float, option (default=0.05)
The alpha value associated with the confidence intervals.
jitter_level: float, option (default=0.00001)
If tied event times are detected, event times are randomly changed by this factor.
seed: int, option (default=None)
To produce replicate results with tied event times, the numpy.random.seed can be specified in the function.
calculate_variance: bool, option (default=True)
By default, AalenJohansenFitter calculates the variance and corresponding confidence intervals. Due to how the
variance is calculated, the variance must be calculated for each event time individually. This is
computationally intensive. For some procedures, like bootstrapping, the variance is not necessary. To reduce
computation time during these procedures, `calculate_variance` can be set to `False` to skip the variance
calculation.
Example
-------
>>> from lifelines import AalenJohansenFitter
>>> from lifelines.datasets import load_waltons
>>> T, E = load_waltons()['T'], load_waltons()['E']
>>> ajf = AalenJohansenFitter(calculate_variance=True)
>>> ajf.fit(T, E, event_of_interest=1)
>>> ajf.cumulative_density_
>>> ajf.plot()
References
----------
If you are interested in learning more, we recommend the following open-access
paper; Edwards JK, Hester LL, Gokhale M, Lesko CR. Methodologic Issues When Estimating Risks in
Pharmacoepidemiology. Curr Epidemiol Rep. 2016;3(4):285-296.
"""
def __init__(self, jitter_level=0.0001, seed=None, alpha=0.05, calculate_variance=True):
UnivariateFitter.__init__(self, alpha=alpha)
self._jitter_level = jitter_level
self._seed = seed # Seed is for the jittering process
self._calc_var = calculate_variance # Optionally skips calculating variance to save time on bootstraps
@CensoringType.right_censoring
def fit(
self,
durations,
event_observed,
event_of_interest,
timeline=None,
entry=None,
label="AJ_estimate",
alpha=None,
ci_labels=None,
weights=None,
): # pylint: disable=too-many-arguments,too-many-locals
"""
Parameters
----------
durations: an array or pd.Series of length n -- duration of subject was observed for
event_observed: an array, or pd.Series, of length n. Integer indicator of distinct events. Must be
only positive integers, where 0 indicates censoring.
event_of_interest: integer -- indicator for event of interest. All other integers are considered competing events
Ex) event_observed contains 0, 1, 2 where 0:censored, 1:lung cancer, and 2:death. If event_of_interest=1, then death (2)
is considered a competing event. The returned cumulative incidence function corresponds to risk of lung cancer
timeline: return the best estimate at the values in timelines (positively increasing)
entry: an array, or pd.Series, of length n -- relative time when a subject entered the study. This is
useful for left-truncated (not left-censored) observations. If None, all members of the population
were born at time 0.
label: a string to name the column of the estimate.
alpha: the alpha value in the confidence intervals. Overrides the initializing
alpha for this call to fit only.
ci_labels: add custom column names to the generated confidence intervals
as a length-2 list: [<lower-bound name>, <upper-bound name>]. Default: <label>_lower_<1-alpha/2>
weights: n array, or pd.Series, of length n, if providing a weighted dataset. For example, instead
of providing every subject as a single element of `durations` and `event_observed`, one could
weigh subject differently.
Returns
-------
self : AalenJohansenFitter
self, with new properties like ``cumulative_incidence_``.
"""
# Checking for tied event times
ties = self._check_for_duplicates(durations=durations, events=event_observed)
if ties:
warnings.warn(
dedent(
"""Tied event times were detected. The Aalen-Johansen estimator cannot handle tied event times.
To resolve ties, data is randomly jittered."""
),
Warning,
)
durations = self._jitter(
durations=pd.Series(durations),
event=pd.Series(event_observed),
jitter_level=self._jitter_level,
seed=self._seed,
)
alpha = alpha if alpha else self.alpha
# Creating label for event of interest & indicator for that event
event_of_interest = int(event_of_interest)
cmprisk_label = "CIF_" + str(event_of_interest)
self.label_cmprisk = "observed_" + str(event_of_interest)
# Fitting Kaplan-Meier for either event of interest OR competing risk
km = KaplanMeierFitter().fit(
durations, event_observed=event_observed, timeline=timeline, entry=entry, weights=weights
)
aj = km.event_table
aj["overall_survival"] = km.survival_function_
aj["lagged_overall_survival"] = aj["overall_survival"].shift()
# Setting up table for calculations and to return to user
event_spec = pd.Series(event_observed) == event_of_interest
self.durations, self.event_observed, *_, event_table = _preprocess_inputs(
durations=durations, event_observed=event_spec, timeline=timeline, entry=entry, weights=weights
)
event_spec_times = event_table["observed"]
event_spec_times = event_spec_times.rename(self.label_cmprisk)
aj = pd.concat([aj, event_spec_times], axis=1).reset_index()
# Estimator of Cumulative Incidence (Density) Function
aj[cmprisk_label] = (aj[self.label_cmprisk] / aj["at_risk"] * aj["lagged_overall_survival"]).cumsum()
aj.loc[0, cmprisk_label] = 0 # Setting initial CIF to be zero
aj = aj.set_index("event_at")
# Setting attributes
self._estimation_method = "cumulative_density_"
self._estimate_name = "cumulative_density_"
self.timeline = km.timeline
self._update_docstrings()
self._label = label
self.cumulative_density_ = pd.DataFrame(aj[cmprisk_label])
# Technically, cumulative incidence, but consistent with KaplanMeierFitter
self.event_table = aj[
["removed", "observed", self.label_cmprisk, "censored", "entrance", "at_risk"]
] # Event table
if self._calc_var:
self.variance_, self.confidence_interval_ = self._bounds(
aj["lagged_overall_survival"], alpha=alpha, ci_labels=ci_labels
)
else:
self.variance_, self.confidence_interval_ = None, None
self.confidence_interval_cumulative_density_ = self.confidence_interval_
return self
def _jitter(self, durations, event, jitter_level, seed=None):
"""Determine extent to jitter tied event times. Automatically called by fit if tied event times are detected
"""
np.random.seed(seed)
if jitter_level <= 0:
raise ValueError("The jitter level is less than zero, please select a jitter value greater than 0")
event_times = durations[event != 0].copy()
n = event_times.shape[0]
# Determining extent to jitter event times up or down
shift = np.random.uniform(low=-1, high=1, size=n) * jitter_level
event_times += shift
durations_jitter = event_times.align(durations)[0].fillna(durations)
# Recursive call if event times are still tied after jitter
if self._check_for_duplicates(durations=durations_jitter, events=event):
return self._jitter(durations=durations_jitter, event=event, jitter_level=jitter_level, seed=seed)
return durations_jitter
def _bounds(self, lagged_survival, alpha, ci_labels):
"""Bounds are based on pg 411 of "Modelling Survival Data in Medical Research" David Collett 3rd Edition, which
is derived from Greenwood's variance estimator. Confidence intervals are obtained using the delta method
transformation of SE(log(-log(F_j))). This ensures that the confidence intervals all lie between 0 and 1.
Formula for the variance follows:
.. math::
Var(F_j) = sum((F_j(t) - F_j(t_i))**2 * d/(n*(n-d) + S(t_i-1)**2 * ((d*(n-d))/n**3) +
-2 * sum((F_j(t) - F_j(t_i)) * S(t_i-1) * (d/n**2)
Delta method transformation:
.. math::
SE(log(-log(F_j) = SE(F_j) / (F_j * |log(F_j)|)
More information can be found at: https://support.sas.com/documentation/onlinedoc/stat/141/lifetest.pdf
There is also an alternative method (Aalen) but this is not currently implemented
"""
# Preparing environment
ci = 1 - alpha
df = self.event_table.copy()
df["Ft"] = self.cumulative_density_
df["lagS"] = lagged_survival.fillna(1)
if ci_labels is None:
ci_labels = ["%s_upper_%g" % (self._label, ci), "%s_lower_%g" % (self._label, ci)]
assert len(ci_labels) == 2, "ci_labels should be a length 2 array."
# Have to loop through each time independently. Don't think there is a faster way
all_vars = []
for _, r in df.iterrows():
sf = df.loc[df.index <= r.name].copy()
F_t = float(r["Ft"])
first_term = np.sum(
(F_t - sf["Ft"]) ** 2 * sf["observed"] / sf["at_risk"] / (sf["at_risk"] - sf["observed"])
)
second_term = np.sum(
sf["lagS"] ** 2
/ sf["at_risk"]
* sf[self.label_cmprisk]
/ sf["at_risk"]
* (sf["at_risk"] - sf[self.label_cmprisk])
/ sf["at_risk"]
)
third_term = np.sum((F_t - sf["Ft"]) / sf["at_risk"] * sf["lagS"] * sf[self.label_cmprisk] / sf["at_risk"])
variance = first_term + second_term - 2 * third_term
all_vars.append(variance)
df["variance"] = all_vars
# Calculating Confidence Intervals
df["F_transformed"] = np.log(-np.log(df["Ft"]))
df["se_transformed"] = np.sqrt(df["variance"]) / (df["Ft"] * np.absolute(np.log(df["Ft"])))
zalpha = inv_normal_cdf(1 - alpha / 2)
df[ci_labels[0]] = np.exp(-np.exp(df["F_transformed"] + zalpha * df["se_transformed"]))
df[ci_labels[1]] = np.exp(-np.exp(df["F_transformed"] - zalpha * df["se_transformed"]))
return df["variance"], df[ci_labels]
@staticmethod
def _check_for_duplicates(durations, events):
"""Checks for duplicated event times in the data set. This is narrowed to detecting duplicated event times
where the events are of different types
"""
# Setting up DataFrame to detect duplicates
df = pd.DataFrame({"t": durations, "e": events})
# Finding duplicated event times
dup_times = df.loc[df["e"] != 0, "t"].duplicated(keep=False)
# Finding duplicated events and event times
dup_events = df.loc[df["e"] != 0, ["t", "e"]].duplicated(keep=False)
# Detect duplicated times with different event types
return (dup_times & (~dup_events)).any()