From c082f43de787fc08cd1406ed1b82fcf68be4fbf5 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 24 Sep 2025 08:26:08 +0100 Subject: [PATCH 1/6] Container class for solver results --- src/causalprog/solvers/solver_result.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 src/causalprog/solvers/solver_result.py diff --git a/src/causalprog/solvers/solver_result.py b/src/causalprog/solvers/solver_result.py new file mode 100644 index 0000000..1d06cff --- /dev/null +++ b/src/causalprog/solvers/solver_result.py @@ -0,0 +1,25 @@ +"""Container class for outputs from solver methods.""" + +from dataclasses import dataclass + +import numpy.typing as npt + +from causalprog.utils.norms import PyTree + + +@dataclass(frozen=True) +class SolverResult: + """ + Container class for outputs from solver methods. + + Instances of this class provide a container for useful information that + comes out of running one of the solver methods on a causal problem. + """ + + successful: bool + reason: str + parameters: PyTree + obj_val: npt.ArrayLike + grad_val: PyTree + iters: int + maxiters: int From e3a60e13c5d27131d858d6312d3c2b9003aa504a Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 24 Sep 2025 08:38:34 +0100 Subject: [PATCH 2/6] SGD uses solver_result return value --- src/causalprog/solvers/sgd.py | 34 +++++++++++++++++-------- src/causalprog/solvers/solver_result.py | 23 +++++++++++++---- 2 files changed, 42 insertions(+), 15 deletions(-) diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py index 6205c36..153a9aa 100644 --- a/src/causalprog/solvers/sgd.py +++ b/src/causalprog/solvers/sgd.py @@ -8,6 +8,7 @@ import numpy.typing as npt import optax +from causalprog.solvers.solver_result import SolverResult from causalprog.utils.norms import PyTree, l2_normsq @@ -87,20 +88,33 @@ def objective(x: npt.ArrayLike) -> npt.ArrayLike: def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: return convergence_criteria(x, dx) < tolerance - gradient = jax.grad(objective) - opt_state = optimiser.init(initial_guess) + converged = False + opt_state = optimiser.init(initial_guess) current_params = deepcopy(initial_guess) - gradient_value = gradient(current_params) - for i in range(maxiter): - updates, opt_state = optimiser.update(gradient_value, opt_state) - current_params = optax.apply_updates(current_params, updates) + gradient = jax.grad(objective) + for _ in range(maxiter + 1): objective_value = objective(current_params) gradient_value = gradient(current_params) - if is_converged(objective_value, gradient_value): - return current_params, objective_value, gradient_value, i + 1 + if converged := is_converged(objective_value, gradient_value): + break + + updates, opt_state = optimiser.update(gradient_value, opt_state) + current_params = optax.apply_updates(current_params, updates) - msg = f"Did not converge after {i + 1} iterations." - raise RuntimeError(msg) + iters_used = _ + reason_msg = ( + f"Did not converge after {iters_used} iterations." if not converged else "" + ) + + return SolverResult( + arg_result=current_params, + grad_val=gradient_value, + iters=iters_used, + maxiter=maxiter, + obj_val=objective_value, + reason=reason_msg, + successful=converged, + ) diff --git a/src/causalprog/solvers/solver_result.py b/src/causalprog/solvers/solver_result.py index 1d06cff..b7ff7ff 100644 --- a/src/causalprog/solvers/solver_result.py +++ b/src/causalprog/solvers/solver_result.py @@ -14,12 +14,25 @@ class SolverResult: Instances of this class provide a container for useful information that comes out of running one of the solver methods on a causal problem. + + Attributes: + arg_result: Argument to the objective function at final iteration (the solution, + if `successful is `True`). + grad_val: Value of the gradient of the objective function at the `arg_result`. + iters: Number of iterations performed. + maxiter: Maximum number of iterations the solver was permitted to perform. + obj_val: Value of the objective function at `arg_result`. + reason: Human-readable string explaining success or reasons for solver failure. + successful: `True` if solver converged, in which case `arg_result` is the + argument to the objective function at the solution of the problem being + solved. `False` otherwise. + """ - successful: bool - reason: str - parameters: PyTree - obj_val: npt.ArrayLike + arg_result: PyTree grad_val: PyTree iters: int - maxiters: int + maxiter: int + obj_val: npt.ArrayLike + reason: str + successful: bool From 0d9c3898802ff7e3a96605c7c521a502226e7a92 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 24 Sep 2025 08:39:00 +0100 Subject: [PATCH 3/6] Update typehint return type for SGD --- src/causalprog/solvers/sgd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py index 153a9aa..a52c97a 100644 --- a/src/causalprog/solvers/sgd.py +++ b/src/causalprog/solvers/sgd.py @@ -23,7 +23,7 @@ def stochastic_gradient_descent( maxiter: int = 1000, optimiser: optax.GradientTransformationExtraArgs | None = None, tolerance: float = 1.0e-8, -) -> tuple[PyTree, npt.ArrayLike, npt.ArrayLike, int]: +) -> SolverResult: """ Minimise a function of one argument using Stochastic Gradient Descent (SGD). From ce62cea8711a8f59f545c7065f1013af0e4ad042 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 24 Sep 2025 08:42:19 +0100 Subject: [PATCH 4/6] Fix SGD tests --- src/causalprog/solvers/sgd.py | 2 +- tests/test_solvers/test_sgd.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py index a52c97a..fd9c2d3 100644 --- a/src/causalprog/solvers/sgd.py +++ b/src/causalprog/solvers/sgd.py @@ -106,7 +106,7 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: iters_used = _ reason_msg = ( - f"Did not converge after {iters_used} iterations." if not converged else "" + f"Did not converge after {iters_used} iterations" if not converged else "" ) return SolverResult( diff --git a/tests/test_solvers/test_sgd.py b/tests/test_solvers/test_sgd.py index 7d87696..1fc8a04 100644 --- a/tests/test_solvers/test_sgd.py +++ b/tests/test_solvers/test_sgd.py @@ -28,7 +28,7 @@ pytest.param( lambda x: (x**2).sum(), jnp.atleast_1d(10.0), - RuntimeError("Did not converge after 1 iterations"), + "Did not converge after 1 iterations", {"maxiter": 1}, id="Reaches iteration limit", ), @@ -67,8 +67,7 @@ def test_sgd( obj_fn: Callable[[PyTree], npt.ArrayLike], initial_guess: PyTree, kwargs_to_sgd: dict[str, Any], - expected: PyTree | Exception, - raises_context, + expected: PyTree | str, ) -> None: """Test the SGD method on a (deterministic) problem. @@ -79,12 +78,12 @@ def test_sgd( if not kwargs_to_sgd: kwargs_to_sgd = {} - if isinstance(expected, Exception): - with raises_context(expected): - stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd) - else: - result = stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd)[0] + result = stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd) + if isinstance(expected, str): + assert not result.successful + assert result.reason == expected + else: assert jax.tree_util.tree_all( - jax.tree_util.tree_map(jax.numpy.allclose, result, expected) + jax.tree_util.tree_map(jax.numpy.allclose, result.arg_result, expected) ) From 42a6626a8f5c2d1bd1a33d01eac91313008914ee Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 24 Sep 2025 08:50:14 +0100 Subject: [PATCH 5/6] Fix two normal example test --- tests/test_integration/test_two_normal_example.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index c5bcb03..ec449cd 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -107,7 +107,7 @@ def objective(x, key): } l_mult = jnp.atleast_1d(lagrange_mult_sol) - opt_params, _, _, _ = stochastic_gradient_descent( + result = stochastic_gradient_descent( objective, (params, l_mult), convergence_criteria=lambda x, _: jnp.abs(x), @@ -116,8 +116,10 @@ def objective(x, key): maxiter=maxiter, tolerance=minimisation_tolerance, ) + assert result.successful, "SGD did not converge." + # Unpack concatenated arguments - params, l_mult = opt_params + params, l_mult = result.arg_result # The lagrangian is independent of nu_x, thus it should not have changed value. assert jnp.isclose(params["cov2"], nu_x_starting_value), ( From f2a4279d181e2ab376f169356120c5270f78b5fb Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 24 Sep 2025 08:51:06 +0100 Subject: [PATCH 6/6] Rename attribute to more sensible name --- src/causalprog/solvers/sgd.py | 2 +- src/causalprog/solvers/solver_result.py | 10 +++++----- tests/test_integration/test_two_normal_example.py | 2 +- tests/test_solvers/test_sgd.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py index fd9c2d3..141d5e3 100644 --- a/src/causalprog/solvers/sgd.py +++ b/src/causalprog/solvers/sgd.py @@ -110,7 +110,7 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: ) return SolverResult( - arg_result=current_params, + fn_args=current_params, grad_val=gradient_value, iters=iters_used, maxiter=maxiter, diff --git a/src/causalprog/solvers/solver_result.py b/src/causalprog/solvers/solver_result.py index b7ff7ff..eb09457 100644 --- a/src/causalprog/solvers/solver_result.py +++ b/src/causalprog/solvers/solver_result.py @@ -16,20 +16,20 @@ class SolverResult: comes out of running one of the solver methods on a causal problem. Attributes: - arg_result: Argument to the objective function at final iteration (the solution, + fn_args: Argument to the objective function at final iteration (the solution, if `successful is `True`). - grad_val: Value of the gradient of the objective function at the `arg_result`. + grad_val: Value of the gradient of the objective function at the `fn_args`. iters: Number of iterations performed. maxiter: Maximum number of iterations the solver was permitted to perform. - obj_val: Value of the objective function at `arg_result`. + obj_val: Value of the objective function at `fn_args`. reason: Human-readable string explaining success or reasons for solver failure. - successful: `True` if solver converged, in which case `arg_result` is the + successful: `True` if solver converged, in which case `fn_args` is the argument to the objective function at the solution of the problem being solved. `False` otherwise. """ - arg_result: PyTree + fn_args: PyTree grad_val: PyTree iters: int maxiter: int diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index ec449cd..3d7db90 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -119,7 +119,7 @@ def objective(x, key): assert result.successful, "SGD did not converge." # Unpack concatenated arguments - params, l_mult = result.arg_result + params, l_mult = result.fn_args # The lagrangian is independent of nu_x, thus it should not have changed value. assert jnp.isclose(params["cov2"], nu_x_starting_value), ( diff --git a/tests/test_solvers/test_sgd.py b/tests/test_solvers/test_sgd.py index 1fc8a04..f602ee6 100644 --- a/tests/test_solvers/test_sgd.py +++ b/tests/test_solvers/test_sgd.py @@ -85,5 +85,5 @@ def test_sgd( assert result.reason == expected else: assert jax.tree_util.tree_all( - jax.tree_util.tree_map(jax.numpy.allclose, result.arg_result, expected) + jax.tree_util.tree_map(jax.numpy.allclose, result.fn_args, expected) )