Skip to content

Commit

Permalink
example + some improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Mar 27, 2019
1 parent 3d5a7b2 commit 117942b
Show file tree
Hide file tree
Showing 3 changed files with 692 additions and 78 deletions.
653 changes: 653 additions & 0 deletions examples/Piecewise exponential regression demo.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions lifelines/fitters/piecewise_exponential_fitter.py
Expand Up @@ -77,6 +77,7 @@ def __init__(self, breakpoints, *args, **kwargs):
raise ValueError("Do not add inf to the breakpoints.")

if breakpoints and breakpoints[0] <= 0:
raise ValueError("Starting breakpoint must be greater than 0.")

self.breakpoints = np.append(breakpoints, [np.inf])
n_breakpoints = len(self.breakpoints)
Expand Down
116 changes: 38 additions & 78 deletions lifelines/fitters/piecewise_exponential_regression_fitter.py
Expand Up @@ -52,48 +52,50 @@ class PiecewiseExponentialRegressionFitter(BaseFitter):
...
\end{cases}
You specify the breakpoints, :math:`\tau_i`, and *lifelines* will find the
optional values for the parameters.
and $\lambda_i(x) = \exp(\mathbf{\beta}_i x^T), \;\; \mathbf{\beta}_i = (\beta_{i,1}, \beta_{i,2}, ...)$. That is, each period has a hazard rate, $\lambda_i$ the is the exponential of a linear model. The parameters of each linear model are unique to that period - different periods have different parameters (later we will generalize this).
Why do I want a model like this? Well, it offers lots of flexibility (at the cost of efficiency though), but importantly I can see:
1. Influence of variables over time.
2. Looking at important variables at specific "drops" (or regime changes). For example, what variables cause the large drop at the start? What variables prevent death at the second billing?
3. Predictive power: since we model the hazard more accurately (we hope) than a simpler parametric form, we have better estimates of a subjects survival curve.
After calling the `.fit` method, you have access to properties like: ``params_``
A summary of the fit is available with the method ``print_summary()``
Parameters
-----------
breakpoints: list
a list of times when a new exponential model is constructed.
alpha: float, optional (default=0.05)
the level in the confidence intervals.
fit_intercept: boolean, optional (default=True)
Allow lifelines to add an intercept column of 1s to df, and ancillary_df if applicable.
penalizer: float, optional (default=0.0)
the penalizer coefficient to the size of the coefficients. See `l1_ratio`. Must be equal to or greater than 0.
Attributes
----------
cumulative_hazard_ : DataFrame
The estimated cumulative hazard (with custom timeline if provided)
confidence_interval_cumulative_hazard_ : DataFrame
The lower and upper confidence intervals for the cumulative hazard
hazard_ : DataFrame
The estimated hazard (with custom timeline if provided)
confidence_interval_hazard_ : DataFrame
The lower and upper confidence intervals for the hazard
survival_function_ : DataFrame
The estimated survival function (with custom timeline if provided)
confidence_interval_survival_function_ : DataFrame
The lower and upper confidence intervals for the survival function
cumumlative_density_ : DataFrame
The estimated cumulative density function (with custom timeline if provided)
confidence_interval_cumumlative_density_ : DataFrame
The lower and upper confidence intervals for the cumulative density
params_ : DataFrame
The estimated coefficients
confidence_intervals_ : DataFrame
The lower and upper confidence intervals for the coefficients
durations: Series
The event_observed variable provided
event_observed: Series
The event_observed variable provided
weights: Series
The event_observed variable provided
variance_matrix_ : numpy array
The variance matrix of the coefficients
median_: float
The median time to event
lambda_i_: float
The fitted parameter in the model, for i = 0, 1 ... n-1 breakpoints
durations: array
The durations provided
event_observed: array
The event_observed variable provided
timeline: array
The time line to use for plotting and indexing
entry: array or None
The entry array provided, or None
breakpoints: array
The provided breakpoints
standard_errors_: Series
the standard errors of the estimates
score_: float
the concordance index of the model.
"""

def __init__(self, breakpoints, alpha=0.05, penalizer=0.0, fit_intercept=True, *args, **kwargs):
Expand All @@ -117,8 +119,7 @@ def __init__(self, breakpoints, alpha=0.05, penalizer=0.0, fit_intercept=True, *
def _cumulative_hazard(self, params, T, X):
n = T.shape[0]
T = T.reshape((n, 1))
bp = self.breakpoints
M = np.minimum(np.tile(bp, (n, 1)), T)
M = np.minimum(np.tile(self.breakpoints, (n, 1)), T)
M = np.hstack([M[:, tuple([0])], np.diff(M, axis=1)])
lambdas_ = np.array(
[np.exp(-np.dot(X, params[self._LOOKUP_SLICE["lambda_%d_" % i]])) for i in range(self.n_breakpoints)]
Expand Down Expand Up @@ -373,7 +374,7 @@ def _compute_variance_matrix(self):
unit_scaled_variance_matrix_ = np.linalg.pinv(self._hessian_)
warning_text = dedent(
"""\
The hessian was not invertable. We will instead approximate it using the psuedo-inverse.
The Hessian was not invertible. We will instead approximate it using the pseudo-inverse.
It's advisable to not trust the variances reported, and to be suspicious of the
fitted parameters too. Perform plots of the cumulative hazard to help understand
Expand Down Expand Up @@ -742,53 +743,12 @@ def plot(self, columns=None, parameter=None, **errorbar_kwargs):

return ax

def plot_covariate_groups(self, covariates, values, plot_baseline=True, **kwargs):
"""
Produces a visual representation comparing the baseline survival curve of the model versus
what happens when a covariate(s) is varied over values in a group. This is useful to compare
subjects' survival as we vary covariate(s), all else being held equal. The baseline survival
curve is equal to the predicted survival curve at all average values in the original dataset.
Parameters
----------
covariates: string or list
a string (or list of strings) of the covariate in the original dataset that we wish to vary.
values: 1d or 2d iterable
an iterable of the values we wish the covariate to take on.
plot_baseline: bool
also display the baseline survival, defined as the survival at the mean of the original dataset.
kwargs:
pass in additional plotting commands
Returns
-------
ax: matplotlib axis, or list of axis'
the matplotlib axis that be edited.
Examples
---------
>>> from lifelines import datasets, WeibullAFTFitter
>>> rossi = datasets.load_rossi()
>>> wf = WeibullAFTFitter().fit(rossi, 'week', 'arrest')
>>> wf.plot_covariate_groups('prio', values=np.arange(0, 15), cmap='coolwarm')
>>> # multiple variables at once
>>> wf.plot_covariate_groups(['prio', 'paro'], values=[[0, 0], [5, 0], [10, 0], [0, 1], [5, 1], [10, 1]], cmap='coolwarm')
>>> # if you have categorical variables, you can simply things:
>>> wf.plot_covariate_groups(['dummy1', 'dummy2', 'dummy3'], values=np.eye(3))
"""
raise NotImplementedError()

def _prep_inputs_for_prediction_and_return_parameters(self, X):
X = X.copy()

if isinstance(X, pd.DataFrame):
X = X[self.params_["lambda_0_"].index]
if self.fit_intercept:
X["_intercept"] = 1.0
X = X[self.params_["lambda_0_"].index]

return np.array([np.exp(np.dot(X, self.params_["lambda_%d_" % i])) for i in range(self.n_breakpoints)])

0 comments on commit 117942b

Please sign in to comment.