Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
36a3dbc
Initial boosting implementation.
j-c-c Jan 11, 2024
0a0b86c
Add boosting flag
j-c-c Jan 12, 2024
5f1d159
Boost the kernel.
j-c-c Jan 12, 2024
27c98cb
Add IdentityFilter for Simulation. Fixes issue 978.
j-c-c Jan 17, 2024
9059bc8
Adjust weights in MeanEstimator.
j-c-c Jan 17, 2024
93907a9
Add testing for boosted MeanEstimator.
j-c-c Jan 17, 2024
8ff66f0
Add to backproject docstring.
j-c-c Jan 17, 2024
4bdd0eb
SymmetryGroup Check in backproject. With test.
j-c-c Jan 17, 2024
19e139f
revert Simulation unique_filters.
j-c-c Jan 19, 2024
f94aec7
Temp fix for evaluate_src_filters_on_grid.
j-c-c Jan 19, 2024
15801e1
accept symmetry_group strings in backproject.
j-c-c Jan 19, 2024
6f54f61
Add mse test. Add docstring param.
j-c-c Jan 19, 2024
4f915fe
refactor mse
j-c-c Jan 19, 2024
fdd1462
backproject return vol with symmetry group.
j-c-c Jan 19, 2024
78c8d8d
Use tile instead of concatenate.
j-c-c Jan 22, 2024
27a3608
Add boost flag test.
j-c-c Jan 23, 2024
a5c50b8
Add WeightedVolumesEstimator boosting test.
j-c-c Jan 23, 2024
cbad44d
Apply backproject boosting in loop.
j-c-c Jan 26, 2024
2a7d9a7
Fix manually boosted rots in test_boost_flag.
j-c-c Jan 26, 2024
93e64fb
Boost in loop in compute_kernel.
j-c-c Jan 26, 2024
865c4d7
remove IdentityFilter in boosting test.
j-c-c Jan 26, 2024
c4a8c35
remove import.
j-c-c Jan 26, 2024
fadf481
reword docstring
j-c-c Jan 29, 2024
7835347
Add testing docstring
j-c-c Jan 29, 2024
3878cc5
Add missing docstring.
j-c-c Feb 1, 2024
9695418
remove unnecesary wildcard
j-c-c Feb 1, 2024
6a70ee0
Cache testing source.
j-c-c Feb 1, 2024
9f421fe
Use count instead of confusing math
j-c-c Feb 1, 2024
3a7e103
typo
j-c-c Feb 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions src/aspire/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from aspire.nufft import anufft
from aspire.numeric import fft, xp
from aspire.utils import FourierRingCorrelation, anorm, crop_pad_2d, grid_2d
from aspire.volume import SymmetryGroup

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -489,13 +490,16 @@ def size(self):
# probably not needed, transition
return np.size(self._data)

def backproject(self, rot_matrices):
def backproject(self, rot_matrices, symmetry_group=None):
"""
Backproject images along rotation
Backproject images along rotations. If a symmetry group is provided, images
used in back-projection are duplicated (boosted) for symmetric viewing directions.
Note, it is assumed that a main axis of symmetry aligns with the z-axis.

:param im: An Image (stack) to backproject.
:param rot_matrices: An n-by-3-by-3 array of rotation matrices \
corresponding to viewing directions.
:param rot_matrices: An n-by-3-by-3 array of rotation matrices
corresponding to viewing directions.
:param symmetry_group: A SymmetryGroup instance or string indicating symmetry, ie. "C3".
If supplied, uses symmetry to increase number of images used in back-projection.

:return: Volume instance corresonding to the backprojected images.
"""
Expand All @@ -511,22 +515,44 @@ def backproject(self, rot_matrices):
self.n_images == rot_matrices.shape[0]
), "Number of rotation matrices must match the number of images"

# TODO: rotated_grids might as well give us correctly shaped array in the first place
pts_rot = aspire.volume.rotated_grids(L, rot_matrices).astype(
self.dtype, copy=False
)
pts_rot = pts_rot.reshape((3, -1))
# Get symmetry rotations from SymmetryGroup.
if symmetry_group is None:
symmetry_rots = np.eye(3, dtype=self.dtype)[None]
else:
if isinstance(symmetry_group, str):
symmetry_group = SymmetryGroup.from_string(
symmetry_group, dtype=self.dtype
)
if not isinstance(symmetry_group, SymmetryGroup):
raise TypeError(
f"`symmetry_group` must be a `SymmetryGroup` instance. Found {type(symmetry_group)}."
)
symmetry_rots = symmetry_group.matrices

# Compute Fourier transform of images.
im_f = xp.asnumpy(fft.centered_fft2(xp.asarray(self._data))) / (L**2)
if L % 2 == 0:
im_f[:, 0, :] = 0
im_f[:, :, 0] = 0

im_f = im_f.flatten()

vol = anufft(im_f, pts_rot[::-1], (L, L, L), real=True) / L
# Backproject. Apply boosting by looping over symmetry rotations.
vol = np.zeros((L, L, L), dtype=self.dtype)
for sym_rot in symmetry_rots:
rotations = sym_rot @ rot_matrices

# TODO: rotated_grids might as well give us correctly shaped array in the first place
pts_rot = aspire.volume.rotated_grids(L, rotations).astype(
self.dtype, copy=False
)
pts_rot = pts_rot.reshape((3, -1))

vol += anufft(im_f, pts_rot[::-1], (L, L, L), real=True)

vol /= L

return aspire.volume.Volume(vol)
return aspire.volume.Volume(vol, symmetry_group=symmetry_group)

def show(self, columns=5, figsize=(20, 10), colorbar=True):
"""
Expand Down
4 changes: 4 additions & 0 deletions src/aspire/reconstruction/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
checkpoint_iterations=10,
checkpoint_prefix="volume_checkpoint",
maxiter=100,
boost=True,
):
"""
An object representing a 2*L-by-2*L-by-2*L array containing the non-centered Fourier transform of the mean
Expand All @@ -44,13 +45,16 @@ def __init__(
before returning. This should be used in conjunction with
`checkpoint_iterations` to prevent excessive disk usage.
`None` disables.
:param boost: Option to use `src` symmetry to boost number of images used for mean estimation (Boolean).
Default of `True` employs symmetry boosting.
"""

self.src = src
self.basis = basis
self.dtype = self.src.dtype
self.batch_size = batch_size
self.preconditioner = preconditioner
self.boost = boost

# dtype configuration
if not self.dtype == self.basis.dtype:
Expand Down
51 changes: 35 additions & 16 deletions src/aspire/reconstruction/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,14 @@ def _compute_kernel(self):
# Note, because we're iteratively summing it is critical we zero this array.
kernel = np.zeros((self.r, self.r, _2L, _2L, _2L), dtype=self.dtype)

# Handle symmetry boosting.
sym_rots = np.eye(3, dtype=self.dtype)[None]
if self.boost:
sym_rots = self.src.symmetry_group.matrices

for i in range(0, self.src.n, self.batch_size):
_range = np.arange(i, min(self.src.n, i + self.batch_size), dtype=int)

pts_rot = rotated_grids(self.src.L, self.src.rotations[_range, :, :])
pts_rot = pts_rot.reshape((3, -1))
assert pts_rot.dtype == self.dtype

sq_filters_f = evaluate_src_filters_on_grid(self.src, _range) ** 2

amplitudes_sq = (self.src.amplitudes[_range] ** 2).astype(
self.dtype, copy=False
)
Expand All @@ -117,11 +116,19 @@ def _compute_kernel(self):

weights = np.transpose(weights, (2, 0, 1)).flatten()

batch_kernel = (
1
/ (self.r * self.src.L**4)
* anufft(weights, pts_rot[::-1], (_2L, _2L, _2L), real=True)
)
# Apply boosting.
batch_kernel = np.zeros((_2L, _2L, _2L), dtype=self.dtype)
for sym_rot in sym_rots:
rotations = sym_rot @ self.src.rotations[_range]
pts_rot = rotated_grids(self.src.L, rotations)
pts_rot = pts_rot.reshape((3, -1))

batch_kernel += (
1
/ (self.r * self.src.L**4)
* anufft(weights, pts_rot[::-1], (_2L, _2L, _2L), real=True)
)

kernel[k, j] += batch_kernel

# r x r symmetric
Expand Down Expand Up @@ -153,6 +160,12 @@ def src_backward(self):
:return: The adjoint mapping applied to the images, averaged over the whole dataset and expressed
as coefficients of `basis`.
"""
# Handle symmetry boosting.
symmetry_group = None
sym_order = 1
if self.boost:
symmetry_group = self.src.symmetry_group
sym_order = len(symmetry_group.matrices)

# src_vols_wt_backward
vol_rhs = Volume(
Expand All @@ -163,12 +176,15 @@ def src_backward(self):
for k in range(self.r):
im = self.src.images[i : i + self.batch_size]

batch_vol_rhs = (
self.src.im_backward(im, i, self.weights[:, k]) / self.src.n
)
batch_vol_rhs = self.src.im_backward(
im,
i,
self.weights[:, k],
symmetry_group=symmetry_group,
) / (self.src.n * sym_order)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why apply 1/sym_order here instead of in backproject?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No reason other than it made sense to me at the time. This is where we are scaling by number of images and boosting increases the number of images. It also pairs nicely with scaling on line 187.

vol_rhs[k] += batch_vol_rhs.astype(self.dtype)

res = np.sqrt(self.src.n) * self.basis.evaluate_t(vol_rhs)
res = np.sqrt(self.src.n * sym_order) * self.basis.evaluate_t(vol_rhs)
logger.info(f"Determined weighted adjoint mappings. Shape = {res.shape}")

return res
Expand Down Expand Up @@ -281,7 +297,10 @@ class MeanEstimator(WeightedVolumesEstimator):
"""

def __init__(self, src, basis, **kwargs):
weights = np.ones((src.n, 1)) / np.sqrt(src.n)
# Note, Handle boosting by adjusting weights based on symmetric order.
weights = np.ones((src.n, 1)) / np.sqrt(
src.n * len(src.symmetry_group.matrices)
)
super().__init__(weights, src, basis, **kwargs)

def __getattr__(self, name):
Expand Down
10 changes: 7 additions & 3 deletions src/aspire/source/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,14 +940,16 @@ def normalize_background(self, bg_radius=1.0, do_ramp=True):
LambdaXform(normalize_bg, bg_radius=bg_radius, do_ramp=do_ramp)
)

def im_backward(self, im, start, weights=None):
def im_backward(self, im, start, weights=None, symmetry_group=None):
"""
Apply adjoint mapping to set of images

:param im: An Image instance to which we wish to apply the adjoint of the forward model.
:param start: Start index of image to consider
:param weights: Optional vector of weights to apply to images.
Weights should be length `self.n`.
Weights should be length `self.n`.
:param symmetry_group: A SymmetryGroup instance. If supplied, uses symmetry to increase
number of images used in back-projectioon.
:return: An L-by-L-by-L volume containing the sum of the adjoint mappings applied to the start+num-1 images.
"""
num = im.n_images
Expand All @@ -960,7 +962,9 @@ def im_backward(self, im, start, weights=None):
if weights is not None:
im *= weights[all_idx, np.newaxis, np.newaxis]

vol = im.backproject(self.rotations[start : start + num, :, :])[0]
vol = im.backproject(
self.rotations[start : start + num, :, :], symmetry_group=symmetry_group
)[0]

return vol

Expand Down
27 changes: 26 additions & 1 deletion tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from scipy.datasets import face

from aspire.image import Image
from aspire.utils import powerset, utest_tolerance
from aspire.utils import Rotation, powerset, utest_tolerance
from aspire.volume import CnSymmetryGroup

from .test_utils import matplotlib_dry_run

Expand Down Expand Up @@ -296,6 +297,30 @@ def testShow():
im.show()


def test_backproject_symmetry_group():
"""
Test backproject SymmetryGroup pass through and error message.
"""
ary = np.random.random((5, 8, 8))
im = Image(ary)
rots = Rotation.generate_random_rotations(5).matrices

# Attempt backproject with bad symmetry group.
not_a_symmetry_group = []
with raises(TypeError, match=r"`symmetry_group` must be a `SymmetryGroup`"):
_ = im.backproject(rots, symmetry_group=not_a_symmetry_group)

# Symmetry from string.
vol = im.backproject(rots, symmetry_group="C3")
assert isinstance(vol.symmetry_group, CnSymmetryGroup)

# Symmetry from instance.
vol = im.backproject(
rots, symmetry_group=CnSymmetryGroup(order=3, dtype=np.float32)
)
assert isinstance(vol.symmetry_group, CnSymmetryGroup)


def test_asnumpy_readonly():
"""
Attempting assignment should raise an error.
Expand Down
Loading