Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
19b7f1e
Reinstate check after renaming caused bugs in sampling
willGraham01 Aug 15, 2025
a5f55bc
Remove old CausalProblem class
willGraham01 Aug 15, 2025
de904bf
Add barebones classes to be populated later
willGraham01 Aug 15, 2025
32ccf0b
Update two normal example test to use new infrastructure
willGraham01 Aug 15, 2025
c8fbaa5
Rework so that the lagrangian can be passed model parameters and the …
willGraham01 Aug 15, 2025
87d128f
ruffing
willGraham01 Aug 15, 2025
7c171c4
Refactor out g.model argument from the Lagrangian call
willGraham01 Aug 15, 2025
3372193
Make TODOs obvious so I don't forget to do them
willGraham01 Aug 15, 2025
87fc5a2
Add docstrings and more TODOs
willGraham01 Aug 18, 2025
d202568
Todo resolution and addition
willGraham01 Aug 18, 2025
d80aec8
Merge branch 'main' into wgraham/causal-problem-rework
willGraham01 Aug 18, 2025
7f32319
Make _CPConstraint callable
willGraham01 Aug 18, 2025
3189f2d
Hide _CPComponent attributes that we don't expect to change
willGraham01 Aug 18, 2025
c2dab0c
Test __call__ for _CPComponents
willGraham01 Aug 18, 2025
648ff9d
Add note about __call__ in docstring
willGraham01 Aug 18, 2025
deea4dc
Merge branch 'main' into wgraham/make-ce-con-callables
willGraham01 Aug 18, 2025
77eb893
Fix bug in how handlers are applied
willGraham01 Aug 18, 2025
cb7d923
Write tests for features
willGraham01 Aug 18, 2025
1302338
Merge branch 'main' into wgraham/make-ce-con-callables
willGraham01 Aug 18, 2025
05a3d5b
Edit Constraint so it is created in pieces
willGraham01 Aug 19, 2025
f87b2fa
Rework Constraint.__init__ and docstring to match new format
willGraham01 Aug 19, 2025
de2f623
Update two_normal_example integration test
willGraham01 Aug 19, 2025
28f0772
Remove todo note
willGraham01 Aug 19, 2025
b7ba1c9
Merge branch 'main' into wgraham/constraint-creation-ease
willGraham01 Aug 19, 2025
c5f8270
Merge branch 'main' into wgraham/constraint-creation-ease
willGraham01 Aug 20, 2025
533d626
Update src/causalprog/causal_problem/causal_estimand.py
willGraham01 Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 72 additions & 10 deletions src/causalprog/causal_problem/causal_estimand.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable
from typing import Any, Concatenate, TypeAlias

import jax.numpy as jnp
import numpy.typing as npt

Model: TypeAlias = Callable[..., Any]
Expand Down Expand Up @@ -94,13 +95,74 @@ class Constraint(_CPComponent):
and $\epsilon$ is some tolerance.
"""

# TODO: (https://github.com/UCL/causalprog/issues/89)
# Should explain that Constraint needs more inputs and slightly different
# interpretation of the `do_with_samples` object.
# Inputs:
# - include epsilon as an input (allows constraints to have different tolerances)
# - `do_with_samples` should just be $g(\theta)$. Then have the instance build the
# full constraint that will need to be called in the Lagrangian.
# - $g$ still needs to be scalar valued? Allow a wrapper function to be applied in
# the event $g$ is vector-valued.
# If we do this, will also need to override __call__...
data: npt.ArrayLike
tolerance: npt.ArrayLike
_outer_norm: Callable[[npt.ArrayLike], float]

def __init__(
self,
*effect_handlers: ModelMask,
model_quantity: Callable[..., npt.ArrayLike],
outer_norm: Callable[[npt.ArrayLike], float] | None = None,
data: npt.ArrayLike = 0.0,
tolerance: float = 1.0e-6,
) -> None:
r"""
Create a new constraint.

Constraints have the form

$$ c(\theta) :=
\mathrm{norm}\left( g(\theta)
- g_{\mathrm{data}} \right)
- \epsilon $$

where;
- $\mathrm{norm}$ is the outer norm of the constraint (`outer_norm`),
- $g(\theta)$ is the model quantity involved in the constraint
(`model_quantity`),
- $g_{\mathrm{data}}$ is the observed data (`data`),
- $\epsilon$ is the tolerance in the data (`tolerance`).

In a causal problem, each constraint appears as the condition $c(\theta)\leq 0$
in the minimisation / maximisation (hence the inclusion of the $-\epsilon$
term within $c(\theta)$ itself).

$g$ should be a (possibly vector-valued) function that acts on (a subset of)
samples from the random variables of the causal problem. It must accept
variable keyword-arguments only, and should access the samples for each random
variable by indexing via the RV names (node labels). It should return the
model quantity as computed from the samples, that $g_{\mathrm{data}}$ observed.

$g_{\mathrm{data}}$ should be a fixed value whose shape is broadcast-able with
the return shape of $g$. It defaults to $0$ if not explicitly set.

$\mathrm{norm}$ should be a suitable norm to take on the difference between the
model quantity as predicted by the samples ($g$) and the observed data
($g_{\mathrm{data}}$). It must return a scalar value. The default is the 2-norm.
"""
super().__init__(*effect_handlers, do_with_samples=model_quantity)

if outer_norm is None:
self._outer_norm = jnp.linalg.vector_norm
else:
self._outer_norm = outer_norm

self.data = data
self.tolerance = tolerance

def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike:
"""
Evaluate the constraint, given RV samples.

Args:
samples: Mapping of RV (node) labels to drawn samples.

Returns:
Value of the constraint.

"""
return (
self._outer_norm(self._do_with_samples(**samples) - self.data)
- self.tolerance
)
4 changes: 3 additions & 1 deletion tests/test_integration/test_two_normal_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def test_two_normal_example(
# Setup the optimisation problem from the graph
ce = CausalEstimand(do_with_samples=lambda **pv: pv["X"].mean())
con = Constraint(
do_with_samples=lambda **pv: jnp.abs(pv["UX"].mean() - phi_observed) - epsilon
model_quantity=lambda **pv: pv["UX"].mean(),
data=phi_observed,
tolerance=epsilon,
)
cp = CausalProblem(
two_normal_graph(cov=1.0),
Expand Down
Loading