diff --git a/src/aspire/basis/__init__.py b/src/aspire/basis/__init__.py index 8f292a31bc..48e645ecc4 100644 --- a/src/aspire/basis/__init__.py +++ b/src/aspire/basis/__init__.py @@ -2,6 +2,7 @@ # isort: off from .basis import Basis, Coef, ComplexCoef +from .dirac import DiracBasis, DiracBasis2D, DiracBasis3D from .steerable import SteerableBasis2D from .fb import FBBasisMixin diff --git a/src/aspire/basis/dirac.py b/src/aspire/basis/dirac.py new file mode 100644 index 0000000000..7a40304ea9 --- /dev/null +++ b/src/aspire/basis/dirac.py @@ -0,0 +1,105 @@ +import logging + +import numpy as np + +from aspire.basis import Basis + +logger = logging.getLogger(__name__) + + +class DiracBasis(Basis): + """ + Dirac basis in 1D. + + Supports subclassing higher dimensions. + """ + + dim = 1 + + def __init__(self, size, mask=None, dtype=np.float32): + """ + Initialize Dirac basis. + + :param size: The shape defining the basis. May be a tuple + or an integer, in which case, a uniformly sized basis is assumed. + :param mask: A boolean mask matching `size` indicating which + coordinates to include in the basis. Default `None` + implies all via `np.full((size,)*dimension, True)`. + :return: DiracBasis2D instance. + """ + + # Size + if isinstance(size, int): + size = (size,) * self.dim + + # Masking + if mask is None: + mask = np.full(size, True) + if mask.shape != size: + raise ValueError(f"Invalid mask size. Should match {size} or `None`.") + # Ensure boolean mask + self.mask = np.array(mask, dtype=bool) + + super().__init__(size, dtype=dtype) + + def _build(self): + """Private method building basis internals.""" + self.count = np.count_nonzero(self.mask) + + def _evaluate(self, v): + """ + Evaluate stack of standard coordinate coefficients from Dirac basis. + + :param v: Dirac basis coefficents. [..., self.count] + :return: Standard basis coefficients. [..., *self.sz] + """ + + # Initialize zeros array of standard basis size. + x = np.zeros((v.shape[0], *self.sz), dtype=self.dtype) + + # Assign basis coefficient values + x[..., self.mask] = v + + return x + + def expand(self, x): + """ + See _evaluate. + """ + return self.evaluate_t(x) + + def _evaluate_t(self, x): + """ + Evaluate stack of Dirac basis coefficients from standard basis. + + :param x: Standard basis coefficients. [..., *self.sz] + :return: Dirac basis coefficents. [..., self.count] + """ + + # Initialize zeros array of dirac basis (mask) count. + v = np.zeros((x.shape[0], self.count), dtype=self.dtype) + + # Assign basis coefficient values + v = x[..., self.mask] + + return v + + +class DiracBasis2D(DiracBasis): + """ + Dirac basis in 2D. + + See `DiracBasis` documentation. + """ + + dim = 2 + + +class DiracBasis3D(DiracBasis): + """ + Dirac basis in 3D. + + See `DiracBasis` documentation. + """ + + dim = 3 diff --git a/tests/test_dirac_basis.py b/tests/test_dirac_basis.py new file mode 100644 index 0000000000..c0210eae44 --- /dev/null +++ b/tests/test_dirac_basis.py @@ -0,0 +1,109 @@ +import logging + +import numpy as np +import pytest + +from aspire.basis import DiracBasis2D, DiracBasis3D +from aspire.reconstruction import MeanEstimator +from aspire.source import Simulation +from aspire.utils import grid_2d, grid_3d + +logger = logging.getLogger(__name__) + + +SIZES = [31, 32] +DTYPES = [np.float32, np.float64] +MASKS = [None, 16] # will be created in `mask` fixture +DIMS = [2, 3] + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +@pytest.fixture(params=SIZES, ids=lambda x: f"size={x}", scope="module") +def size(request): + return request.param + + +@pytest.fixture(params=DIMS, ids=lambda x: f"dim={x}", scope="module") +def dim(request): + return request.param + + +@pytest.fixture(params=MASKS, ids=lambda x: f"mask={x}", scope="module") +def mask(request, size, dim): + mask = request.param + # When provided a mask radius ... + if mask is not None: + if dim == 2: + grid = grid_2d + elif dim == 3: + grid = grid_3d + + # ... compute mask of `size` < radius + mask = grid(size, normalized=False)["r"] < mask + + return mask + + +@pytest.fixture +def basis(size, dtype, mask, dim): + # get size, agnostic to 1d, 2d `size` + s = np.atleast_1d(size)[0] + + if mask is not None: + mask = np.pad(mask, s - mask.shape[-1]) # standard basis implicitly square + mask = mask[:s, :s] # crop to `size` (mainly for odd) + + if dim == 2: + dirac_basis = DiracBasis2D + elif dim == 3: + dirac_basis = DiracBasis3D + + basis = dirac_basis(size, mask=mask, dtype=dtype) + + return basis + + +def test_roundtrip(basis, mask): + + # basis._cls is `Image` or `Volume` + x = basis._cls(np.random.random(basis.sz).astype(basis.dtype)) + + coef = basis.evaluate_t(x) + _x = basis.evaluate(coef) + + if mask is not None: + # Mask case + ref = x * mask + np.testing.assert_allclose(_x, ref) + # Negated mask joined with outer values should all be zero + assert np.all((ref * ~mask).asnumpy() == 0) + else: + np.testing.assert_allclose(_x, x) + + _coef = basis.expand(_x) + np.testing.assert_allclose(_coef, coef) + + +def test_dirac_mean_vol_est(size, dtype): + """ + Test the DiracBasis3D passes through MeanEstimator. + """ + + basis = DiracBasis3D(size, dtype=dtype) + + src = Simulation( + n=300, + L=size, + C=1, + dtype=dtype, + offsets=0, + amplitudes=1, + ) + + est_vol = MeanEstimator(src, basis=basis).estimate() + + np.testing.assert_array_less(np.mean(src.vols - est_vol), 1e-5)