Skip to content

Commit

Permalink
Merge 33f680d into 8136156
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Dec 29, 2016
2 parents 8136156 + 33f680d commit cca18ac
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 17 deletions.
47 changes: 34 additions & 13 deletions lifelines/fitters/coxph_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def fit(self, df, duration_col, event_col=None,
self.durations = T
self.event_observed = E

self.baseline_hazard_ = self._compute_baseline_hazard()
self.baseline_hazard_ = self._compute_baseline_hazards(df, T, E)
self.baseline_cumulative_hazard_ = self.baseline_hazard_.cumsum()
self.baseline_survival_ = exp(-self.baseline_cumulative_hazard_)
return self
Expand Down Expand Up @@ -447,13 +447,21 @@ def predict_cumulative_hazard(self, X):
X: a (n,d) covariate numpy array or DataFrame. If a DataFrame, columns
can be in any order. If a numpy array, columns must be in the
same order as the training data.
Returns the cumulative hazard for the individuals.
"""
v = self.predict_partial_hazard(X)
s_0 = self.baseline_survival_
col = _get_index(X)
return pd.DataFrame(-np.dot(np.log(s_0), v.T), index=self.baseline_survival_.index, columns=col)
if self.strata:
cumulative_hazard_ = pd.DataFrame()
for stratum, stratified_X in X.groupby(self.strata):
s_0 = self.baseline_survival_[[stratum]]
col = _get_index(stratified_X)
v = self.predict_partial_hazard(stratified_X)
cumulative_hazard_ = cumulative_hazard_.merge(pd.DataFrame(-np.dot(np.log(s_0), v.T), index=s_0.index, columns=col), how='outer', right_index=True, left_index=True)
else:
v = self.predict_partial_hazard(X)
s_0 = self.baseline_survival_
col = _get_index(X)
cumulative_hazard_ = pd.DataFrame(-np.dot(np.log(s_0), v.T), columns=col, index=s_0.index)

return cumulative_hazard_

def predict_survival_function(self, X):
"""
Expand Down Expand Up @@ -499,23 +507,36 @@ def predict_expectation(self, X):
v = self.predict_survival_function(X)[index]
return pd.DataFrame(trapz(v.values.T, v.index), index=index)

def _compute_baseline_hazard(self):
def _compute_baseline_hazard(self, data, durations, event_observed, name):
# http://courses.nus.edu.sg/course/stacar/internet/st3242/handouts/notes3.pdf
ind_hazards = self.predict_partial_hazard(self.data).values
ind_hazards = self.predict_partial_hazard(data).values

event_table = survival_table_from_events(self.durations.values,
self.event_observed.values)
event_table = survival_table_from_events(durations, event_observed)

baseline_hazard_ = pd.DataFrame(np.zeros((event_table.shape[0], 1)),
index=event_table.index,
columns=['baseline hazard'])
columns=[name])

for t, s in event_table.iterrows():
less = np.array(self.durations >= t)
less = np.array(durations >= t)
if ind_hazards[less].sum() == 0:
v = 0
else:
v = (s['observed'] / ind_hazards[less].sum())
baseline_hazard_.ix[t] = v

return baseline_hazard_

def _compute_baseline_hazards(self, df, T, E):
if self.strata:
baseline_hazards_ = pd.DataFrame(index=self.durations.unique())
for stratum in df.index.unique():
baseline_hazards_ = baseline_hazards_.merge(
self._compute_baseline_hazard(data=df.ix[[stratum]], durations=T.ix[[stratum]], event_observed=E.ix[[stratum]], name=stratum),
left_index=True,
right_index=True,
how='left')
return baseline_hazards_.fillna(0)

else:
return self._compute_baseline_hazard(data=df, durations=T, event_observed=E, name='baseline hazard')
16 changes: 15 additions & 1 deletion tests/test_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,9 +916,10 @@ def test_strata_works_if_only_a_single_element_is_in_the_strata(self):

def test_strata_against_r_output(self, rossi):
"""
> library(survival)
> ross = read.csv('rossi.csv')
> r = coxph(formula = Surv(week, arrest) ~ fin + age + strata(race,
paro, mar, wexp) + prio, data = rossi)
> r
> r$loglik
"""

Expand All @@ -928,6 +929,19 @@ def test_strata_against_r_output(self, rossi):
npt.assert_almost_equal(cp.summary['coef'].values, [-0.335, -0.059, 0.100], decimal=3)
assert abs(cp._log_likelihood - -436.9339) / 436.9339 < 0.01

def test_hazard_works_as_intended_with_strata_against_R_output(self, rossi):
"""
> library(survival)
> ross = read.csv('rossi.csv')
> r = coxph(formula = Surv(week, arrest) ~ fin + age + strata(race,
paro, mar, wexp) + prio, data = rossi)
> basehaz(r, centered=FALSE)
"""
cp = CoxPHFitter(normalize=False)
cp.fit(rossi, 'week', 'arrest', strata=['race', 'paro', 'mar', 'wexp'])
npt.assert_almost_equal(cp.baseline_cumulative_hazard_[(0, 0, 0, 0)].ix[[14, 35, 37, 43, 52]].values, [0.28665890, 0.63524149, 1.01822603, 1.48403930, 1.48403930], decimal=2)
npt.assert_almost_equal(cp.baseline_cumulative_hazard_[(0, 0, 0, 1)].ix[[27, 43, 48, 52]].values, [0.35738173, 0.76415714, 1.26635373, 1.26635373], decimal=2)

def test_predict_log_hazard_relative_to_mean_with_normalization(self, rossi):
cox = CoxPHFitter(normalize=True)
cox.fit(rossi, 'week', 'arrest')
Expand Down
8 changes: 5 additions & 3 deletions tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,15 @@ def test_group_survival_table_from_events_on_waltons_data():
assert all(removed.index == observed.index)
assert all(removed.index == censored.index)


def test_survival_table_from_events_at_risk_column():
df = load_waltons()
# from R
expected = [163.0, 162.0, 160.0, 157.0, 154.0, 152.0, 151.0, 148.0, 144.0, 139.0, 134.0, 133.0, 130.0, 128.0, 126.0, 119.0, 118.0,
expected = [163.0, 162.0, 160.0, 157.0, 154.0, 152.0, 151.0, 148.0, 144.0, 139.0, 134.0, 133.0, 130.0, 128.0, 126.0, 119.0, 118.0,
108.0, 107.0, 99.0, 96.0, 89.0, 87.0, 69.0, 65.0, 49.0, 38.0, 36.0, 27.0, 24.0, 14.0, 1.0]
df = utils.survival_table_from_events(df['T'], df['E'])
assert list(df['at_risk'][1:]) == expected # skip the first event as that is the birth time, 0.
assert list(df['at_risk'][1:]) == expected # skip the first event as that is the birth time, 0.


def test_survival_table_to_events_casts_to_float():
T, C = np.array([1, 2, 3, 4, 4, 5]), np.array([True, False, True, True, True, True])
Expand Down Expand Up @@ -306,7 +308,7 @@ def test_both_concordance_index_function_deal_with_ties_the_same_way():
actual_times = np.array([1, 1, 2])
predicted_times = np.array([1, 2, 3])
obs = np.ones(3)
assert fast_cindex(actual_times, predicted_times, obs) == slow_cindex(actual_times, predicted_times, obs) == 1.0
assert fast_cindex(actual_times, predicted_times, obs) == slow_cindex(actual_times, predicted_times, obs) == 1.0


def test_survival_table_from_events_with_non_negative_T_and_no_lagged_births():
Expand Down

0 comments on commit cca18ac

Please sign in to comment.