Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/causalprog/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Algorithms."""

from .do import do
from .expectation import expectation, standard_deviation
from .moments import expectation, moment, standard_deviation
46 changes: 0 additions & 46 deletions src/causalprog/algorithms/expectation.py

This file was deleted.

72 changes: 72 additions & 0 deletions src/causalprog/algorithms/moments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Algorithms for estimating the expectation and standard deviation."""

import jax
import numpy.typing as npt

from causalprog.graph import Graph


def sample(
graph: Graph,
outcome_node_label: str,
samples: int,
*,
rng_key: jax.Array,
) -> npt.NDArray[float]:
"""Sample data from (a random variable attached to) a node in a graph."""
nodes = graph.roots_down_to_outcome(outcome_node_label)

values: dict[str, npt.NDArray[float]] = {}
keys = jax.random.split(rng_key, len(nodes))

for node, key in zip(nodes, keys, strict=False):
values[node.label] = node.sample(values, samples, rng_key=key)
return values[outcome_node_label]


def expectation(
graph: Graph,
outcome_node_label: str,
samples: int,
*,
rng_key: jax.Array,
) -> float:
"""Estimate the expectation of (a random variable attached to) a node in a graph."""
return moment(1, graph, outcome_node_label, samples, rng_key=rng_key)


def standard_deviation(
graph: Graph,
outcome_node_label: str,
samples: int,
*,
rng_key: jax.Array,
rng_key_first_moment: jax.Array | None = None,
) -> float:
"""Estimate the standard deviation of (a RV attached to) a node in a graph."""
return (
moment(2, graph, outcome_node_label, samples, rng_key=rng_key)
- moment(
1,
graph,
outcome_node_label,
samples,
rng_key=rng_key if rng_key_first_moment is None else rng_key_first_moment,
)
** 2
) ** 0.5


def moment(
order: int,
graph: Graph,
outcome_node_label: str,
samples: int,
*,
rng_key: jax.Array,
) -> float:
"""Estimate a moment of (a random variable attached to) a node in a graph."""
return (
sum(sample(graph, outcome_node_label, samples, rng_key=rng_key) ** order)
/ samples
)
21 changes: 11 additions & 10 deletions src/causalprog/graph/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,19 +169,20 @@ def sample(
samples: int,
rng_key: jax.Array,
) -> npt.NDArray[float]:
d = self._dist(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again this is more a note rather than something that needs to be done in this PR; but now that we're sticking to numpyro we can avoid implementing this sequential sampling down the nodes we've got going right now, and instead use Graph.model to generate the samples for us.

# Pass in node values derived from construction so far
**{
native_name: sampled_dependencies[node_name]
for native_name, node_name in self.parameters.items()
},
# Pass in any constant parameters this node sets
**self.constant_parameters,
)
return numpyro.sample(
self.label,
self._dist(
# Pass in node values derived from construction so far
**{
native_name: sampled_dependencies[node_name]
for native_name, node_name in self.parameters.items()
},
# Pass in any constant parameters this node sets
**self.constant_parameters,
),
d,
rng_key=rng_key,
sample_shape=(samples,),
sample_shape=(samples,) if d.batch_shape == () and samples > 1 else (),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, see my comment above. Using the Graph.model will save us having to do things like this.

)

@override
Expand Down
132 changes: 64 additions & 68 deletions tests/fixtures/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,36 @@


@pytest.fixture
def normal_graph() -> Callable[[NormalGraphNodes | None], Graph]:
"""Creates a 3-node graph:
def normal_graph() -> Callable[[float, float], Graph]:
"""Creates a graph with one normal distribution X.

mean (P) cov (P)
|---> outcome <----|

where outcome is a normal distribution.

Parameter nodes are initialised with no `value` set.
Parameter nodes are included if no values are given for the mean and covariance.
"""

def _inner(normal_graph_nodes: NormalGraphNodes | None = None) -> Graph:
if normal_graph_nodes is None:
normal_graph_nodes = {
"mean": ParameterNode(label="mean"),
"cov": ParameterNode(label="cov"),
"outcome": DistributionNode(
Normal,
label="outcome",
parameters={"loc": "mean", "scale": "std"},
),
}
graph = Graph(label="normal dist")
graph.add_edge(normal_graph_nodes["mean"], normal_graph_nodes["outcome"])
graph.add_edge(normal_graph_nodes["cov"], normal_graph_nodes["outcome"])
def _inner(mean: float | None = None, cov: float | None = None):
graph = Graph(label="normal_graph")
parameters = {}
constant_parameters = {}
if mean is None:
graph.add_node(ParameterNode(label="mean"))
parameters["loc"] = "mean"
else:
constant_parameters["loc"] = mean
if cov is None:
graph.add_node(ParameterNode(label="cov"))
parameters["scale"] = "cov"
else:
constant_parameters["scale"] = cov
graph.add_node(
DistributionNode(
Normal,
label="X",
parameters=parameters,
constant_parameters=constant_parameters,
)
)
for node in parameters.values():
graph.add_edge(node, "X")
return graph

return _inner
Expand All @@ -56,65 +61,56 @@ def two_normal_graph() -> Callable[[float, float, float], Graph]:
where UX is a normal distribution with mean `mean` and covariance `cov`, and X is
a normal distrubution with mean UX and covariance `cov2`.

"""

def _inner(mean: float = 5.0, cov: float = 1.0, cov2: float = 1.0) -> Graph:
graph = Graph(label="G0")
graph.add_node(
DistributionNode(
Normal,
label="UX",
constant_parameters={"loc": mean, "scale": cov**2},
)
)
graph.add_node(
DistributionNode(
Normal,
label="X",
parameters={"loc": "UX"},
constant_parameters={"scale": cov2**2},
)
)
graph.add_edge("UX", "X")

return graph

return _inner


@pytest.fixture
def two_normal_graph_parametrized_mean() -> Callable[[float], Graph]:
"""Creates a graph:
SDUX SDX
| |
V v
mu_x --> UX --> X

where UX is a normal distribution with mean mu_x and covariance `co`, and X is
a normal distrubution with mean UX and covariance nu_y.
Parameter nodes are included if no values are given for the mean and covariances.

"""

def _inner(cov: float = 1.0) -> Graph:
graph = Graph(label="G0")
graph.add_node(ParameterNode(label="nu_y"))
graph.add_node(ParameterNode(label="mu_x"))
def _inner(
mean: float | None = None, cov: float | None = None, cov2: float | None = None
) -> Graph:
graph = Graph(label="two_normal_graph")

x_parameters = {"loc": "UX"}
x_constant_parameters = {}
ux_parameters = {}
ux_constant_parameters = {}
if mean is None:
graph.add_node(ParameterNode(label="mean"))
ux_parameters["loc"] = "mean"
else:
ux_constant_parameters["loc"] = mean
if cov is None:
graph.add_node(ParameterNode(label="cov"))
ux_parameters["scale"] = "cov"
else:
ux_constant_parameters["scale"] = cov
if cov2 is None:
graph.add_node(ParameterNode(label="cov2"))
x_parameters["scale"] = "cov2"
else:
x_constant_parameters["scale"] = cov2

graph.add_node(
DistributionNode(
Normal,
label="UX",
parameters={"loc": "mu_x"},
constant_parameters={"scale": cov},
parameters=ux_parameters,
constant_parameters=ux_constant_parameters,
)
)
graph.add_node(
DistributionNode(
Normal,
label="X",
parameters={"loc": "UX", "scale": "nu_y"},
parameters=x_parameters,
constant_parameters=x_constant_parameters,
)
)
graph.add_edge("UX", "X")
for node in ux_parameters.values():
graph.add_edge(node, "UX")
for node in x_parameters.values():
graph.add_edge(node, "X")

return graph

Expand All @@ -125,9 +121,9 @@ def _inner(cov: float = 1.0) -> Graph:
def two_normal_graph_expected_model() -> Callable[..., dict[str, npt.ArrayLike]]:
"""Creates the model that the two_normal_graph should produce."""

def _inner(mu_x: float, nu_y: float) -> dict[str, npt.ArrayLike]:
ux = numpyro.sample("UX", Normal(loc=mu_x, scale=1.0))
x = numpyro.sample("X", Normal(loc=ux, scale=nu_y))
def _inner(mean: float, cov2: float) -> dict[str, npt.ArrayLike]:
ux = numpyro.sample("UX", Normal(loc=mean, scale=1.0))
x = numpyro.sample("X", Normal(loc=ux, scale=cov2))

return {"X": x, "UX": ux}

Expand Down
12 changes: 3 additions & 9 deletions tests/test_causal_problem/test_callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def which(request: pytest.FixtureRequest) -> Literal["causal_estimand", "constra
"graph_argument": "g",
},
0.0,
# Empirical calculation with 1000 samples with fixture RNG key
# should give 1.8808e-2 as the empirical expectation.
2.0e-2,
3.0e-2,
id="E[x], infer association",
),
pytest.param(
Expand All @@ -106,9 +104,7 @@ def which(request: pytest.FixtureRequest) -> Literal["causal_estimand", "constra
},
# x has fixed std 1, and nu_y will be set to 1.
1.0**2 + 1.0**2,
# Empirical calculation with 1000 samples with fixture RNG key
# should give 1.8506 as the empirical std of y.
2.0e-1,
3.0e-1,
id="Var[y]",
),
pytest.param(
Expand All @@ -120,9 +116,7 @@ def which(request: pytest.FixtureRequest) -> Literal["causal_estimand", "constra
},
# As per the previous test cases
jnp.array([0.0, 1.0**2 + 1.0**2]),
# As per the above cases, both components should be within
# 2.0e-1 of the analytical value.
jnp.array([2.0e-2, 2.0e-1]),
jnp.array([3.0e-2, 2.0e-1]),
id="E[x], Var[y]",
),
],
Expand Down
Loading