Skip to content

Commit

Permalink
Added and documented resampler abc.
Browse files Browse the repository at this point in the history
  • Loading branch information
cgranade committed Sep 23, 2016
1 parent d46aca0 commit 2afc8d4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
11 changes: 11 additions & 0 deletions doc/source/apiref/resamplers.rst
Expand Up @@ -20,6 +20,16 @@ algorithm as the effective sample size is reduced, *resampling* is used to
adaptively move particles so as to better represent the posterior distribution.
**QInfer** allows for such algorithms to be specified in a modular way.

:class:`Resampler` - Abstract base class for resampling algorithms
------------------------------------------------------------------

Class Reference
~~~~~~~~~~~~~~~

.. autoclass:: Resampler
:members:
:special-members: __call__

:class:`LiuWestResampler` - Liu and West (2000) resampling algorithm
---------------------------------------------------------------------

Expand All @@ -28,3 +38,4 @@ Class Reference

.. autoclass:: LiuWestResampler
:members:
:special-members: __call__
35 changes: 34 additions & 1 deletion src/qinfer/resamplers.py
Expand Up @@ -32,6 +32,7 @@

# We use __all__ to restrict what globals are visible to external modules.
__all__ = [
'Resampler',
'LiuWestResampler'
]

Expand All @@ -43,6 +44,9 @@

from .utils import outer_product, particle_meanfn, particle_covariance_mtx

from abc import ABCMeta, abstractmethod, abstractproperty
from future.utils import with_metaclass

import qinfer.clustering
from qinfer._exceptions import ResamplerWarning, ResamplerError

Expand All @@ -54,6 +58,35 @@

## CLASSES ####################################################################

class Resampler(with_metaclass(ABCMeta, object)):
@abstractmethod
def __call__(self, model, particle_weights, particle_locations,
n_particles=None,
precomputed_mean=None, precomputed_cov=None
):
"""
Resample the particles given by ``particle_weights`` and
``particle_locations``, drawing ``n_particles`` new particles.
:param Model model: Model from which the particles are drawn,
used to define the valid region for resampling.
:param np.ndarray particle_weights: Weights of each particle,
represented as an array of shape ``(n_original_particles, )``
and dtype :obj:`float`.
:param np.ndarray particle_locations: Locations of each particle,
represented as an array of shape ``(n_original_particles,
model.n_modelparams)`` and dtype :obj:`float`.
:param int n_particles: Number of new particles to draw, or
`None` to draw the same number as the original distribution.
:param np.ndarray precomputed_mean: Mean of the original
distribution, or `None` if this should be computed by the resampler.
:param np.ndarray precomputed_cov: Covariance of the original
distribution, or `None` if this should be computed by the resampler.
:return np.ndarray new_weights: Weights of each new particle.
:return np.ndarray new_locations: Locations of each new particle.
"""

class ClusteringResampler(object):
r"""
Creates a resampler that breaks the particles into clusters, then applies
Expand Down Expand Up @@ -127,7 +160,7 @@ def __call__(self, model, particle_weights, particle_locations):

return new_weights, new_locs

class LiuWestResampler(object):
class LiuWestResampler(Resampler):
r"""
Creates a resampler instance that applies the algorithm of
[LW01]_ to redistribute the particles.
Expand Down

1 comment on commit 2afc8d4

@taalexander
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

Please sign in to comment.