Skip to content

Commit

Permalink
Merge e21f3f4 into 19af3ca
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Jan 10, 2016
2 parents 19af3ca + e21f3f4 commit 42161ff
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 86 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
### Changelogs

#### Forthcoming 0.9.0
- new prediction function in `CoxPHFitter`, `predict_log_hazard_relative_to_mean`, that mimics what R's `predict.coxph` does.
- removing the `predict` method in CoxPHFitter and AalenAdditiveFitter. This is because the choice of `predict_median` as a default was causing too much confusion, and no other natual choice as a default was available. All other `predict_` methods remain.
- Default predict method in `k_fold_cross_validation` is now `predict_expectation`

#### 0.8.1
- supports matplotlib 1.5.
- introduction of a param `nn_cumulative_hazards` in AalenAdditiveModel's `__init__` (default True). This parameter will truncate all non-negative cumulative hazards in prediction methods to 0.
Expand Down
61 changes: 55 additions & 6 deletions lifelines/fitters/aalen_additive_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

from lifelines.fitters import BaseFitter
from lifelines.utils import _get_index, inv_normal_cdf, epanechnikov_kernel, \
ridge_regression as lr, qth_survival_times
ridge_regression as lr, qth_survival_times, coalesce

from lifelines.utils.progress_bar import progress_bar
from lifelines.plotting import plot_regressions
from lifelines.plotting import fill_between_steps


class AalenAdditiveFitter(BaseFitter):
Expand Down Expand Up @@ -218,7 +219,6 @@ def _fit_static(self, dataframe, duration_col, event_col=None,
self.durations = T
self.event_observed = C
self._compute_confidence_intervals()
self.plot = plot_regressions(self)

return

Expand Down Expand Up @@ -310,7 +310,6 @@ def _fit_varying(self, dataframe, duration_col="T", event_col="E",
self.durations = T
self.event_observed = C
self._compute_confidence_intervals()
self.plot = plot_regressions(self)

return

Expand Down Expand Up @@ -415,5 +414,55 @@ def predict_expectation(self, X):
t = self.cumulative_hazards_.index
return pd.DataFrame(trapz(self.predict_survival_function(X)[index].values.T, t), index=index)

def predict(self, X):
return self.predict_median(X)
def plot(self, ix=None, iloc=None, columns=[], legend=True, **kwargs):
""""
A wrapper around plotting. Matplotlib plot arguments can be passed in, plus:
ix: specify a time-based subsection of the curves to plot, ex:
.plot(ix=slice(0.,10.)) will plot the time values between t=0. and t=10.
iloc: specify a location-based subsection of the curves to plot, ex:
.plot(iloc=slice(0,10)) will plot the first 10 time points.
columns: If not empty, plot a subset of columns from the cumulative_hazards_. Default all.
legend: show legend in figure.
"""
from matplotlib import pyplot as plt


def shaded_plot(ax, x, y, y_upper, y_lower, **kwargs):
base_line, = ax.plot(x, y, drawstyle='steps-post', **kwargs)
fill_between_steps(x, y_lower, y2=y_upper, ax=ax, alpha=0.25,
color=base_line.get_color(), linewidth=1.0)


assert (ix is None or iloc is None), 'Cannot set both ix and iloc in call to .plot'

get_method = "ix" if ix is not None else "iloc"
if iloc == ix is None:
user_submitted_ix = slice(0, None)
else:
user_submitted_ix = ix if ix is not None else iloc
get_loc = lambda df: getattr(df, get_method)[user_submitted_ix]

if len(columns) == 0:
columns = self.cumulative_hazards_.columns

if 'ax' in kwargs:
# don't use a .get here, as the default parameter will be called. In this case,
# plt.figure().add_subplot(111), which instantiates a new window
ax = kwargs['ax']
else:
ax = plt.figure().add_subplot(111)

x = get_loc(self.cumulative_hazards_).index.values.astype(float)

for column in columns:
y = get_loc(self.cumulative_hazards_[column]).values
y_upper = get_loc(self.confidence_intervals_[column].ix['upper']).values
y_lower = get_loc(self.confidence_intervals_[column].ix['lower']).values
shaded_plot(ax, x, y, y_upper, y_lower, label=kwargs.get('label', column))

if legend:
ax.legend()

return ax
46 changes: 36 additions & 10 deletions lifelines/fitters/coxph_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,9 @@ def print_summary(self):

def predict_partial_hazard(self, X):
"""
X: a (n,d) covariate matrix
X: a (n,d) covariate numpy array or DataFrame. If a DataFrame, columns
can be in any order. If a numpy array, columns must be in the
same order as the training data.
If covariates were normalized during fitting, they are normalized
in the same way here.
Expand All @@ -425,9 +427,24 @@ def predict_partial_hazard(self, X):

return pd.DataFrame(exp(np.dot(X, self.hazards_.T)), index=index)

def predict_log_hazard_relative_to_mean(self, X):
"""
X: a (n,d) covariate numpy array or DataFrame. If a DataFrame, columns
can be in any order. If a numpy array, columns must be in the
same order as the training data.
Returns the log hazard relative to the hazard of the mean covariates. This is the behaviour
of R's predict.coxph.
"""
mean_covariates = self.data.mean(0).to_frame().T
return np.log(self.predict_partial_hazard(X)/self.predict_partial_hazard(mean_covariates).squeeze())


def predict_cumulative_hazard(self, X):
"""
X: a (n,d) covariate matrix
X: a (n,d) covariate numpy array or DataFrame. If a DataFrame, columns
can be in any order. If a numpy array, columns must be in the
same order as the training data.
Returns the cumulative hazard for the individuals.
"""
Expand All @@ -438,39 +455,48 @@ def predict_cumulative_hazard(self, X):

def predict_survival_function(self, X):
"""
X: a (n,d) covariate matrix
X: a (n,d) covariate numpy array or DataFrame. If a DataFrame, columns
can be in any order. If a numpy array, columns must be in the
same order as the training data.
Returns the survival functions for the individuals
Returns the estimated survival functions for the individuals
"""
return exp(-self.predict_cumulative_hazard(X))

def predict_percentile(self, X, p=0.5):
"""
X: a (n,d) covariate matrix
Returns the median lifetimes for the individuals.
X: a (n,d) covariate numpy array or DataFrame. If a DataFrame, columns
can be in any order. If a numpy array, columns must be in the
same order as the training data.
By default, returns the median lifetimes for the individuals.
http://stats.stackexchange.com/questions/102986/percentile-loss-functions
"""
index = _get_index(X)
return qth_survival_times(p, self.predict_survival_function(X)[index])

def predict_median(self, X):
"""
X: a (n,d) covariate matrix
X: a (n,d) covariate numpy array or DataFrame. If a DataFrame, columns
can be in any order. If a numpy array, columns must be in the
same order as the training data.
Returns the median lifetimes for the individuals
"""
return self.predict_percentile(X, 0.5)

def predict_expectation(self, X):
"""
X: a (n,d) covariate numpy array or DataFrame. If a DataFrame, columns
can be in any order. If a numpy array, columns must be in the
same order as the training data.
Compute the expected lifetime, E[T], using covarites X.
"""
index = _get_index(X)
v = self.predict_survival_function(X)[index]
return pd.DataFrame(trapz(v.values.T, v.index), index=index)

def predict(self, X):
return self.predict_median(X)

def _compute_baseline_hazard(self):
# http://courses.nus.edu.sg/course/stacar/internet/st3242/handouts/notes3.pdf
ind_hazards = self.predict_partial_hazard(self.data).values
Expand Down
56 changes: 1 addition & 55 deletions lifelines/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,60 +211,6 @@ def plot_lifetimes(lifetimes, event_observed=None, birthtimes=None,
return


def shaded_plot(x, y, y_upper, y_lower, **kwargs):
from matplotlib import pyplot as plt

ax = kwargs.pop('ax', plt.gca())
base_line, = ax.plot(x, y, drawstyle='steps-post', **kwargs)
fill_between_steps(x, y_lower, y2=y_upper, ax=ax, alpha=0.25,
color=base_line.get_color(), linewidth=1.0)
return


def plot_regressions(self):
def plot(ix=None, iloc=None, columns=[], legend=True, **kwargs):
""""
A wrapper around plotting. Matplotlib plot arguments can be passed in, plus:
ix: specify a time-based subsection of the curves to plot, ex:
.plot(ix=slice(0.,10.)) will plot the time values between t=0. and t=10.
iloc: specify a location-based subsection of the curves to plot, ex:
.plot(iloc=slice(0,10)) will plot the first 10 time points.
columns: If not empty, plot a subset of columns from the cumulative_hazards_. Default all.
legend: show legend in figure.
"""
from matplotlib import pyplot as plt

assert (ix is None or iloc is None), 'Cannot set both ix and iloc in call to .plot'

get_method = "ix" if ix is not None else "iloc"
if iloc == ix is None:
user_submitted_ix = slice(0, None)
else:
user_submitted_ix = ix if ix is not None else iloc
get_loc = lambda df: getattr(df, get_method)[user_submitted_ix]

if len(columns) == 0:
columns = self.cumulative_hazards_.columns

if "ax" not in kwargs:
kwargs["ax"] = plt.figure().add_subplot(111)

x = get_loc(self.cumulative_hazards_).index.values.astype(float)
for column in columns:
y = get_loc(self.cumulative_hazards_[column]).values
y_upper = get_loc(self.confidence_intervals_[column].ix['upper']).values
y_lower = get_loc(self.confidence_intervals_[column].ix['lower']).values
shaded_plot(x, y, y_upper, y_lower, ax=kwargs["ax"], label=coalesce(kwargs.get('label'), column))

if legend:
kwargs["ax"].legend()

return kwargs["ax"]
return plot


def set_kwargs_ax(kwargs):
from matplotlib import pyplot as plt
if "ax" not in kwargs:
Expand All @@ -281,7 +227,7 @@ def set_kwargs_color(kwargs):
next(kwargs["ax"]._get_lines.color_cycle))

def set_kwargs_drawstyle(kwargs):
kwargs['drawstyle'] = coalesce(kwargs.get('drawstyle'), 'steps-post')
kwargs['drawstyle'] = kwargs.get('drawstyle', 'steps-post')


def create_dataframe_slicer(iloc, ix):
Expand Down
6 changes: 3 additions & 3 deletions lifelines/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,8 @@ def AandS_approximation(p):

def k_fold_cross_validation(fitters, df, duration_col, event_col=None,
k=5, evaluation_measure=concordance_index,
predictor="predict_median", fitter_kwargs={},
predictor_kwargs={}):
predictor="predict_expectation", predictor_kwargs={},
fitter_kwargs={}):
"""
Perform cross validation on a dataset. If multiple models are provided,
all models will train on each of the k subsets.
Expand All @@ -441,7 +441,7 @@ def k_fold_cross_validation(fitters, df, duration_col, event_col=None,
between two series of event times
predictor: a string that matches a prediction method on the fitter instances.
For example, "predict_expectation" or "predict_percentile".
Default is "predict_median"
Default is "predict_expectation"
The interface for the method is:
predict(self, data, **optional_kwargs)
fitter_kwargs: keyword args to pass into fitter.fit method
Expand Down
2 changes: 1 addition & 1 deletion lifelines/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import unicode_literals

__version__ = '0.8.1.0'
__version__ = '0.9.0.0'
47 changes: 36 additions & 11 deletions tests/test_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@
from collections import Counter, Iterable
import os

try:
from StringIO import StringIO as stringio, StringIO
except ImportError:
from io import StringIO, BytesIO as stringio

import numpy as np
import pandas as pd
import pytest



from pandas.util.testing import assert_frame_equal, assert_series_equal
import numpy.testing as npt

Expand Down Expand Up @@ -514,6 +521,13 @@ def test_BHF_fit(self):

class TestRegressionFitters():

def test_pickle(self, rossi):
from pickle import dump
for fitter in [CoxPHFitter, AalenAdditiveFitter]:
output = stringio()
f = fitter().fit(rossi, 'week', 'arrest')
dump(f, output)

def test_fit_methods_require_duration_col(self):
X = load_regression_dataset()

Expand Down Expand Up @@ -544,16 +558,15 @@ def test_fit_methods_can_accept_optional_event_col_param(self):

def test_predict_methods_in_regression_return_same_types(self):
X = load_regression_dataset()
x = X[X.columns - ['T', 'E']]

aaf = AalenAdditiveFitter()
cph = CoxPHFitter()

aaf.fit(X, duration_col='T', event_col='E')
cph.fit(X, duration_col='T', event_col='E')

for fit_method in ['predict_percentile', 'predict_median', 'predict_expectation', 'predict_survival_function', 'predict', 'predict_cumulative_hazard']:
assert isinstance(getattr(aaf, fit_method)(x), type(getattr(cph, fit_method)(x)))
for fit_method in ['predict_percentile', 'predict_median', 'predict_expectation', 'predict_survival_function', 'predict_cumulative_hazard']:
assert isinstance(getattr(aaf, fit_method)(X), type(getattr(cph, fit_method)(X)))

def test_duration_vector_can_be_normalized(self):
df = load_kidney_transplant()
Expand All @@ -575,13 +588,11 @@ def test_prediction_methods_respect_index(self, data_pred2):
cph.fit(data_pred2, duration_col='t', event_col='E')
npt.assert_array_equal(cph.predict_partial_hazard(x).index, expected_index)
npt.assert_array_equal(cph.predict_percentile(x).index, expected_index)
npt.assert_array_equal(cph.predict(x).index, expected_index)
npt.assert_array_equal(cph.predict_expectation(x).index, expected_index)

aaf = AalenAdditiveFitter()
aaf.fit(data_pred2, duration_col='t', event_col='E')
npt.assert_array_equal(aaf.predict_percentile(x).index, expected_index)
npt.assert_array_equal(aaf.predict(x).index, expected_index)
npt.assert_array_equal(aaf.predict_expectation(x).index, expected_index)


Expand All @@ -604,11 +615,6 @@ def test_summary(self, rossi):
def test_print_summary(self, rossi):

import sys
try:
from StringIO import StringIO
except:
from io import StringIO

saved_stdout = sys.stdout
try:
out = StringIO()
Expand Down Expand Up @@ -924,6 +930,25 @@ def test_strata_against_r_output(self, rossi):
assert abs(cp._log_likelihood - -436.9339) / 436.9339 < 0.01


def test_predict_log_hazard_relative_to_mean_with_normalization(self, rossi):
cox = CoxPHFitter(normalize=True)
cox.fit(rossi, 'week', 'arrest')

# they are equal because the data is normalized, so the mean of the covarites is all 0,
# thus exp(beta * 0) == 1, so exp(beta * X)/exp(beta * 0) = exp(beta * X)
assert_frame_equal(cox.predict_log_hazard_relative_to_mean(rossi), np.log(cox.predict_partial_hazard(rossi)))


def test_predict_log_hazard_relative_to_mean_without_normalization(self, rossi):
cox = CoxPHFitter(normalize=False)
cox.fit(rossi, 'week', 'arrest')
log_relative_hazards = cox.predict_log_hazard_relative_to_mean(rossi)
means = rossi.mean(0).to_frame().T
assert cox.predict_partial_hazard(means).values[0][0] != 1.0
assert_frame_equal(log_relative_hazards, np.log(cox.predict_partial_hazard(rossi) / cox.predict_partial_hazard(means).squeeze()))



class TestAalenAdditiveFitter():

def test_nn_cumulative_hazard_will_set_cum_hazards_to_0(self, rossi):
Expand Down Expand Up @@ -1118,7 +1143,7 @@ def test_crossval_for_aalen_add(self, data_pred2, data_pred1):

expected = 0.90
msg = "Expected min-mean c-index {:.2f} < {:.2f}"
assert np.mean(mean_scores) > expected, msg.format(expected, scores.mean())
assert np.mean(mean_scores) > expected, msg.format(expected, np.mean(scores))

def test_predict_cumulative_hazard_inputs(self, data_pred1):
aaf = AalenAdditiveFitter()
Expand Down

0 comments on commit 42161ff

Please sign in to comment.