-
Notifications
You must be signed in to change notification settings - Fork 0
Description
We previously used a re-parametrisation trick for drawing samples from a normal distribution. Namely, it is possible to sample from a
This transformation works even when
As such, we should work this back into the codebase, and at the same time compartmentalize the behaviour of the sample, expectation, and similar methods.
Remove the DistributionFamily class
Though we initially thought we would need it, in fact it seems we do not need (and will not need) the functionality the DistributionFamily (and subclasses) provide. We can instead just give DistributionNode._dist a Distribution class, and replace statements like Node._dist.construct(**) with Node._dist(**).
We should do this just to ensure the changes in the next stages don't have to straddle more classes than necessary.
Separate sample behaviour
Currently, Node.sample is responsible for sampling "up" the graph, then constructing the resulting distributions for the node itself, and then drawing and returning the samples. We would like to separate this behaviour, delegating the actual sample-drawing process to the Distribution class (so we can then implement "speed-ups" in sampling).
Note that we also might want to do something similar in the future for expectation and more generally, computing moments of distributions.
Ideally, we would have that:
Node.sampleis told values of (dependent) parameters that have already been sampled (currently we pass these in adict), as well as the number of samples to draw and the rng key to use.Node.samplecombines the values of the dependent parameters with it's constant parameters, and passes these values into it'sself._distobject to actually do the sampling: pseudo-codeself._dist(**relevant_parameters).sample(n_samples, rng_key).
Distribution.sampledraws the actual samples. In the base class, this should be done in the same manner it is now (draw one sample at a time), however we will later overwrite this for certain distributions in the final portion of this issue.
Sampling via Parameterisations
We should change the Normal distribution class we provide to sample in the manner described at the start of this issue. IE To draw samples from a fixed (parameter independent) standard normal, and transform these via vector / matrix arithmetic. This would require overwriting the .sample method in the appropriate class.
Parametrisation could be taken further, particularly because it will have use when providing derivatives to optimisation functions (and it could be abstracted in the base class to allow the trick to be used for other distributions with this trick). But for the scope of this issue we will keep this as a follow-on issue once everything else is done.