From 8d9d76def3b367648f3506855c2e6e1a5bc5a06b Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 3 Sep 2025 11:32:54 +0100 Subject: [PATCH 1/9] Refactor SGD method --- src/causalprog/solvers/__init__.py | 1 + src/causalprog/solvers/sgd.py | 58 ++++++++++++++++++++++++++++++ src/causalprog/utils/norms.py | 18 ++++++++++ 3 files changed, 77 insertions(+) create mode 100644 src/causalprog/solvers/__init__.py create mode 100644 src/causalprog/solvers/sgd.py create mode 100644 src/causalprog/utils/norms.py diff --git a/src/causalprog/solvers/__init__.py b/src/causalprog/solvers/__init__.py new file mode 100644 index 0000000..953674e --- /dev/null +++ b/src/causalprog/solvers/__init__.py @@ -0,0 +1 @@ +"""Solvers for Causal Problems.""" diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py new file mode 100644 index 0000000..1998f29 --- /dev/null +++ b/src/causalprog/solvers/sgd.py @@ -0,0 +1,58 @@ +"""Minimisation via Stochastic Gradient Descent.""" + +from collections.abc import Callable + +import jax +import jax.numpy as jnp +import numpy.typing as npt +import optax + +from causalprog.utils.norms import PyTree, l2_normsq + + +def minimise( + obj_fn: Callable, + initial_guess: npt.ArrayLike, # should be a pytree really + *, + convergence_criteria: Callable[[PyTree, PyTree], npt.ArrayLike] | None, + fn_args: tuple | None = None, + fn_kwargs: dict | None = None, + learning_rate: float = 1.0e-1, + maxiter: int = 100, + optimiser: optax.GradientTransformationExtraArgs | None = None, + tolerance: float = 1.0e-8, +) -> npt.ArrayLike: + """Minimise a function of one argument using Stochastic Gradient Descent.""" + if not fn_args: + fn_args = () + if not fn_kwargs: + fn_kwargs = {} + if not convergence_criteria: + convergence_criteria = lambda _, dx: jnp.sqrt(l2_normsq(dx)) # noqa: E731 + if not optimiser: + optimiser = optax.adam(learning_rate) + + def objective(x: npt.ArrayLike) -> npt.ArrayLike: + return obj_fn(x, *fn_args, **fn_kwargs) + + def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: + return convergence_criteria(x, dx) <= tolerance + + gradient = jax.grad(objective) + + opt_state = optimiser.init(initial_guess) + + current_params = initial_guess.copy(deep=True) + for _ in range(maxiter): + grads = gradient(current_params) + updates, opt_state = optimiser.update(grads, opt_state) + current_params = optax.apply_updates(current_params, updates) + + objective_value = objective(current_params) + gradient_value = gradient(current_params) + + if is_converged(objective_value, gradient_value): + return current_params + + msg = f"Did not converge after {_} iterations." + raise RuntimeError(msg) diff --git a/src/causalprog/utils/norms.py b/src/causalprog/utils/norms.py new file mode 100644 index 0000000..c118583 --- /dev/null +++ b/src/causalprog/utils/norms.py @@ -0,0 +1,18 @@ +"""Misc collection of norm-like functions for PyTree structures.""" + +from typing import TypeVar + +import jax +import numpy.typing as npt + +PyTree = TypeVar("PyTree") + + +def l2_normsq(x: PyTree) -> npt.ArrayLike: + """ + Square of the l2-norm of a PyTree. + + This is effectively "sum(elements**2 in leaf for leaf in x)". + """ + leaves, _ = jax.tree_util.tree_flatten(x) + return sum(jax.numpy.sum(leaf**2) for leaf in leaves) From 39ec3ac955b7134d0aad23b42751597e2decbc3d Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 3 Sep 2025 11:38:58 +0100 Subject: [PATCH 2/9] Test that l2^2 norm is implemented correctly --- tests/test_utils/test_norms.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 tests/test_utils/test_norms.py diff --git a/tests/test_utils/test_norms.py b/tests/test_utils/test_norms.py new file mode 100644 index 0000000..0d40d54 --- /dev/null +++ b/tests/test_utils/test_norms.py @@ -0,0 +1,27 @@ +from collections.abc import Callable + +import numpy as np +import pytest + +from causalprog.utils.norms import PyTree, l2_normsq + + +@pytest.mark.parametrize( + ("pt", "norm", "expected_value"), + [ + pytest.param(1.0, l2_normsq, 1.0, id="l2^2, scalar"), + pytest.param( + np.array([1.0, 2.0, 3.0]), l2_normsq, 14.0, id="l2^2, numpy array" + ), + pytest.param( + {"a": 1.0, "b": (np.arange(3), [2.0, (-1.0, 0.0)])}, + l2_normsq, + 1.0 + (np.arange(3) ** 2).sum() + 4.0 + 1.0, + id="l2^2, PyTree", + ), + ], +) +def test_norm_value( + pt: PyTree, norm: Callable[[PyTree], float], expected_value: float +) -> None: + assert np.allclose(norm(pt), expected_value) From 6affb023a7e534451d7601625041ffb3dfc2c078 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 3 Sep 2025 12:10:56 +0100 Subject: [PATCH 3/9] Write tests for sgd minimiser --- src/causalprog/solvers/sgd.py | 14 +++--- tests/test_solvers/test_sgd.py | 85 ++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 7 deletions(-) create mode 100644 tests/test_solvers/test_sgd.py diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py index 1998f29..e9afe00 100644 --- a/src/causalprog/solvers/sgd.py +++ b/src/causalprog/solvers/sgd.py @@ -10,15 +10,15 @@ from causalprog.utils.norms import PyTree, l2_normsq -def minimise( - obj_fn: Callable, - initial_guess: npt.ArrayLike, # should be a pytree really +def stochastic_gradient_descent( + obj_fn: Callable[[PyTree], npt.ArrayLike], + initial_guess: PyTree, *, - convergence_criteria: Callable[[PyTree, PyTree], npt.ArrayLike] | None, + convergence_criteria: Callable[[PyTree, PyTree], npt.ArrayLike] | None = None, fn_args: tuple | None = None, fn_kwargs: dict | None = None, learning_rate: float = 1.0e-1, - maxiter: int = 100, + maxiter: int = 1000, optimiser: optax.GradientTransformationExtraArgs | None = None, tolerance: float = 1.0e-8, ) -> npt.ArrayLike: @@ -42,7 +42,7 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: opt_state = optimiser.init(initial_guess) - current_params = initial_guess.copy(deep=True) + current_params = initial_guess.copy() for _ in range(maxiter): grads = gradient(current_params) updates, opt_state = optimiser.update(grads, opt_state) @@ -54,5 +54,5 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: if is_converged(objective_value, gradient_value): return current_params - msg = f"Did not converge after {_} iterations." + msg = f"Did not converge after {_ + 1} iterations." raise RuntimeError(msg) diff --git a/tests/test_solvers/test_sgd.py b/tests/test_solvers/test_sgd.py new file mode 100644 index 0000000..597cbc1 --- /dev/null +++ b/tests/test_solvers/test_sgd.py @@ -0,0 +1,85 @@ +from collections.abc import Callable +from typing import Any + +import jax +import jax.numpy as jnp +import numpy.typing as npt +import pytest + +from causalprog.solvers.sgd import stochastic_gradient_descent +from causalprog.utils.norms import PyTree + + +@pytest.mark.parametrize( + ( + "obj_fn", + "initial_guess", + "expected", + "kwargs_to_sgd", + ), + [ + pytest.param( + lambda x: (x**2).sum(), + jnp.atleast_1d(1.0), + jnp.atleast_1d(0.0), + None, + id="Deterministic x**2", + ), + pytest.param( + lambda x: (x**2).sum(), + jnp.atleast_1d(10.0), + RuntimeError("Did not converge after 1 iterations"), + {"maxiter": 1}, + id="Reaches iteration limit", + ), + pytest.param( + lambda x: (x**2).sum(), + jnp.atleast_1d(1.0), + jnp.atleast_1d(0.9), + { + "convergence_criteria": lambda x, _: jnp.abs(x.sum()), + "tolerance": 1.0e0, + "learning_rate": 1e-1, + }, + id="Converge on function value less than 1", + ), + pytest.param( + lambda x, a: ((x - a) ** 2).sum(), + jnp.atleast_1d(1.0), + jnp.atleast_1d(2.0), + { + "fn_args": (2.0,), + }, + id="Fix positional argument", + ), + pytest.param( + lambda x, *, a: ((x - a) ** 2).sum(), + jnp.atleast_1d(1.0), + jnp.atleast_1d(2.0), + { + "fn_kwargs": {"a": 2.0}, + }, + id="Fix keyword argument", + ), + ], +) +def test_sgd( + obj_fn: Callable[[PyTree], npt.ArrayLike], + initial_guess: PyTree, + kwargs_to_sgd: dict[str, Any], + expected: PyTree | Exception, + raises_context, +) -> None: + """""" + if not kwargs_to_sgd: + kwargs_to_sgd = {} + + if isinstance(expected, Exception): + with raises_context(expected): + stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd) + else: + result = stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd) + + assert jax.tree_util.tree_all( + jax.tree_util.tree_map(jax.numpy.allclose, result, expected) + ) From 6d3497f8c02cc767801cef0eb7bdcce127011dd5 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 3 Sep 2025 12:47:04 +0100 Subject: [PATCH 4/9] Refactor integration test to now use SGD method --- src/causalprog/solvers/sgd.py | 56 ++++++++++++++++--- .../test_two_normal_example.py | 43 +++++--------- 2 files changed, 62 insertions(+), 37 deletions(-) diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py index e9afe00..c752e57 100644 --- a/src/causalprog/solvers/sgd.py +++ b/src/causalprog/solvers/sgd.py @@ -1,6 +1,7 @@ """Minimisation via Stochastic Gradient Descent.""" from collections.abc import Callable +from copy import deepcopy import jax import jax.numpy as jnp @@ -21,8 +22,48 @@ def stochastic_gradient_descent( maxiter: int = 1000, optimiser: optax.GradientTransformationExtraArgs | None = None, tolerance: float = 1.0e-8, -) -> npt.ArrayLike: - """Minimise a function of one argument using Stochastic Gradient Descent.""" +) -> tuple[PyTree, npt.ArrayLike, npt.ArrayLike, int]: + """ + Minimise a function of one argument using Stochastic Gradient Descent (SGD). + + The function provided will be minimised over its first argument. The `fn_args` + and `fn_kwargs` keys can be used to supply additional parameters that need to be + passed to `obj_fn`, but which should be held constant. + + SGD terminates when the `convergence_criteria` is found to be smaller than the + `tolerance`. That is, when + `convergence_criteria(objective_value, gradient_value) <= tolerance` is found to + be `True`, the algorithm considers a minimum to have been found. + + The optimiser to use can be selected by passing in a suitable `optax` optimiser + via the `optimiser` command. By default, `optax.adams` is used with the supplied + `learning_rate. Providing a value for `optimiser` will result in the `learning_rate` + argument being ignored. + + Args: + obj_fn: Function to be minimised over its first argument. + initial_guess: Initial guess for the minimising argument. + convergence_criteria: The quantity that will be tested against `tolerance`, to + determine whether the method has converged to a minimum. It should be a + `callable` that takes the current value of `obj_fn` as its 1st argument, and + the current value of the gradient of `obj_fn` as its 2nd argument. The + default criteria is the l2-norm of the gradient. + fn_args: Positional arguments to be passed to `obj_fn`, and held constant. + fn_kwargs: Keyword arguments to be passed to `obj_fn`, and held constant. + learning_rate: Default learning rate (or step size) to use when using the + default `optimiser`. No effect if `optimiser` is provided explicitly. + maxiter: Maximum number of iterations to perform. An error will be reported if + this number of iterations is exceeded. + optimiser: The `optax` optimiser to use during the update step. + tolerance: `tolerance` used when determining if a minimum has been found. + + Returns: + Minimising argument of `obj_fn`. + Value of `obj_fn` at the minimum. + Gradient of `obj_fn` at the minimum. + Number of iterations performed. + + """ if not fn_args: fn_args = () if not fn_kwargs: @@ -36,23 +77,22 @@ def objective(x: npt.ArrayLike) -> npt.ArrayLike: return obj_fn(x, *fn_args, **fn_kwargs) def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: - return convergence_criteria(x, dx) <= tolerance + return convergence_criteria(x, dx) < tolerance gradient = jax.grad(objective) - opt_state = optimiser.init(initial_guess) - current_params = initial_guess.copy() + current_params = deepcopy(initial_guess) + gradient_value = gradient(current_params) for _ in range(maxiter): - grads = gradient(current_params) - updates, opt_state = optimiser.update(grads, opt_state) + updates, opt_state = optimiser.update(gradient_value, opt_state) current_params = optax.apply_updates(current_params, updates) objective_value = objective(current_params) gradient_value = gradient(current_params) if is_converged(objective_value, gradient_value): - return current_params + return current_params, objective_value, gradient_value, _ + 1 msg = f"Did not converge after {_ + 1} iterations." raise RuntimeError(msg) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 0664d76..979fdbc 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -1,14 +1,13 @@ -import sys from collections.abc import Callable import jax import jax.numpy as jnp -import optax import pytest from causalprog.causal_problem.causal_problem import CausalProblem from causalprog.causal_problem.components import CausalEstimand, Constraint from causalprog.graph import Graph +from causalprog.solvers.sgd import stochastic_gradient_descent @pytest.mark.parametrize( @@ -92,8 +91,8 @@ def test_two_normal_example( # We'll be seeking stationary points of the Lagrangian, using the # naive approach of minimising the norm of its gradient. We will need to # ensure we "converge" to a minimum value suitably close to 0. - def objective(params, l_mult, key): - v = jax.grad(lagrangian, argnums=(0, 1))(params, l_mult, key) + def objective(x, key): + v = jax.grad(lagrangian, argnums=(0, 1))(*x, rng_key=key) return sum(value**2 for value in v[0].values()) + (v[1] ** 2).sum() # Choose a starting guess that is at the optimal solution, in the hopes that @@ -107,31 +106,17 @@ def objective(params, l_mult, key): } l_mult = jnp.atleast_1d(lagrange_mult_sol) - # Setup SGD optimiser - optimiser = optax.adam(adams_learning_rate) - opt_state = optimiser.init((params, l_mult)) - - # Run optimisation loop on gradient of the Lagrangian - converged = False - for _ in range(maxiter): - # Actual iteration loop - grads = jax.jacobian(objective, argnums=(0, 1))(params, l_mult, rng_key) - updates, opt_state = optimiser.update(grads, opt_state) - params, l_mult = optax.apply_updates((params, l_mult), updates) - - # Convergence "check" and progress update - objective_value = objective(params, l_mult, rng_key) - sys.stdout.write( - f"{_}, F_val={objective_value:.4e}, " - f"mu_ux={params['mean']:.4e}, " - f"nu_x={params['cov2']:.4e}, " - f"lambda={l_mult[0]:.4e}\n" - ) - if jnp.abs(objective_value) <= minimisation_tolerance: - converged = True - break - - assert converged, f"Did not converge, final objective value: {objective_value}" + opt_params, _, _, _ = stochastic_gradient_descent( + objective, + (params, l_mult), + convergence_criteria=lambda x, _: jnp.abs(x), + fn_kwargs={"key": rng_key}, + learning_rate=adams_learning_rate, + maxiter=maxiter, + tolerance=minimisation_tolerance, + ) + # Unpack concatenated arguments + params, l_mult = opt_params # The lagrangian is independent of nu_x, thus it should not have changed value. assert jnp.isclose(params["cov2"], nu_x_starting_value), ( From ca8e3c317be9f390af9fabf4bc82be7c2ecc709c Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 3 Sep 2025 14:02:37 +0100 Subject: [PATCH 5/9] Use actual norm-function rather than hacky piece-together --- tests/test_integration/test_two_normal_example.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 979fdbc..c5bcb03 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -8,6 +8,7 @@ from causalprog.causal_problem.components import CausalEstimand, Constraint from causalprog.graph import Graph from causalprog.solvers.sgd import stochastic_gradient_descent +from causalprog.utils.norms import l2_normsq @pytest.mark.parametrize( @@ -93,7 +94,7 @@ def test_two_normal_example( # ensure we "converge" to a minimum value suitably close to 0. def objective(x, key): v = jax.grad(lagrangian, argnums=(0, 1))(*x, rng_key=key) - return sum(value**2 for value in v[0].values()) + (v[1] ** 2).sum() + return l2_normsq(v) # Choose a starting guess that is at the optimal solution, in the hopes that # SGD converges quickly. We almost certainly will not have this luxury in general. From 831d5619d6f17d2ed5a6171e808b1d74d819dd9d Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 3 Sep 2025 14:16:03 +0100 Subject: [PATCH 6/9] Tidy docstring --- src/causalprog/solvers/sgd.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py index c752e57..bc48702 100644 --- a/src/causalprog/solvers/sgd.py +++ b/src/causalprog/solvers/sgd.py @@ -26,9 +26,15 @@ def stochastic_gradient_descent( """ Minimise a function of one argument using Stochastic Gradient Descent (SGD). - The function provided will be minimised over its first argument. The `fn_args` - and `fn_kwargs` keys can be used to supply additional parameters that need to be - passed to `obj_fn`, but which should be held constant. + The `obj_fn` provided will be minimised over its first argument. If you wish to + minimise a function over a different argument, or multiple arguments, wrap it in a + suitable `lambda` expression that has the correct call signature. For example, to + minimise a function `f(x, y, z)` over `y` and `z`, use + `g = lambda yz, x: f(x, yz[0], yz[1])`, and pass `g` in as `obj_fn`. Note that + you will also need to provide a constant value for `x` via `fn_args` or `fn_kwargs`. + + The `fn_args` and `fn_kwargs` keys can be used to supply additional parameters that + need to be passed to `obj_fn`, but which should be held constant. SGD terminates when the `convergence_criteria` is found to be smaller than the `tolerance`. That is, when @@ -37,8 +43,8 @@ def stochastic_gradient_descent( The optimiser to use can be selected by passing in a suitable `optax` optimiser via the `optimiser` command. By default, `optax.adams` is used with the supplied - `learning_rate. Providing a value for `optimiser` will result in the `learning_rate` - argument being ignored. + `learning_rate`. Providing an explicit value for `optimiser` will result in the + `learning_rate` argument being ignored. Args: obj_fn: Function to be minimised over its first argument. From 6d933a774e5ede0425759da9dfe48cc3fba1799d Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 3 Sep 2025 14:17:22 +0100 Subject: [PATCH 7/9] Qualify convergence condition default --- src/causalprog/solvers/sgd.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py index bc48702..f1029cc 100644 --- a/src/causalprog/solvers/sgd.py +++ b/src/causalprog/solvers/sgd.py @@ -39,7 +39,9 @@ def stochastic_gradient_descent( SGD terminates when the `convergence_criteria` is found to be smaller than the `tolerance`. That is, when `convergence_criteria(objective_value, gradient_value) <= tolerance` is found to - be `True`, the algorithm considers a minimum to have been found. + be `True`, the algorithm considers a minimum to have been found. The default + condition under which the algorithm terminates is when the norm of the gradient + at the current argument value is smaller than the provided `tolerance`. The optimiser to use can be selected by passing in a suitable `optax` optimiser via the `optimiser` command. By default, `optax.adams` is used with the supplied From 1e1a29c1045bbf2e05cf63231f2e25077560bd7a Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 3 Sep 2025 14:21:11 +0100 Subject: [PATCH 8/9] Drop extra SGD info in test return --- tests/test_solvers/test_sgd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_solvers/test_sgd.py b/tests/test_solvers/test_sgd.py index 597cbc1..5a8b236 100644 --- a/tests/test_solvers/test_sgd.py +++ b/tests/test_solvers/test_sgd.py @@ -78,7 +78,7 @@ def test_sgd( with raises_context(expected): stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd) else: - result = stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd) + result = stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd)[0] assert jax.tree_util.tree_all( jax.tree_util.tree_map(jax.numpy.allclose, result, expected) From a560dff1aeff0fd5a0884dbdc3567cced42a3642 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 3 Sep 2025 15:50:19 +0100 Subject: [PATCH 9/9] Apply code review suggestions --- src/causalprog/solvers/sgd.py | 6 +++--- tests/test_solvers/test_sgd.py | 7 ++++++- tests/test_utils/test_norms.py | 4 +--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py index f1029cc..6205c36 100644 --- a/src/causalprog/solvers/sgd.py +++ b/src/causalprog/solvers/sgd.py @@ -92,7 +92,7 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: current_params = deepcopy(initial_guess) gradient_value = gradient(current_params) - for _ in range(maxiter): + for i in range(maxiter): updates, opt_state = optimiser.update(gradient_value, opt_state) current_params = optax.apply_updates(current_params, updates) @@ -100,7 +100,7 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: gradient_value = gradient(current_params) if is_converged(objective_value, gradient_value): - return current_params, objective_value, gradient_value, _ + 1 + return current_params, objective_value, gradient_value, i + 1 - msg = f"Did not converge after {_ + 1} iterations." + msg = f"Did not converge after {i + 1} iterations." raise RuntimeError(msg) diff --git a/tests/test_solvers/test_sgd.py b/tests/test_solvers/test_sgd.py index 5a8b236..7d87696 100644 --- a/tests/test_solvers/test_sgd.py +++ b/tests/test_solvers/test_sgd.py @@ -70,7 +70,12 @@ def test_sgd( expected: PyTree | Exception, raises_context, ) -> None: - """""" + """Test the SGD method on a (deterministic) problem. + + This is just an assurance check that all the components of the method are working + as intended. In each test case, we minimise (a variation of) x**2, changing the + options that we pass to the SGD solver. + """ if not kwargs_to_sgd: kwargs_to_sgd = {} diff --git a/tests/test_utils/test_norms.py b/tests/test_utils/test_norms.py index 0d40d54..fa0abe6 100644 --- a/tests/test_utils/test_norms.py +++ b/tests/test_utils/test_norms.py @@ -21,7 +21,5 @@ ), ], ) -def test_norm_value( - pt: PyTree, norm: Callable[[PyTree], float], expected_value: float -) -> None: +def test_norm_value(pt: PyTree, norm: Callable[[PyTree], float], expected_value: float): assert np.allclose(norm(pt), expected_value)