-
Notifications
You must be signed in to change notification settings - Fork 0
Add moment function #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0afe05f
221f79c
a6c04b7
278a631
1e77f8f
83aca68
3d50a90
6cf500d
5a18ca5
fca9fad
15ac295
ef08b45
2291b6b
4779188
2276f0d
e8e1911
1762b4f
b431a69
b30169b
cc70daa
bf1cbcb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| """Algorithms.""" | ||
|
|
||
| from .do import do | ||
| from .expectation import expectation, standard_deviation | ||
| from .moments import expectation, moment, standard_deviation |
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| """Algorithms for estimating the expectation and standard deviation.""" | ||
|
|
||
| import jax | ||
| import numpy.typing as npt | ||
|
|
||
| from causalprog.graph import Graph | ||
|
|
||
|
|
||
| def sample( | ||
| graph: Graph, | ||
| outcome_node_label: str, | ||
| samples: int, | ||
| *, | ||
| rng_key: jax.Array, | ||
| ) -> npt.NDArray[float]: | ||
| """Sample data from (a random variable attached to) a node in a graph.""" | ||
| nodes = graph.roots_down_to_outcome(outcome_node_label) | ||
|
|
||
| values: dict[str, npt.NDArray[float]] = {} | ||
| keys = jax.random.split(rng_key, len(nodes)) | ||
|
|
||
| for node, key in zip(nodes, keys, strict=False): | ||
| values[node.label] = node.sample(values, samples, rng_key=key) | ||
| return values[outcome_node_label] | ||
|
|
||
|
|
||
| def expectation( | ||
| graph: Graph, | ||
| outcome_node_label: str, | ||
| samples: int, | ||
| *, | ||
| rng_key: jax.Array, | ||
| ) -> float: | ||
| """Estimate the expectation of (a random variable attached to) a node in a graph.""" | ||
| return moment(1, graph, outcome_node_label, samples, rng_key=rng_key) | ||
|
|
||
|
|
||
| def standard_deviation( | ||
| graph: Graph, | ||
| outcome_node_label: str, | ||
| samples: int, | ||
| *, | ||
| rng_key: jax.Array, | ||
| rng_key_first_moment: jax.Array | None = None, | ||
| ) -> float: | ||
| """Estimate the standard deviation of (a RV attached to) a node in a graph.""" | ||
| return ( | ||
| moment(2, graph, outcome_node_label, samples, rng_key=rng_key) | ||
| - moment( | ||
| 1, | ||
| graph, | ||
| outcome_node_label, | ||
| samples, | ||
| rng_key=rng_key if rng_key_first_moment is None else rng_key_first_moment, | ||
| ) | ||
| ** 2 | ||
| ) ** 0.5 | ||
|
|
||
|
|
||
| def moment( | ||
| order: int, | ||
| graph: Graph, | ||
| outcome_node_label: str, | ||
| samples: int, | ||
| *, | ||
| rng_key: jax.Array, | ||
| ) -> float: | ||
| """Estimate a moment of (a random variable attached to) a node in a graph.""" | ||
| return ( | ||
| sum(sample(graph, outcome_node_label, samples, rng_key=rng_key) ** order) | ||
| / samples | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -169,19 +169,20 @@ def sample( | |
| samples: int, | ||
| rng_key: jax.Array, | ||
| ) -> npt.NDArray[float]: | ||
| d = self._dist( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again this is more a note rather than something that needs to be done in this PR; but now that we're sticking to |
||
| # Pass in node values derived from construction so far | ||
| **{ | ||
| native_name: sampled_dependencies[node_name] | ||
| for native_name, node_name in self.parameters.items() | ||
| }, | ||
| # Pass in any constant parameters this node sets | ||
| **self.constant_parameters, | ||
| ) | ||
| return numpyro.sample( | ||
| self.label, | ||
| self._dist( | ||
| # Pass in node values derived from construction so far | ||
| **{ | ||
| native_name: sampled_dependencies[node_name] | ||
| for native_name, node_name in self.parameters.items() | ||
| }, | ||
| # Pass in any constant parameters this node sets | ||
| **self.constant_parameters, | ||
| ), | ||
| d, | ||
| rng_key=rng_key, | ||
| sample_shape=(samples,), | ||
| sample_shape=(samples,) if d.batch_shape == () and samples > 1 else (), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise, see my comment above. Using the |
||
| ) | ||
|
|
||
| @override | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.