Skip to content

Solving the Optimisation Problem #75

@willGraham01

Description

@willGraham01

Is Your Feature Request Related to a Problem? Please Describe

Related to #38

We are going to need a jax-friendly way of solving the minimisation problem

$$ \mathrm{min}_{\theta} f(\theta), \quad\text{s.t. } g(\theta) \leq 0. $$

Describe the Solution You'd Like

A method (either written ourselves or from an external library) that can solve the above problem. We'd probably need to use the KKT algorithm (the generalisation of the method of Lagrange multipliers) but can make the small simplification that there are no equality constraints.

We could do this manually ourselves, following a very naive method, that ignores the stochastic element of our $f$ and $g$ functions:

  1. Assemble the augmented Lagrangian

$$ \mathcal{L}(\theta, \lambda) := f(\theta) + \lambda g(\theta), $$

as a callable, jax-friendly Python function. (This should be easy enough if we have a way of evaluating $f$ and $g$ already).

  1. Set $\mathbf{x} = (\theta, \lambda)$ and $\mathbf{F}(\mathbf{x}) := \nabla\mathcal{L}(\mathbf{x})$, again in a jax-friendly way.

  2. We solve $\mathbf{F}(\mathbf{x}) = 0$. This will require something like Newton-Raphson; where we take an initial guess $\mathbf{x}_0$ and iteratively solve

$$ \mathbf{F}(\mathbf{x}_n) = \nabla\mathbf{F}(\mathbf{x}_n) \mathbf{y} $$

for $\mathbf{y}$, then update $\mathbf{x}_{n+1} = \mathbf{y} + \mathbf{x}_n$.

  1. Given the solution $\mathbf{x}_{sol} = (\theta_{sol}, \lambda_{sol})$ of part (3), we then check that $\lambda_{sol} \geq 0$. If this is true, then $\theta_{sol}$ is a (local) minimiser.

This has a number of problems;

  • The stochastic manner in which we evaluate $f$ and $g$ is ignored.
  • It requires an initial guess of the suspected optimal parameters, but also the Lagrange multiplier $\lambda$.
  • It requires computing the Hessian of $\mathcal{L}$!

Describe Alternatives You've Considered

In terms of other algorithms we could consider to solve the minimisation problem - setting the stochasticity aside - SLSQP (Sequential Least SQuares Programming) is an option. scipy.optimize has an implementation and a paper which describes the method, though we'd need to write it ourselves again since jax.scipy.optimize does not have a corresponding implementation.

In terms of addressing the stochasticity elephant, in place of part (3) we could try solving

$$ \mathrm{min}_{\mathbf{x}} \left(\nabla\mathbf{F}(\mathbf{x})\right)^2, $$

as an unconstrained minimisation problem. We could use stochastic gradient descent methods like ADAMS, which we know can be made to work with jax functions. However, we would not be allowed to stop until we found a global minimum (of 0), which opens up another can of worms.

Additional Context

Context for Our Use Case

Note that in our particular use case, $\theta$ is a vector (or dict) of our parameters to the causal model, $f = \sigma$, the causal estimand, and $g$ will be our (scalar-valued!) constraint function of the form

$$ g = \vert \phi(\theta) - \phi_{\text{obs}} \vert - \epsilon, $$

where $\epsilon$ is the associated tolerance we want to allow in the data, $\phi(\theta)$ are the predicted values of the quantities observed in $\phi_{\text{obs}}$, and $\vert\cdot\vert$ is some appropriate norm.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions