Skip to content

Commit

Permalink
Merge pull request #275 from CamDavidsonPilon/baseline-survival-perf
Browse files Browse the repository at this point in the history
improve baseline survival performance
  • Loading branch information
CamDavidsonPilon committed Dec 31, 2016
2 parents 29cc72e + fe133d6 commit 1e5745c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 16 deletions.
22 changes: 7 additions & 15 deletions lifelines/fitters/coxph_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,23 +509,15 @@ def predict_expectation(self, X):

def _compute_baseline_hazard(self, data, durations, event_observed, name):
# http://courses.nus.edu.sg/course/stacar/internet/st3242/handouts/notes3.pdf
ind_hazards = self.predict_partial_hazard(data).values
ind_hazards = self.predict_partial_hazard(data)
ind_hazards['event_at'] = durations
ind_hazards_summed_over_durations = ind_hazards.groupby('event_at')[0].sum().sort_index(ascending=False).cumsum()
ind_hazards_summed_over_durations.name = 'hazards'

event_table = survival_table_from_events(durations, event_observed)

baseline_hazard_ = pd.DataFrame(np.zeros((event_table.shape[0], 1)),
index=event_table.index,
columns=[name])

for t, s in event_table.iterrows():
less = np.array(durations >= t)
if ind_hazards[less].sum() == 0:
v = 0
else:
v = (s['observed'] / ind_hazards[less].sum())
baseline_hazard_.ix[t] = v

return baseline_hazard_
event_table = event_table.join(ind_hazards_summed_over_durations)
baseline_hazard = pd.DataFrame(event_table['observed'] / event_table['hazards'], columns=[name]).fillna(0)
return baseline_hazard

def _compute_baseline_hazards(self, df, T, E):
if self.strata:
Expand Down
1 change: 0 additions & 1 deletion lifelines/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def survival_table_from_events(death_times, event_observed, birth_times=None,
births = pd.DataFrame(birth_times, columns=['event_at'])
births[entrance] = 1
births_table = births.groupby('event_at').sum()

event_table = death_table.join(births_table, how='outer', sort=True).fillna(0) # http://wesmckinney.com/blog/?p=414
event_table[at_risk] = event_table[entrance].cumsum() - event_table[removed].cumsum().shift(1).fillna(0)
return event_table.astype(int)
Expand Down

0 comments on commit 1e5745c

Please sign in to comment.