Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions src/causalprog/solvers/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy.typing as npt
import optax

from causalprog.solvers.solver_result import SolverResult
from causalprog.utils.norms import PyTree, l2_normsq


Expand All @@ -22,7 +23,7 @@ def stochastic_gradient_descent(
maxiter: int = 1000,
optimiser: optax.GradientTransformationExtraArgs | None = None,
tolerance: float = 1.0e-8,
) -> tuple[PyTree, npt.ArrayLike, npt.ArrayLike, int]:
) -> SolverResult:
"""
Minimise a function of one argument using Stochastic Gradient Descent (SGD).

Expand Down Expand Up @@ -87,20 +88,33 @@ def objective(x: npt.ArrayLike) -> npt.ArrayLike:
def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool:
return convergence_criteria(x, dx) < tolerance

gradient = jax.grad(objective)
opt_state = optimiser.init(initial_guess)
converged = False

opt_state = optimiser.init(initial_guess)
current_params = deepcopy(initial_guess)
gradient_value = gradient(current_params)
for i in range(maxiter):
updates, opt_state = optimiser.update(gradient_value, opt_state)
current_params = optax.apply_updates(current_params, updates)
gradient = jax.grad(objective)

for _ in range(maxiter + 1):
objective_value = objective(current_params)
gradient_value = gradient(current_params)

if is_converged(objective_value, gradient_value):
return current_params, objective_value, gradient_value, i + 1
if converged := is_converged(objective_value, gradient_value):
break

updates, opt_state = optimiser.update(gradient_value, opt_state)
current_params = optax.apply_updates(current_params, updates)

msg = f"Did not converge after {i + 1} iterations."
raise RuntimeError(msg)
iters_used = _
reason_msg = (
f"Did not converge after {iters_used} iterations" if not converged else ""
)

return SolverResult(
fn_args=current_params,
grad_val=gradient_value,
iters=iters_used,
maxiter=maxiter,
obj_val=objective_value,
reason=reason_msg,
successful=converged,
)
38 changes: 38 additions & 0 deletions src/causalprog/solvers/solver_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Container class for outputs from solver methods."""

from dataclasses import dataclass

import numpy.typing as npt

from causalprog.utils.norms import PyTree


@dataclass(frozen=True)
class SolverResult:
"""
Container class for outputs from solver methods.

Instances of this class provide a container for useful information that
comes out of running one of the solver methods on a causal problem.

Attributes:
fn_args: Argument to the objective function at final iteration (the solution,
if `successful is `True`).
grad_val: Value of the gradient of the objective function at the `fn_args`.
iters: Number of iterations performed.
maxiter: Maximum number of iterations the solver was permitted to perform.
obj_val: Value of the objective function at `fn_args`.
reason: Human-readable string explaining success or reasons for solver failure.
successful: `True` if solver converged, in which case `fn_args` is the
argument to the objective function at the solution of the problem being
solved. `False` otherwise.

"""

fn_args: PyTree
grad_val: PyTree
iters: int
maxiter: int
obj_val: npt.ArrayLike
reason: str
successful: bool
6 changes: 4 additions & 2 deletions tests/test_integration/test_two_normal_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def objective(x, key):
}
l_mult = jnp.atleast_1d(lagrange_mult_sol)

opt_params, _, _, _ = stochastic_gradient_descent(
result = stochastic_gradient_descent(
objective,
(params, l_mult),
convergence_criteria=lambda x, _: jnp.abs(x),
Expand All @@ -116,8 +116,10 @@ def objective(x, key):
maxiter=maxiter,
tolerance=minimisation_tolerance,
)
assert result.successful, "SGD did not converge."

# Unpack concatenated arguments
params, l_mult = opt_params
params, l_mult = result.fn_args

# The lagrangian is independent of nu_x, thus it should not have changed value.
assert jnp.isclose(params["cov2"], nu_x_starting_value), (
Expand Down
17 changes: 8 additions & 9 deletions tests/test_solvers/test_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
pytest.param(
lambda x: (x**2).sum(),
jnp.atleast_1d(10.0),
RuntimeError("Did not converge after 1 iterations"),
"Did not converge after 1 iterations",
{"maxiter": 1},
id="Reaches iteration limit",
),
Expand Down Expand Up @@ -67,8 +67,7 @@ def test_sgd(
obj_fn: Callable[[PyTree], npt.ArrayLike],
initial_guess: PyTree,
kwargs_to_sgd: dict[str, Any],
expected: PyTree | Exception,
raises_context,
expected: PyTree | str,
) -> None:
"""Test the SGD method on a (deterministic) problem.

Expand All @@ -79,12 +78,12 @@ def test_sgd(
if not kwargs_to_sgd:
kwargs_to_sgd = {}

if isinstance(expected, Exception):
with raises_context(expected):
stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd)
else:
result = stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd)[0]
result = stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd)

if isinstance(expected, str):
assert not result.successful
assert result.reason == expected
else:
assert jax.tree_util.tree_all(
jax.tree_util.tree_map(jax.numpy.allclose, result, expected)
jax.tree_util.tree_map(jax.numpy.allclose, result.fn_args, expected)
)
Loading