Skip to content

Conversation

@mscroggs
Copy link
Collaborator

@mscroggs mscroggs commented Aug 8, 2025

  • Add a moment function that estimates the moment of any order from a graph.
  • Make the expectation and standard_deviation functions use this
  • Change the default std dev in tests from 1, so that it is not equal to variance (!)
  • Use numpyro for sampling: progress towards NumPyro for Sampling #58
  • Make the output of Node.sample have the correct shape - this is progress towards fixing Out of memory error #74
  • Removes test that samples are within machine precision, as numpyro appears to not be using the rng_key in exactly the same was as jax was

Follows on from #69, so we should merge that first.

@mscroggs mscroggs requested a review from willGraham01 August 13, 2025 10:32
Copy link
Collaborator

@willGraham01 willGraham01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The moment and other functions implemented look good, so this is OK to merge. Though one thing that catches my attention is that we can now use numpyro to draw the samples from the Graph.model, rather than doing our current "sample down the tree" implementation. This should also fix our current issues with test speed too, but I'll open a separate issue for this.

That, and (from my understanding) if a ParameterNode does not have a current value set, the sample and thus other moment methods will throw an error if we try to use them? But this is a problem I think we can sidestep by switching to using Graph.model in the sample method instead.

Approving the PR because none of these problems are introduced in this PR (they're already in main since we're currently meshing numpyro into everything). We can fix these in another PR.

samples: int,
rng_key: jax.Array,
) -> npt.NDArray[float]:
d = self._dist(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again this is more a note rather than something that needs to be done in this PR; but now that we're sticking to numpyro we can avoid implementing this sequential sampling down the nodes we've got going right now, and instead use Graph.model to generate the samples for us.

d,
rng_key=rng_key,
sample_shape=(samples,),
sample_shape=(samples,) if d.batch_shape == () and samples > 1 else (),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, see my comment above. Using the Graph.model will save us having to do things like this.

@willGraham01
Copy link
Collaborator

willGraham01 commented Aug 14, 2025

(Disregard the above comment, I forgot that we already discussed this in #58 (comment) ). We're already planning to swap to doing an MCMC using Graph.model )

@mscroggs mscroggs enabled auto-merge (squash) August 14, 2025 10:19
@mscroggs mscroggs merged commit 8d21af5 into main Aug 14, 2025
5 checks passed
@mscroggs mscroggs deleted the mscroggs/moments branch August 14, 2025 10:21
willGraham01 added a commit that referenced this pull request Aug 14, 2025
@mscroggs mscroggs mentioned this pull request Aug 14, 2025
willGraham01 added a commit that referenced this pull request Aug 14, 2025
* Barebones integration test

* Add optax to test requirements

* Docstring to actually give some help about the method

* Typing

* Add progress statements so I can debug

* Satisfied with answers, purge prints

* Adapt to changes from #70, which also fixes a bug with missing edges in the original graph
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants