From f945a14a9ba0d8d3eee3e54ae3d74ab97ea573fe Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 13 Aug 2025 16:03:05 +0100 Subject: [PATCH 1/7] Barebones integration test --- .../test_two_normal_example.py | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 tests/test_integration/test_two_normal_example.py diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py new file mode 100644 index 0000000..04040b5 --- /dev/null +++ b/tests/test_integration/test_two_normal_example.py @@ -0,0 +1,107 @@ +import sys +from collections.abc import Callable + +import jax +import jax.numpy as jnp +import numpy.typing as npt +import optax +from numpyro.infer import Predictive + +from causalprog.graph import Graph + + +def test_two_normal_example( + rng_key: jax.Array, + two_normal_graph_parametrized_mean: Callable[[], Graph], + n_samples: int = 500, # 1000 causes LLVM memory error... check cleanup of mem + phi_observed: float = 0.0, + epsilon: float = 1.0, + nu_y_starting_value: float = 1.0, + lagrange_mult_sol: float = 1.0, # Solution value of the lagrange multiplier + maxiter: int = 200, + minimisation_tolerance: float = 1.0e-6, +): + g = two_normal_graph_parametrized_mean() + predictive_model = Predictive(g.model, num_samples=n_samples) + + def lagrangian( + parameter_values: dict[str, npt.ArrayLike], + predictive_model: Predictive, + rng_key: jax.Array, + ): + subkeys = jax.random.split(rng_key, predictive_model.num_samples) + l_mult = parameter_values["_l_mult"] + + def _x_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: + return predictive_model(key, **pv)["X"] + + def _ce(pv, subkeys): + return jax.vmap(_x_sampler, in_axes=(None, 0))(pv, subkeys).mean() + + def _ux_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: + return predictive_model(key, **pv)["UX"] + + def _constraint(pv, subkeys): + return ( + jnp.abs( + jax.vmap(_ux_sampler, in_axes=(None, 0))(pv, subkeys).mean() + - phi_observed + ) + - epsilon + ) + + return _ce(parameter_values, subkeys) + l_mult * _constraint( + parameter_values, subkeys + ) + + # objective is euclidean norm of the gradient of the lagrangian + def objective(params, predictive, key): + v = jax.grad(lagrangian)(params, predictive, key) + return sum(value**2 for value in v.values()) + + # Try starting close to the optimal parameter values + mu_x_sol = phi_observed - epsilon + nu_y_sol = nu_y_starting_value # nu_y is free - it does not affect the outcome + lambda_sol = lagrange_mult_sol + params = { + "mu_x": mu_x_sol, + "nu_y": nu_y_sol, + "_l_mult": lambda_sol, + } + + # Setup optimiser + adams_learning_rate = 1.0e-1 + optimiser = optax.adam(adams_learning_rate) + opt_state = optimiser.init(params) + + converged = False + for i in range(maxiter): + sys.stdout.write(f"{i}, ") + # Actual iteration loop + grads = jax.jacobian(objective)(params, predictive_model, rng_key) + updates, opt_state = optimiser.update(grads, opt_state) + params = optax.apply_updates(params, updates) + + # Early break if needed + objective_value = objective(params, predictive_model, rng_key) + if jnp.abs(objective_value) <= minimisation_tolerance: + converged = True + sys.stdout.write("CONVERGED - ") + break + sys.stdout.write("END ITERATIONS\n") + + assert jnp.isclose(nu_y_starting_value, params["nu_y"]), ( + "nu_y value has changed, despite gradient being independent of it" + ) + assert converged, f"Did not converge, final objective value: {objective_value}" + + sys.stdout.write( + f"Converged at: mu_x={params['mu_x']:.5e}, nu_y={params['nu_y']:.5e}" + ) + + assert params["_l_mult"] > 0.0, ( + f"Converged, but not to a minimiser (lagrange multiplier = {params['_l_mult']})" + ) + rtol = jnp.sqrt(1.0 / n_samples) + assert jnp.isclose(params["mu_x"], mu_x_sol, rtol=rtol) + assert jnp.isclose(params["_l_mult"], lagrange_mult_sol, atol=rtol) From 85b63948e6405bdb1ff3aa8b36864fc65b113b11 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 13 Aug 2025 16:05:58 +0100 Subject: [PATCH 2/7] Add optax to test requirements --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 9596f2c..477cdee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ optional-dependencies = {dev = [ "mkdocstrings-python", ], test = [ "numpy", + "optax", "pytest", "pytest-cov", ]} From 1fd63eff6426612032b572d698451c534e135aa7 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 13 Aug 2025 16:16:41 +0100 Subject: [PATCH 3/7] Docstring to actually give some help about the method --- .../test_two_normal_example.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 04040b5..48018d6 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -21,6 +21,38 @@ def test_two_normal_example( maxiter: int = 200, minimisation_tolerance: float = 1.0e-6, ): + """Solves the 'two normal' graph example problem. + + Assume we have the following model: + mu_x -> X ~ N(mu_x, 1.0) + | + v + nu_y -> Y ~ N(X, nu_y) + + and are interested in the causal estimand + + sigma(mu_x, nu_y) = E[Y] = mu_x, + + with constraints + + phi(mu_x, nu_y) = E[X] = mu_x. + + With observed data phi_observed, and tolerance in the data epsilon, we are + effectively looking to solve the minimisation problem; + + min_{mu_x, nu_y} mu_x, subject to |mu_x - phi_observed| <= epsilon. + + The solution to this is mu_x^* = mu_x - phi_observed. The value of nu_y can be any + positive value. + + The corresponding Lagrangian that we will form will be + + L(mu_x, nu_y, l_mult) = mu_x + l_mult * (|mu_x - phi_observed| - epsilon) + + which has stationary points when mu_x = mu_x^* and l_mult = +/ 1. + + TODO: solve max problem too....? + """ g = two_normal_graph_parametrized_mean() predictive_model = Predictive(g.model, num_samples=n_samples) From 1cc34bbd84c5eaef563ba3b1115ee80c3a2c0c44 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 13 Aug 2025 16:21:04 +0100 Subject: [PATCH 4/7] Typing --- tests/test_integration/test_two_normal_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 48018d6..59b62e3 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -86,7 +86,7 @@ def _constraint(pv, subkeys): parameter_values, subkeys ) - # objective is euclidean norm of the gradient of the lagrangian + # Objective is euclidean norm of the gradient of the lagrangian def objective(params, predictive, key): v = jax.grad(lagrangian)(params, predictive, key) return sum(value**2 for value in v.values()) From 0ef4eb6d8cb8cdfafab8e86453367ceea46a841f Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 14 Aug 2025 10:21:57 +0100 Subject: [PATCH 5/7] Add progress statements so I can debug --- .../test_two_normal_example.py | 92 +++++++++++++------ 1 file changed, 65 insertions(+), 27 deletions(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 59b62e3..af65f83 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -5,21 +5,33 @@ import jax.numpy as jnp import numpy.typing as npt import optax +import pytest from numpyro.infer import Predictive from causalprog.graph import Graph +@pytest.mark.parametrize( + "is_solving_max", + [ + pytest.param(False, id="Minimise"), + pytest.param(True, id="Maximise"), + ], +) def test_two_normal_example( rng_key: jax.Array, two_normal_graph_parametrized_mean: Callable[[], Graph], + adams_learning_rate: float = 1.0e-1, n_samples: int = 500, # 1000 causes LLVM memory error... check cleanup of mem - phi_observed: float = 0.0, - epsilon: float = 1.0, - nu_y_starting_value: float = 1.0, + phi_observed: float = 0.0, # The observed data + epsilon: float = 1.0, # The tolerance in the observed data + nu_y_starting_value: float = 1.0, # Where to start nu_y, the independent parameter lagrange_mult_sol: float = 1.0, # Solution value of the lagrange multiplier - maxiter: int = 200, + maxiter: int = 100, # Max iterations to allow (~100 sufficient for test cases) + # Threshold for minimisation function value being considered 0 minimisation_tolerance: float = 1.0e-6, + *, + is_solving_max: bool, ): """Solves the 'two normal' graph example problem. @@ -42,17 +54,17 @@ def test_two_normal_example( min_{mu_x, nu_y} mu_x, subject to |mu_x - phi_observed| <= epsilon. - The solution to this is mu_x^* = mu_x - phi_observed. The value of nu_y can be any - positive value. + The solution to this is mu_x^* = mu_x +/- phi_observed (+ in the maximisation case). + The value of nu_y can be any positive value. The corresponding Lagrangian that we will form will be - L(mu_x, nu_y, l_mult) = mu_x + l_mult * (|mu_x - phi_observed| - epsilon) + L(mu_x, nu_y, l_mult) = +/- mu_x + l_mult * (|mu_x - phi_observed| - epsilon) - which has stationary points when mu_x = mu_x^* and l_mult = +/ 1. - - TODO: solve max problem too....? + (again with + in the max case). In both cases, this is minimised at + L(mu_x^*, nu_y, 1). """ + # Setup the optimisation problem from the graph g = two_normal_graph_parametrized_mean() predictive_model = Predictive(g.model, num_samples=n_samples) @@ -60,6 +72,8 @@ def lagrangian( parameter_values: dict[str, npt.ArrayLike], predictive_model: Predictive, rng_key: jax.Array, + *, + ce_prefactor: float, ): subkeys = jax.random.split(rng_key, predictive_model.num_samples) l_mult = parameter_values["_l_mult"] @@ -68,7 +82,10 @@ def _x_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: return predictive_model(key, **pv)["X"] def _ce(pv, subkeys): - return jax.vmap(_x_sampler, in_axes=(None, 0))(pv, subkeys).mean() + return ( + ce_prefactor + * jax.vmap(_x_sampler, in_axes=(None, 0))(pv, subkeys).mean() + ) def _ux_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: return predictive_model(key, **pv)["UX"] @@ -86,54 +103,75 @@ def _constraint(pv, subkeys): parameter_values, subkeys ) - # Objective is euclidean norm of the gradient of the lagrangian - def objective(params, predictive, key): - v = jax.grad(lagrangian)(params, predictive, key) + # In both cases, the Lagrange multiplier has the value 1.0 at the minimum. + lambda_sol = 1.0 + ce_prefactor = 1.0 if not is_solving_max else -1.0 + mu_x_sol = phi_observed - ce_prefactor * epsilon + + # We'll be seeking stationary points of the Lagrangian, using the + # naive approach of minimising the norm of its gradient. We will need to + # ensure we "converge" to a minimum value suitably close to 0. + def objective(params, predictive, key, ce_prefactor=ce_prefactor): + v = jax.grad(lagrangian)(params, predictive, key, ce_prefactor=ce_prefactor) return sum(value**2 for value in v.values()) - # Try starting close to the optimal parameter values - mu_x_sol = phi_observed - epsilon - nu_y_sol = nu_y_starting_value # nu_y is free - it does not affect the outcome - lambda_sol = lagrange_mult_sol + # Choose a starting guess that is at the optimal solution, in the hopes that + # SGD converges quickly. We almost certainly will not have this luxury in general. + # The value of nu_y is free; the Lagrangian is independent of it. + # As such, it can take any value and should not change during the optimisation + # iterations. params = { "mu_x": mu_x_sol, - "nu_y": nu_y_sol, + "nu_y": nu_y_starting_value, "_l_mult": lambda_sol, } - - # Setup optimiser - adams_learning_rate = 1.0e-1 + # Setup SGD optimiser optimiser = optax.adam(adams_learning_rate) opt_state = optimiser.init(params) converged = False for i in range(maxiter): - sys.stdout.write(f"{i}, ") # Actual iteration loop - grads = jax.jacobian(objective)(params, predictive_model, rng_key) + grads = jax.jacobian(objective)( + params, predictive_model, rng_key, ce_prefactor=ce_prefactor + ) updates, opt_state = optimiser.update(grads, opt_state) params = optax.apply_updates(params, updates) - # Early break if needed - objective_value = objective(params, predictive_model, rng_key) + # Convergence "check" and progress update + objective_value = objective( + params, predictive_model, rng_key, ce_prefactor=ce_prefactor + ) + + sys.stdout.write( + f"\n\t{i}, F_val={objective_value:.2e}, " + f"mu_x={params['mu_x']:.3e}, l_mult={params['_l_mult']:.3e}" + ) + if jnp.abs(objective_value) <= minimisation_tolerance: converged = True sys.stdout.write("CONVERGED - ") break + sys.stdout.write("END ITERATIONS\n") + # Confirm that nu_y has not changed, being an independent variable. assert jnp.isclose(nu_y_starting_value, params["nu_y"]), ( "nu_y value has changed, despite gradient being independent of it" ) assert converged, f"Did not converge, final objective value: {objective_value}" sys.stdout.write( - f"Converged at: mu_x={params['mu_x']:.5e}, nu_y={params['nu_y']:.5e}" + f"Converged at: mu_x={params['mu_x']:.5e}, l_mult={params['_l_mult']:.5e}" ) + # Confirm that we found a minimiser that does satisfy the inequality constraints. assert params["_l_mult"] > 0.0, ( f"Converged, but not to a minimiser (lagrange multiplier = {params['_l_mult']})" ) + + # Give a generous error margin in mu_x and the Lagrange multiplier, + # given SGD is being used on MC-integral functions. rtol = jnp.sqrt(1.0 / n_samples) assert jnp.isclose(params["mu_x"], mu_x_sol, rtol=rtol) assert jnp.isclose(params["_l_mult"], lagrange_mult_sol, atol=rtol) From 13ca8ece137b3aca54743f6d4a78a6f25a159bde Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 14 Aug 2025 10:22:48 +0100 Subject: [PATCH 6/7] Satisfied with answers, purge prints --- .../test_integration/test_two_normal_example.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index af65f83..290d875 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -1,4 +1,3 @@ -import sys from collections.abc import Callable import jax @@ -22,7 +21,7 @@ def test_two_normal_example( rng_key: jax.Array, two_normal_graph_parametrized_mean: Callable[[], Graph], adams_learning_rate: float = 1.0e-1, - n_samples: int = 500, # 1000 causes LLVM memory error... check cleanup of mem + n_samples: int = 500, phi_observed: float = 0.0, # The observed data epsilon: float = 1.0, # The tolerance in the observed data nu_y_starting_value: float = 1.0, # Where to start nu_y, the independent parameter @@ -130,7 +129,7 @@ def objective(params, predictive, key, ce_prefactor=ce_prefactor): opt_state = optimiser.init(params) converged = False - for i in range(maxiter): + for _ in range(maxiter): # Actual iteration loop grads = jax.jacobian(objective)( params, predictive_model, rng_key, ce_prefactor=ce_prefactor @@ -143,28 +142,16 @@ def objective(params, predictive, key, ce_prefactor=ce_prefactor): params, predictive_model, rng_key, ce_prefactor=ce_prefactor ) - sys.stdout.write( - f"\n\t{i}, F_val={objective_value:.2e}, " - f"mu_x={params['mu_x']:.3e}, l_mult={params['_l_mult']:.3e}" - ) - if jnp.abs(objective_value) <= minimisation_tolerance: converged = True - sys.stdout.write("CONVERGED - ") break - sys.stdout.write("END ITERATIONS\n") - # Confirm that nu_y has not changed, being an independent variable. assert jnp.isclose(nu_y_starting_value, params["nu_y"]), ( "nu_y value has changed, despite gradient being independent of it" ) assert converged, f"Did not converge, final objective value: {objective_value}" - sys.stdout.write( - f"Converged at: mu_x={params['mu_x']:.5e}, l_mult={params['_l_mult']:.5e}" - ) - # Confirm that we found a minimiser that does satisfy the inequality constraints. assert params["_l_mult"] > 0.0, ( f"Converged, but not to a minimiser (lagrange multiplier = {params['_l_mult']})" From f234cb75090d237a44ebd89e1cb1ef9d2069bada Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 14 Aug 2025 14:51:44 +0100 Subject: [PATCH 7/7] Adapt to changes from #70, which also fixes a bug with missing edges in the original graph --- .../test_two_normal_example.py | 66 ++++++++++--------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 290d875..5283fb1 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -19,12 +19,12 @@ ) def test_two_normal_example( rng_key: jax.Array, - two_normal_graph_parametrized_mean: Callable[[], Graph], + two_normal_graph: Callable[..., Graph], adams_learning_rate: float = 1.0e-1, n_samples: int = 500, phi_observed: float = 0.0, # The observed data epsilon: float = 1.0, # The tolerance in the observed data - nu_y_starting_value: float = 1.0, # Where to start nu_y, the independent parameter + nu_x_starting_value: float = 1.0, # Where to start nu_x in the solver initial guess lagrange_mult_sol: float = 1.0, # Solution value of the lagrange multiplier maxiter: int = 100, # Max iterations to allow (~100 sufficient for test cases) # Threshold for minimisation function value being considered 0 @@ -32,39 +32,47 @@ def test_two_normal_example( *, is_solving_max: bool, ): - """Solves the 'two normal' graph example problem. + r"""Solves the 'two normal' graph example problem. - Assume we have the following model: - mu_x -> X ~ N(mu_x, 1.0) - | - v - nu_y -> Y ~ N(X, nu_y) + We use the `two_normal_graph` with `cov=1.0`. For the purposes of this test, we will + write $\mu_{ux}$ for the parameter `mean`, and $\nu_{x}$ for the parameter `cov2`, + giving us the following model: - and are interested in the causal estimand + $$ + \mu_{ux} \rightarrow UX \sim \mathcal{N}(\mu_{ux}, 1.0) + \rightarrow X, X \vert UX \sim \mathcal{N}(UX, \nu_{x}) + \leftarrow \nu_{x}. + $$ - sigma(mu_x, nu_y) = E[Y] = mu_x, + We will be interested in the causal estimand - with constraints + $$ \sigma(\mu_{ux}, \nu_{x}) = \mathbb{E}[X] = \mu_{ux}, $$ - phi(mu_x, nu_y) = E[X] = mu_x. + with observed data (constraints) - With observed data phi_observed, and tolerance in the data epsilon, we are + $$ \phi(\mu_{ux}, \nu_{x}) = \mathbb{E}[UX] = \mu_{ux}. $$ + + With observed data $\phi_{obs}$, and tolerance in the data $\epsilon$, we are effectively looking to solve the minimisation problem; - min_{mu_x, nu_y} mu_x, subject to |mu_x - phi_observed| <= epsilon. + $$ \mathrm{min}_{\mu_{ux}, \nu_{x}} \mu_{ux}, \quad + \text{subject to } \vert \mu_{ux} - \phi_{obs} \vert \leq \epsilon. + $$ - The solution to this is mu_x^* = mu_x +/- phi_observed (+ in the maximisation case). - The value of nu_y can be any positive value. + The solution to this is $\mu_{ux}^{*} = \mu_{ux} \pm \phi_{obs}$ ($+$ in the + maximisation case). The value of $\nu_{x}$ can be any positive value, since in this + setup both $\phi$ and $\sigma$ are independent of it. The corresponding Lagrangian that we will form will be - L(mu_x, nu_y, l_mult) = +/- mu_x + l_mult * (|mu_x - phi_observed| - epsilon) + $$ \mathcal{L}(\mu_{ux}, \nu_{x}, \lambda) = \pm \mu_{ux} + + \lambda(\vert \mu_{ux} - \phi_{obs} \vert - \epsilon), $$ - (again with + in the max case). In both cases, this is minimised at - L(mu_x^*, nu_y, 1). + (again with $+\mu_{ux}$ in the maximisation case). In both cases, $\mathcal{L}$ is + minimised at $(\mu_{ux}^{*}, \nu_x, 1)$. """ # Setup the optimisation problem from the graph - g = two_normal_graph_parametrized_mean() + g = two_normal_graph(cov=1.0) predictive_model = Predictive(g.model, num_samples=n_samples) def lagrangian( @@ -78,7 +86,7 @@ def lagrangian( l_mult = parameter_values["_l_mult"] def _x_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: - return predictive_model(key, **pv)["X"] + return predictive_model(key, **pv)["UX"] def _ce(pv, subkeys): return ( @@ -87,7 +95,7 @@ def _ce(pv, subkeys): ) def _ux_sampler(pv: dict[str, npt.ArrayLike], key: jax.Array) -> float: - return predictive_model(key, **pv)["UX"] + return predictive_model(key, **pv)["X"] def _constraint(pv, subkeys): return ( @@ -116,12 +124,12 @@ def objective(params, predictive, key, ce_prefactor=ce_prefactor): # Choose a starting guess that is at the optimal solution, in the hopes that # SGD converges quickly. We almost certainly will not have this luxury in general. - # The value of nu_y is free; the Lagrangian is independent of it. + # The value of nu_x is free; the Lagrangian is independent of it. # As such, it can take any value and should not change during the optimisation # iterations. params = { - "mu_x": mu_x_sol, - "nu_y": nu_y_starting_value, + "mean": mu_x_sol, + "cov2": nu_x_starting_value, "_l_mult": lambda_sol, } # Setup SGD optimiser @@ -146,10 +154,6 @@ def objective(params, predictive, key, ce_prefactor=ce_prefactor): converged = True break - # Confirm that nu_y has not changed, being an independent variable. - assert jnp.isclose(nu_y_starting_value, params["nu_y"]), ( - "nu_y value has changed, despite gradient being independent of it" - ) assert converged, f"Did not converge, final objective value: {objective_value}" # Confirm that we found a minimiser that does satisfy the inequality constraints. @@ -157,8 +161,8 @@ def objective(params, predictive, key, ce_prefactor=ce_prefactor): f"Converged, but not to a minimiser (lagrange multiplier = {params['_l_mult']})" ) - # Give a generous error margin in mu_x and the Lagrange multiplier, + # Give a generous error margin in mu_ux and the Lagrange multiplier, # given SGD is being used on MC-integral functions. rtol = jnp.sqrt(1.0 / n_samples) - assert jnp.isclose(params["mu_x"], mu_x_sol, rtol=rtol) + assert jnp.isclose(params["mean"], mu_x_sol, rtol=rtol) assert jnp.isclose(params["_l_mult"], lagrange_mult_sol, atol=rtol)