diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 787725df83..81589b1472 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -12,6 +12,7 @@ from aspire.nufft import anufft from aspire.numeric import fft, xp from aspire.utils import FourierRingCorrelation, anorm, crop_pad_2d, grid_2d +from aspire.volume import SymmetryGroup logger = logging.getLogger(__name__) @@ -489,13 +490,16 @@ def size(self): # probably not needed, transition return np.size(self._data) - def backproject(self, rot_matrices): + def backproject(self, rot_matrices, symmetry_group=None): """ - Backproject images along rotation + Backproject images along rotations. If a symmetry group is provided, images + used in back-projection are duplicated (boosted) for symmetric viewing directions. + Note, it is assumed that a main axis of symmetry aligns with the z-axis. - :param im: An Image (stack) to backproject. - :param rot_matrices: An n-by-3-by-3 array of rotation matrices \ - corresponding to viewing directions. + :param rot_matrices: An n-by-3-by-3 array of rotation matrices + corresponding to viewing directions. + :param symmetry_group: A SymmetryGroup instance or string indicating symmetry, ie. "C3". + If supplied, uses symmetry to increase number of images used in back-projection. :return: Volume instance corresonding to the backprojected images. """ @@ -511,12 +515,21 @@ def backproject(self, rot_matrices): self.n_images == rot_matrices.shape[0] ), "Number of rotation matrices must match the number of images" - # TODO: rotated_grids might as well give us correctly shaped array in the first place - pts_rot = aspire.volume.rotated_grids(L, rot_matrices).astype( - self.dtype, copy=False - ) - pts_rot = pts_rot.reshape((3, -1)) + # Get symmetry rotations from SymmetryGroup. + if symmetry_group is None: + symmetry_rots = np.eye(3, dtype=self.dtype)[None] + else: + if isinstance(symmetry_group, str): + symmetry_group = SymmetryGroup.from_string( + symmetry_group, dtype=self.dtype + ) + if not isinstance(symmetry_group, SymmetryGroup): + raise TypeError( + f"`symmetry_group` must be a `SymmetryGroup` instance. Found {type(symmetry_group)}." + ) + symmetry_rots = symmetry_group.matrices + # Compute Fourier transform of images. im_f = xp.asnumpy(fft.centered_fft2(xp.asarray(self._data))) / (L**2) if L % 2 == 0: im_f[:, 0, :] = 0 @@ -524,9 +537,22 @@ def backproject(self, rot_matrices): im_f = im_f.flatten() - vol = anufft(im_f, pts_rot[::-1], (L, L, L), real=True) / L + # Backproject. Apply boosting by looping over symmetry rotations. + vol = np.zeros((L, L, L), dtype=self.dtype) + for sym_rot in symmetry_rots: + rotations = sym_rot @ rot_matrices + + # TODO: rotated_grids might as well give us correctly shaped array in the first place + pts_rot = aspire.volume.rotated_grids(L, rotations).astype( + self.dtype, copy=False + ) + pts_rot = pts_rot.reshape((3, -1)) + + vol += anufft(im_f, pts_rot[::-1], (L, L, L), real=True) + + vol /= L - return aspire.volume.Volume(vol) + return aspire.volume.Volume(vol, symmetry_group=symmetry_group) def show(self, columns=5, figsize=(20, 10), colorbar=True): """ diff --git a/src/aspire/reconstruction/estimator.py b/src/aspire/reconstruction/estimator.py index 9e0e561229..b0832be174 100644 --- a/src/aspire/reconstruction/estimator.py +++ b/src/aspire/reconstruction/estimator.py @@ -18,6 +18,7 @@ def __init__( checkpoint_iterations=10, checkpoint_prefix="volume_checkpoint", maxiter=100, + boost=True, ): """ An object representing a 2*L-by-2*L-by-2*L array containing the non-centered Fourier transform of the mean @@ -44,6 +45,8 @@ def __init__( before returning. This should be used in conjunction with `checkpoint_iterations` to prevent excessive disk usage. `None` disables. + :param boost: Option to use `src` symmetry to boost number of images used for mean estimation (Boolean). + Default of `True` employs symmetry boosting. """ self.src = src @@ -51,6 +54,7 @@ def __init__( self.dtype = self.src.dtype self.batch_size = batch_size self.preconditioner = preconditioner + self.boost = boost # dtype configuration if not self.dtype == self.basis.dtype: diff --git a/src/aspire/reconstruction/mean.py b/src/aspire/reconstruction/mean.py index d25d914276..760bca3e3d 100644 --- a/src/aspire/reconstruction/mean.py +++ b/src/aspire/reconstruction/mean.py @@ -90,15 +90,14 @@ def _compute_kernel(self): # Note, because we're iteratively summing it is critical we zero this array. kernel = np.zeros((self.r, self.r, _2L, _2L, _2L), dtype=self.dtype) + # Handle symmetry boosting. + sym_rots = np.eye(3, dtype=self.dtype)[None] + if self.boost: + sym_rots = self.src.symmetry_group.matrices + for i in range(0, self.src.n, self.batch_size): _range = np.arange(i, min(self.src.n, i + self.batch_size), dtype=int) - - pts_rot = rotated_grids(self.src.L, self.src.rotations[_range, :, :]) - pts_rot = pts_rot.reshape((3, -1)) - assert pts_rot.dtype == self.dtype - sq_filters_f = evaluate_src_filters_on_grid(self.src, _range) ** 2 - amplitudes_sq = (self.src.amplitudes[_range] ** 2).astype( self.dtype, copy=False ) @@ -117,11 +116,19 @@ def _compute_kernel(self): weights = np.transpose(weights, (2, 0, 1)).flatten() - batch_kernel = ( - 1 - / (self.r * self.src.L**4) - * anufft(weights, pts_rot[::-1], (_2L, _2L, _2L), real=True) - ) + # Apply boosting. + batch_kernel = np.zeros((_2L, _2L, _2L), dtype=self.dtype) + for sym_rot in sym_rots: + rotations = sym_rot @ self.src.rotations[_range] + pts_rot = rotated_grids(self.src.L, rotations) + pts_rot = pts_rot.reshape((3, -1)) + + batch_kernel += ( + 1 + / (self.r * self.src.L**4) + * anufft(weights, pts_rot[::-1], (_2L, _2L, _2L), real=True) + ) + kernel[k, j] += batch_kernel # r x r symmetric @@ -153,6 +160,12 @@ def src_backward(self): :return: The adjoint mapping applied to the images, averaged over the whole dataset and expressed as coefficients of `basis`. """ + # Handle symmetry boosting. + symmetry_group = None + sym_order = 1 + if self.boost: + symmetry_group = self.src.symmetry_group + sym_order = len(symmetry_group.matrices) # src_vols_wt_backward vol_rhs = Volume( @@ -163,12 +176,15 @@ def src_backward(self): for k in range(self.r): im = self.src.images[i : i + self.batch_size] - batch_vol_rhs = ( - self.src.im_backward(im, i, self.weights[:, k]) / self.src.n - ) + batch_vol_rhs = self.src.im_backward( + im, + i, + self.weights[:, k], + symmetry_group=symmetry_group, + ) / (self.src.n * sym_order) vol_rhs[k] += batch_vol_rhs.astype(self.dtype) - res = np.sqrt(self.src.n) * self.basis.evaluate_t(vol_rhs) + res = np.sqrt(self.src.n * sym_order) * self.basis.evaluate_t(vol_rhs) logger.info(f"Determined weighted adjoint mappings. Shape = {res.shape}") return res @@ -281,7 +297,10 @@ class MeanEstimator(WeightedVolumesEstimator): """ def __init__(self, src, basis, **kwargs): - weights = np.ones((src.n, 1)) / np.sqrt(src.n) + # Note, Handle boosting by adjusting weights based on symmetric order. + weights = np.ones((src.n, 1)) / np.sqrt( + src.n * len(src.symmetry_group.matrices) + ) super().__init__(weights, src, basis, **kwargs) def __getattr__(self, name): diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index aa127bfb66..dfed0825f7 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -940,14 +940,16 @@ def normalize_background(self, bg_radius=1.0, do_ramp=True): LambdaXform(normalize_bg, bg_radius=bg_radius, do_ramp=do_ramp) ) - def im_backward(self, im, start, weights=None): + def im_backward(self, im, start, weights=None, symmetry_group=None): """ Apply adjoint mapping to set of images :param im: An Image instance to which we wish to apply the adjoint of the forward model. :param start: Start index of image to consider :param weights: Optional vector of weights to apply to images. - Weights should be length `self.n`. + Weights should be length `self.n`. + :param symmetry_group: A SymmetryGroup instance. If supplied, uses symmetry to increase + number of images used in back-projectioon. :return: An L-by-L-by-L volume containing the sum of the adjoint mappings applied to the start+num-1 images. """ num = im.n_images @@ -960,7 +962,9 @@ def im_backward(self, im, start, weights=None): if weights is not None: im *= weights[all_idx, np.newaxis, np.newaxis] - vol = im.backproject(self.rotations[start : start + num, :, :])[0] + vol = im.backproject( + self.rotations[start : start + num, :, :], symmetry_group=symmetry_group + )[0] return vol diff --git a/tests/test_image.py b/tests/test_image.py index 75e511cb3e..ab616d04e6 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -10,7 +10,8 @@ from scipy.datasets import face from aspire.image import Image -from aspire.utils import powerset, utest_tolerance +from aspire.utils import Rotation, powerset, utest_tolerance +from aspire.volume import CnSymmetryGroup from .test_utils import matplotlib_dry_run @@ -296,6 +297,30 @@ def testShow(): im.show() +def test_backproject_symmetry_group(): + """ + Test backproject SymmetryGroup pass through and error message. + """ + ary = np.random.random((5, 8, 8)) + im = Image(ary) + rots = Rotation.generate_random_rotations(5).matrices + + # Attempt backproject with bad symmetry group. + not_a_symmetry_group = [] + with raises(TypeError, match=r"`symmetry_group` must be a `SymmetryGroup`"): + _ = im.backproject(rots, symmetry_group=not_a_symmetry_group) + + # Symmetry from string. + vol = im.backproject(rots, symmetry_group="C3") + assert isinstance(vol.symmetry_group, CnSymmetryGroup) + + # Symmetry from instance. + vol = im.backproject( + rots, symmetry_group=CnSymmetryGroup(order=3, dtype=np.float32) + ) + assert isinstance(vol.symmetry_group, CnSymmetryGroup) + + def test_asnumpy_readonly(): """ Attempting assignment should raise an error. diff --git a/tests/test_mean_estimator_boosting.py b/tests/test_mean_estimator_boosting.py new file mode 100644 index 0000000000..005c0597af --- /dev/null +++ b/tests/test_mean_estimator_boosting.py @@ -0,0 +1,201 @@ +import numpy as np +import pytest + +from aspire.basis import FFBBasis3D +from aspire.reconstruction import MeanEstimator, WeightedVolumesEstimator +from aspire.source import ArrayImageSource, Simulation +from aspire.utils import Rotation, utest_tolerance +from aspire.volume import ( + AsymmetricVolume, + CnSymmetricVolume, + DnSymmetricVolume, + OSymmetricVolume, + TSymmetricVolume, +) + +SEED = 23 + +RESOLUTION = [ + 32, + pytest.param(33, marks=pytest.mark.expensive), +] + +DTYPE = [ + np.float32, + pytest.param(np.float64, marks=pytest.mark.expensive), +] + +# Symmetric volume parameters, (volume_type, symmetric_order). +VOL_PARAMS = [ + (AsymmetricVolume, None), + (CnSymmetricVolume, 4), + (CnSymmetricVolume, 5), + (DnSymmetricVolume, 2), + pytest.param((TSymmetricVolume, None), marks=pytest.mark.expensive), + pytest.param((OSymmetricVolume, None), marks=pytest.mark.expensive), +] + + +# Fixtures. +@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}", scope="module") +def resolution(request): + return request.param + + +@pytest.fixture(params=DTYPE, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +@pytest.fixture( + params=VOL_PARAMS, ids=lambda x: f"volume={x[0]}, order={x[1]}", scope="module" +) +def volume(request, resolution, dtype): + Volume, order = request.param + vol_kwargs = dict( + L=resolution, + C=1, + seed=SEED, + dtype=dtype, + ) + if order: + vol_kwargs["order"] = order + + return Volume(**vol_kwargs).generate() + + +@pytest.fixture(scope="module") +def source(volume): + src = Simulation( + n=200, + vols=volume, + offsets=0, + amplitudes=1, + seed=SEED, + dtype=volume.dtype, + ) + src = src.cache() # precompute images + + return src + + +@pytest.fixture(scope="module") +def estimated_volume(source): + basis = FFBBasis3D(source.L, dtype=source.dtype) + estimator = MeanEstimator(source, basis) + estimated_volume = estimator.estimate() + + return estimated_volume + + +# Weighted volume fixture. Only tesing C1, C4, and C5. +@pytest.fixture( + params=VOL_PARAMS[:3], ids=lambda x: f"volume={x[0]}, order={x[1]}", scope="module" +) +def weighted_volume(request, resolution, dtype): + Volume, order = request.param + vol_kwargs = dict( + L=resolution, + C=2, + seed=SEED, + dtype=dtype, + ) + if order: + vol_kwargs["order"] = order + + return Volume(**vol_kwargs).generate() + + +@pytest.fixture(scope="module") +def weighted_source(weighted_volume): + src = Simulation( + n=400, + vols=weighted_volume, + offsets=0, + amplitudes=1, + seed=SEED, + dtype=weighted_volume.dtype, + ) + + return src + + +# MeanEstimator Tests. +def test_fsc(source, estimated_volume): + """Compare estimated volume to source volume with FSC.""" + # Fourier Shell Correlation + fsc_resolution, fsc = source.vols.fsc(estimated_volume, pixel_size=1, cutoff=0.5) + + # Check that resolution is less than 2.1 pixels. + np.testing.assert_array_less(fsc_resolution, 2.1) + + # Check that second to last correlation value is high (>.90). + np.testing.assert_array_less(0.90, fsc[0, -2]) + + +def test_mse(source, estimated_volume): + """Check the mean-squared error between source and estimated volumes.""" + mse = np.mean((source.vols.asnumpy() - estimated_volume.asnumpy()) ** 2) + np.testing.assert_allclose(mse, 0, atol=1e-3) + + +def test_total_energy(source, estimated_volume): + """Test that energy is preserved in reconstructed volume.""" + og_total_energy = np.sum(source.vols) + recon_total_energy = np.sum(estimated_volume) + np.testing.assert_allclose(og_total_energy, recon_total_energy, rtol=1e-3) + + +def test_boost_flag(source, estimated_volume): + """Manually boost a source and reconstruct without boosting.""" + ims = source.projections[:] + rots = source.rotations + sym_rots = source.symmetry_group.matrices + sym_order = len(sym_rots) + + # Manually boosted images and rotations. + ims_boosted = np.tile(ims, (sym_order, 1, 1)) + rots_boosted = np.zeros((sym_order * source.n, 3, 3), dtype=source.dtype) + for i, sym_rot in enumerate(sym_rots): + rots_boosted[i * source.n : (i + 1) * source.n] = sym_rot @ rots + rots_boosted = Rotation(rots_boosted) + + # Manually boosted source. + boosted_source = ArrayImageSource(ims_boosted, angles=rots_boosted.angles) + + # Estimate volume with boosting OFF. + basis = FFBBasis3D(boosted_source.L, dtype=boosted_source.dtype) + estimator = MeanEstimator(boosted_source, basis, boost=False) + est_vol = estimator.estimate() + + # Check reconstructions are equal. + mse = np.mean((estimated_volume.asnumpy() - est_vol.asnumpy()) ** 2) + np.testing.assert_allclose(mse, 0, atol=utest_tolerance(source.dtype)) + + +# WeightVolumesEstimator Tests. +def test_weighted_volumes(weighted_source): + """ + Test WeightedVolumeEstimator reconstructs multiple volumes using symmetry boosting. + """ + src = weighted_source + + # Use source states to assign weights to volumes. + weights = np.zeros((src.n, src.C), dtype=src.dtype) + weights[:, 0] = abs(src.states - 1.99) # sends states [1, 2] to weights [.99, .01] + weights[:, 1] = 1 - weights[:, 0] # sets weights for states [1, 2] as [.01, .99] + + # Scale weights + n0 = np.count_nonzero(src.states == 1) # number of images from vol[0] + n1 = np.count_nonzero(src.states == 2) # number of images from vol[1] + weights[:, 0] = weights[:, 0] / weights[:, 0].sum() * np.sqrt(n0) + weights[:, 1] = weights[:, 1] / weights[:, 1].sum() * np.sqrt(n1) + + # Initialize estimator. + basis = FFBBasis3D(src.L, dtype=src.dtype) + estimator = WeightedVolumesEstimator(src=src, basis=basis, weights=weights) + est_vols = estimator.estimate() + + # Check FSC (scaling may not be close enough to match mse) + _, corr = src.vols.fsc(est_vols) + np.testing.assert_array_less(0.91, corr[:, -2])