Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 65 additions & 5 deletions src/pyrecest/filters/abstract_particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,71 @@


class AbstractParticleFilter(AbstractFilter):
def __init__(self, initial_filter_state=None):
def __init__(
self,
initial_filter_state=None,
resampling_criterion: Callable | None = None,
):
AbstractFilter.__init__(self, initial_filter_state)
self.resampling_criterion = resampling_criterion

@property
def resampling_criterion(self):
"""Criterion deciding whether to resample after an update.

``None`` preserves the historical behavior and always resamples.
Otherwise, the callable receives the current weighted filter state and
must return a truthy value if the particle set should be resampled.
"""
return self._resampling_criterion

@resampling_criterion.setter
def resampling_criterion(self, criterion: Callable | None):
if criterion is not None and not callable(criterion):
raise TypeError("resampling_criterion must be callable or None")
self._resampling_criterion = criterion

def set_resampling_criterion(self, criterion: Callable | None):
"""Set the post-update resampling criterion and return the filter."""
self.resampling_criterion = criterion
return self

def should_resample(self) -> bool:
"""Return whether the current weighted particle set should resample.

The default criterion, ``None``, always returns ``True`` to retain the
previous update behavior.
"""
if self.resampling_criterion is None:
return True
return bool(self.resampling_criterion(self.filter_state))

def resample(self):
"""Manually resample particles according to their current weights.

The particle locations are sampled with replacement from the current
weighted particle set, and the resulting weights are reset to uniform.
"""
self._filter_state.d = self.filter_state.sample(
self.filter_state.w.shape[0]
)
self._filter_state.w = (
ones_like(self.filter_state.w) / self.filter_state.w.shape[0]
)
return self

def resample_if_needed(self) -> bool:
"""Resample if the configured criterion requests it.

Returns
-------
bool
``True`` if resampling was performed, otherwise ``False``.
"""
if self.should_resample():
self.resample()
return True
return False

def predict_identity(self, noise_distribution):
self.predict_nonlinear(
Expand Down Expand Up @@ -166,10 +229,7 @@ def update_nonlinear_using_likelihood(self, likelihood, measurement=None):
lambda x: likelihood(measurement, x)
)

self._filter_state.d = self.filter_state.sample(self.filter_state.w.shape[0])
self._filter_state.w = (
1 / self.filter_state.w.shape[0] * ones_like(self.filter_state.w)
)
self.resample_if_needed()

def association_likelihood(self, likelihood: AbstractManifoldSpecificDistribution):
likelihood_val = sum(likelihood.pdf(self.filter_state.d) * self.filter_state.w)
Expand Down
25 changes: 24 additions & 1 deletion tests/filters/test_euclidean_particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# pylint: disable=no-name-in-module,no-member
from pyrecest.backend import array, mean, ones, random, vstack, zeros, zeros_like
from pyrecest.distributions import GaussianDistribution
from pyrecest.distributions import GaussianDistribution, LinearDiracDistribution
from pyrecest.filters.euclidean_particle_filter import EuclideanParticleFilter


Expand Down Expand Up @@ -66,6 +66,29 @@ def test_predict_update_cycle_3d_forced_particle_pos_no_pred(self):
self.pf.get_point_estimate(), force_first_particle_pos, atol=0.2
)

def test_update_can_skip_automatic_resampling(self):
pf = EuclideanParticleFilter(n_particles=4, dim=1)
particles = array([[0.0], [1.0], [2.0], [3.0]])
pf.filter_state = LinearDiracDistribution(particles)
pf.set_resampling_criterion(lambda _state: False)

pf.update_nonlinear_using_likelihood(lambda x: x[:, 0] + 1.0)

npt.assert_allclose(pf.filter_state.d, particles)
npt.assert_allclose(pf.filter_state.w, array([0.1, 0.2, 0.3, 0.4]))

def test_manual_resample_resets_weights(self):
pf = EuclideanParticleFilter(n_particles=4, dim=1)
pf.filter_state = LinearDiracDistribution(
array([[0.0], [1.0], [2.0], [3.0]]),
array([0.0, 0.0, 0.0, 1.0]),
)

pf.resample()

npt.assert_allclose(pf.filter_state.d, array([[3.0], [3.0], [3.0], [3.0]]))
npt.assert_allclose(pf.filter_state.w, ones(4) / 4)


if __name__ == "__main__":
unittest.main()
Loading