Skip to content

Commit

Permalink
final clean up; tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Jun 21, 2015
1 parent 9ef3877 commit b59dbaa
Show file tree
Hide file tree
Showing 10 changed files with 640 additions and 684 deletions.
26 changes: 22 additions & 4 deletions lifelines/_base_fitter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# -*- coding: utf-8 -*-
from __future__ import print_function
import numpy as np
import pandas as pd
from lifelines.plotting import plot_estimate
from lifelines.utils import _conditional_time_to_event_
from lifelines.utils import qth_survival_times


class BaseFitter(object):

Expand All @@ -20,6 +22,7 @@ def __repr__(self):
s = """<lifelines.%s>""" % classname
return s


class UnivariateFitter(BaseFitter):

def _plot_estimate(self, *args):
Expand All @@ -46,7 +49,6 @@ def subtract(other):
subtract.__doc__ = doc_string
return subtract


def _divide(self, estimate):
class_name = self.__class__.__name__
doc_string = """
Expand All @@ -68,7 +70,6 @@ def divide(other):
divide.__doc__ = doc_string
return divide


def _predict(self, estimate, label):
class_name = self.__class__.__name__
doc_string = """
Expand All @@ -93,4 +94,21 @@ def predict(time):

@property
def conditional_time_to_event_(self):
return _conditional_time_to_event_(self)
return self._conditional_time_to_event_()

def _conditional_time_to_event_(self):
"""
Return a DataFrame, with index equal to survival_function_, that estimates the median
duration remaining until the death event, given survival up until time t. For example, if an
individual exists until age 1, their expected life remaining *given they lived to time 1*
might be 9 years.
Returns:
conditional_time_to_: DataFrame, with index equal to survival_function_
"""
age = self.survival_function_.index.values[:, None]
columns = ['%s - Conditional time remaining to event' % self._label]
return pd.DataFrame(qth_survival_times(self.survival_function_[self._label] * 0.5, self.survival_function_).T.sort(ascending=False).values,
index=self.survival_function_.index,
columns=columns) - age
5 changes: 1 addition & 4 deletions lifelines/aalen_additive_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
from __future__ import print_function
import numpy as np
import pandas as pd

from numpy import dot, exp
from numpy.linalg import LinAlgError
from scipy.integrate import trapz
from lifelines._base_fitter import BaseFitter
from lifelines.utils import _get_index, inv_normal_cdf, epanechnikov_kernel, \
ridge_regression as lr, qth_survival_times
ridge_regression as lr, qth_survival_times
from lifelines.progress_bar import progress_bar
from lifelines.plotting import plot_regressions

Expand Down Expand Up @@ -410,4 +408,3 @@ def predict_expectation(self, X):

def predict(self, X):
return self.predict_median(X)

5 changes: 1 addition & 4 deletions lifelines/breslow_fleming_harrington_fitter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# -*- coding: utf-8 -*-
from __future__ import print_function
import numpy as np
import pandas as pd

from lifelines._base_fitter import UnivariateFitter
from lifelines.nelson_aalen_fitter import NelsonAalenFitter
from lifelines.utils import preprocess_inputs, _additive_estimate, median_survival_times,\
inv_normal_cdf
from lifelines.utils import median_survival_times


class BreslowFlemingHarringtonFitter(UnivariateFitter):
Expand Down Expand Up @@ -69,4 +67,3 @@ def fit(self, durations, event_observed=None, timeline=None, entry=None,
self.plot = self._plot_estimate("survival_function_")
self.plot_survival_function = self.plot
return self

2 changes: 1 addition & 1 deletion lifelines/coxph_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from lifelines._base_fitter import BaseFitter
from lifelines.utils import survival_table_from_events, inv_normal_cdf, normalize,\
significance_code, concordance_index, _get_index, qth_survival_times
significance_code, concordance_index, _get_index, qth_survival_times


class CoxPHFitter(BaseFitter):
Expand Down
8 changes: 1 addition & 7 deletions lifelines/estimation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
# -*- coding: utf-8 -*-
from __future__ import print_function

import numpy as np
import pandas as pd

from lifelines._base_fitter import BaseFitter
from lifelines.weibull_fitter import WeibullFitter
from lifelines.exponential_fitter import ExponentialFitter
from lifelines.nelson_aalen_fitter import NelsonAalenFitter
from lifelines.kaplan_meier_fitter import KaplanMeierFitter
from lifelines.breslow_fleming_harrington_fitter import BreslowFlemingHarringtonFitter
from lifelines.coxph_fitter import CoxPHFitter
from lifelines.coxph_fitter import CoxPHFitter
from lifelines.aalen_additive_fitter import AalenAdditiveFitter

1 change: 1 addition & 0 deletions lifelines/exponential_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from lifelines._base_fitter import UnivariateFitter
from lifelines.utils import inv_normal_cdf


class ExponentialFitter(UnivariateFitter):

"""
Expand Down
10 changes: 5 additions & 5 deletions lifelines/kaplan_meier_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pandas as pd

from lifelines._base_fitter import UnivariateFitter
from lifelines.utils import preprocess_inputs, _additive_estimate, StatError, inv_normal_cdf,\
median_survival_times
from lifelines.utils import _preprocess_inputs, _additive_estimate, StatError, inv_normal_cdf,\
median_survival_times


class KaplanMeierFitter(UnivariateFitter):
Expand All @@ -28,7 +28,7 @@ def fit(self, durations, event_observed=None, timeline=None, entry=None, label='
event_observed: an array, or pd.Series, of length n -- True if the the death was observed, False if the event
was lost (right-censored). Defaults all True if event_observed==None
entry: an array, or pd.Series, of length n -- relative time when a subject entered the study. This is
useful for left-truncated (not left-censored) observations. If None, all members of the population
useful for left-truncated (not left-censored) observations. If None, all members of the population
were born at time 0.
label: a string to name the column of the estimate.
alpha: the alpha value in the confidence intervals. Overrides the initializing
Expand All @@ -44,7 +44,7 @@ def fit(self, durations, event_observed=None, timeline=None, entry=None, label='
"""
# if the user is interested in left-censorship, we return the cumulative_density_, no survival_function_,
estimate_name = 'survival_function_' if not left_censorship else 'cumulative_density_'
v = preprocess_inputs(durations, event_observed, timeline, entry)
v = _preprocess_inputs(durations, event_observed, timeline, entry)
self.durations, self.event_observed, self.timeline, self.entry, self.event_table = v
self._label = label
alpha = alpha if alpha else self.alpha
Expand Down Expand Up @@ -98,4 +98,4 @@ def _additive_f(self, population, deaths):

def _additive_var(self, population, deaths):
np.seterr(divide='ignore')
return (1. * deaths / (population * (population - deaths))).replace([np.inf], 0)
return (1. * deaths / (population * (population - deaths))).replace([np.inf], 0)
7 changes: 4 additions & 3 deletions lifelines/nelson_aalen_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import pandas as pd

from lifelines._base_fitter import UnivariateFitter
from lifelines.utils import preprocess_inputs, _additive_estimate, epanechnikov_kernel,\
inv_normal_cdf
from lifelines.utils import _preprocess_inputs, _additive_estimate, epanechnikov_kernel,\
inv_normal_cdf


class NelsonAalenFitter(UnivariateFitter):

Expand Down Expand Up @@ -55,7 +56,7 @@ def fit(self, durations, event_observed=None, timeline=None, entry=None,
"""

v = preprocess_inputs(durations, event_observed, timeline, entry)
v = _preprocess_inputs(durations, event_observed, timeline, entry)
self.durations, self.event_observed, self.timeline, self.entry, self.event_table = v

cumulative_hazard_, cumulative_sq_ = _additive_estimate(self.event_table, self.timeline,
Expand Down
2 changes: 1 addition & 1 deletion lifelines/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def plot(ix=None, iloc=None, flat=False, show_censors=False,
estimate_ = cls.smoothed_hazard_(bandwidth)
confidence_interval_ = \
cls.smoothed_hazard_confidence_intervals_(bandwidth,
hazard_=estimate_.values[:, 0])
hazard_=estimate_.values[:, 0])
else:
confidence_interval_ = getattr(cls, 'confidence_interval_')
estimate_ = getattr(cls, estimate)
Expand Down
Loading

0 comments on commit b59dbaa

Please sign in to comment.