Skip to content

Commit

Permalink
ENH deal with tiny loss improvements in line search (#724)
Browse files Browse the repository at this point in the history
* ENH deal with tiny loss improvements in line search

* DOC add changelog entry
  • Loading branch information
lorentzenchr committed Nov 8, 2023
1 parent d4b7b5d commit a304d55
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Unreleased

- Require Python>=3.9 in line with `NEP 29 <https://numpy.org/neps/nep-0029-deprecation_policy.html#support-table>`_
- Build and test with Python 3.12 in CI.
- Added line search stopping criterion for tiny loss improvements based on gradient information.

2.6.0 - 2023-09-05
------------------
Expand Down
30 changes: 28 additions & 2 deletions src/glum/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
enet_coordinate_descent_gram,
identify_active_rows,
)
from ._distribution import ExponentialDispersionModel
from ._distribution import ExponentialDispersionModel, get_one_over_variance
from ._link import Link
from ._util import _safe_lin_pred, _safe_sandwich_dot

Expand Down Expand Up @@ -758,6 +758,7 @@ def line_search(state: IRLSState, data: IRLSData, d: np.ndarray):
"""
# line search parameters
(beta, sigma) = (0.5, 0.0001)
eps = 16 * np.finfo(state.obj_val.dtype).eps

# line search by sequence beta^k, k=0, 1, ..
# F(w + lambda d) - F(w) <= lambda * bound
Expand All @@ -771,6 +772,9 @@ def line_search(state: IRLSState, data: IRLSData, d: np.ndarray):
# Note: the L2 penalty term is included in the score.
bound = sigma * (-(state.score @ d) + P1wd_1 - P1w_1)

# np.sum(np.abs(state.score))
sum_abs_grad_old = -1 # defer calculation

# The step direction in row space. We'll be multiplying this by varying
# step sizes during the line search. Factoring this matrix-vector product
# out of the inner loop improve performance a lot!
Expand All @@ -785,8 +789,30 @@ def line_search(state: IRLSState, data: IRLSData, d: np.ndarray):
eta_wd, mu_wd, obj_val_wd, coef_wd_P2 = _update_predictions(
state, data, coef_wd, X_dot_d, factor=factor
)
if (mu_wd.max() < 1e43) and (obj_val_wd - state.obj_val <= factor * bound):
# 1. Check Armijo / sufficient decrease condition.
loss_improvement = obj_val_wd - state.obj_val
if mu_wd.max() < 1e43 and loss_improvement <= factor * bound:
break
# 2. Deal with relative loss differences around machine precision.
tiny_loss = np.abs(state.obj_val * eps)
if np.abs(loss_improvement) <= tiny_loss:
if sum_abs_grad_old < 0:
sum_abs_grad_old = linalg.norm(state.score, ord=1)
# 2.1 Check sum of absolute gradients as alternative condition.
# Therefore, we need the recent gradient, see update_quadratic.
sigma_inv = get_one_over_variance(
data.family, data.link, mu_wd, eta_wd, 1.0, data.sample_weight
)
d1 = data.link.inverse_derivative(eta_wd) # = h'(eta)
d1_sigma_inv = d1 * sigma_inv
gradient_rows = d1_sigma_inv * (data.y - mu_wd)
grad = gradient_rows @ data.X
if data.fit_intercept:
grad = np.concatenate(([gradient_rows.sum()], grad))
grad -= coef_wd_P2
sum_abs_grad = linalg.norm(grad, ord=1)
if sum_abs_grad < sum_abs_grad_old:
break
factor *= beta
else:
warnings.warn(
Expand Down

0 comments on commit a304d55

Please sign in to comment.