diff --git a/src/causalprog/solvers/__init__.py b/src/causalprog/solvers/__init__.py new file mode 100644 index 0000000..953674e --- /dev/null +++ b/src/causalprog/solvers/__init__.py @@ -0,0 +1 @@ +"""Solvers for Causal Problems.""" diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py new file mode 100644 index 0000000..6205c36 --- /dev/null +++ b/src/causalprog/solvers/sgd.py @@ -0,0 +1,106 @@ +"""Minimisation via Stochastic Gradient Descent.""" + +from collections.abc import Callable +from copy import deepcopy + +import jax +import jax.numpy as jnp +import numpy.typing as npt +import optax + +from causalprog.utils.norms import PyTree, l2_normsq + + +def stochastic_gradient_descent( + obj_fn: Callable[[PyTree], npt.ArrayLike], + initial_guess: PyTree, + *, + convergence_criteria: Callable[[PyTree, PyTree], npt.ArrayLike] | None = None, + fn_args: tuple | None = None, + fn_kwargs: dict | None = None, + learning_rate: float = 1.0e-1, + maxiter: int = 1000, + optimiser: optax.GradientTransformationExtraArgs | None = None, + tolerance: float = 1.0e-8, +) -> tuple[PyTree, npt.ArrayLike, npt.ArrayLike, int]: + """ + Minimise a function of one argument using Stochastic Gradient Descent (SGD). + + The `obj_fn` provided will be minimised over its first argument. If you wish to + minimise a function over a different argument, or multiple arguments, wrap it in a + suitable `lambda` expression that has the correct call signature. For example, to + minimise a function `f(x, y, z)` over `y` and `z`, use + `g = lambda yz, x: f(x, yz[0], yz[1])`, and pass `g` in as `obj_fn`. Note that + you will also need to provide a constant value for `x` via `fn_args` or `fn_kwargs`. + + The `fn_args` and `fn_kwargs` keys can be used to supply additional parameters that + need to be passed to `obj_fn`, but which should be held constant. + + SGD terminates when the `convergence_criteria` is found to be smaller than the + `tolerance`. That is, when + `convergence_criteria(objective_value, gradient_value) <= tolerance` is found to + be `True`, the algorithm considers a minimum to have been found. The default + condition under which the algorithm terminates is when the norm of the gradient + at the current argument value is smaller than the provided `tolerance`. + + The optimiser to use can be selected by passing in a suitable `optax` optimiser + via the `optimiser` command. By default, `optax.adams` is used with the supplied + `learning_rate`. Providing an explicit value for `optimiser` will result in the + `learning_rate` argument being ignored. + + Args: + obj_fn: Function to be minimised over its first argument. + initial_guess: Initial guess for the minimising argument. + convergence_criteria: The quantity that will be tested against `tolerance`, to + determine whether the method has converged to a minimum. It should be a + `callable` that takes the current value of `obj_fn` as its 1st argument, and + the current value of the gradient of `obj_fn` as its 2nd argument. The + default criteria is the l2-norm of the gradient. + fn_args: Positional arguments to be passed to `obj_fn`, and held constant. + fn_kwargs: Keyword arguments to be passed to `obj_fn`, and held constant. + learning_rate: Default learning rate (or step size) to use when using the + default `optimiser`. No effect if `optimiser` is provided explicitly. + maxiter: Maximum number of iterations to perform. An error will be reported if + this number of iterations is exceeded. + optimiser: The `optax` optimiser to use during the update step. + tolerance: `tolerance` used when determining if a minimum has been found. + + Returns: + Minimising argument of `obj_fn`. + Value of `obj_fn` at the minimum. + Gradient of `obj_fn` at the minimum. + Number of iterations performed. + + """ + if not fn_args: + fn_args = () + if not fn_kwargs: + fn_kwargs = {} + if not convergence_criteria: + convergence_criteria = lambda _, dx: jnp.sqrt(l2_normsq(dx)) # noqa: E731 + if not optimiser: + optimiser = optax.adam(learning_rate) + + def objective(x: npt.ArrayLike) -> npt.ArrayLike: + return obj_fn(x, *fn_args, **fn_kwargs) + + def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: + return convergence_criteria(x, dx) < tolerance + + gradient = jax.grad(objective) + opt_state = optimiser.init(initial_guess) + + current_params = deepcopy(initial_guess) + gradient_value = gradient(current_params) + for i in range(maxiter): + updates, opt_state = optimiser.update(gradient_value, opt_state) + current_params = optax.apply_updates(current_params, updates) + + objective_value = objective(current_params) + gradient_value = gradient(current_params) + + if is_converged(objective_value, gradient_value): + return current_params, objective_value, gradient_value, i + 1 + + msg = f"Did not converge after {i + 1} iterations." + raise RuntimeError(msg) diff --git a/src/causalprog/utils/norms.py b/src/causalprog/utils/norms.py new file mode 100644 index 0000000..c118583 --- /dev/null +++ b/src/causalprog/utils/norms.py @@ -0,0 +1,18 @@ +"""Misc collection of norm-like functions for PyTree structures.""" + +from typing import TypeVar + +import jax +import numpy.typing as npt + +PyTree = TypeVar("PyTree") + + +def l2_normsq(x: PyTree) -> npt.ArrayLike: + """ + Square of the l2-norm of a PyTree. + + This is effectively "sum(elements**2 in leaf for leaf in x)". + """ + leaves, _ = jax.tree_util.tree_flatten(x) + return sum(jax.numpy.sum(leaf**2) for leaf in leaves) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 0664d76..c5bcb03 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -1,14 +1,14 @@ -import sys from collections.abc import Callable import jax import jax.numpy as jnp -import optax import pytest from causalprog.causal_problem.causal_problem import CausalProblem from causalprog.causal_problem.components import CausalEstimand, Constraint from causalprog.graph import Graph +from causalprog.solvers.sgd import stochastic_gradient_descent +from causalprog.utils.norms import l2_normsq @pytest.mark.parametrize( @@ -92,9 +92,9 @@ def test_two_normal_example( # We'll be seeking stationary points of the Lagrangian, using the # naive approach of minimising the norm of its gradient. We will need to # ensure we "converge" to a minimum value suitably close to 0. - def objective(params, l_mult, key): - v = jax.grad(lagrangian, argnums=(0, 1))(params, l_mult, key) - return sum(value**2 for value in v[0].values()) + (v[1] ** 2).sum() + def objective(x, key): + v = jax.grad(lagrangian, argnums=(0, 1))(*x, rng_key=key) + return l2_normsq(v) # Choose a starting guess that is at the optimal solution, in the hopes that # SGD converges quickly. We almost certainly will not have this luxury in general. @@ -107,31 +107,17 @@ def objective(params, l_mult, key): } l_mult = jnp.atleast_1d(lagrange_mult_sol) - # Setup SGD optimiser - optimiser = optax.adam(adams_learning_rate) - opt_state = optimiser.init((params, l_mult)) - - # Run optimisation loop on gradient of the Lagrangian - converged = False - for _ in range(maxiter): - # Actual iteration loop - grads = jax.jacobian(objective, argnums=(0, 1))(params, l_mult, rng_key) - updates, opt_state = optimiser.update(grads, opt_state) - params, l_mult = optax.apply_updates((params, l_mult), updates) - - # Convergence "check" and progress update - objective_value = objective(params, l_mult, rng_key) - sys.stdout.write( - f"{_}, F_val={objective_value:.4e}, " - f"mu_ux={params['mean']:.4e}, " - f"nu_x={params['cov2']:.4e}, " - f"lambda={l_mult[0]:.4e}\n" - ) - if jnp.abs(objective_value) <= minimisation_tolerance: - converged = True - break - - assert converged, f"Did not converge, final objective value: {objective_value}" + opt_params, _, _, _ = stochastic_gradient_descent( + objective, + (params, l_mult), + convergence_criteria=lambda x, _: jnp.abs(x), + fn_kwargs={"key": rng_key}, + learning_rate=adams_learning_rate, + maxiter=maxiter, + tolerance=minimisation_tolerance, + ) + # Unpack concatenated arguments + params, l_mult = opt_params # The lagrangian is independent of nu_x, thus it should not have changed value. assert jnp.isclose(params["cov2"], nu_x_starting_value), ( diff --git a/tests/test_solvers/test_sgd.py b/tests/test_solvers/test_sgd.py new file mode 100644 index 0000000..7d87696 --- /dev/null +++ b/tests/test_solvers/test_sgd.py @@ -0,0 +1,90 @@ +from collections.abc import Callable +from typing import Any + +import jax +import jax.numpy as jnp +import numpy.typing as npt +import pytest + +from causalprog.solvers.sgd import stochastic_gradient_descent +from causalprog.utils.norms import PyTree + + +@pytest.mark.parametrize( + ( + "obj_fn", + "initial_guess", + "expected", + "kwargs_to_sgd", + ), + [ + pytest.param( + lambda x: (x**2).sum(), + jnp.atleast_1d(1.0), + jnp.atleast_1d(0.0), + None, + id="Deterministic x**2", + ), + pytest.param( + lambda x: (x**2).sum(), + jnp.atleast_1d(10.0), + RuntimeError("Did not converge after 1 iterations"), + {"maxiter": 1}, + id="Reaches iteration limit", + ), + pytest.param( + lambda x: (x**2).sum(), + jnp.atleast_1d(1.0), + jnp.atleast_1d(0.9), + { + "convergence_criteria": lambda x, _: jnp.abs(x.sum()), + "tolerance": 1.0e0, + "learning_rate": 1e-1, + }, + id="Converge on function value less than 1", + ), + pytest.param( + lambda x, a: ((x - a) ** 2).sum(), + jnp.atleast_1d(1.0), + jnp.atleast_1d(2.0), + { + "fn_args": (2.0,), + }, + id="Fix positional argument", + ), + pytest.param( + lambda x, *, a: ((x - a) ** 2).sum(), + jnp.atleast_1d(1.0), + jnp.atleast_1d(2.0), + { + "fn_kwargs": {"a": 2.0}, + }, + id="Fix keyword argument", + ), + ], +) +def test_sgd( + obj_fn: Callable[[PyTree], npt.ArrayLike], + initial_guess: PyTree, + kwargs_to_sgd: dict[str, Any], + expected: PyTree | Exception, + raises_context, +) -> None: + """Test the SGD method on a (deterministic) problem. + + This is just an assurance check that all the components of the method are working + as intended. In each test case, we minimise (a variation of) x**2, changing the + options that we pass to the SGD solver. + """ + if not kwargs_to_sgd: + kwargs_to_sgd = {} + + if isinstance(expected, Exception): + with raises_context(expected): + stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd) + else: + result = stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd)[0] + + assert jax.tree_util.tree_all( + jax.tree_util.tree_map(jax.numpy.allclose, result, expected) + ) diff --git a/tests/test_utils/test_norms.py b/tests/test_utils/test_norms.py new file mode 100644 index 0000000..fa0abe6 --- /dev/null +++ b/tests/test_utils/test_norms.py @@ -0,0 +1,25 @@ +from collections.abc import Callable + +import numpy as np +import pytest + +from causalprog.utils.norms import PyTree, l2_normsq + + +@pytest.mark.parametrize( + ("pt", "norm", "expected_value"), + [ + pytest.param(1.0, l2_normsq, 1.0, id="l2^2, scalar"), + pytest.param( + np.array([1.0, 2.0, 3.0]), l2_normsq, 14.0, id="l2^2, numpy array" + ), + pytest.param( + {"a": 1.0, "b": (np.arange(3), [2.0, (-1.0, 0.0)])}, + l2_normsq, + 1.0 + (np.arange(3) ** 2).sum() + 4.0 + 1.0, + id="l2^2, PyTree", + ), + ], +) +def test_norm_value(pt: PyTree, norm: Callable[[PyTree], float], expected_value: float): + assert np.allclose(norm(pt), expected_value)