Skip to content

Commit

Permalink
v0.20.3
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Mar 23, 2019
1 parent f7ca123 commit 4bdfdf1
Show file tree
Hide file tree
Showing 14 changed files with 111 additions and 22 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
### Changelog

#### 0.20.3

##### New features
- Now `cumulative_density_` & `survival_function_` are _always_ present on a fitted `KaplanMeierFitter`.
- New attributes/methods on `KaplanMeierFitter`: `plot_cumulative_density()`, `confidence_interval_cumulative_density_`, `plot_survival_function` and `confidence_interval_survival_function_`.


#### 0.20.2

##### New features
Expand Down
1 change: 1 addition & 0 deletions docs/Examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ Invert axis
.. image:: /images/invert_y_axis.png

.. note:: This is deprecated and we suggest to use `kmf.plot_cumulative_density()` instead.


Displaying at-risk counts below plots
Expand Down
15 changes: 12 additions & 3 deletions docs/Quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,23 @@ After calling the ``fit`` method, we have access to new properties like ``surviv
.. code:: python
kmf.survival_function_
kmf.confidence_interval_
kmf.cumulative_density_
kmf.median_
kmf.plot()
kmf.plot_survival_function() # or just kmf.plot()
.. image:: images/quickstart_kmf.png

By specifying the ``timeline`` keyword argument in ``fit``, we can change how the above models are index:
Alternatively, you can plot the cumulative density function:

.. code:: python
kmf.plot_cumulative_density()
.. image:: images/quickstart_kmf_cdf.png


By specifying the ``timeline`` keyword argument in ``fit``, we can change how the above models are indexed:

.. code:: python
Expand Down
Binary file modified docs/images/coxph_plot_quickstart.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/quickstart_aaf.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/quickstart_kmf.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/quickstart_kmf_cdf.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/quickstart_multi.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/quickstart_predict_aaf.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/waft_plot_quickstart.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
55 changes: 44 additions & 11 deletions lifelines/fitters/kaplan_meier_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
StatisticalWarning,
coalesce,
)
from lifelines.plotting import plot_loglogs
from lifelines.plotting import plot_loglogs, _plot_estimate


class KaplanMeierFitter(UnivariateFitter):
Expand Down Expand Up @@ -108,6 +108,7 @@ def fit(
self._check_values(event_observed)

self.left_censorship = left_censorship
self._label = label

if weights is not None:
if (weights.astype(int) != weights).any():
Expand All @@ -121,13 +122,15 @@ def fit(
)

# if the user is interested in left-censorship, we return the cumulative_density_, no survival_function_,
estimate_name = "survival_function_" if not left_censorship else "cumulative_density_"
v = _preprocess_inputs(durations, event_observed, timeline, entry, weights)
self.durations, self.event_observed, self.timeline, self.entry, self.event_table = v
primary_estimate_name = "survival_function_" if not self.left_censorship else "cumulative_density_"
secondary_estimate_name = "cumulative_density_" if not self.left_censorship else "survival_function_"

self.durations, self.event_observed, self.timeline, self.entry, self.event_table = _preprocess_inputs(
durations, event_observed, timeline, entry, weights
)

self._label = label
alpha = alpha if alpha else self.alpha
log_survival_function, cumulative_sq_ = _additive_estimate(
log_estimate, cumulative_sq_ = _additive_estimate(
self.event_table, self.timeline, self._additive_f, self._additive_var, left_censorship
)

Expand All @@ -145,19 +148,23 @@ def fit(
)

# estimation
setattr(self, estimate_name, pd.DataFrame(np.exp(log_survival_function), columns=[self._label]))
self.__estimate = getattr(self, estimate_name)
setattr(self, primary_estimate_name, pd.DataFrame(np.exp(log_estimate), columns=[self._label]))
setattr(self, secondary_estimate_name, pd.DataFrame(1 - np.exp(log_estimate), columns=[self._label]))

self.__estimate = getattr(self, primary_estimate_name)
self.confidence_interval_ = self._bounds(cumulative_sq_[:, None], alpha, ci_labels)
self.median_ = median_survival_times(self.__estimate, left_censorship=left_censorship)
self._cumulative_sq_ = cumulative_sq_

setattr(self, "confidence_interval_" + primary_estimate_name, self.confidence_interval_)
setattr(self, "confidence_interval_" + secondary_estimate_name, 1 - self.confidence_interval_)

# estimation methods
self._estimation_method = estimate_name
self._estimate_name = estimate_name
self._estimation_method = primary_estimate_name
self._estimate_name = primary_estimate_name
self._predict_label = label
self._update_docstrings()

setattr(self, "plot_" + estimate_name.rstrip("_"), self.plot)
return self

def _check_values(self, array):
Expand Down Expand Up @@ -185,6 +192,22 @@ def survival_function_at_times(self, times, label=None):
label = coalesce(label, self._label)
return pd.Series(self.predict(times), index=_to_array(times), name=label)

def plot_survival_function(self, **kwargs):
return _plot_estimate(
self,
estimate=self.survival_function_,
confidence_intervals=self.confidence_interval_survival_function_,
**kwargs
)

def plot_cumulative_density(self, **kwargs):
return _plot_estimate(
self,
estimate=self.cumulative_density_,
confidence_intervals=self.confidence_interval_cumulative_density_,
**kwargs
)

def _bounds(self, cumulative_sq_, alpha, ci_labels):
# This method calculates confidence intervals using the exponential Greenwood formula.
# See https://www.math.wustl.edu/%7Esawyer/handouts/greenwood.pdf
Expand All @@ -207,3 +230,13 @@ def _additive_f(self, population, deaths):
def _additive_var(self, population, deaths):
np.seterr(divide="ignore")
return (deaths / (population * (population - deaths))).replace([np.inf], 0)

def plot_cumulative_hazard(self, **kwargs):
raise NotImplementedError(
"The Kaplan-Meier estimator is not used to estimate the cumulative hazard. Try the NelsonAalenFitter or any other parametric model"
)

def plot_hazard(self, **kwargs):
raise NotImplementedError(
"The Kaplan-Meier estimator is not used to estimate the hazard. Try the NelsonAalenFitter or any other parametric model"
)
18 changes: 17 additions & 1 deletion lifelines/plotting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

import warnings
from textwrap import dedent
import numpy as np
from lifelines.utils import coalesce
from scipy import stats
Expand Down Expand Up @@ -484,7 +485,7 @@ def _plot_estimate(
.plot(iloc=slice(0,10))
will plot the first 10 time points.
invert_y_axis: bool
boolean to invert the y-axis, useful to show cumulative graphs instead of survival graphs.
boolean to invert the y-axis, useful to show cumulative graphs instead of survival graphs. (Deprecated, use ``plot_cumulative_density()``)
Returns
-------
Expand All @@ -495,6 +496,21 @@ def _plot_estimate(
cls, estimate, confidence_intervals, loc, iloc, show_censors, censor_styles, **kwargs
)

if invert_y_axis:
warnings.warn(
dedent(
"""
The invert_y_axis will be removed in lifelines 0.21.0. Likely you are trying to plot the cumulative density function?
That's now part of the KaplanMeierFitter,
>>> kmf.plot_cumulative_density()
>>> # nice
"""
),
PendingDeprecationWarning,
)

dataframe_slicer = create_dataframe_slicer(iloc, loc)

if show_censors and cls.event_table["censored"].sum() > 0:
Expand Down
31 changes: 27 additions & 4 deletions tests/test_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def test_llf_small_values(self, llf):


class TestWeibullFitter:
@flaky
@flaky(max_runs=3, min_passes=2)
@pytest.mark.parametrize("N", [50, 100, 500, 1000])
def test_left_censorship_inference(self, N):
T_actual = 0.5 * np.random.weibull(5, size=N)
Expand Down Expand Up @@ -902,7 +902,6 @@ def test_passing_in_left_censorship_creates_a_cumulative_density(self, sample_li
kmf.fit(T, C, left_censorship=True)
assert hasattr(kmf, "cumulative_density_")
assert hasattr(kmf, "plot_cumulative_density")
assert not hasattr(kmf, "survival_function_")

def test_kmf_left_censorship_stats(self):
# from http://www.public.iastate.edu/~pdixon/stat505/Chapter%2011.pdf
Expand Down Expand Up @@ -1025,6 +1024,29 @@ def test_late_entry_with_against_R(self):
expected = [1.0, 1.0, 0.667, 0.444, 0.222, 0.111, 0.0]
npt.assert_allclose(kmf.survival_function_.values.reshape(7), expected, rtol=1e-2)

def test_kmf_has_both_survival_function_and_cumulative_density(self):
# right censoring
kmf = KaplanMeierFitter().fit(np.arange(100), left_censorship=False)
assert hasattr(kmf, "survival_function_")
assert hasattr(kmf, "plot_survival_function")
assert hasattr(kmf, "confidence_interval_survival_function_")
assert_frame_equal(kmf.confidence_interval_survival_function_, kmf.confidence_interval_)

assert hasattr(kmf, "cumulative_density_")
assert hasattr(kmf, "plot_cumulative_density")
assert hasattr(kmf, "confidence_interval_cumulative_density_")

# left censoring
kmf = KaplanMeierFitter().fit(np.arange(100), left_censorship=True)
assert hasattr(kmf, "survival_function_")
assert hasattr(kmf, "plot_survival_function")
assert hasattr(kmf, "confidence_interval_survival_function_")

assert hasattr(kmf, "cumulative_density_")
assert hasattr(kmf, "plot_cumulative_density")
assert hasattr(kmf, "confidence_interval_cumulative_density_")
assert_frame_equal(kmf.confidence_interval_cumulative_density_, kmf.confidence_interval_)

def test_late_entry_with_tied_entry_and_death(self):
np.random.seed(101)

Expand Down Expand Up @@ -3814,8 +3836,9 @@ def test_jitter(self, fitter):
def test_tied_input_data(self, fitter):
# Based on new setup of ties, this counts as a valid tie
d = [1, 2, 2, 4, 5, 6]
fitter.fit(durations=d, event_observed=[0, 1, 2, 1, 2, 0], event_of_interest=2)
npt.assert_equal(np.any(np.not_equal([0] + d, fitter.event_table.index)), True)
with pytest.warns(Warning, match="Tied event times"):
fitter.fit(durations=d, event_observed=[0, 1, 2, 1, 2, 0], event_of_interest=2)
npt.assert_equal(np.any(np.not_equal([0] + d, fitter.event_table.index)), True)

def test_updated_input_ties(self, fitter):
# Based on the new setup of ties, should not detect any ties as existing
Expand Down
6 changes: 3 additions & 3 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,10 @@ def test_kmf_left_censorship_plots(self, block):
lcd_dataset = load_lcd()
alluvial_fan = lcd_dataset.loc[lcd_dataset["group"] == "alluvial_fan"]
basin_trough = lcd_dataset.loc[lcd_dataset["group"] == "basin_trough"]
kmf.fit(alluvial_fan["T"], alluvial_fan["C"], left_censorship=True, label="alluvial_fan")
kmf.fit(alluvial_fan["T"], alluvial_fan["E"], left_censorship=True, label="alluvial_fan")
ax = kmf.plot()

kmf.fit(basin_trough["T"], basin_trough["C"], left_censorship=True, label="basin_trough")
kmf.fit(basin_trough["T"], basin_trough["E"], left_censorship=True, label="basin_trough")
ax = kmf.plot(ax=ax)
self.plt.title("test_kmf_left_censorship_plots")
self.plt.show(block=block)
Expand Down Expand Up @@ -542,7 +542,7 @@ def test_qq_plot_left_censoring2(self, block):
fig, axes = self.plt.subplots(2, 2, figsize=(9, 5))
axes = axes.reshape(4)
for i, model in enumerate([WeibullFitter(), LogNormalFitter(), LogLogisticFitter(), ExponentialFitter()]):
model.fit(df["T"], df["C"], left_censorship=True)
model.fit(df["T"], df["E"], left_censorship=True)
ax = qq_plot(model, ax=axes[i])
assert ax is not None
self.plt.suptitle("test_qq_plot_left_censoring2")
Expand Down

0 comments on commit 4bdfdf1

Please sign in to comment.