Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 120 additions & 40 deletions src/pyrecest/filters/partitioned_so3_product_particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from collections.abc import Callable, Sequence

# pylint: disable=no-name-in-module,no-member,too-many-positional-arguments
from pyrecest.backend import all, array, exp, ndim, ones, random, stack, sum, to_numpy
from pyrecest.backend import all, array, log, ndim, ones, random, stack, sum, to_numpy
from pyrecest.distributions import SO3DiracDistribution
from pyrecest.distributions._so3_helpers import geodesic_distance

from .so3_product_particle_filter import SO3ProductParticleFilter

Expand Down Expand Up @@ -290,10 +289,44 @@ def update_with_block_likelihoods(
if not all(likelihood_values >= 0.0):
raise ValueError("likelihood values must be nonnegative.")

return self.update_with_block_log_likelihoods(
log(likelihood_values),
resample=resample,
ess_threshold=ess_threshold,
)

def update_with_block_log_likelihoods(
self,
log_likelihood: Callable | Sequence,
measurement=None,
resample: bool = True,
ess_threshold=None,
):
"""Update block weights from block log-likelihoods.

The log-likelihood must evaluate to an array shaped
``(n_blocks, n_particles)``. Each row updates the corresponding block's
weights independently using log-sum-exp normalization.
"""
if callable(log_likelihood):
if measurement is None:
log_likelihood_values = log_likelihood(self.particles)
else:
log_likelihood_values = log_likelihood(measurement, self.particles)
else:
log_likelihood_values = log_likelihood
log_likelihood_values = array(log_likelihood_values, dtype=float)
if log_likelihood_values.shape != (len(self.partition), self.n_particles):
raise ValueError(
"block log-likelihoods must have shape "
f"({len(self.partition)}, {self.n_particles})."
)

self._block_weights = stack(
[
self._normalize_weights(
self._block_weights[block_idx] * likelihood_values[block_idx]
self._normalize_log_weights(
log(self._block_weights[block_idx])
+ log_likelihood_values[block_idx]
)
for block_idx in range(len(self.partition))
],
Expand Down Expand Up @@ -330,16 +363,73 @@ def update_with_component_likelihoods(
if not all(component_likelihoods >= 0.0):
raise ValueError("likelihood values must be nonnegative.")

block_likelihoods = []
return self.update_with_component_log_likelihoods(
log(component_likelihoods),
resample=resample,
ess_threshold=ess_threshold,
)

def update_with_component_log_likelihoods(
self,
component_log_likelihoods,
*,
resample: bool = True,
ess_threshold=None,
):
"""Update from per-component log-likelihoods shaped ``(n_particles, K)``."""
component_log_likelihoods = array(component_log_likelihoods, dtype=float)
if component_log_likelihoods.shape != (self.n_particles, self.num_rotations):
raise ValueError(
"component_log_likelihoods must have shape "
f"({self.n_particles}, {self.num_rotations})."
)

block_log_likelihoods = []
for block in self.partition:
block_likelihood = ones(self.n_particles)
for component_idx in block:
block_likelihood = (
block_likelihood * component_likelihoods[:, component_idx]
block_log_likelihoods.append(
sum(
stack(
[
component_log_likelihoods[:, component_idx]
for component_idx in block
],
axis=1,
),
axis=1,
)
block_likelihoods.append(block_likelihood)
return self.update_with_block_likelihoods(
stack(block_likelihoods, axis=0),
)
return self.update_with_block_log_likelihoods(
stack(block_log_likelihoods, axis=0),
resample=resample,
ess_threshold=ess_threshold,
)

def update_with_geodesic_log_likelihood(
self,
measurement,
noise_std=None,
*,
component_noise_std=None,
mask=None,
confidence=None,
max_noise_std=None,
confidence_exponent: float = 1.0,
outlier_prob: float = 0.0,
resample: bool = True,
ess_threshold=None,
):
"""Update partition weights with masked component geodesic log-likelihoods."""
return self.update_with_component_log_likelihoods(
self.component_geodesic_log_likelihood(
measurement,
noise_std,
component_noise_std=component_noise_std,
mask=mask,
confidence=confidence,
max_noise_std=max_noise_std,
confidence_exponent=confidence_exponent,
outlier_prob=outlier_prob,
),
resample=resample,
ess_threshold=ess_threshold,
)
Expand All @@ -349,39 +439,29 @@ def update_with_geodesic_likelihood(
measurement,
noise_std,
*,
component_noise_std=None,
mask=None,
confidence=None,
max_noise_std=None,
confidence_exponent: float = 1.0,
outlier_prob: float = 0.0,
resample: bool = True,
ess_threshold=None,
):
"""Update with isotropic masked geodesic likelihoods per partition block."""
if noise_std <= 0.0:
raise ValueError("noise_std must be positive.")

measurement = self._as_product_point(measurement, self.num_rotations)
if mask is None:
mask = ones(self.num_rotations)
else:
mask = array(mask, dtype=float)
if mask.shape != (self.num_rotations,):
raise ValueError("mask must have shape (num_rotations,).")
"""Update with masked geodesic likelihoods per partition block.

distances = stack(
[
geodesic_distance(self.particles[:, i, :], measurement[i, :])
for i in range(self.num_rotations)
],
axis=1,
)
block_likelihoods = []
for block in self.partition:
quadratic_terms = stack(
[mask[i] * distances[:, i] ** 2 for i in block],
axis=1,
)
quadratic = sum(quadratic_terms, axis=1) / (noise_std**2)
block_likelihoods.append(exp(-0.5 * quadratic))
return self.update_with_block_likelihoods(
stack(block_likelihoods, axis=0),
This preserves the existing likelihood-space API while delegating to the
log-likelihood implementation for numerical stability.
"""
return self.update_with_geodesic_log_likelihood(
measurement,
noise_std,
component_noise_std=component_noise_std,
mask=mask,
confidence=confidence,
max_noise_std=max_noise_std,
confidence_exponent=confidence_exponent,
outlier_prob=outlier_prob,
resample=resample,
ess_threshold=ess_threshold,
)
Loading
Loading