# The Problem

This notebook demonstrates how to use causalprog to approximate bounds for a simple problem.

For the following distributions

$$
\begin{aligned}
X &\sim \mathcal{N}(\mu_{x}, 1.0) \\
Y \mid X &\sim \mathcal{N}(X, \nu_{y})
\end{aligned}
$$

we wish to calculate bounds for the causal estimand

$$ \sigma(\mu_{x}, \nu_{y}) = \mathbb{E}[Y]$$

given observed data (constraints)

$$ \phi(\mu_{x}, \nu_{y}) = \mathbb{E}[X]$$

and tolerance in the observed data $\epsilon$.

Therefore we aim to solve the following

$$ \mathrm{max}/\mathrm{min}_{\mu_{x}, \nu_{y}} \mu_{x}, \quad
\text{subject to } \vert \mu_{x} - \phi_{obs} \vert \leq \epsilon.
$$

The solution to this is $\mu_{x}^{*} = \phi_{obs} \pm \epsilon $    
The value of $\nu_{y}$ can be any positive value, since in this
setup both $\phi$ and $\sigma$ are independent of it.

The corresponding Lagrangians are: 

$$ \mathcal{L}_{\min}(\mu_{x}, \nu_{y}, \lambda) =  \mu_{x}
+ \lambda(|\mu_{x} - \phi_{obs}| - \epsilon) \qquad \lambda \geq 0 $$

$$ \mathcal{L}_{\max}(\mu_{x}, \nu_{y}, \lambda) = - \mu_{x}
+ \lambda(|\mu_{x} - \phi_{obs}| - \epsilon) \qquad \lambda \geq 0 $$


With KKT (primal-dual) solutions $(\mu_{x}^*, \nu_{y}, \lambda^*) = (\phi_{obs} \pm \epsilon, \nu_{y}, 1)$

In this notebook, with assistance from causalprog, we will attempt to find this solution using the naive approach of minimising  $\| \nabla \mathcal{L} \|_2^2$.



# Building the Graph

Before solving, we must define the problem's corresponding DAG.

In [None]:
from causalprog.graph import Graph

# First initialise the graph
graph = Graph(label="two_normal_graph")

## Adding Nodes
For each distribution in the causal graph (`X` and `Y` in our case), add a `DistributionNode`.

Then, we must set the parameters for our distributions.

### Setting Distribution Parameters

1. **Decision Parameters**  
e.g. $\mu_{x}$, $\nu_{y}$
   - Add a `ParameterNode`.  
   - Reference it in the `parameters` dictionary of the corresponding `DistributionNode`.

2. **Derived Parameters**  
e.g. $\mu_{y}$
   - If a parameter is the result of a previous distribution,  
     reference the corresponding `DistributionNode` in the `parameters` dictionary.

3. **Constant Parameters**  
e.g. $\nu_{x}$
   - Set directly in the `constant_parameters` dictionary of the `DistributionNode`



In [None]:
from numpyro.distributions import Normal

from causalprog.graph import DistributionNode, ParameterNode

graph.add_node(ParameterNode(label="X_mean"))
graph.add_node(
    DistributionNode(
        distribution=Normal,
        label="X",
        constant_parameters={"scale": 1},  # Set X std to 1
        parameters={"loc": "X_mean"},  # Set X mean as result of ParameterNode X_mean
    )
)

graph.add_node(ParameterNode(label="Y_cov"))
graph.add_node(
    DistributionNode(
        distribution=Normal,
        label="Y",
        parameters={
            "loc": "X",  # Set Y mean as result of DistributionNode X
            "scale": "Y_cov",  # Set Y std as result of ParameterNode Y_cov
        },
    )
)

## Adding Edges

We must add edges to the graph to define the relationships between nodes. 

In [None]:
# DistributionNode edges
graph.add_edge("X", "Y")

# ParameterNode edges
graph.add_edge("X_mean", "X")
graph.add_edge("Y_cov", "Y")

# Defining the Problem

We can now use the graph to define the problem we wish to solve.

To do so, we define our constraints with `Constraint` and our causal estimand with `CausalEstimand`.  

Each `Constraint` requires a function to calculate the constrained quantity, e.g. $\mu_{x}$, from samples of the graph. In the function, we reference nodes with the labels we assigned previously ("X" and "Y").  

Likewise, each `CausalEstimand` requires a function to calculate the causal estimand, e.g. $\mu_{y}$, from samples of the graph. Again, we use the labels we assigned previously.

We then use our `Constraint` and `CausalEstimand` to define our problem with `CausalProblem`.  



In [None]:
from causalprog.causal_problem.causal_problem import CausalProblem
from causalprog.causal_problem.components import CausalEstimand, Constraint

# Define the constraint using observed data and a tolerance level.
PHI_OBSERVED = 0.0
EPSILON = 1
constraint = Constraint(
    model_quantity=lambda **pv: pv["X"].mean(),
    data=PHI_OBSERVED,
    tolerance=EPSILON,
)

# Define the causal estimand
causal_estimand = CausalEstimand(do_with_samples=lambda **pv: pv["Y"].mean())

# Define the problem using the graph, constraint, and causal estimand.
causal_problem = CausalProblem(
    graph,
    constraint,
    causal_estimand=causal_estimand,
)

# Calculating the Bounds

Now we calculate the bounds of the causal estimand.
To do this, we'll seek the stationary points of the Lagrangian, using the naive approach of minimising $\| \nabla \mathcal{L} \|_2^2$.


In [None]:
import jax
import jax.numpy as jnp
import numpy.typing as npt
import optax

from causalprog.solvers.sgd import stochastic_gradient_descent
from causalprog.solvers.solver_callbacks import tqdm_callback
from causalprog.solvers.solver_result import SolverResult
from causalprog.utils.norms import l2_normsq

# Define our initial guess for the decision variables and lagrange multiplier
LAGRANGE_MULTIPLIER_INIT = jnp.atleast_1d(0.5)
INIT_PARAMS = {
    "X_mean": 0.0,
    "Y_cov": 1.0,
}

# Choose stochastic gradient descent settings
RNG_KEY = jax.random.key(42)
LEARNING_RATE = 1.0e-1
MAX_OPTIMISER_ITER = 200
optimiser = optax.adam(LEARNING_RATE)


# Define a function to find a bound using a naive Lagrangian approach
def naive_lagrangian_approach(maximum_problem: bool) -> SolverResult:
    lagrangian = causal_problem.lagrangian(
        n_samples=300,
        maximum_problem=maximum_problem,
    )

    # Define the objective function for optimization
    # arg x structure should match stochastic_gradient_descent arg initial_guess
    def objective(x: npt.ArrayLike, key: jax.Array) -> npt.ArrayLike:
        v = jax.grad(lagrangian, argnums=(0, 1))(*x, rng_key=key)
        return l2_normsq(v)

    # Run the optimisation
    return stochastic_gradient_descent(
        obj_fn=objective,
        initial_guess=(INIT_PARAMS, LAGRANGE_MULTIPLIER_INIT),
        fn_kwargs={"key": RNG_KEY},
        maxiter=MAX_OPTIMISER_ITER,
        optimiser=optimiser,
        history_logging_interval=1,  # Log optimisation history every iteration
        callbacks=tqdm_callback(MAX_OPTIMISER_ITER),  # Add a progress bar
    )


# Calculate the bounds
max_result = naive_lagrangian_approach(maximum_problem=True)
min_result = naive_lagrangian_approach(maximum_problem=False)

# Observing the Results


Each `stochastic_gradient_descent` returns a `SolverResult` object. The `SolverResult` stores results and details of the optimisation.  

The attributes of `SolverResult` that are most relevant to this example notebook are:

- `fn_args`: The objective arguments at the final iteration. 
- `obj_val`: The objective value at `fn_args`.
- `fn_args_history`: The history of `fn_args` at each iteration that the optimisation is logged. 
- `obj_val_history`: The history of `obj_val` at each iteration that the optimisation is logged. 

We will use these to define a function that outputs the best results of our solver. 

In [None]:
import numpy as np

from causalprog.solvers.solver_result import SolverResult


# We define a function to print the results of the SGD runs given the SolverResult
def print_sgd_results(result: SolverResult, maximum_problem: bool = True) -> None:
    """Print the results of the stochastic gradient descent runs."""
    if maximum_problem:
        print("-------------Max Bound Results-------------")
    else:
        print("-------------Min Bound Results-------------")
    mindex = np.argmin(result.obj_val_history)
    print(f"Best Parameters: {result.fn_args_history[mindex][0]}")
    print(f"Best Lagrange Multiplier: {result.fn_args_history[mindex][1][0]}")
    print(f"Best Objective Value: {result.obj_val_history[mindex]}\n")


print_sgd_results(max_result, maximum_problem=True)
print_sgd_results(min_result, maximum_problem=False)

Notice that the result we found for the minimum bound is incorrect. 
This is because
\begin{align*}
\frac{\partial \mathcal{L}_{\min}}{\partial \mu_{x}}(\mu_x, \nu_y, \lambda) 
\Big|_{\phi_{\text{obs}} = 0,\, \epsilon = 1}
&= 1 + \lambda \operatorname{sign}(\mu_x) \\[12pt]
\frac{\partial \mathcal{L}_{\min}}{\partial \lambda}(\mu_x, \nu_y, \lambda)
\Big|_{\phi_{\text{obs}} = 0,\, \epsilon = 1}
&= |\mu_x| - 1
\end{align*}

Which means that both 
$$
(\mu_{x}=1, \lambda=-1) \implies \| \nabla \mathcal{L}_{\min} \|_2^2 = 0 \\[12pt]
(\mu_{x}=-1, \lambda=1) \implies \| \nabla \mathcal{L}_{\min} \|_2^2 = 0
$$

Clearly this is not ideal. 
To combat this, we could retry the search with multiple starting points, select the best one we find and hope it's optimal.

Or we could try something else...

# Recalculating the Bounds

The advantage of stochastic gradient descent is that we can customise the objective function to anything we like.  

Given that this problem is convex, finite, and satisfies Slater's condition, we can use an objective function that targets the KKT conditions directly.  
If these conditions are met, our solution will be optimal. 

In [None]:
from collections.abc import Callable

import optax


def build_kkt_residual_obj(
    cp: CausalProblem,
    *,
    n_samples: int = 1000,
    maximum_problem: bool = False,
    alpha: float = 1.0,  # weight: primal feasibility
    beta: float = 1.0,  # weight: complementary slackness
    gamma: float = 1.0,  # weight: dual feasibility
) -> Callable[..., jax.Array]:
    lagrangian = cp.lagrangian(n_samples=n_samples, maximum_problem=maximum_problem)

    def obj(x: npt.ArrayLike, *, rng_key: jax.Array) -> jax.Array:
        params, lam = x
        grad_theta, g = jax.grad(lagrangian, argnums=(0, 1))(
            params, lam, rng_key=rng_key
        )

        # 1) Stationarity
        stat = l2_normsq(grad_theta)

        # 2) Primal feasibility
        primal = jnp.sum(jnp.maximum(g, 0.0) ** 2)

        # 3) Complementary slackness
        fb = jnp.sqrt(lam**2 + g**2) - lam + g  # Fischer-Burmeister
        fb = jnp.sum(fb**2)

        # 4) Dual feasibility
        dual = jnp.sum(jnp.minimum(lam, 0.0) ** 2)

        return stat + alpha * primal + beta * fb + gamma * dual

    return obj


ALPHA = 1.0  # Primal feasibility
BETA = 1.0  # Complementary slackness
GAMMA = 1.0  # Dual feasibility

optimiser = optax.adam(0.01)


# Define a function to find a bound using the KKT residual approach
def kkt_residual_approach(maximum_problem: bool) -> SolverResult:
    objective = build_kkt_residual_obj(
        causal_problem,
        n_samples=200,
        maximum_problem=maximum_problem,
        alpha=ALPHA,
        beta=BETA,
        gamma=GAMMA,
    )

    return stochastic_gradient_descent(
        obj_fn=objective,
        initial_guess=(INIT_PARAMS, LAGRANGE_MULTIPLIER_INIT),
        convergence_criteria=lambda x, _: jnp.abs(x),
        fn_kwargs={"rng_key": RNG_KEY},
        maxiter=1000,
        tolerance=1e-5,  # We can set a tolerance for which convergence is successful
        optimiser=optimiser,
        history_logging_interval=1,
        callbacks=[tqdm_callback(1000)],
    )


max_result = kkt_residual_approach(maximum_problem=True)
min_result = kkt_residual_approach(maximum_problem=False)

print_sgd_results(max_result, maximum_problem=True)
print_sgd_results(min_result, maximum_problem=False)

We have found the correct bounds! ðŸŽ‰
