Skip to content

Commit

Permalink
adding .plot method to coxphfitter
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Jun 11, 2017
1 parent 8e3272a commit f8858e3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
4 changes: 2 additions & 2 deletions lifelines/fitters/aalen_additive_fitter.py
Expand Up @@ -32,15 +32,15 @@ class AalenAdditiveFitter(BaseFitter):
For example, this shrinks the absolute value of c_{i,t}. Recommended, even if a small value.
smoothing_penalizer: Attach a L2 penalizer to difference between adjacent (over time) coefficents. For
example, this shrinks the absolute value of c_{i,t} - c_{i,t+1}.
nn_cumulative_hazard: If True, forces the negative values in cumulative hazards to be 0 instead. Default True.
nn_cumulative_hazard: If True, forces the negative values in predicted cumulative hazards to be 0 instead. Default True.
"""

def __init__(self, fit_intercept=True, alpha=0.95, coef_penalizer=0.5, smoothing_penalizer=0., nn_cumulative_hazard=True):
if not (0 < alpha <= 1.):
raise ValueError('alpha parameter must be between 0 and 1.')
if coef_penalizer < 0 or smoothing_penalizer < 0:
raise ValueError("penalizer parameter must be >= 0.")
raise ValueError("penalizer parameters must be >= 0.")

self.fit_intercept = fit_intercept
self.alpha = alpha
Expand Down
32 changes: 31 additions & 1 deletion lifelines/fitters/coxph_fitter.py
Expand Up @@ -381,7 +381,7 @@ def print_summary(self):
print('n={}, number of events={}'.format(self.data.shape[0],
np.where(self.event_observed)[0].shape[0]),
end='\n\n')
print(df.to_string(float_format=lambda f: '{:.3e}'.format(f)))
print(df.to_string(float_format=lambda f: '{:4.4f}'.format(f)))
# Significance code explanation
print('---')
print("Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ",
Expand Down Expand Up @@ -540,3 +540,33 @@ def _compute_baseline_survival(self):
if self.strata is None:
survival_df.columns = ['baseline survival']
return survival_df

def plot(self, standardized=False):
"""
standardized: standardize each estimated coefficient and confidence interval endpoints by the standard error of the estimate.
"""
from matplotlib import pyplot as plt

ax = plt.figure().add_subplot(111)
yaxis_locations = xrange(len(self.hazards_.columns))

lower_bound = self.confidence_intervals_.loc['lower-bound'].copy()
upper_bound = self.confidence_intervals_.loc['upper-bound'].copy()
hazards = self.hazards_.values[0].copy()

if standardized:
se = self._compute_standard_errors().loc['se']
lower_bound /= se
upper_bound /= se
hazards /= se

order = np.argsort(hazards)
ax.scatter(upper_bound.values[order], yaxis_locations, marker='|', c='k')
ax.scatter(lower_bound.values[order], yaxis_locations, marker='|', c='k')
ax.scatter(hazards[order], yaxis_locations, marker='o', c='k')
ax.hlines(yaxis_locations, lower_bound.values[order], upper_bound.values[order], color='k', lw=1)

plt.yticks(yaxis_locations, self.hazards_.columns[order])
plt.xlabel("standardized coef" if standardized else "coef")
return ax

0 comments on commit f8858e3

Please sign in to comment.