-
Notifications
You must be signed in to change notification settings - Fork 0
Add moment function #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
willGraham01
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The moment and other functions implemented look good, so this is OK to merge. Though one thing that catches my attention is that we can now use numpyro to draw the samples from the Graph.model, rather than doing our current "sample down the tree" implementation. This should also fix our current issues with test speed too, but I'll open a separate issue for this.
That, and (from my understanding) if a ParameterNode does not have a current value set, the sample and thus other moment methods will throw an error if we try to use them? But this is a problem I think we can sidestep by switching to using Graph.model in the sample method instead.
Approving the PR because none of these problems are introduced in this PR (they're already in main since we're currently meshing numpyro into everything). We can fix these in another PR.
| samples: int, | ||
| rng_key: jax.Array, | ||
| ) -> npt.NDArray[float]: | ||
| d = self._dist( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again this is more a note rather than something that needs to be done in this PR; but now that we're sticking to numpyro we can avoid implementing this sequential sampling down the nodes we've got going right now, and instead use Graph.model to generate the samples for us.
| d, | ||
| rng_key=rng_key, | ||
| sample_shape=(samples,), | ||
| sample_shape=(samples,) if d.batch_shape == () and samples > 1 else (), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise, see my comment above. Using the Graph.model will save us having to do things like this.
|
(Disregard the above comment, I forgot that we already discussed this in #58 (comment) ). We're already planning to swap to doing an MCMC using |
* Barebones integration test * Add optax to test requirements * Docstring to actually give some help about the method * Typing * Add progress statements so I can debug * Satisfied with answers, purge prints * Adapt to changes from #70, which also fixes a bug with missing edges in the original graph
Follows on from #69, so we should merge that first.