Skip to content

Commit

Permalink
Merge pull request #295 from CamDavidsonPilon/fix-scaling-in-cox-ph-p…
Browse files Browse the repository at this point in the history
…rediction

Fix scaling in cox ph prediction
  • Loading branch information
CamDavidsonPilon committed Jun 6, 2017
2 parents 0fe0d10 + a54a3c9 commit 8e3272a
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,8 @@
### Changelogs

#### 0.10.1
- fix in internal normalization for `CoxPHFitter` predict methods.

#### 0.10.0
- corrected bug that was returning the wrong baseline survival and hazard values in `CoxPHFitter` when `normalize=True`.
- removed `normalize` kwarg in `CoxPHFitter`. This was causing lots of confusion for users, and added code complexity. It's really nice to be able to remove it.
Expand Down
3 changes: 2 additions & 1 deletion lifelines/fitters/coxph_fitter.py
Expand Up @@ -312,7 +312,7 @@ def fit(self, df, duration_col, event_col=None,
self.durations = T
self.event_observed = E

self.baseline_hazard_ = self._compute_baseline_hazards(normalize(df, 0, 1 / self._norm_std), T, E)
self.baseline_hazard_ = self._compute_baseline_hazards(df * self._norm_std + self._norm_mean, T, E)
self.baseline_cumulative_hazard_ = self.baseline_hazard_.cumsum()
self.baseline_survival_ = self._compute_baseline_survival()
return self
Expand Down Expand Up @@ -427,6 +427,7 @@ def predict_log_partial_hazard(self, X):
X = X[order]

index = _get_index(X)
X = normalize(X, self._norm_mean.values, 1)
return pd.DataFrame(np.dot(X, self.hazards_.T), index=index)

def predict_log_hazard_relative_to_mean(self, X):
Expand Down
2 changes: 1 addition & 1 deletion lifelines/version.py
@@ -1,3 +1,3 @@
from __future__ import unicode_literals

__version__ = '0.10.0'
__version__ = '0.10.1'
85 changes: 77 additions & 8 deletions tests/test_estimation.py
Expand Up @@ -21,7 +21,7 @@
NelsonAalenFitter, BreslowFlemingHarringtonFitter, ExponentialFitter, \
WeibullFitter, BaseFitter
from lifelines.datasets import load_larynx, load_waltons, load_kidney_transplant, load_rossi,\
load_lcd, load_panel_test, load_g3, load_holly_molly_polly
load_lcd, load_panel_test, load_g3, load_holly_molly_polly, load_regression_dataset
from lifelines.generate_datasets import generate_hazard_rates, generate_random_lifetimes, cumulative_integral
from lifelines.utils import concordance_index

Expand Down Expand Up @@ -93,6 +93,10 @@ def rossi():
return load_rossi()


@pytest.fixture
def regression_dataset():
return load_regression_dataset()


class TestBaseFitter():

Expand Down Expand Up @@ -683,16 +687,14 @@ def test_fit_method(self, data_nus):
assert np.abs(cf.hazards_.ix[0][0] - -0.0335) < 0.0001

def test_using_dataframes_vs_numpy_arrays(self, data_pred2):
# First without normalization
cf = CoxPHFitter()
cf.fit(data_pred2, 't', 'E')

X = data_pred2[cf.data.columns]
hazards = cf.predict_partial_hazard(X)

# A Numpy array should return the same result
hazards_n = cf.predict_partial_hazard(np.array(X))
assert np.all(hazards == hazards_n)
assert_frame_equal(
cf.predict_partial_hazard(np.array(X)),
cf.predict_partial_hazard(X)
)

def test_data_normalization(self, data_pred2):
# During fit, CoxPH copies the training data and normalizes it.
Expand Down Expand Up @@ -925,12 +927,79 @@ def test_hazard_works_as_intended_with_strata_against_R_output(self, rossi):
npt.assert_almost_equal(cp.baseline_cumulative_hazard_[(0, 0, 0, 0)].ix[[14, 35, 37, 43, 52]].values, [0.076600555, 0.169748261, 0.272088807, 0.396562717, 0.396562717], decimal=2)
npt.assert_almost_equal(cp.baseline_cumulative_hazard_[(0, 0, 0, 1)].ix[[27, 43, 48, 52]].values, [0.095499001, 0.204196905, 0.338393113, 0.338393113], decimal=2)

def test_baseline_survival_is_the_same_indp_of_location(self, regression_dataset):
df = regression_dataset.copy()
cp1 = CoxPHFitter()
cp1.fit(df, event_col='E', duration_col='T')

df_demeaned = regression_dataset.copy()
df_demeaned[['var1', 'var2', 'var3']] = df_demeaned[['var1', 'var2', 'var3']] - df_demeaned[['var1', 'var2', 'var3']].mean()
cp2 = CoxPHFitter()
cp2.fit(df_demeaned, event_col='E', duration_col='T')
assert_frame_equal(cp2.baseline_survival_, cp1.baseline_survival_)

def test_baseline_cumulative_hazard_is_the_same_indp_of_location(self, regression_dataset):
df = regression_dataset.copy()
cp1 = CoxPHFitter()
cp1.fit(df, event_col='E', duration_col='T')

df_demeaned = regression_dataset.copy()
df_demeaned[['var1', 'var2', 'var3']] = df_demeaned[['var1', 'var2', 'var3']] - df_demeaned[['var1', 'var2', 'var3']].mean()
cp2 = CoxPHFitter()
cp2.fit(df_demeaned, event_col='E', duration_col='T')
assert_frame_equal(cp2.baseline_cumulative_hazard_, cp1.baseline_cumulative_hazard_)

def test_survival_prediction_is_the_same_indp_of_location(self, regression_dataset):
df = regression_dataset.copy()

df_demeaned = regression_dataset.copy()
mean = df_demeaned[['var1', 'var2', 'var3']].mean()
df_demeaned[['var1', 'var2', 'var3']] = df_demeaned[['var1', 'var2', 'var3']] - mean

cp1 = CoxPHFitter()
cp1.fit(df, event_col='E', duration_col='T')

cp2 = CoxPHFitter()
cp2.fit(df_demeaned, event_col='E', duration_col='T')

assert_frame_equal(
cp1.predict_survival_function(df.ix[[0]][['var1', 'var2', 'var3']]),
cp2.predict_survival_function(df_demeaned.ix[[0]][['var1', 'var2', 'var3']])
)

def test_baseline_survival_is_the_same_indp_of_scale(self, regression_dataset):
df = regression_dataset.copy()
cp1 = CoxPHFitter()
cp1.fit(df, event_col='E', duration_col='T')

df_descaled = regression_dataset.copy()
df_descaled[['var1', 'var2', 'var3']] = df_descaled[['var1', 'var2', 'var3']] / df_descaled[['var1', 'var2', 'var3']].std()
cp2 = CoxPHFitter()
cp2.fit(df_descaled, event_col='E', duration_col='T')
assert_frame_equal(cp2.baseline_survival_, cp1.baseline_survival_)

def test_survival_prediction_is_the_same_indp_of_scale(self, regression_dataset):
df = regression_dataset.copy()

df_scaled = regression_dataset.copy()
df_scaled[['var1', 'var2', 'var3']] = df_scaled[['var1', 'var2', 'var3']] * 10.0

cp1 = CoxPHFitter()
cp1.fit(df, event_col='E', duration_col='T')

cp2 = CoxPHFitter()
cp2.fit(df_scaled, event_col='E', duration_col='T')

assert_frame_equal(
cp1.predict_survival_function(df.ix[[0]][['var1', 'var2', 'var3']]),
cp2.predict_survival_function(df_scaled.ix[[0]][['var1', 'var2', 'var3']])
)

def test_predict_log_hazard_relative_to_mean(self, rossi):
cox = CoxPHFitter()
cox.fit(rossi, 'week', 'arrest')
log_relative_hazards = cox.predict_log_hazard_relative_to_mean(rossi)
means = rossi.mean(0).to_frame().T
assert cox.predict_partial_hazard(means).values[0][0] != 1.0
assert_frame_equal(log_relative_hazards, np.log(cox.predict_partial_hazard(rossi) / cox.predict_partial_hazard(means).squeeze()))

def test_warning_is_raised_if_df_has_a_near_constant_column(self, rossi):
Expand Down

0 comments on commit 8e3272a

Please sign in to comment.