-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Is Your Feature Request Related to a Problem? Please Describe
Related to #38
We are going to need a jax-friendly way of solving the minimisation problem
Describe the Solution You'd Like
A method (either written ourselves or from an external library) that can solve the above problem. We'd probably need to use the KKT algorithm (the generalisation of the method of Lagrange multipliers) but can make the small simplification that there are no equality constraints.
We could do this manually ourselves, following a very naive method, that ignores the stochastic element of our
- Assemble the augmented Lagrangian
as a callable, jax-friendly Python function. (This should be easy enough if we have a way of evaluating
-
Set
$\mathbf{x} = (\theta, \lambda)$ and$\mathbf{F}(\mathbf{x}) := \nabla\mathcal{L}(\mathbf{x})$ , again in a jax-friendly way. -
We solve
$\mathbf{F}(\mathbf{x}) = 0$ . This will require something like Newton-Raphson; where we take an initial guess$\mathbf{x}_0$ and iteratively solve
for
- Given the solution
$\mathbf{x}_{sol} = (\theta_{sol}, \lambda_{sol})$ of part (3), we then check that$\lambda_{sol} \geq 0$ . If this is true, then$\theta_{sol}$ is a (local) minimiser.
This has a number of problems;
- The stochastic manner in which we evaluate
$f$ and$g$ is ignored. - It requires an initial guess of the suspected optimal parameters, but also the Lagrange multiplier
$\lambda$ . - It requires computing the Hessian of
$\mathcal{L}$ !
Describe Alternatives You've Considered
In terms of other algorithms we could consider to solve the minimisation problem - setting the stochasticity aside - SLSQP (Sequential Least SQuares Programming) is an option. scipy.optimize has an implementation and a paper which describes the method, though we'd need to write it ourselves again since jax.scipy.optimize does not have a corresponding implementation.
In terms of addressing the stochasticity elephant, in place of part (3) we could try solving
as an unconstrained minimisation problem. We could use stochastic gradient descent methods like ADAMS, which we know can be made to work with jax functions. However, we would not be allowed to stop until we found a global minimum (of 0), which opens up another can of worms.
Additional Context
Context for Our Use Case
Note that in our particular use case, dict) of our parameters to the causal model,
where