diff --git a/src/causalprog/causal_problem/__init__.py b/src/causalprog/causal_problem/__init__.py index 3d45fe4..62cdc3d 100644 --- a/src/causalprog/causal_problem/__init__.py +++ b/src/causalprog/causal_problem/__init__.py @@ -1 +1,7 @@ """Classes for defining causal problems.""" + +from .causal_problem import CausalProblem +from .components import CausalEstimand, Constraint +from .handlers import HandlerToApply + +__all__ = ("CausalEstimand", "CausalProblem", "Constraint", "HandlerToApply") diff --git a/src/causalprog/causal_problem/_base_component.py b/src/causalprog/causal_problem/_base_component.py new file mode 100644 index 0000000..4be1812 --- /dev/null +++ b/src/causalprog/causal_problem/_base_component.py @@ -0,0 +1,86 @@ +"""Base class for components of causal problems.""" + +from collections.abc import Callable +from typing import Any + +import numpy.typing as npt + +from causalprog.causal_problem.handlers import EffectHandler, HandlerToApply, Model + + +class _CPComponent: + """ + Base class for components of a Causal Problem. + + A _CPComponent has an attached method that it can apply to samples + (`do_with_samples`), which will be passed sample values of the RVs + during solution of a Causal Problem and used to evaluate the causal + estimand or constraint the instance represents. + + It also has a sequence of effect handlers that need to be applied + to the sampling model before samples can be drawn to evaluate this + component. For example, if a component requires conditioning on the + value of a RV, the `condition` handler needs to be applied to the + underlying model, before generating samples to pass to the + `do_with_sample` method. `effect_handlers` will be applied to the model + in the order they are given. + """ + + do_with_samples: Callable[..., npt.ArrayLike] + effect_handlers: tuple[HandlerToApply, ...] + + @property + def requires_model_adaption(self) -> bool: + """Return True if effect handlers need to be applied to model.""" + return len(self.effect_handlers) > 0 + + def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike: + """ + Evaluate the estimand or constraint, given sample values. + + Args: + samples: Mapping of RV (node) labels to samples of that RV. + + Returns: + Value of the estimand or constraint, given the samples. + + """ + return self._do_with_samples(**samples) + + def __init__( + self, + *effect_handlers: HandlerToApply | tuple[EffectHandler, dict[str, Any]], + do_with_samples: Callable[..., npt.ArrayLike], + ) -> None: + self.effect_handlers = tuple( + h if isinstance(h, HandlerToApply) else HandlerToApply.from_pair(h) + for h in effect_handlers + ) + self._do_with_samples = do_with_samples + + def apply_effects(self, model: Model) -> Model: + """Apply any necessary effect handlers prior to evaluating.""" + adapted_model = model + for handler in self.effect_handlers: + adapted_model = handler.handler(adapted_model, handler.options) + return adapted_model + + def can_use_same_model_as(self, other: "_CPComponent") -> bool: + """ + Determine if two components use the same (predictive) model. + + Two components rely on the same model if they apply the same handlers + to the model, which occurs if and only if `self.effect_handlers` and + `other.effect_handlers` contain identical entries, in the same order. + """ + if (not isinstance(other, _CPComponent)) or ( + len(self.effect_handlers) != len(other.effect_handlers) + ): + return False + + return all( + my_handler == their_handler + for my_handler, their_handler in zip( + self.effect_handlers, other.effect_handlers, strict=True + ) + ) diff --git a/src/causalprog/causal_problem/causal_problem.py b/src/causalprog/causal_problem/causal_problem.py index b0fd20e..2b5785a 100644 --- a/src/causalprog/causal_problem/causal_problem.py +++ b/src/causalprog/causal_problem/causal_problem.py @@ -6,7 +6,11 @@ import numpy.typing as npt from numpyro.infer import Predictive -from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint +from causalprog.causal_problem._base_component import _CPComponent +from causalprog.causal_problem.components import ( + CausalEstimand, + Constraint, +) from causalprog.graph import Graph @@ -40,6 +44,11 @@ class CausalProblem: causal_estimand: CausalEstimand constraints: list[Constraint] + @property + def _ordered_components(self) -> list[_CPComponent]: + """Internal ordering for components of the causal problem.""" + return [*self.constraints, self.causal_estimand] + def __init__( self, graph: Graph, @@ -51,6 +60,83 @@ def __init__( self.causal_estimand = causal_estimand self.constraints = list(constraints) + def _associate_models_to_components( + self, n_samples: int + ) -> tuple[list[Predictive], list[int]]: + """ + Create models to be used by components of the problem. + + Depending on how many constraints (and the causal estimand) require effect + handlers to wrap `self._underlying_graph.model`, we will need to create several + predictive models to sample from. However, we also want to minimise the number + of such models we have to make, in order to minimise the time we spend + actually computing samples. + + As such, in this method we determine: + - How many models we will need to build, by grouping the constraints and the + causal estimand by the handlers they use. + - Build these models, returning them in a list called `models`. + - Build another list that maps the index of components in + `self._ordered_components` to the index of the model in `models` that they + use. The causal estimand is by convention the component at index -1 of this + returned list. + + Args: + n_samples: Value to be passed to `numpyro.Predictive`'s `num_samples` + argument for each of the models that are constructed from the underlying + graph. + + Returns: + list[Predictive]: List of Predictive models, whose elements contain all the + models needed by the components. + list[int]: Mapping of component indexes (as per `self_ordered_components`) + to the index of the model in the first return argument that the + component uses. + + """ + models: list[Predictive] = [] + grouped_component_indexes: list[list[int]] = [] + for index, component in enumerate(self._ordered_components): + # Determine if this constraint uses the same handlers as those of any of + # the other sets. + belongs_to_existing_group = False + for group in grouped_component_indexes: + # Pull any element from the group to compare models to. + # Items in a group are known to have the same model, so we can just + # pull out the first one. + group_element = self._ordered_components[group[0]] + # Check if the current constraint can also use this model. + if component.can_use_same_model_as(group_element): + group.append(index) + belongs_to_existing_group = True + break + + # If the component does not fit into any existing group, create a new + # group for it. And add the model corresponding to the group to the + # list of models. + if not belongs_to_existing_group: + grouped_component_indexes.append([index]) + + models.append( + Predictive( + component.apply_effects(self._underlying_graph.model), + num_samples=n_samples, + ) + ) + + # Now "invert" the grouping, creating a mapping that maps the index of a + # component to the (index of the) model it uses. + component_index_to_model_index = [] + for index in range(len(self._ordered_components)): + for group_index, group in enumerate(grouped_component_indexes): + if index in group: + component_index_to_model_index.append(group_index) + break + # All indexes should belong to at least one group (worst case scenario, + # their own individual group). Thus, it is safe to do the above to create + # the mapping from component index -> model (group) index. + return models, component_index_to_model_index + def lagrangian( self, n_samples: int = 1000, *, maximum_problem: bool = False ) -> Callable[[dict[str, npt.ArrayLike], npt.ArrayLike, jax.Array], npt.ArrayLike]: @@ -89,26 +175,28 @@ def lagrangian( """ maximisation_prefactor = -1.0 if maximum_problem else 1.0 + # Build association between self.constraints and the model-samples that each + # one needs to use. We do this here, since once it is constructed, it is + # fixed, and doesn't need to be done each time we call the Lagrangian. + models, component_to_index_mapping = self._associate_models_to_components( + n_samples + ) + def _inner( parameter_values: dict[str, npt.ArrayLike], l_mult: jax.Array, rng_key: jax.Array, ) -> npt.ArrayLike: - # In general, we will need to check which of our CE/CONs require masking, - # and do multiple predictive models to account for this... - # We can always pre-build the predictive models too, so we should replace - # the "model" input with something that can map the right predictive models - # to the CE/CONS that need them. - # TODO: https://github.com/UCL/causalprog/issues/90 - predictive_model = Predictive( - model=self._underlying_graph.model, num_samples=n_samples + # Draw samples from all models + all_samples = tuple( + sample_model(model, rng_key, parameter_values) for model in models ) - all_samples = sample_model(predictive_model, rng_key, parameter_values) - value = maximisation_prefactor * self.causal_estimand(all_samples) + value = maximisation_prefactor * self.causal_estimand(all_samples[-1]) # TODO: https://github.com/UCL/causalprog/issues/87 value += sum( - l_mult[i] * c(all_samples) for i, c in enumerate(self.constraints) + l_mult[i] * c(all_samples[component_to_index_mapping[i]]) + for i, c in enumerate(self.constraints) ) return value diff --git a/src/causalprog/causal_problem/causal_estimand.py b/src/causalprog/causal_problem/components.py similarity index 66% rename from src/causalprog/causal_problem/causal_estimand.py rename to src/causalprog/causal_problem/components.py index c9d7a30..2b9900c 100644 --- a/src/causalprog/causal_problem/causal_estimand.py +++ b/src/causalprog/causal_problem/components.py @@ -6,66 +6,13 @@ import jax.numpy as jnp import numpy.typing as npt +from causalprog.causal_problem._base_component import _CPComponent + Model: TypeAlias = Callable[..., Any] EffectHandler: TypeAlias = Callable[Concatenate[Model, ...], Model] ModelMask: TypeAlias = tuple[EffectHandler, dict] -class _CPComponent: - """ - Base class for components of a Causal Problem. - - A _CPComponent has an attached method that it can apply to samples - (`do_with_samples`), which will be passed sample values of the RVs - during solution of a Causal Problem and used to evaluate the causal - estimand or constraint the instance represents. - - It also has a sequence of effect handlers that need to be applied - to the sampling model before samples can be drawn to evaluate this - component. For example, if a component requires conditioning on the - value of a RV, the `condition` handler needs to be applied to the - underlying model, before generating samples to pass to the - `do_with_sample` method. `effect_handlers` will be applied to the model - in the order they are given. - """ - - do_with_samples: Callable[..., npt.ArrayLike] - effect_handlers: tuple[ModelMask, ...] - - @property - def requires_model_adaption(self) -> bool: - """Return True if effect handlers need to be applied to model.""" - return len(self._effect_handlers) > 0 - - def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike: - """ - Evaluate the estimand or constraint, given sample values. - - Args: - samples: Mapping of RV (node) labels to samples of that RV. - - Returns: - Value of the estimand or constraint, given the samples. - - """ - return self._do_with_samples(**samples) - - def __init__( - self, - *effect_handlers: ModelMask, - do_with_samples: Callable[..., npt.ArrayLike], - ) -> None: - self._effect_handlers = tuple(effect_handlers) - self._do_with_samples = do_with_samples - - def apply_effects(self, model: Model) -> Model: - """Apply any necessary effect handlers prior to evaluating.""" - adapted_model = model - for handler, handler_options in self._effect_handlers: - adapted_model = handler(adapted_model, handler_options) - return adapted_model - - class CausalEstimand(_CPComponent): """ A Causal Estimand. diff --git a/src/causalprog/causal_problem/handlers.py b/src/causalprog/causal_problem/handlers.py new file mode 100644 index 0000000..f9958ff --- /dev/null +++ b/src/causalprog/causal_problem/handlers.py @@ -0,0 +1,87 @@ +"""Container class for specifying effect handlers that need to be applied at runtime.""" + +from collections.abc import Callable, Sequence +from dataclasses import dataclass, field +from typing import Any, Concatenate, TypeAlias + +Model: TypeAlias = Callable[..., None] +EffectHandler: TypeAlias = Callable[Concatenate[Model, ...], Model] + + +@dataclass +class HandlerToApply: + """Specifies a handler that needs to be applied to a model at runtime.""" + + handler: EffectHandler + options: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_pair(cls, pair: Sequence) -> "HandlerToApply": + """ + Create an instance from an effect handler and its options. + + The two objects should be passed in as the elements of a container of length + 2. They can be passed in any order; + - One element must be a dictionary, which will be interpreted as the `options` + for the effect handler. + - The other element must be callable, and will be interpreted as the `handler` + itself. + + Args: + pair: Container of two elements, one being the effect handler callable and + the other being the options to pass to it (as a dictionary). + + Returns: + Class instance corresponding to the effect handler and options passed. + + """ + if len(pair) != 2: # noqa: PLR2004 + msg = ( + f"{cls.__name__} can only be constructed from a container of 2 elements" + ) + raise ValueError(msg) + + # __post_init__ will catch cases when the incorrect types for one or both items + # is passed, so we can just naively if-else here. + handler: EffectHandler + options: dict + if callable(pair[0]): + handler = pair[0] + options = pair[1] + else: + handler = pair[1] + options = pair[0] + + return cls(handler=handler, options=options) + + def __post_init__(self) -> None: + """ + Validate set attributes. + + - The handler is a callable object. + - The options have been passed as a dictionary of keyword-value pairs. + """ + if not callable(self.handler): + msg = f"{type(self.handler).__name__} is not callable." + raise TypeError(msg) + if not isinstance(self.options, dict): + msg = ( + "Options should be dictionary mapping option arguments to values " + f"(got {type(self.options).__name__})." + ) + raise TypeError(msg) + + def __eq__(self, other: object) -> bool: + """ + Equality operation. + + `HandlerToApply`s are considered equal if they use the same handler function and + provide the same options to this function. + + Comparison to other types returns `False`. + """ + return ( + isinstance(other, HandlerToApply) + and self.handler is other.handler + and self.options == other.options + ) diff --git a/tests/test_causal_problem/test_associate_models_to_components.py b/tests/test_causal_problem/test_associate_models_to_components.py new file mode 100644 index 0000000..631f44b --- /dev/null +++ b/tests/test_causal_problem/test_associate_models_to_components.py @@ -0,0 +1,115 @@ +"""Test the association of models to the components of the CausalProblem. + +One model is needed fr each unique combination of handlers that the Constraints and +CausalEstimand possess. We can mimic this behaviour by defining a single handler, and +then passing different numbers of copies of this handler in to our Constraints and +CausalEstimand. Different numbers of handlers force different models, and thus we should +end up with one model for each unique number of copies that we use. + +Components are examined in the `_ordered_components` order, which goes through the +`constraints` list first, and then the CausalEstimand at the end. As such, models are +also "created" in this order. +""" + +from collections.abc import Callable + +import pytest + +from causalprog.causal_problem import ( + CausalEstimand, + CausalProblem, + Constraint, + HandlerToApply, +) +from causalprog.graph import Graph, ParameterNode + + +@pytest.fixture +def underlying_graph() -> Graph: + """The underlying graph is not actually important for checking model association, + so we just return a single node graph. + """ + g = Graph(label="Placeholder") + g.add_node(ParameterNode(label="p")) + return g + + +def placeholder_handler_fn(*args, **kwargs) -> None: + """IS comparison means this function does need to be statically defined. + + args[0] is the model input, so we just effectively return the same model. + """ + return args[0] + + +@pytest.fixture +def placeholder_handler() -> Callable[[], HandlerToApply]: + """Creates a HandlerToApply instance. + + We will use copies of the returned instance to "trick" the CasualProblem + class into creating additional models due to different numbers of handlers + being applied to its components. + """ + + def _inner() -> HandlerToApply: + return HandlerToApply(placeholder_handler_fn, {}) + + return _inner + + +@pytest.mark.parametrize( + ( + "handlers_to_give_to_constraints", + "handlers_to_give_estimand", + "expected_components_to_models_mapping", + ), + [ + # There are no constraints, so the CausalEstimand uses the only model created. + pytest.param([], 0, [0], id="No constraints"), + # All constraints and the causal estimand use the same model. + pytest.param([0] * 3, 0, [0] * 4, id="Same model used by all"), + # 1st constraint: Model w/ 1 handler created, taking model index 0. + # 2nd constraint: re-uses model index 0 (with 1 handler). + # CE: Model w/ 0 handlers created, taking model index 1. + pytest.param([1] * 2, 0, [0, 0, 1], id="CausalEstimand is always last model"), + # 1st constraint: Model w/ 1 handler created, taking model index 0. + # 2nd constraint: Model w/ 0 handlers created, taking model index 1. + # 3rd constraint: re-uses model index 0 (with 1 handler). + # 4th constraint: Model w/ 2 handlers created, taking model index 2. + # CE: re-uses model index 0 (with 1 handler). + pytest.param( + [1, 0, 1, 2], 1, [0, 1, 0, 2, 0], id="Models created in particular order" + ), + ], +) +def test_associate_models_to_components( + handlers_to_give_to_constraints: list[int], + handlers_to_give_estimand: int, + expected_components_to_models_mapping: list[int], + placeholder_handler: Callable[[], HandlerToApply], + underlying_graph: Graph, + n_samples: int = 1, +) -> None: + # The number of models is the number of 'unique numbers of handlers' given to the + # constraints and causal estimand. + expected_number_of_models = len( + set(handlers_to_give_to_constraints).union({handlers_to_give_estimand}) + ) + + constraints = [ + Constraint( + *(placeholder_handler() for _ in range(copies)), + model_quantity=lambda **pv: 0.0, + ) + for copies in handlers_to_give_to_constraints + ] + ce = CausalEstimand( + *(placeholder_handler() for _ in range(handlers_to_give_estimand)), + do_with_samples=lambda **pv: None, + ) + + cp = CausalProblem(underlying_graph, *constraints, causal_estimand=ce) + models, association = cp._associate_models_to_components(n_samples) # noqa: SLF001 + + assert len(models) == expected_number_of_models + assert association == expected_components_to_models_mapping diff --git a/tests/test_causal_problem/test_cpcomponent.py b/tests/test_causal_problem/test_cpcomponent.py index a331d10..d5d46e6 100644 --- a/tests/test_causal_problem/test_cpcomponent.py +++ b/tests/test_causal_problem/test_cpcomponent.py @@ -5,7 +5,8 @@ import pytest from numpyro.handlers import condition, do -from causalprog.causal_problem.causal_estimand import Model, ModelMask, _CPComponent +from causalprog.causal_problem import HandlerToApply +from causalprog.causal_problem._base_component import Model, _CPComponent from causalprog.graph import Graph @@ -150,7 +151,7 @@ def _inner(**two_normal_graph_options: float) -> Callable[..., None]: ], ) def test_apply_handlers( - handlers: tuple[ModelMask], + handlers: tuple[HandlerToApply], expected_model: Model, two_normal_graph: Callable[..., Graph], request: pytest.FixtureRequest, diff --git a/tests/test_causal_problem/test_cpcomponent_same_model.py b/tests/test_causal_problem/test_cpcomponent_same_model.py new file mode 100644 index 0000000..a8b0c96 --- /dev/null +++ b/tests/test_causal_problem/test_cpcomponent_same_model.py @@ -0,0 +1,105 @@ +"""Tests for _CPComponent.can_use_same_model. + +- This method should be symmetric in its arguments. +- Returns False when one of the arguments is not a _CPComponent instance. +- CausalEstimands and Constraints should still be able to share models. +- Models can be shared IFF the same handlers, in the same order, are applied. +""" + +import pytest + +from causalprog.causal_problem import ( + CausalEstimand, + Constraint, + HandlerToApply, +) +from causalprog.causal_problem._base_component import _CPComponent + + +# HandlerToApply compares the handler argument with IS, so we need to instantiate here. +def handler_a(**pv) -> None: + return + + +def handler_b(**pv) -> None: + return + + +@pytest.mark.parametrize( + ("component_1", "component_2", "expected_result"), + [ + pytest.param( + _CPComponent( + HandlerToApply(handler_a, {}), do_with_samples=lambda **pv: None + ), + _CPComponent( + HandlerToApply(handler_a, {}), do_with_samples=lambda **pv: None + ), + True, + id="Same model as self", + ), + pytest.param( + _CPComponent(do_with_samples=lambda **pv: None), + _CPComponent(do_with_samples=lambda **pv: None), + True, + id="No effect handlers case is handled", + ), + pytest.param( + _CPComponent( + HandlerToApply(handler_a, {}), do_with_samples=lambda **pv: 1.0 + ), + _CPComponent( + HandlerToApply(handler_a, {}), do_with_samples=lambda **pv: 2.0 + ), + True, + id="_do_with_samples does not affect model compatibility", + ), + pytest.param( + CausalEstimand( + HandlerToApply(handler_a, {}), do_with_samples=lambda **pv: None + ), + Constraint(HandlerToApply(handler_a, {}), model_quantity=lambda **pv: 0.0), + True, + id="CausalEstimand and Constraints can share models", + ), + pytest.param( + _CPComponent( + HandlerToApply(handler_a, {"option": "a"}), + do_with_samples=lambda **pv: None, + ), + _CPComponent( + HandlerToApply(handler_a, {"option": "b"}), + do_with_samples=lambda **pv: None, + ), + False, + id="Different handlers deny same model", + ), + pytest.param( + _CPComponent( + HandlerToApply(handler_a, {"option": "a"}), + HandlerToApply(handler_a, {"option": "b"}), + do_with_samples=lambda **pv: None, + ), + _CPComponent( + HandlerToApply(handler_a, {"option": "b"}), + HandlerToApply(handler_a, {"option": "a"}), + do_with_samples=lambda **pv: None, + ), + False, + id="Different handler order denies same model", + ), + pytest.param( + _CPComponent(do_with_samples=lambda **pv: None), + 1.0, + False, + id="Compare to non-_CPComponent", + ), + ], +) +def test_can_use_same_model( + component_1: _CPComponent, component_2: _CPComponent, *, expected_result: bool +) -> None: + if isinstance(component_1, _CPComponent): + assert component_1.can_use_same_model_as(component_2) == expected_result + if isinstance(component_2, _CPComponent): + assert component_2.can_use_same_model_as(component_1) == expected_result diff --git a/tests/test_causal_problem/test_handlers.py b/tests/test_causal_problem/test_handlers.py new file mode 100644 index 0000000..3801201 --- /dev/null +++ b/tests/test_causal_problem/test_handlers.py @@ -0,0 +1,107 @@ +import pytest + +from causalprog.causal_problem import HandlerToApply +from causalprog.causal_problem._base_component import EffectHandler + + +def placeholder_callable() -> EffectHandler: + """Stand-in for an effect handler.""" + return lambda model, **kwargs: (lambda **pv: model(**kwargs)) + + +@pytest.mark.parametrize( + ( + "args", + "expected_error", + "use_classmethod", + ), + [ + pytest.param( + (placeholder_callable, {}), None, False, id="Standard construction" + ), + pytest.param((placeholder_callable, {}), None, True, id="Via from_pair"), + pytest.param( + ({}, placeholder_callable), None, True, id="Via from_pair (out of order)" + ), + pytest.param( + (placeholder_callable, []), + TypeError( + "Options should be dictionary mapping option arguments to values " + "(got list)." + ), + False, + id="Wrong options type", + ), + pytest.param( + (1.0, {}), + TypeError("float is not callable."), + False, + id="Handler is not callable", + ), + pytest.param( + (0, 0, 0), + ValueError( + "HandlerToApply can only be constructed from a container of 2 elements" + ), + True, + id="Tuple too long", + ), + ], +) +def test_handlertoapply_creation( + args: tuple[EffectHandler | dict, ...], + expected_error: Exception | None, + raises_context, + *, + use_classmethod: bool, +): + if isinstance(expected_error, Exception): + if use_classmethod: + with raises_context(expected_error): + HandlerToApply.from_pair(args) + else: + with raises_context(expected_error): + HandlerToApply(*args) + else: + handler = ( + HandlerToApply.from_pair(args) if use_classmethod else HandlerToApply(*args) + ) + + assert isinstance(handler.options, dict) + assert callable(handler.handler) + + +@pytest.mark.parametrize( + ("left", "right", "expected_result"), + [ + pytest.param( + HandlerToApply(placeholder_callable, {"option1": 1.0}), + HandlerToApply(placeholder_callable, {"option1": 1.0}), + True, + id="Identical", + ), + pytest.param( + HandlerToApply(lambda **pv: None, {"option1": 1.0}), + HandlerToApply(lambda **pv: None, {"option1": 1.0}), + False, + id="callables compared using IS", + ), + pytest.param( + HandlerToApply(placeholder_callable, {"option1": 1.0}), + HandlerToApply(placeholder_callable, {"option2": 1.0}), + False, + id="Options must match", + ), + pytest.param( + HandlerToApply(placeholder_callable, {"option1": 1.0}), + 1.0, + False, + id="Comparison to different object class", + ), + ], +) +def test_handlertoapply_equality( + left: object, right: object, *, expected_result: bool +) -> None: + assert (left == right) == expected_result + assert (left == right) == (right == left) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 381915d..0664d76 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -6,8 +6,8 @@ import optax import pytest -from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint from causalprog.causal_problem.causal_problem import CausalProblem +from causalprog.causal_problem.components import CausalEstimand, Constraint from causalprog.graph import Graph