Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
19b7f1e
Reinstate check after renaming caused bugs in sampling
willGraham01 Aug 15, 2025
a5f55bc
Remove old CausalProblem class
willGraham01 Aug 15, 2025
de904bf
Add barebones classes to be populated later
willGraham01 Aug 15, 2025
32ccf0b
Update two normal example test to use new infrastructure
willGraham01 Aug 15, 2025
c8fbaa5
Rework so that the lagrangian can be passed model parameters and the …
willGraham01 Aug 15, 2025
87d128f
ruffing
willGraham01 Aug 15, 2025
7c171c4
Refactor out g.model argument from the Lagrangian call
willGraham01 Aug 15, 2025
3372193
Make TODOs obvious so I don't forget to do them
willGraham01 Aug 15, 2025
87fc5a2
Add docstrings and more TODOs
willGraham01 Aug 18, 2025
d202568
Todo resolution and addition
willGraham01 Aug 18, 2025
d80aec8
Merge branch 'main' into wgraham/causal-problem-rework
willGraham01 Aug 18, 2025
7f32319
Make _CPConstraint callable
willGraham01 Aug 18, 2025
3189f2d
Hide _CPComponent attributes that we don't expect to change
willGraham01 Aug 18, 2025
c2dab0c
Test __call__ for _CPComponents
willGraham01 Aug 18, 2025
648ff9d
Add note about __call__ in docstring
willGraham01 Aug 18, 2025
deea4dc
Merge branch 'main' into wgraham/make-ce-con-callables
willGraham01 Aug 18, 2025
77eb893
Fix bug in how handlers are applied
willGraham01 Aug 18, 2025
cb7d923
Write tests for features
willGraham01 Aug 18, 2025
1302338
Merge branch 'main' into wgraham/make-ce-con-callables
willGraham01 Aug 18, 2025
05a3d5b
Edit Constraint so it is created in pieces
willGraham01 Aug 19, 2025
f87b2fa
Rework Constraint.__init__ and docstring to match new format
willGraham01 Aug 19, 2025
de2f623
Update two_normal_example integration test
willGraham01 Aug 19, 2025
28f0772
Remove todo note
willGraham01 Aug 19, 2025
b7ba1c9
Merge branch 'main' into wgraham/constraint-creation-ease
willGraham01 Aug 19, 2025
3366e94
Create wrapper class to make handlers easier. __eq__ placeholder for now
willGraham01 Aug 19, 2025
fe73614
Tidy eq docstring
willGraham01 Aug 19, 2025
1e396e2
Write necessary comparison methods
willGraham01 Aug 19, 2025
440873b
Write model association method and implement in Lagrangian
willGraham01 Aug 20, 2025
12d2b71
Docstirngs and breakout HandlerToApply class to submodule
willGraham01 Aug 20, 2025
cc3d2a0
Docstirngs and breakout HandlerToApply class to submodule
willGraham01 Aug 20, 2025
65ab781
Tests for HandlerToApply class
willGraham01 Aug 20, 2025
9335477
Tests for can_use_same_model
willGraham01 Aug 20, 2025
8a4483e
Tests for associating models to components of the CP
willGraham01 Aug 20, 2025
c5f8270
Merge branch 'main' into wgraham/constraint-creation-ease
willGraham01 Aug 20, 2025
821072e
Reorganise classes now that structure is somewhat rigid
willGraham01 Aug 20, 2025
2c359ed
Add module-level import for user-facing functions
willGraham01 Aug 20, 2025
fb9025b
Merge branch 'wgraham/constraint-creation-ease' into wgraham/handle-h…
willGraham01 Aug 27, 2025
e7c7b26
Update src/causalprog/causal_problem/causal_problem.py
willGraham01 Sep 3, 2025
783163c
Merge branch 'main' into wgraham/handle-handlers
willGraham01 Sep 3, 2025
5036d81
Reinstate if not else fix
willGraham01 Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/causalprog/causal_problem/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
"""Classes for defining causal problems."""

from .causal_problem import CausalProblem
from .components import CausalEstimand, Constraint
from .handlers import HandlerToApply

__all__ = ("CausalEstimand", "CausalProblem", "Constraint", "HandlerToApply")
86 changes: 86 additions & 0 deletions src/causalprog/causal_problem/_base_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Base class for components of causal problems."""

from collections.abc import Callable
from typing import Any

import numpy.typing as npt

from causalprog.causal_problem.handlers import EffectHandler, HandlerToApply, Model


class _CPComponent:
"""
Base class for components of a Causal Problem.

A _CPComponent has an attached method that it can apply to samples
(`do_with_samples`), which will be passed sample values of the RVs
during solution of a Causal Problem and used to evaluate the causal
estimand or constraint the instance represents.

It also has a sequence of effect handlers that need to be applied
to the sampling model before samples can be drawn to evaluate this
component. For example, if a component requires conditioning on the
value of a RV, the `condition` handler needs to be applied to the
underlying model, before generating samples to pass to the
`do_with_sample` method. `effect_handlers` will be applied to the model
in the order they are given.
"""

do_with_samples: Callable[..., npt.ArrayLike]
effect_handlers: tuple[HandlerToApply, ...]

@property
def requires_model_adaption(self) -> bool:
"""Return True if effect handlers need to be applied to model."""
return len(self.effect_handlers) > 0

def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike:
"""
Evaluate the estimand or constraint, given sample values.

Args:
samples: Mapping of RV (node) labels to samples of that RV.

Returns:
Value of the estimand or constraint, given the samples.

"""
return self._do_with_samples(**samples)

def __init__(
self,
*effect_handlers: HandlerToApply | tuple[EffectHandler, dict[str, Any]],
do_with_samples: Callable[..., npt.ArrayLike],
) -> None:
self.effect_handlers = tuple(
h if isinstance(h, HandlerToApply) else HandlerToApply.from_pair(h)
for h in effect_handlers
)
self._do_with_samples = do_with_samples

def apply_effects(self, model: Model) -> Model:
"""Apply any necessary effect handlers prior to evaluating."""
adapted_model = model
for handler in self.effect_handlers:
adapted_model = handler.handler(adapted_model, handler.options)
return adapted_model

def can_use_same_model_as(self, other: "_CPComponent") -> bool:
"""
Determine if two components use the same (predictive) model.

Two components rely on the same model if they apply the same handlers
to the model, which occurs if and only if `self.effect_handlers` and
`other.effect_handlers` contain identical entries, in the same order.
"""
if (not isinstance(other, _CPComponent)) or (
len(self.effect_handlers) != len(other.effect_handlers)
):
return False

return all(
my_handler == their_handler
for my_handler, their_handler in zip(
self.effect_handlers, other.effect_handlers, strict=True
)
)
112 changes: 100 additions & 12 deletions src/causalprog/causal_problem/causal_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import numpy.typing as npt
from numpyro.infer import Predictive

from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint
from causalprog.causal_problem._base_component import _CPComponent
from causalprog.causal_problem.components import (
CausalEstimand,
Constraint,
)
from causalprog.graph import Graph


Expand Down Expand Up @@ -40,6 +44,11 @@ class CausalProblem:
causal_estimand: CausalEstimand
constraints: list[Constraint]

@property
def _ordered_components(self) -> list[_CPComponent]:
"""Internal ordering for components of the causal problem."""
return [*self.constraints, self.causal_estimand]

def __init__(
self,
graph: Graph,
Expand All @@ -51,6 +60,83 @@ def __init__(
self.causal_estimand = causal_estimand
self.constraints = list(constraints)

def _associate_models_to_components(
self, n_samples: int
) -> tuple[list[Predictive], list[int]]:
"""
Create models to be used by components of the problem.

Depending on how many constraints (and the causal estimand) require effect
handlers to wrap `self._underlying_graph.model`, we will need to create several
predictive models to sample from. However, we also want to minimise the number
of such models we have to make, in order to minimise the time we spend
actually computing samples.

As such, in this method we determine:
- How many models we will need to build, by grouping the constraints and the
causal estimand by the handlers they use.
- Build these models, returning them in a list called `models`.
- Build another list that maps the index of components in
`self._ordered_components` to the index of the model in `models` that they
use. The causal estimand is by convention the component at index -1 of this
returned list.

Args:
n_samples: Value to be passed to `numpyro.Predictive`'s `num_samples`
argument for each of the models that are constructed from the underlying
graph.

Returns:
list[Predictive]: List of Predictive models, whose elements contain all the
models needed by the components.
list[int]: Mapping of component indexes (as per `self_ordered_components`)
to the index of the model in the first return argument that the
component uses.

"""
models: list[Predictive] = []
grouped_component_indexes: list[list[int]] = []
for index, component in enumerate(self._ordered_components):
# Determine if this constraint uses the same handlers as those of any of
# the other sets.
belongs_to_existing_group = False
for group in grouped_component_indexes:
# Pull any element from the group to compare models to.
# Items in a group are known to have the same model, so we can just
# pull out the first one.
group_element = self._ordered_components[group[0]]
# Check if the current constraint can also use this model.
if component.can_use_same_model_as(group_element):
group.append(index)
belongs_to_existing_group = True
break

# If the component does not fit into any existing group, create a new
# group for it. And add the model corresponding to the group to the
# list of models.
if not belongs_to_existing_group:
grouped_component_indexes.append([index])

models.append(
Predictive(
component.apply_effects(self._underlying_graph.model),
num_samples=n_samples,
)
)

# Now "invert" the grouping, creating a mapping that maps the index of a
# component to the (index of the) model it uses.
component_index_to_model_index = []
for index in range(len(self._ordered_components)):
for group_index, group in enumerate(grouped_component_indexes):
if index in group:
component_index_to_model_index.append(group_index)
break
# All indexes should belong to at least one group (worst case scenario,
# their own individual group). Thus, it is safe to do the above to create
# the mapping from component index -> model (group) index.
return models, component_index_to_model_index

def lagrangian(
self, n_samples: int = 1000, *, maximum_problem: bool = False
) -> Callable[[dict[str, npt.ArrayLike], npt.ArrayLike, jax.Array], npt.ArrayLike]:
Expand Down Expand Up @@ -89,26 +175,28 @@ def lagrangian(
"""
maximisation_prefactor = -1.0 if maximum_problem else 1.0

# Build association between self.constraints and the model-samples that each
# one needs to use. We do this here, since once it is constructed, it is
# fixed, and doesn't need to be done each time we call the Lagrangian.
models, component_to_index_mapping = self._associate_models_to_components(
n_samples
)

def _inner(
parameter_values: dict[str, npt.ArrayLike],
l_mult: jax.Array,
rng_key: jax.Array,
) -> npt.ArrayLike:
# In general, we will need to check which of our CE/CONs require masking,
# and do multiple predictive models to account for this...
# We can always pre-build the predictive models too, so we should replace
# the "model" input with something that can map the right predictive models
# to the CE/CONS that need them.
# TODO: https://github.com/UCL/causalprog/issues/90
predictive_model = Predictive(
model=self._underlying_graph.model, num_samples=n_samples
# Draw samples from all models
all_samples = tuple(
sample_model(model, rng_key, parameter_values) for model in models
)
all_samples = sample_model(predictive_model, rng_key, parameter_values)

value = maximisation_prefactor * self.causal_estimand(all_samples)
value = maximisation_prefactor * self.causal_estimand(all_samples[-1])
# TODO: https://github.com/UCL/causalprog/issues/87
value += sum(
l_mult[i] * c(all_samples) for i, c in enumerate(self.constraints)
l_mult[i] * c(all_samples[component_to_index_mapping[i]])
for i, c in enumerate(self.constraints)
)
return value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,66 +6,13 @@
import jax.numpy as jnp
import numpy.typing as npt

from causalprog.causal_problem._base_component import _CPComponent

Model: TypeAlias = Callable[..., Any]
EffectHandler: TypeAlias = Callable[Concatenate[Model, ...], Model]
ModelMask: TypeAlias = tuple[EffectHandler, dict]


class _CPComponent:
"""
Base class for components of a Causal Problem.

A _CPComponent has an attached method that it can apply to samples
(`do_with_samples`), which will be passed sample values of the RVs
during solution of a Causal Problem and used to evaluate the causal
estimand or constraint the instance represents.

It also has a sequence of effect handlers that need to be applied
to the sampling model before samples can be drawn to evaluate this
component. For example, if a component requires conditioning on the
value of a RV, the `condition` handler needs to be applied to the
underlying model, before generating samples to pass to the
`do_with_sample` method. `effect_handlers` will be applied to the model
in the order they are given.
"""

do_with_samples: Callable[..., npt.ArrayLike]
effect_handlers: tuple[ModelMask, ...]

@property
def requires_model_adaption(self) -> bool:
"""Return True if effect handlers need to be applied to model."""
return len(self._effect_handlers) > 0

def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike:
"""
Evaluate the estimand or constraint, given sample values.

Args:
samples: Mapping of RV (node) labels to samples of that RV.

Returns:
Value of the estimand or constraint, given the samples.

"""
return self._do_with_samples(**samples)

def __init__(
self,
*effect_handlers: ModelMask,
do_with_samples: Callable[..., npt.ArrayLike],
) -> None:
self._effect_handlers = tuple(effect_handlers)
self._do_with_samples = do_with_samples

def apply_effects(self, model: Model) -> Model:
"""Apply any necessary effect handlers prior to evaluating."""
adapted_model = model
for handler, handler_options in self._effect_handlers:
adapted_model = handler(adapted_model, handler_options)
return adapted_model


class CausalEstimand(_CPComponent):
"""
A Causal Estimand.
Expand Down
Loading