diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py index 6205c36..141d5e3 100644 --- a/src/causalprog/solvers/sgd.py +++ b/src/causalprog/solvers/sgd.py @@ -8,6 +8,7 @@ import numpy.typing as npt import optax +from causalprog.solvers.solver_result import SolverResult from causalprog.utils.norms import PyTree, l2_normsq @@ -22,7 +23,7 @@ def stochastic_gradient_descent( maxiter: int = 1000, optimiser: optax.GradientTransformationExtraArgs | None = None, tolerance: float = 1.0e-8, -) -> tuple[PyTree, npt.ArrayLike, npt.ArrayLike, int]: +) -> SolverResult: """ Minimise a function of one argument using Stochastic Gradient Descent (SGD). @@ -87,20 +88,33 @@ def objective(x: npt.ArrayLike) -> npt.ArrayLike: def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: return convergence_criteria(x, dx) < tolerance - gradient = jax.grad(objective) - opt_state = optimiser.init(initial_guess) + converged = False + opt_state = optimiser.init(initial_guess) current_params = deepcopy(initial_guess) - gradient_value = gradient(current_params) - for i in range(maxiter): - updates, opt_state = optimiser.update(gradient_value, opt_state) - current_params = optax.apply_updates(current_params, updates) + gradient = jax.grad(objective) + for _ in range(maxiter + 1): objective_value = objective(current_params) gradient_value = gradient(current_params) - if is_converged(objective_value, gradient_value): - return current_params, objective_value, gradient_value, i + 1 + if converged := is_converged(objective_value, gradient_value): + break + + updates, opt_state = optimiser.update(gradient_value, opt_state) + current_params = optax.apply_updates(current_params, updates) - msg = f"Did not converge after {i + 1} iterations." - raise RuntimeError(msg) + iters_used = _ + reason_msg = ( + f"Did not converge after {iters_used} iterations" if not converged else "" + ) + + return SolverResult( + fn_args=current_params, + grad_val=gradient_value, + iters=iters_used, + maxiter=maxiter, + obj_val=objective_value, + reason=reason_msg, + successful=converged, + ) diff --git a/src/causalprog/solvers/solver_result.py b/src/causalprog/solvers/solver_result.py new file mode 100644 index 0000000..eb09457 --- /dev/null +++ b/src/causalprog/solvers/solver_result.py @@ -0,0 +1,38 @@ +"""Container class for outputs from solver methods.""" + +from dataclasses import dataclass + +import numpy.typing as npt + +from causalprog.utils.norms import PyTree + + +@dataclass(frozen=True) +class SolverResult: + """ + Container class for outputs from solver methods. + + Instances of this class provide a container for useful information that + comes out of running one of the solver methods on a causal problem. + + Attributes: + fn_args: Argument to the objective function at final iteration (the solution, + if `successful is `True`). + grad_val: Value of the gradient of the objective function at the `fn_args`. + iters: Number of iterations performed. + maxiter: Maximum number of iterations the solver was permitted to perform. + obj_val: Value of the objective function at `fn_args`. + reason: Human-readable string explaining success or reasons for solver failure. + successful: `True` if solver converged, in which case `fn_args` is the + argument to the objective function at the solution of the problem being + solved. `False` otherwise. + + """ + + fn_args: PyTree + grad_val: PyTree + iters: int + maxiter: int + obj_val: npt.ArrayLike + reason: str + successful: bool diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index c5bcb03..3d7db90 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -107,7 +107,7 @@ def objective(x, key): } l_mult = jnp.atleast_1d(lagrange_mult_sol) - opt_params, _, _, _ = stochastic_gradient_descent( + result = stochastic_gradient_descent( objective, (params, l_mult), convergence_criteria=lambda x, _: jnp.abs(x), @@ -116,8 +116,10 @@ def objective(x, key): maxiter=maxiter, tolerance=minimisation_tolerance, ) + assert result.successful, "SGD did not converge." + # Unpack concatenated arguments - params, l_mult = opt_params + params, l_mult = result.fn_args # The lagrangian is independent of nu_x, thus it should not have changed value. assert jnp.isclose(params["cov2"], nu_x_starting_value), ( diff --git a/tests/test_solvers/test_sgd.py b/tests/test_solvers/test_sgd.py index 7d87696..f602ee6 100644 --- a/tests/test_solvers/test_sgd.py +++ b/tests/test_solvers/test_sgd.py @@ -28,7 +28,7 @@ pytest.param( lambda x: (x**2).sum(), jnp.atleast_1d(10.0), - RuntimeError("Did not converge after 1 iterations"), + "Did not converge after 1 iterations", {"maxiter": 1}, id="Reaches iteration limit", ), @@ -67,8 +67,7 @@ def test_sgd( obj_fn: Callable[[PyTree], npt.ArrayLike], initial_guess: PyTree, kwargs_to_sgd: dict[str, Any], - expected: PyTree | Exception, - raises_context, + expected: PyTree | str, ) -> None: """Test the SGD method on a (deterministic) problem. @@ -79,12 +78,12 @@ def test_sgd( if not kwargs_to_sgd: kwargs_to_sgd = {} - if isinstance(expected, Exception): - with raises_context(expected): - stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd) - else: - result = stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd)[0] + result = stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd) + if isinstance(expected, str): + assert not result.successful + assert result.reason == expected + else: assert jax.tree_util.tree_all( - jax.tree_util.tree_map(jax.numpy.allclose, result, expected) + jax.tree_util.tree_map(jax.numpy.allclose, result.fn_args, expected) )