Skip to content

Commit

Permalink
nvm
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed May 20, 2020
1 parent b6de77e commit 01be2c6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
10 changes: 7 additions & 3 deletions lifelines/fitters/kaplan_meier_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,10 @@ def fit_interval_censoring(
label=None,
alpha=None,
ci_labels=None,
show_progress=False,
entry=None,
weights=None,
tol=1e-7,
tol: float = 1e-5,
show_progress: bool = False,
) -> "KaplanMeierFitter":
"""
Fit the model to a interval-censored dataset using non-parametric MLE. This estimator is
Expand Down Expand Up @@ -168,6 +168,10 @@ def fit_interval_censoring(
if providing a weighted dataset. For example, instead
of providing every subject as a single element of `durations` and `event_observed`, one could
weigh subject differently.
tol: float, optional
minimum difference in log likelihood changes for iterative algorithm.
show_progress: bool, optional
display information during fitting.
Returns
-------
Expand Down Expand Up @@ -203,7 +207,7 @@ def fit_interval_censoring(

self._label = coalesce(label, self._label, "NPMLE_estimate")

results = npmle(self.lower_bound, self.upper_bound, verbose=show_progress)
results = npmle(self.lower_bound, self.upper_bound, verbose=show_progress, tol=tol)
self.survival_function_ = reconstruct_survival_function(*results, self.timeline, label=self._label).loc[self.timeline]
self.cumulative_density_ = 1 - self.survival_function_

Expand Down
11 changes: 5 additions & 6 deletions lifelines/fitters/npmle.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def create_turnbull_lookup(turnbull_intervals, observation_intervals):
return {o: list(s) for o, s in turnbull_lookup.items()}


def check_convergence(p_new, p_old, turnbull_lookup, weights, tol, i, verbose=False):
def check_convergence(p_new, p_old, tol, i, verbose=False):
if verbose:
print("Iteration %d: delta: %.6f" % (i, norm(p_new - p_old)))
if log_likelihood(p_new, turnbull_lookup, weights) - log_likelihood(p_old, turnbull_lookup, weights) < tol:
if norm(p_new - p_old) < tol:
return True
return False

Expand All @@ -120,7 +120,7 @@ def probs(o):
return o / (o + 1)


def npmle(left, right, tol=1e-10, weights=None, verbose=False, max_iter=1e5):
def npmle(left, right, tol=1e-5, weights=None, verbose=False, max_iter=1e5):
"""
left and right are closed intervals.
TODO: extend this to open-closed intervals.
Expand Down Expand Up @@ -150,17 +150,16 @@ def npmle(left, right, tol=1e-10, weights=None, verbose=False, max_iter=1e5):

while (not converged) and (i < max_iter):
p_new = E_step_M_step(observation_intervals, p, turnbull_lookup, weights)
converged = check_convergence(p_new, p, turnbull_lookup, weights, tol, i, verbose=verbose)
converged = check_convergence(p_new, p, tol, i, verbose=verbose)

# find alpha that maximizes ll using a line search
best_alpha, best_p, best_ll = None, None, -np.inf
best_p, best_ll = None, -np.inf
delta = odds(p_new) - odds(p)
for alpha in np.array([1.0, 1.25, 1.75, 2.5]):
p_temp = probs(odds(p) + alpha * delta)
ll_temp = log_likelihood(p_temp, turnbull_lookup, weights)
if best_ll < ll_temp:
best_ll = ll_temp
best_alpha = alpha
best_p = p_temp

p = best_p
Expand Down

0 comments on commit 01be2c6

Please sign in to comment.