Skip to content

Commit

Permalink
fix plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Apr 12, 2019
1 parent e72ce18 commit 9bbbc24
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 27 deletions.
17 changes: 11 additions & 6 deletions lifelines/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
import numpy as np
from lifelines.utils import coalesce
from lifelines.utils import coalesce, CensoringType
from scipy import stats

__all__ = ["add_at_risk_counts", "plot_lifetimes", "qq_plot"]
Expand Down Expand Up @@ -107,15 +107,20 @@ def qq_plot(model, **plot_kwargs):
COL_EMP = "empirical quantiles"
COL_THEO = "fitted %s quantiles" % dist

kmf = KaplanMeierFitter().fit(
model.durations, model.event_observed, left_censorship=model.left_censorship, label=COL_EMP
)
if model.left_censorship:
is_left_censored = model._censoring_type == CensoringType.LEFT
is_interval_censored = model._censoring_type == CensoringType.INTERVAL
is_right_censored = model._censoring_type == CensoringType.RIGHT

if is_left_censored:
kmf = KaplanMeierFitter().fit_left_censoring(model.durations, model.event_observed, label=COL_EMP)
q = np.unique(kmf.cumulative_density_.values[:, 0])
quantiles = qth_survival_times(q, kmf.cumulative_density_, cdf=True)
else:
elif is_right_censored:
kmf = KaplanMeierFitter().fit_right_censoring(model.durations, model.event_observed, label=COL_EMP)
q = np.unique(1 - kmf.survival_function_.values[:, 0])
quantiles = qth_survival_times(q, 1 - kmf.survival_function_, cdf=True)
elif is_interval_censored:
raise NotImplementedError()

quantiles[COL_THEO] = dist_object.ppf(q)
quantiles = quantiles.replace([-np.inf, 0, np.inf], np.nan).dropna()
Expand Down
27 changes: 6 additions & 21 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,21 +92,6 @@ def test_kmf_with_risk_counts(self, block, kmf):
self.plt.title("test_kmf_with_risk_counts")
self.plt.show(block=block)

def test_kmf_with_inverted_axis(self, block, kmf):

T = np.random.exponential(size=100)
kmf = KaplanMeierFitter()
kmf.fit(T, label="t2")
ax = kmf.plot(invert_y_axis=True, at_risk_counts=True)

T = np.random.exponential(3, size=100)
kmf = KaplanMeierFitter()
kmf.fit(T, label="t1")
kmf.plot(invert_y_axis=True, ax=ax, ci_force_lines=False)

self.plt.title("test_kmf_with_inverted_axis")
self.plt.show(block=block)

def test_naf_plotting_with_custom_colours(self, block):
data1 = np.random.exponential(5, size=(200, 1))
data2 = np.random.exponential(1, size=(500))
Expand Down Expand Up @@ -436,10 +421,10 @@ def test_kmf_left_censorship_plots(self, block):
lcd_dataset = load_lcd()
alluvial_fan = lcd_dataset.loc[lcd_dataset["group"] == "alluvial_fan"]
basin_trough = lcd_dataset.loc[lcd_dataset["group"] == "basin_trough"]
kmf.fit(alluvial_fan["T"], alluvial_fan["E"], left_censorship=True, label="alluvial_fan")
kmf.fit_left_censoring(alluvial_fan["T"], alluvial_fan["E"], label="alluvial_fan")
ax = kmf.plot()

kmf.fit(basin_trough["T"], basin_trough["E"], left_censorship=True, label="basin_trough")
kmf.fit_left_censoring(basin_trough["T"], basin_trough["E"], label="basin_trough")
ax = kmf.plot(ax=ax)
self.plt.title("test_kmf_left_censorship_plots")
self.plt.show(block=block)
Expand Down Expand Up @@ -537,7 +522,7 @@ def test_left_censorship_cdf_plots(self, block):
fig, axes = self.plt.subplots(2, 2, figsize=(9, 5))
axes = axes.reshape(4)
for i, model in enumerate([WeibullFitter(), LogNormalFitter(), LogLogisticFitter(), ExponentialFitter()]):
model.fit(df["NH4.mg.per.L"], ~df["Censored"], left_censorship=True)
model.fit_left_censoring(df["NH4.mg.per.L"], ~df["Censored"])
ax = cdf_plot(model, ax=axes[i])
assert ax is not None
self.plt.suptitle("test_left_censorship_cdf_plots")
Expand All @@ -559,7 +544,7 @@ def test_qq_plot_left_censoring(self, block):
fig, axes = self.plt.subplots(2, 2, figsize=(9, 5))
axes = axes.reshape(4)
for i, model in enumerate([WeibullFitter(), LogNormalFitter(), LogLogisticFitter(), ExponentialFitter()]):
model.fit(df["NH4.mg.per.L"], ~df["Censored"], left_censorship=True)
model.fit_left_censoring(df["NH4.mg.per.L"], ~df["Censored"])
ax = qq_plot(model, ax=axes[i])
assert ax is not None
self.plt.suptitle("test_qq_plot_left_censoring")
Expand All @@ -570,7 +555,7 @@ def test_qq_plot_left_censoring2(self, block):
fig, axes = self.plt.subplots(2, 2, figsize=(9, 5))
axes = axes.reshape(4)
for i, model in enumerate([WeibullFitter(), LogNormalFitter(), LogLogisticFitter(), ExponentialFitter()]):
model.fit(df["T"], df["E"], left_censorship=True)
model.fit_left_censoring(df["T"], df["E"])
ax = qq_plot(model, ax=axes[i])
assert ax is not None
self.plt.suptitle("test_qq_plot_left_censoring2")
Expand All @@ -593,7 +578,7 @@ def test_qq_plot_left_censoring_with_known_distribution(self, block):
fig, axes = self.plt.subplots(2, 2, figsize=(9, 5))
axes = axes.reshape(4)
for i, model in enumerate([WeibullFitter(), LogNormalFitter(), LogLogisticFitter(), ExponentialFitter()]):
model.fit(T, E, left_censorship=True)
model.fit_left_censoring(T, E)
ax = qq_plot(model, ax=axes[i])
assert ax is not None
self.plt.suptitle("test_qq_plot_left_censoring_with_known_distribution")
Expand Down

0 comments on commit 9bbbc24

Please sign in to comment.