# Solving a Causal Problem with `causalprog`

This notebook demonstrates how to use `causalprog` to compute bounds on the causal estimand for a simple causal problem.


## Problem Setup and Analytic Solution

We will assume that we have the following random variables;

\begin{align*}
X \sim \mathcal{N}(\mu_X, 1.0),
&\quad 
Y \mid X \sim \mathcal{N}(X, \nu_Y)
\end{align*}

where $\mu_X$ and $\nu_Y$ are the parameters for our problem.

**Note:** to disambiguate between terms, `causalprog` refers to $\mu_X$ and $\nu_Y$ as the **model parameters**.
Model parameters are the variables that fully define the model described in the causal problem (equivalently that parameterise the RVs that appear in the DAG).

Our causal estimand $\sigma$ is given by

\begin{align*}
\sigma(\mu_X, \nu_Y) = \mathbb{E}[Y],
\end{align*}

which we can analytically compute to be $\sigma(\mu_X, \nu_Y) = \mu_X$.
The quantities predicted by our data $\phi_{obs}$ (which form our constraints) are given by

\begin{align*}
\phi(\mu_X, \nu_Y) = \mathbb{E}[X],
\end{align*}

where again we can analytically compute that $\phi(\mu_X, \nu_Y) = \mu_X$.
We'll also assume that we have some tolerance $\epsilon$ in our data.

Thus, we're aiming to solve the following optimisation problems;

\begin{align*}
\max/\min_{\mu_X, \nu_Y} \sigma
&= \max/\min_{\mu_X, \nu_Y} \mu_X, \\
\text{subject to } \quad \vert \phi(\mu_X, \nu_Y) - \phi_{obs} \vert &= \vert \mu_X - \phi_{obs} \vert \leq \epsilon.
\end{align*}

The solution to this is $\mu_X^{*} = \phi_{obs} \pm \epsilon$ (with addition in the maximisation case).
The value of $\nu_Y$ can be any positive value, since both $\phi$ and $\sigma$ are independent of it.

The corresponding Lagrangians are: 

\begin{align*}
\mathcal{L}_{\min}(\mu_X, \nu_Y, \lambda) =  \mu_X
+ \lambda(|\mu_X - \phi_{obs}| - \epsilon), \\
\mathcal{L}_{\max}(\mu_X, \nu_Y, \lambda) = - \mu_X
+ \lambda(|\mu_X - \phi_{obs}| - \epsilon)
\end{align*}

where we require that $\lambda\geq 0$ in each case (to obtain a stationary point of the correct orientation).
The KKT (primal-dual) solutions are $(\mu_X^*, \nu_Y, \lambda^*) = (\phi_{obs} \pm \epsilon, \nu_Y, 1)$.

In this notebook, with assistance from `causalprog`, we will attempt to find this solution using the naive approach of minimising  $\| \nabla \mathcal{L} \|_2^2$.

## Building the Graph

To setup a problem in `causalprog`, we need to construct the Directed Acyclic Graph (DAG) that represents the causal problem.
The `Graph` class is how DAGs are represented in `causalprog`.

In [None]:
from causalprog.graph import Graph

# Initialises the graph. Note that this graph currently doesn't
# have any nodes or edges in it!
graph = Graph(label="two_normal_graph")

### Adding Nodes

For each parameter and random variable (RV) in our causal problem, we need a node to represent it.

1. **Model Parameters** are represented with `ParameterNode`s. The model parameters are the set of variables that fully parametrise the (RVs that appear in the) DAG / causal problem.
   - In our example, these are the values $\mu_X$ and $\nu_Y$.
   - For each of these, we add a `ParameterNode`. 
   - Model parameters that are referenced in the `parameters` dictionary of a `DistributionNode` will be used when constructing the `DistributionNode`'s RV.

2. **Derived (RV) Parameters** are parameters of RVs that are the result of sampling from a previous distribution, or take the value of a model parameter.
   - In our example, the mean of $Y$ is such a quantity. This quantity is determined by the RV $X$.
   - Also in our example, the mean of $X$ is such a quantity, being determined by $\mu_X$ (a model parameter).
   - Derived parameters should be provided to the `parameters` argument of a `DistributionNode`, and are stored in its eponymous attribute.

1. **Constant (RV) Parameters** are constant values that appear in the problem, but are still needed to construct some of the RVs.
   - In our example, the (co)variance of $X$ is a constant parameter, taking the value $1$.
   - These values should be passed to the `constant_parameters` argument of a `DistributionNode`, and are stored in its eponymous attribute.



In [None]:
from numpyro.distributions import Normal

from causalprog.graph import DistributionNode, ParameterNode

graph.add_node(ParameterNode(label="mu_X"))
graph.add_node(
    DistributionNode(
        distribution=Normal,
        label="X",
        constant_parameters={"scale": 1},  # Variance of X is constant
        parameters={"loc": "mu_X"},  # Mean of X is given by mu_X
    )
)

graph.add_node(ParameterNode(label="nu_Y"))
graph.add_node(
    DistributionNode(
        distribution=Normal,
        label="Y",
        parameters={
            "loc": "X",  # Mean of Y determined by the RV X
            "scale": "nu_Y",  # Variance of Y is given by nu_Y
        },
    )
)

### Adding Edges

We must add edges between nodes in the graph to allow it to pick up connections between parameters and RVs, and derived parameters and other RVs.
There is no need to add nodes for constant parameters (nor a need for edges to have these constants picked up by the nodes that use them).

Note that edges should be directed **into** the dependent RV.
That is, `DistributionNode`s should have edges directed into them from the nodes referenced by their derived parameters.
These nodes may be either `ParameterNode`s; for example our RV $X$ will need an edge from the `ParameterNode` for $\mu_X$ into the `DistributionNode` representing $X$.
They may also be other `DistributionNode`s; for example, $Y$ has $X$ as a derived parameter since we know that $Y \vert X \sim \mathcal{N}(X, \nu_Y)$.

In [None]:
# We know Y | X, so the edge is directed X -> Y
graph.add_edge("X", "Y")

# Parameters should point to the RVs they are used to define
graph.add_edge("mu_X", "X")
graph.add_edge("nu_Y", "Y")

## Defining the Problem

We can now use our graph to define the problem we wish to solve.
To do so, we define our constraints using the `Constraint` class, and our causal estimand with the `CausalEstimand` class.  

Each `Constraint` requires a function to calculate the constrained quantity, e.g. $\mathbb{E}[X]$, from samples (from RVs) of the graph.
In the function, we reference nodes with the labels we assigned previously (`"X"` and `"Y"`).  

Likewise, each `CausalEstimand` requires a function to calculate the causal estimand, $\mathbb{E}[Y]$, from samples (from RVs) of the graph.
Again, we use the labels we assigned previously.

We then pass our `Constraint`s and `CausalEstimand` to the `CausalProblem` class.  

In [None]:
from causalprog.causal_problem.causal_problem import CausalProblem
from causalprog.causal_problem.components import CausalEstimand, Constraint

# Define the constraint using observed data and a tolerance level.
PHI_OBSERVED = 0.0
EPSILON = 1
constraint = Constraint(
    model_quantity=lambda **pv: pv["X"].mean(),
    data=PHI_OBSERVED,
    tolerance=EPSILON,
)

# Define the causal estimand
causal_estimand = CausalEstimand(do_with_samples=lambda **pv: pv["Y"].mean())

# Define the problem using the graph, constraint, and causal estimand.
causal_problem = CausalProblem(
    graph,
    constraint,
    causal_estimand=causal_estimand,
)

## Solving the Causal Problem

Now we calculate the bounds of the causal estimand.
To do this, we'll seek the stationary points of the Lagrangian, using the naive approach of minimising $\| \nabla \mathcal{L} \|_2^2$.

`causalprog` provides a few out-of-the-box solvers that we can utilise, and the `CausalProblem` class can construct some helpful functions for us to use, like the Lagrangian $\mathcal{L}$ for our problem!

In [None]:
import jax
import jax.numpy as jnp
import numpy.typing as npt
import optax

from causalprog.solvers.sgd import stochastic_gradient_descent
from causalprog.solvers.solver_callbacks import tqdm_callback
from causalprog.solvers.solver_result import SolverResult
from causalprog.utils.norms import l2_normsq

# Define our initial guess for the decision variables and lagrange multiplier
LAGRANGE_MULTIPLIER_INIT = jnp.atleast_1d(0.5)
INIT_PARAMS = {
    "mu_X": 0.0,
    "nu_Y": 1.0,
}

# Choose stochastic gradient descent settings
RNG_KEY = jax.random.key(42)
LEARNING_RATE = 1.0e-1
MAX_OPTIMISER_ITER = 200
optimiser = optax.adam(LEARNING_RATE)


# Define a function to find a bound using a naive Lagrangian approach
def naive_lagrangian_approach(maximum_problem: bool) -> SolverResult:
    # Have the CasualProblem class construct the Lagrangian for us.
    # We can specify here whether we are seeking the maximum or minimum
    # bound for the Causal Estimand.
    lagrangian = causal_problem.lagrangian(
        n_samples=300,
        maximum_problem=maximum_problem,
    )

    # Define the objective function for optimization.
    # Our naive approach requires us to minimise the L2-norm of the Lagrangian's
    # gradient.
    def objective(x: npt.ArrayLike, key: jax.Array) -> npt.ArrayLike:
        v = jax.grad(lagrangian, argnums=(0, 1))(*x, rng_key=key)
        return l2_normsq(v)

    # Invoke the SDG solver to return the solution.
    return stochastic_gradient_descent(
        obj_fn=objective,
        initial_guess=(INIT_PARAMS, LAGRANGE_MULTIPLIER_INIT),
        fn_kwargs={"key": RNG_KEY},
        maxiter=MAX_OPTIMISER_ITER,
        optimiser=optimiser,
        history_logging_interval=1,  # Log optimisation history every iteration
        callbacks=tqdm_callback(MAX_OPTIMISER_ITER),  # Add a progress bar
    )


# Calculate the bounds
max_result = naive_lagrangian_approach(maximum_problem=True)
min_result = naive_lagrangian_approach(maximum_problem=False)

### Interpreting the Results

`causalprog`'s solvers, like `stochastic_gradient_descent`, return a `SolverResult` object.
The `SolverResult` stores helpful information about the optimisation process.  

The attributes of `SolverResult` that are most relevant are:

- `fn_args`: The objective arguments at the final iteration. 
- `obj_val`: The objective value at `fn_args`.
- `fn_args_history`: The history of `fn_args` at each iteration that the optimisation is logged. 
- `obj_val_history`: The history of `obj_val` at each iteration that the optimisation is logged. 

We can inspect these attributes to display the results of our optimisation.

In [None]:
import numpy as np

from causalprog.solvers.solver_result import SolverResult


# Since we want to print results for both the maximisation and minimisation problem,
# we'll define a function to print the results of the SGD runs given the SolverResult
def print_sgd_results(result: SolverResult, maximum_problem: bool = True) -> None:
    """Print the results of the stochastic gradient descent runs."""
    if maximum_problem:
        print("-------------Max Bound Results-------------")
    else:
        print("-------------Min Bound Results-------------")
    mindex = np.argmin(result.obj_val_history)
    print(f"Best Parameters: {result.fn_args_history[mindex][0]}")
    print(f"Best Lagrange Multiplier: {result.fn_args_history[mindex][1][0]}")
    print(f"Best Objective Value: {result.obj_val_history[mindex]}\n")


print_sgd_results(max_result, maximum_problem=True)
print_sgd_results(min_result, maximum_problem=False)

Notice that the result we found for the minimum bound is incorrect. 
This is because both $(\mu_X, \lambda) = (1, -1)$ and $(\mu_X, \lambda) = (1, 1)$ provide us with $\vert \nabla \mathcal{L}_{\min} \vert _2^2 = 0$.

Clearly this is not ideal. 
To combat this, we could retry the search with multiple starting points, select the best one we find and hope it's optimal.

Or we could try something else...

### Solving In a Smarter Way

`causalprog` provides us with an easy way to assemble functions that are relevant to the optimisation problem we need to solve.
But we are free to further manipulate these functions to suit our particular problem.

Given that our problem is convex, finite, and satisfies Slater's condition, we can use an objective function that targets the KKT conditions directly.  
If these conditions are met, our solution will be optimal. 

In [None]:
from collections.abc import Callable

import optax


def build_kkt_residual_obj(
    cp: CausalProblem,
    *,
    n_samples: int = 1000,
    maximum_problem: bool = False,
    alpha: float = 1.0,  # weight: primal feasibility
    beta: float = 1.0,  # weight: complementary slackness
    gamma: float = 1.0,  # weight: dual feasibility
) -> Callable[..., jax.Array]:
    lagrangian = cp.lagrangian(n_samples=n_samples, maximum_problem=maximum_problem)

    def obj(x: npt.ArrayLike, *, rng_key: jax.Array) -> jax.Array:
        params, lam = x
        grad_theta, g = jax.grad(lagrangian, argnums=(0, 1))(
            params, lam, rng_key=rng_key
        )

        # 1) Stationarity
        stat = l2_normsq(grad_theta)

        # 2) Primal feasibility
        primal = jnp.sum(jnp.maximum(g, 0.0) ** 2)

        # 3) Complementary slackness
        fb = jnp.sqrt(lam**2 + g**2) - lam + g  # Fischer-Burmeister
        fb = jnp.sum(fb**2)

        # 4) Dual feasibility
        dual = jnp.sum(jnp.minimum(lam, 0.0) ** 2)

        return stat + alpha * primal + beta * fb + gamma * dual

    return obj


ALPHA = 1.0  # Primal feasibility
BETA = 1.0  # Complementary slackness
GAMMA = 1.0  # Dual feasibility

optimiser = optax.adam(0.01)


# Define a function to find a bound using the KKT residual approach
def kkt_residual_approach(maximum_problem: bool) -> SolverResult:
    objective = build_kkt_residual_obj(
        causal_problem,
        n_samples=200,
        maximum_problem=maximum_problem,
        alpha=ALPHA,
        beta=BETA,
        gamma=GAMMA,
    )

    return stochastic_gradient_descent(
        obj_fn=objective,
        initial_guess=(INIT_PARAMS, LAGRANGE_MULTIPLIER_INIT),
        convergence_criteria=lambda x, _: jnp.abs(x),
        fn_kwargs={"rng_key": RNG_KEY},
        maxiter=1000,
        tolerance=1e-5,  # We can set a tolerance for which convergence is successful
        optimiser=optimiser,
        history_logging_interval=1,
        callbacks=[tqdm_callback(1000)],
    )


max_result = kkt_residual_approach(maximum_problem=True)
min_result = kkt_residual_approach(maximum_problem=False)

print_sgd_results(max_result, maximum_problem=True)
print_sgd_results(min_result, maximum_problem=False)

We have found (an approximation to) the correct bounds! ðŸŽ‰
