Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/aspire/basis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# isort: off

from .basis import Basis, Coef, ComplexCoef
from .dirac import DiracBasis, DiracBasis2D, DiracBasis3D
from .steerable import SteerableBasis2D
from .fb import FBBasisMixin

Expand Down
105 changes: 105 additions & 0 deletions src/aspire/basis/dirac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import logging

import numpy as np

from aspire.basis import Basis

logger = logging.getLogger(__name__)


class DiracBasis(Basis):
"""
Dirac basis in 1D.

Supports subclassing higher dimensions.
"""

dim = 1

def __init__(self, size, mask=None, dtype=np.float32):
"""
Initialize Dirac basis.

:param size: The shape defining the basis. May be a tuple
or an integer, in which case, a uniformly sized basis is assumed.
:param mask: A boolean mask matching `size` indicating which
coordinates to include in the basis. Default `None`
implies all via `np.full((size,)*dimension, True)`.
:return: DiracBasis2D instance.
"""

# Size
if isinstance(size, int):
size = (size,) * self.dim

# Masking
if mask is None:
mask = np.full(size, True)
if mask.shape != size:
raise ValueError(f"Invalid mask size. Should match {size} or `None`.")
# Ensure boolean mask
self.mask = np.array(mask, dtype=bool)

super().__init__(size, dtype=dtype)

def _build(self):
"""Private method building basis internals."""
self.count = np.count_nonzero(self.mask)

def _evaluate(self, v):
"""
Evaluate stack of standard coordinate coefficients from Dirac basis.

:param v: Dirac basis coefficents. [..., self.count]
:return: Standard basis coefficients. [..., *self.sz]
"""

# Initialize zeros array of standard basis size.
x = np.zeros((v.shape[0], *self.sz), dtype=self.dtype)

# Assign basis coefficient values
x[..., self.mask] = v

return x

def expand(self, x):
"""
See _evaluate.
"""
return self.evaluate_t(x)

def _evaluate_t(self, x):
"""
Evaluate stack of Dirac basis coefficients from standard basis.

:param x: Standard basis coefficients. [..., *self.sz]
:return: Dirac basis coefficents. [..., self.count]
"""

# Initialize zeros array of dirac basis (mask) count.
v = np.zeros((x.shape[0], self.count), dtype=self.dtype)

# Assign basis coefficient values
v = x[..., self.mask]

return v


class DiracBasis2D(DiracBasis):
"""
Dirac basis in 2D.

See `DiracBasis` documentation.
"""

dim = 2


class DiracBasis3D(DiracBasis):
"""
Dirac basis in 3D.

See `DiracBasis` documentation.
"""

dim = 3
109 changes: 109 additions & 0 deletions tests/test_dirac_basis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import logging

import numpy as np
import pytest

from aspire.basis import DiracBasis2D, DiracBasis3D
from aspire.reconstruction import MeanEstimator
from aspire.source import Simulation
from aspire.utils import grid_2d, grid_3d

logger = logging.getLogger(__name__)


SIZES = [31, 32]
DTYPES = [np.float32, np.float64]
MASKS = [None, 16] # will be created in `mask` fixture
DIMS = [2, 3]


@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module")
def dtype(request):
return request.param


@pytest.fixture(params=SIZES, ids=lambda x: f"size={x}", scope="module")
def size(request):
return request.param


@pytest.fixture(params=DIMS, ids=lambda x: f"dim={x}", scope="module")
def dim(request):
return request.param


@pytest.fixture(params=MASKS, ids=lambda x: f"mask={x}", scope="module")
def mask(request, size, dim):
mask = request.param
# When provided a mask radius ...
if mask is not None:
if dim == 2:
grid = grid_2d
elif dim == 3:
grid = grid_3d

# ... compute mask of `size` < radius
mask = grid(size, normalized=False)["r"] < mask

return mask


@pytest.fixture
def basis(size, dtype, mask, dim):
# get size, agnostic to 1d, 2d `size`
s = np.atleast_1d(size)[0]

if mask is not None:
mask = np.pad(mask, s - mask.shape[-1]) # standard basis implicitly square
mask = mask[:s, :s] # crop to `size` (mainly for odd)

if dim == 2:
dirac_basis = DiracBasis2D
elif dim == 3:
dirac_basis = DiracBasis3D

basis = dirac_basis(size, mask=mask, dtype=dtype)

return basis


def test_roundtrip(basis, mask):

# basis._cls is `Image` or `Volume`
x = basis._cls(np.random.random(basis.sz).astype(basis.dtype))

coef = basis.evaluate_t(x)
_x = basis.evaluate(coef)

if mask is not None:
# Mask case
ref = x * mask
np.testing.assert_allclose(_x, ref)
# Negated mask joined with outer values should all be zero
assert np.all((ref * ~mask).asnumpy() == 0)
else:
np.testing.assert_allclose(_x, x)

_coef = basis.expand(_x)
np.testing.assert_allclose(_coef, coef)


def test_dirac_mean_vol_est(size, dtype):
"""
Test the DiracBasis3D passes through MeanEstimator.
"""

basis = DiracBasis3D(size, dtype=dtype)

src = Simulation(
n=300,
L=size,
C=1,
dtype=dtype,
offsets=0,
amplitudes=1,
)

est_vol = MeanEstimator(src, basis=basis).estimate()

np.testing.assert_array_less(np.mean(src.vols - est_vol), 1e-5)
Loading