Skip to content

Re-introduce parametrised sampling #47

@willGraham01

Description

@willGraham01

We previously used a re-parametrisation trick for drawing samples from a normal distribution. Namely, it is possible to sample from a $\mathcal{N}(\mu, \sigma^2)$ distribution by drawing samples $S$ from a $\mathcal{N}(0, 1)$ distribution and then transforming as $\mu + S * \sigma$.

This transformation works even when $\mu$ and $\sigma$ are also vectors corresponding to different parameter values (element-wise multiplication produces a suitable sample from each element).

As such, we should work this back into the codebase, and at the same time compartmentalize the behaviour of the sample, expectation, and similar methods.

Remove the DistributionFamily class

Though we initially thought we would need it, in fact it seems we do not need (and will not need) the functionality the DistributionFamily (and subclasses) provide. We can instead just give DistributionNode._dist a Distribution class, and replace statements like Node._dist.construct(**) with Node._dist(**).

We should do this just to ensure the changes in the next stages don't have to straddle more classes than necessary.

Separate sample behaviour

Currently, Node.sample is responsible for sampling "up" the graph, then constructing the resulting distributions for the node itself, and then drawing and returning the samples. We would like to separate this behaviour, delegating the actual sample-drawing process to the Distribution class (so we can then implement "speed-ups" in sampling).

Note that we also might want to do something similar in the future for expectation and more generally, computing moments of distributions.

Ideally, we would have that:

  • Node.sample is told values of (dependent) parameters that have already been sampled (currently we pass these in a dict), as well as the number of samples to draw and the rng key to use.
    • Node.sample combines the values of the dependent parameters with it's constant parameters, and passes these values into it's self._dist object to actually do the sampling: pseudo-code self._dist(**relevant_parameters).sample(n_samples, rng_key).
  • Distribution.sample draws the actual samples. In the base class, this should be done in the same manner it is now (draw one sample at a time), however we will later overwrite this for certain distributions in the final portion of this issue.

Sampling via Parameterisations

We should change the Normal distribution class we provide to sample in the manner described at the start of this issue. IE To draw samples from a fixed (parameter independent) standard normal, and transform these via vector / matrix arithmetic. This would require overwriting the .sample method in the appropriate class.

Parametrisation could be taken further, particularly because it will have use when providing derivatives to optimisation functions (and it could be abstracted in the base class to allow the trick to be used for other distributions with this trick). But for the scope of this issue we will keep this as a follow-on issue once everything else is done.

Sub-issues

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions