Skip to content

Commit

Permalink
finish example
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed May 16, 2019
1 parent 5d66f3d commit 220dcb9
Showing 1 changed file with 32 additions and 20 deletions.
52 changes: 32 additions & 20 deletions lifelines/fitters/piecewise_exponential_regression_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,25 +203,37 @@ def fit(
Examples
--------
TODO
>>> from lifelines import WeibullAFTFitter
>>>
>>> df = pd.DataFrame({
>>> 'T': [5, 3, 9, 8, 7, 4, 4, 3, 2, 5, 6, 7],
>>> 'E': [1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0],
>>> 'var': [0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2],
>>> 'age': [4, 3, 9, 8, 7, 4, 4, 3, 2, 5, 6, 7],
>>> })
>>>
>>> aft = WeibullAFTFitter()
>>> aft.fit(df, 'T', 'E')
>>> aft.print_summary()
>>> aft.predict_median(df)
>>>
>>> aft = WeibullAFTFitter()
>>> aft.fit(df, 'T', 'E', ancillary_df=df)
>>> aft.print_summary()
>>> aft.predict_median(df)
>>> N, d = 80000, 2
>>> # some numbers take from http://statwonk.com/parametric-survival.html
>>> breakpoints = (1, 31, 34, 62, 65)
>>> betas = np.array(
>>> [
>>> [1.0, -0.2, np.log(15)],
>>> [5.0, -0.4, np.log(333)],
>>> [9.0, -0.6, np.log(18)],
>>> [5.0, -0.8, np.log(500)],
>>> [2.0, -1.0, np.log(20)],
>>> [1.0, -1.2, np.log(500)],
>>> ]
>>> )
>>> X = 0.1 * np.random.exponential(size=(N, d))
>>> X = np.c_[X, np.ones(N)]
>>> T = np.empty(N)
>>> for i in range(N):
>>> lambdas = np.exp(-betas.dot(X[i, :]))
>>> T[i] = piecewise_exponential_survival_data(1, breakpoints, lambdas)[0]
>>> T_censor = np.minimum(
>>> T.mean() * np.random.exponential(size=N), 110
>>> ) # 110 is the end of observation, eg. current time.
>>> df = pd.DataFrame(X[:, :-1], columns=["var1", "var2"])
>>> df["T"] = np.round(np.maximum(np.minimum(T, T_censor), 0.1), 1)
>>> df["E"] = T <= T_censor
>>> pew = PiecewiseExponentialRegressionFitter(breakpoints=breakpoints, penalizer=0.0001).fit(df, "T", "E")
>>> pew.print_summary()
>>> pew.plot()
"""
if duration_col is None:
Expand Down Expand Up @@ -279,7 +291,7 @@ def fit(
assert "_intercept" not in df
df["_intercept"] = 1.0

self._LOOKUP_SLICE = self._create_slicer(len(df.columns)) # TODO
self._LOOKUP_SLICE = self._create_slicer(len(df.columns))

_norm_std = df.std(0)
self._norm_mean = df.mean(0)
Expand Down

0 comments on commit 220dcb9

Please sign in to comment.