Skip to content

Commit

Permalink
Merge 57eb82c into 05bc785
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Dec 30, 2015
2 parents 05bc785 + 57eb82c commit 960f6bd
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 31 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Expand Up @@ -2,7 +2,8 @@

#### Forthcoming 0.9.0
- new prediction function in `CoxPHFitter`, `predict_log_hazard_relative_to_mean`, that mimics what R's `predict.coxph` does.
- changing the default `predict` method in lifelines to return _not the median_, but another value dependent on the fitter that is calling it. This is because too often the `predict_median` function was returning inf values which would significantly damage measure of concordence index.
- removing the `predict` method in CoxPHFitter and AalenAdditiveFitter. This is because the choice of `predict_median` as a default was causing too much confusion, and no other natual choice as a default was available. All other `predict_` methods remain.
- Default predict method in `k_fold_cross_validation` is now `predict_expectation`

#### 0.8.1
- supports matplotlib 1.5.
Expand Down
13 changes: 1 addition & 12 deletions lifelines/fitters/aalen_additive_fitter.py
Expand Up @@ -414,17 +414,6 @@ def predict_expectation(self, X):
t = self.cumulative_hazards_.index
return pd.DataFrame(trapz(self.predict_survival_function(X)[index].values.T, t), index=index)

def predict(self, X):
"""
X: a (n,d) covariate numpy array or DataFrame. If a DataFrame, columns
can be in any order. If a numpy array, columns must be in the
same order as the training data.
Returns the median lifetimes for the individuals. Alias for predict_median
"""
return self.predict_median(X)


def plot(self, ix=None, iloc=None, columns=[], legend=True, **kwargs):
""""
A wrapper around plotting. Matplotlib plot arguments can be passed in, plus:
Expand Down Expand Up @@ -476,4 +465,4 @@ def shaded_plot(ax, x, y, y_upper, y_lower, **kwargs):
if legend:
ax.legend()

return ax
return ax
12 changes: 0 additions & 12 deletions lifelines/fitters/coxph_fitter.py
Expand Up @@ -497,18 +497,6 @@ def predict_expectation(self, X):
v = self.predict_survival_function(X)[index]
return pd.DataFrame(trapz(v.values.T, v.index), index=index)

def predict(self, X):
"""
X: a (n,d) covariate numpy array or DataFrame. If a DataFrame, columns
can be in any order. If a numpy array, columns must be in the
same order as the training data.
Returns the median predicted lifetime. This is different from R's predict.coxph, which returns
the the linear predictor for the log-hazard relative to a mean survival estimate (this is
available in predict_log_hazard_relative_to_mean)
"""
return self.predict_median(X)

def _compute_baseline_hazard(self):
# http://courses.nus.edu.sg/course/stacar/internet/st3242/handouts/notes3.pdf
ind_hazards = self.predict_partial_hazard(self.data).values
Expand Down
4 changes: 2 additions & 2 deletions lifelines/utils/__init__.py
Expand Up @@ -416,7 +416,7 @@ def AandS_approximation(p):

def k_fold_cross_validation(fitters, df, duration_col, event_col=None,
k=5, evaluation_measure=concordance_index,
predictor="predict_median", predictor_kwargs={}):
predictor="predict_expectation", predictor_kwargs={}):
"""
Perform cross validation on a dataset. If multiple models are provided,
all models will train on each of the k subsets.
Expand All @@ -440,7 +440,7 @@ def k_fold_cross_validation(fitters, df, duration_col, event_col=None,
between two series of event times
predictor: a string that matches a prediction method on the fitter instances.
For example, "predict_expectation" or "predict_percentile".
Default is "predict_median"
Default is "predict_expectation"
The interface for the method is:
predict(self, data, **optional_kwargs)
predictor_kwargs: keyword args to pass into predictor-method.
Expand Down
6 changes: 2 additions & 4 deletions tests/test_estimation.py
Expand Up @@ -565,7 +565,7 @@ def test_predict_methods_in_regression_return_same_types(self):
aaf.fit(X, duration_col='T', event_col='E')
cph.fit(X, duration_col='T', event_col='E')

for fit_method in ['predict_percentile', 'predict_median', 'predict_expectation', 'predict_survival_function', 'predict', 'predict_cumulative_hazard']:
for fit_method in ['predict_percentile', 'predict_median', 'predict_expectation', 'predict_survival_function', 'predict_cumulative_hazard']:
assert isinstance(getattr(aaf, fit_method)(X), type(getattr(cph, fit_method)(X)))

def test_duration_vector_can_be_normalized(self):
Expand All @@ -588,13 +588,11 @@ def test_prediction_methods_respect_index(self, data_pred2):
cph.fit(data_pred2, duration_col='t', event_col='E')
npt.assert_array_equal(cph.predict_partial_hazard(x).index, expected_index)
npt.assert_array_equal(cph.predict_percentile(x).index, expected_index)
npt.assert_array_equal(cph.predict(x).index, expected_index)
npt.assert_array_equal(cph.predict_expectation(x).index, expected_index)

aaf = AalenAdditiveFitter()
aaf.fit(data_pred2, duration_col='t', event_col='E')
npt.assert_array_equal(aaf.predict_percentile(x).index, expected_index)
npt.assert_array_equal(aaf.predict(x).index, expected_index)
npt.assert_array_equal(aaf.predict_expectation(x).index, expected_index)


Expand Down Expand Up @@ -1145,7 +1143,7 @@ def test_crossval_for_aalen_add(self, data_pred2, data_pred1):

expected = 0.90
msg = "Expected min-mean c-index {:.2f} < {:.2f}"
assert np.mean(mean_scores) > expected, msg.format(expected, scores.mean())
assert np.mean(mean_scores) > expected, msg.format(expected, np.mean(scores))

def test_predict_cumulative_hazard_inputs(self, data_pred1):
aaf = AalenAdditiveFitter()
Expand Down

0 comments on commit 960f6bd

Please sign in to comment.