-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Problem Formulation
This is essentially just a small departure from the "two normal example", except we're introducing another constraint and slightly more complexity in the objective functions.
It can be solved analytically too, which means we have some solid information to go off when validating our experiments.
Importantly, we're also requiring an effect handler to be used in one of the constraints, which is included so that we have (a test that can be turned into) an example problem that demonstrates how the HandlerToApply class is used.
Consider a causal problem with three parameter nodes;
The causal estimand of interest is
and we have observable quantities (model constraints)
We have gathered observable data for these quantites
This means that we are looking to solve the following problem:
Analytic Solution
Analytically, we can simplify and solve the above problem to obtain the exact solution, as well as other quantities of interest.
We first observe that
and thus our problem is
Importantly, this means that our solution is independent of
The solutions for the other parameters occur at
with the addition taken in the maximisation case, and subtraction in the minimisation case.
If we attempt to reformulate our problem in terms of Lagrange multipliers, we will end up with the Lagrangian
(note that the subtraction is in the maximisation case, and addition in the minimisation case, here!).
Taking the gradient of this function results in
where
Stationary points of
as we expect.
Note that
Things to Explore
Though this problem is solvable analytically, there are a number of things we many want to explore numerically.
Getting Started
Setting up this problem using causalprog is the first step.
A lot of the steps will be similar to the "two normal example", which is found in tests/test_integration/test_two_normal_example.py.
- Create a
Graphinstance (and suitbaleDistributionNodes andParameterNodes) that represent the causal problem above. - Define a
CasualEstimandinstance representing$\sigma$ .
Note that the package convention is that the causal estimand and constraints are defined in terms of the RVs, not the parameters. - Define two
Constraintinstances representing the two constraints of the problem.
Note that to fully define these constraints, you will need values for$\phi_{1, obs}$ ,$\phi_{2, obs}$ ,$\epsilon_1$ , and$\epsilon_2$ , so you may want to take the approach of refactoring the setup of the casual problem into a function that accepts values for these quantities as its arguments. - Define the
CausalProbleminstance for the problem above, and use it to construct the Lagrangian viaCausalProblem.lagrangian.
From there, you can save the output of CausalProblem.lagrangian to a variable.
You can then do things like jax.grad on this variable to compute the gradient of the Lagrangian too (again, as a function), to use it in optimisation methods.
Note that you will need to take the jax.grad of CausalProblem.lagrangian with respect to both the 0th argument (the problem parameters) and 1st argument (the Lagrange multipliers).
From there, we can start to run experiments on how we go about obtaining the solution points from the Lagrangian.
Stability of the Solution wrt Optimizer Inputs
A good piece of exploratory analysis to start with would be to determine how volatile the solver is with respect to the meta-parameters that the user must provide.
There are several of these:
- The initial guess for the solver.
We should observe that the$\nu_y$ initial value does not affect the convergence of the solver.
As for the other parameters ($\mu_x$ and$\mu_y$ ), we should examine what happens as the distance between the initial guess for these parameters and the analytic solution increases.
The initial guess for the Lagrange multiplier is also an interesting question that we should address - however it is likely more interesting to record the distance of the initial guess for the multipliers from$0$ , rather than from the analytical solution at which the problem is minimised. - The learning rate / step size of the solver.
- The "tolerance" at which we treat a value of the gradient as
$0$ . - The solver itself (though this ties into "How to determine the stationary points").
We should examine at least the first two of these.
Pertiant information to record is:
- The number of iterations until convergence (though for practical reasons we may want to set a max iteration value, since we're not guaranteed to converge).
- Whether or not the solver converged.
- Whether or not the solver found a minimiser (note that the solver can converge and not find a minimiser!).
- The distance from the found solution to the analytic solution.
Once we have this data, we will hopefully gain some insights into how much we can have the package handle, and how much we will need to rely on the the user to tweak meta-parameters when using the package.
In terms of displaying the above outputs; for each quantity being investigated (learning rate, distance of initial guess from solution, etc), we should plot this along the x-axis against the distance between the found and analytic solutions.
We can do similar for the number of iterations used.
Use different marker colours to denote whether or not the solver converged (and if it did, whether it found a minimiser).
How to Determine the Stationary Points
The naive trick that @willGraham01 came up with to find the stationary points of
However, other methods may be possible, such as the Quadratic Penalty method.
If we're feeling particularly bold, we could even try some 2nd-order derivative method like Newton-Raphson, but that would likely not work since it isn't designed to be used with stochastic functions.
It's worth us attempting to implement at least one other method in addition to the naive method above, then repeating the above analysis for each solution method.
However, we should also add the runtime of the method to the list of things to record and display (since number of iterations will no longer be a suitable measure of "time" on its own, when comparing across methods).
Where to Work
Since these benchmarks may prove useful to retain and re-run in the future, but we don't want to be running them on CI every push, I suggest creating a benchmarks folder at the repository root level.
Any scripts and/or supporting functions that are required for the analysis here can then be placed into that folder.
Pull requests can be used to add/update these scripts as needed.
Tasks
- Complete the "Getting Started" section.
- Create a function that sets up the Lagrangian function (and a function that computes the norm of its gradient), and returns these values.
- Create a function that, given suitable inputs for the meta-parameters, sets up the above problem and attempts to solve it, returning the quantities of interest (iterations performed, the found solution, etc).
- Throughout the above, flag (by opening issues) anything about the codebase that is unclear or misleading. Particularly with regards to how to use something, or what something is doing that runs counter to expectations.
- (Optional, but likely quite useful for the package itself) Setup this problem and its solution as an integration test in
tests/test_integration, operating in a similar way to the currently presenttests/test_integration/test_two_normal_example.py. "Test solve with effect handlers" might be an appropriate name to give here. - Setup script(s) to perform the stability analysis, output-ing the graphs as appropriate (they can be saved to the
benchmarksfolder for now).- We will need to decide on suitable "grid values" over which to run the stability analysis, for each variable. For now, these can just be defined as constants outside the
__name__ == "__main__"block. - Would recommend a single script for each variable that is being varied. Though refactoring across the files may be possible to help reduce the code bloat.
- For each variable, report findings inside the PR body that introduces that script.
- We will need to decide on suitable "grid values" over which to run the stability analysis, for each variable. For now, these can just be defined as constants outside the
- Explore and implement other methods for solving the optimisation problem.
- Again, first implement any new solver methods as scripts to allow for running of the analysis.
- If we are happy with the results that we see, we can then move the methods into the package proper.