Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/pyrecest/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@
from .piecewise_constant_filter import PiecewiseConstantFilter
from .random_matrix_tracker import RandomMatrixTracker
from .se2_ukf import SE2UKF
from .so3_grid_transition import (
quaternion_grid_transition_density,
so3_right_multiplication_grid_transition,
)
from .so3_product_particle_filter import SO3ProductParticleFilter
from .spherical_harmonics_eot_tracker import (
SphericalHarmonicsEOTTracker,
Expand Down Expand Up @@ -187,6 +191,7 @@
"RandomMatrixTracker",
"SCGPTracker",
"ScGpTracker",
"quaternion_grid_transition_density",
"Track",
"TrackManager",
"TrackManagerStepResult",
Expand All @@ -200,6 +205,7 @@
"SE2FilterMixin",
"SE2UKF",
"SO3ProductParticleFilter",
"so3_right_multiplication_grid_transition",
"SphericalHarmonicsEOTTracker",
"SphericalHarmonicsExtendedObjectTracker",
"StateSpaceSubdivisionFilter",
Expand Down
154 changes: 154 additions & 0 deletions src/pyrecest/filters/so3_grid_transition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Transition-density helpers for grid filters on SO(3)."""

# pylint: disable=no-name-in-module,no-member,redefined-builtin
from pyrecest.backend import (
abs,
all,
amax,
array,
clip,
exp,
linalg,
ndim,
sum,
transpose,
)
from pyrecest.distributions._so3_helpers import (
exp_map_identity,
normalize_quaternions,
quaternion_multiply,
)
from pyrecest.distributions.conditional.sd_half_cond_sd_half_grid_distribution import (
SdHalfCondSdHalfGridDistribution,
)
from pyrecest.distributions.hypersphere_subset.abstract_hypersphere_subset_distribution import (
AbstractHypersphereSubsetDistribution,
)


def so3_right_multiplication_grid_transition(
grid,
orientation_increment,
kappa,
) -> SdHalfCondSdHalfGridDistribution:
"""Build a soft grid transition for right-multiplicative SO(3) dynamics.

The returned conditional density represents

``q_next = q_current * delta_q``

on a scalar-last unit-quaternion grid. Columns condition on the current
grid point and rows correspond to the next grid point, i.e.
``grid_values[i, j] = f(grid[i] | grid[j])``. This matches
:meth:`HyperhemisphericalGridFilter.predict_nonlinear_via_transition_density`.

Parameters
----------
grid : array_like or object with ``get_grid()``
Quaternion grid of shape ``(n_grid, 4)``. Quaternions are interpreted
as scalar-last SO(3) representatives and canonicalized to the upper
S3 hemisphere.
orientation_increment : array_like
Either a tangent-vector increment of shape ``(3,)`` at the identity or
a scalar-last quaternion increment of shape ``(4,)``.
kappa : float
Positive concentration parameter. Larger values place more mass on the
grid point nearest to ``q_current * delta_q``.

Returns
-------
SdHalfCondSdHalfGridDistribution
Normalized conditional density on the same canonicalized quaternion
grid.

Notes
-----
The unnormalized score is proportional to

``exp(kappa * |<q_next, q_current * delta_q>|**2)``.

The columns are normalized by the S3 upper-hemisphere grid quadrature rule,
so ``mean(grid_values[:, j]) * surface(S3+) == 1`` for every column.
"""

if kappa <= 0.0:
raise ValueError("kappa must be positive.")

quaternion_grid = _as_quaternion_grid(grid)
delta_quaternion = _as_so3_increment(orientation_increment)

targets = quaternion_multiply(quaternion_grid, delta_quaternion)
inner_products = clip(abs(quaternion_grid @ transpose(targets)), 0.0, 1.0)

# Subtracting the per-column maximum keeps the normalization stable for
# large kappa without changing the normalized conditional density.
exponents = kappa * inner_products**2
scores = exp(exponents - amax(exponents, axis=0, keepdims=True))
density_values = scores / sum(scores, axis=0, keepdims=True)

manifold_size = (
0.5
* AbstractHypersphereSubsetDistribution.compute_unit_hypersphere_surface(
quaternion_grid.shape[1] - 1
)
)
density_values = density_values * (quaternion_grid.shape[0] / manifold_size)

return SdHalfCondSdHalfGridDistribution(
quaternion_grid,
density_values,
enforce_pdf_nonnegative=True,
)


def quaternion_grid_transition_density(
grid,
orientation_increment,
kappa,
) -> SdHalfCondSdHalfGridDistribution:
"""Alias for :func:`so3_right_multiplication_grid_transition`."""

return so3_right_multiplication_grid_transition(
grid,
orientation_increment,
kappa,
)


def _as_quaternion_grid(grid):
if hasattr(grid, "get_grid"):
grid = grid.get_grid()

quaternion_grid = array(grid, dtype=float)
if ndim(quaternion_grid) != 2 or quaternion_grid.shape[1] != 4:
raise ValueError("grid must have shape (n_grid, 4) with scalar-last quaternions.")
if quaternion_grid.shape[0] == 0:
raise ValueError("grid must contain at least one quaternion.")
if not all(linalg.norm(quaternion_grid, axis=1) > 0.0):
raise ValueError("grid quaternions must be nonzero.")

return normalize_quaternions(quaternion_grid)


def _as_so3_increment(orientation_increment):
values = array(orientation_increment, dtype=float)
if ndim(values) == 1:
if values.shape[0] == 3:
return exp_map_identity(values)[0]
if values.shape[0] == 4:
return normalize_quaternions(values)[0]
elif ndim(values) == 2 and values.shape[0] == 1:
if values.shape[1] == 3:
return exp_map_identity(values[0])[0]
if values.shape[1] == 4:
return normalize_quaternions(values)[0]

raise ValueError(
"orientation_increment must have shape (3,) tangent or (4,) quaternion."
)


__all__ = [
"quaternion_grid_transition_density",
"so3_right_multiplication_grid_transition",
]
172 changes: 172 additions & 0 deletions tests/filters/test_so3_grid_transition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import unittest

import pyrecest.backend
from pyrecest.backend import (
abs as backend_abs,
allclose,
arange,
argmax,
array,
array_equal,
diagonal,
max as backend_max,
mean,
sum as backend_sum,
transpose,
)
from pyrecest.distributions import SdHalfCondSdHalfGridDistribution
from pyrecest.distributions._so3_helpers import exp_map_identity, quaternion_multiply
from pyrecest.filters import (
HyperhemisphericalGridFilter,
quaternion_grid_transition_density,
so3_right_multiplication_grid_transition,
)


@unittest.skipIf(
pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member
reason="Not supported on JAX backend",
)
class TestSO3GridTransition(unittest.TestCase):
def setUp(self):
self.filter_ = HyperhemisphericalGridFilter(32, 3)
self.grid = self.filter_.filter_state.get_grid()
self.manifold_size = self.filter_.filter_state.get_manifold_size()

def test_returns_normalized_conditional_density(self):
transition = so3_right_multiplication_grid_transition(
self.grid,
array([0.0, 0.0, 0.0]),
24.0,
)

self.assertIsInstance(transition, SdHalfCondSdHalfGridDistribution)
self.assertTrue(allclose(transition.get_grid(), self.grid, atol=1e-12))
column_integrals = mean(transition.grid_values, axis=0) * self.manifold_size
self.assertTrue(allclose(column_integrals, 1.0, atol=1e-10))

def test_identity_increment_peaks_on_current_grid_cell(self):
transition = so3_right_multiplication_grid_transition(
self.grid,
array([0.0, 0.0, 0.0]),
80.0,
)

column_maxima = backend_max(transition.grid_values, axis=0)
self.assertTrue(
allclose(diagonal(transition.grid_values), column_maxima, atol=1e-12)
)

def test_nonzero_tangent_increment_peaks_at_rotated_grid_cell(self):
tangent_increment = array([0.7, 0.2, -0.1])
transition = so3_right_multiplication_grid_transition(
self.grid,
tangent_increment,
40.0,
)

delta_quaternion = exp_map_identity(tangent_increment)[0]
targets = quaternion_multiply(self.grid, delta_quaternion)
expected_indices = argmax(backend_abs(self.grid @ transpose(targets)), axis=0)
actual_indices = argmax(transition.grid_values, axis=0)

self.assertTrue(array_equal(actual_indices, expected_indices))
self.assertGreater(
int(backend_sum(actual_indices != arange(self.grid.shape[0]))),
0,
)

def test_accepts_tangent_and_quaternion_increments(self):
tangent_increment = array([0.1, -0.2, 0.3])
delta_quaternion = exp_map_identity(tangent_increment)[0]

transition_from_tangent = so3_right_multiplication_grid_transition(
self.grid,
tangent_increment,
18.0,
)
transition_from_quaternion = so3_right_multiplication_grid_transition(
self.grid,
delta_quaternion,
18.0,
)
transition_from_alias = quaternion_grid_transition_density(
self.grid,
tangent_increment,
18.0,
)

self.assertTrue(
allclose(
transition_from_tangent.grid_values,
transition_from_quaternion.grid_values,
atol=1e-12,
)
)
self.assertTrue(
allclose(
transition_from_tangent.grid_values,
transition_from_alias.grid_values,
atol=1e-12,
)
)

def test_antipodal_increment_and_grid_representatives_are_invariant(self):
delta_quaternion = exp_map_identity(array([0.2, -0.1, 0.3]))[0]
transition = so3_right_multiplication_grid_transition(
self.grid,
delta_quaternion,
20.0,
)
transition_from_antipodal_increment = so3_right_multiplication_grid_transition(
self.grid,
-delta_quaternion,
20.0,
)
transition_from_antipodal_grid = so3_right_multiplication_grid_transition(
-self.grid,
delta_quaternion,
20.0,
)

self.assertTrue(
allclose(
transition.grid_values,
transition_from_antipodal_increment.grid_values,
atol=1e-12,
)
)
self.assertTrue(
allclose(
transition.grid_values,
transition_from_antipodal_grid.grid_values,
atol=1e-12,
)
)
self.assertTrue(
allclose(transition_from_antipodal_grid.get_grid(), self.grid, atol=1e-12)
)

def test_rejects_invalid_inputs(self):
with self.assertRaises(ValueError):
so3_right_multiplication_grid_transition(
self.grid,
array([0.0, 0.0, 0.0]),
0.0,
)
with self.assertRaises(ValueError):
so3_right_multiplication_grid_transition(
self.grid[:, :3],
array([0.0, 0.0, 0.0]),
1.0,
)
with self.assertRaises(ValueError):
so3_right_multiplication_grid_transition(
self.grid,
array([0.0, 0.0]),
1.0,
)


if __name__ == "__main__":
unittest.main()
Loading