diff --git a/doc/source/apiref/resamplers.rst b/doc/source/apiref/resamplers.rst index f81c3cc..f6bd070 100644 --- a/doc/source/apiref/resamplers.rst +++ b/doc/source/apiref/resamplers.rst @@ -20,6 +20,16 @@ algorithm as the effective sample size is reduced, *resampling* is used to adaptively move particles so as to better represent the posterior distribution. **QInfer** allows for such algorithms to be specified in a modular way. +:class:`Resampler` - Abstract base class for resampling algorithms +------------------------------------------------------------------ + +Class Reference +~~~~~~~~~~~~~~~ + +.. autoclass:: Resampler + :members: + :special-members: __call__ + :class:`LiuWestResampler` - Liu and West (2000) resampling algorithm --------------------------------------------------------------------- @@ -28,3 +38,4 @@ Class Reference .. autoclass:: LiuWestResampler :members: + :special-members: __call__ diff --git a/src/qinfer/resamplers.py b/src/qinfer/resamplers.py index 16e2aa6..f691e08 100644 --- a/src/qinfer/resamplers.py +++ b/src/qinfer/resamplers.py @@ -32,6 +32,7 @@ # We use __all__ to restrict what globals are visible to external modules. __all__ = [ + 'Resampler', 'LiuWestResampler' ] @@ -43,6 +44,9 @@ from .utils import outer_product, particle_meanfn, particle_covariance_mtx +from abc import ABCMeta, abstractmethod, abstractproperty +from future.utils import with_metaclass + import qinfer.clustering from qinfer._exceptions import ResamplerWarning, ResamplerError @@ -54,6 +58,35 @@ ## CLASSES #################################################################### +class Resampler(with_metaclass(ABCMeta, object)): + @abstractmethod + def __call__(self, model, particle_weights, particle_locations, + n_particles=None, + precomputed_mean=None, precomputed_cov=None + ): + """ + Resample the particles given by ``particle_weights`` and + ``particle_locations``, drawing ``n_particles`` new particles. + + :param Model model: Model from which the particles are drawn, + used to define the valid region for resampling. + :param np.ndarray particle_weights: Weights of each particle, + represented as an array of shape ``(n_original_particles, )`` + and dtype :obj:`float`. + :param np.ndarray particle_locations: Locations of each particle, + represented as an array of shape ``(n_original_particles, + model.n_modelparams)`` and dtype :obj:`float`. + :param int n_particles: Number of new particles to draw, or + `None` to draw the same number as the original distribution. + :param np.ndarray precomputed_mean: Mean of the original + distribution, or `None` if this should be computed by the resampler. + :param np.ndarray precomputed_cov: Covariance of the original + distribution, or `None` if this should be computed by the resampler. + + :return np.ndarray new_weights: Weights of each new particle. + :return np.ndarray new_locations: Locations of each new particle. + """ + class ClusteringResampler(object): r""" Creates a resampler that breaks the particles into clusters, then applies @@ -127,7 +160,7 @@ def __call__(self, model, particle_weights, particle_locations): return new_weights, new_locs -class LiuWestResampler(object): +class LiuWestResampler(Resampler): r""" Creates a resampler instance that applies the algorithm of [LW01]_ to redistribute the particles.