diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 8c21eff633..a7a044b60e 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -4,7 +4,7 @@ import numpy as np import scipy.sparse as sparse -from aspire.basis import PolarBasis2D +from aspire.operators import PolarFT from aspire.utils import common_line_from_rots from aspire.utils.random import choice @@ -69,27 +69,16 @@ def _build(self): imgs = self.src.images[:] - # Obtain coefficients in polar Fourier basis for input 2D images - self.basis = PolarBasis2D( + # Obtain coefficients of polar Fourier transform for input 2D images + self.pft = PolarFT( (self.n_res, self.n_res), self.n_rad, self.n_theta, dtype=self.dtype ) - self.pf = self.basis.evaluate_t(imgs) - self.pf = self.pf.reshape(self.n_img, self.n_theta, self.n_rad) + self.pf = self.pft.transform(imgs) - n_theta_half = self.n_theta // 2 - - # The last two dimension of pf is of size n_theta x n_rad. We will convert pf - # into an array of size (n_theta/2) x (n_rad-1), that is, take half of each ray - # through the origin except the DC part, and also take the angles only up to PI. - # This is due to the fact that the original images are real, and thus each ray - # is conjugate symmetric. We therefore gain nothing by taking longer correlations - # (of length 2*n_rad-1 instead of n_rad-1). In the Matlab version, pf is converted to - # the size of (n_theta/2) x (2*n_rad-1) but most of the calculations of build_clmatrix - # and estimate_shifts below only use the size of (n_theta/2) x (n_rad-1). In the - # Python version we will use the size of (n_theta/2) x (n_rad-1) directly and make - # sure every part is using it. By taking shorter correlations we can speed the - # computation by a factor of two. - self.pf = np.flip(self.pf[:, n_theta_half:, 1:], 2) + # We remove the DC the component. pf has size (n_img) x (n_theta/2) x (n_rad-1), + # with pf[:, :, 0] containing low frequency content and pf[:, :, -1] containing + # high frequency content. + self.pf = self.pf[:, :, 1:] def estimate_rotations(self): """ @@ -241,7 +230,7 @@ def estimate_shifts(self, equations_factor=1, max_memory=4000): show = True # Negative sign comes from using -i conversion of Fourier transformation est_shifts = sparse.linalg.lsqr(shift_equations, -shift_b, show=show)[0] - est_shifts = est_shifts.reshape((2, self.n_img), order="F") + est_shifts = est_shifts.reshape((self.n_img, 2)) return est_shifts @@ -273,6 +262,10 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): n_theta_half = self.n_theta // 2 n_img = self.n_img + + # `estimate_shifts()` requires that rotations have already been estimated. + if self.rotations is None: + self.estimate_rotations() rotations = self.rotations pf = self.pf.copy() @@ -444,14 +437,15 @@ def _generate_shift_phase_and_filter(self, r_max, max_shift, shift_step): # Number of shifts to try n_shifts = int(np.ceil(2 * max_shift / shift_step + 1)) - # only half of ray - rk = np.arange(-r_max, 0) + # only half of ray, excluding the DC component. + rk = np.arange(1, r_max + 1) # Generate all shift phases shifts = -max_shift + shift_step * np.arange(n_shifts) shift_phases = np.exp(np.outer(shifts, -2 * np.pi * 1j * rk / (2 * r_max + 1))) # Set filter for common-line detection h = np.sqrt(np.abs(rk)) * np.exp(-np.square(rk) / (2 * (r_max / 4) ** 2)) + return shifts, shift_phases, h def _generate_index_pairs(self, n_equations): @@ -512,7 +506,10 @@ def _apply_filter_and_norm(self, subscripts, pf, r_max, h): # Note if we'd rather not have the dtype and casting args, # we can control h.dtype instead. np.einsum(subscripts, pf, h, out=pf, dtype=pf.dtype, casting="same_kind") - pf[..., r_max - 1 : r_max + 2] = 0 + + # This is a high pass filter, cutting out the lowest frequency + # (DC has already been removed). + pf[..., 0] = 0 pf /= np.linalg.norm(pf, axis=-1)[..., np.newaxis] return pf diff --git a/src/aspire/abinitio/commonline_c3_c4.py b/src/aspire/abinitio/commonline_c3_c4.py index 59f2599b07..7ddfbde7cd 100644 --- a/src/aspire/abinitio/commonline_c3_c4.py +++ b/src/aspire/abinitio/commonline_c3_c4.py @@ -4,6 +4,7 @@ from numpy.linalg import eigh, norm, svd from aspire.abinitio import CLOrient3D, SyncVotingMixin +from aspire.operators import PolarFT from aspire.utils import ( J_conjugate, Rotation, @@ -294,7 +295,7 @@ def _estimate_inplane_rotations(self, vis): # Reconstruct the full polar Fourier for use in correlation. self.pf only consists of # rays in the range [180, 360), with shape (n_img, n_theta//2, n_rad-1). - pf = np.concatenate((pf, np.conj(pf)), axis=1) + pf = PolarFT.half_to_full(pf) # Normalize rays. pf /= norm(pf, axis=-1)[..., np.newaxis] @@ -444,7 +445,7 @@ def _self_clmatrix_c3_c4(self): # Reconstruct the full polar Fourier for use in correlation. self.pf only consists of # rays in the range [180, 360), with shape (n_img, n_theta//2, n_rad-1). - pf_full = np.concatenate((pf, np.conj(pf)), axis=1) + pf_full = PolarFT.half_to_full(pf) # The self-common-lines matrix holds two indices per image that represent # the two self common-lines in the image. diff --git a/src/aspire/abinitio/commonline_cn.py b/src/aspire/abinitio/commonline_cn.py index abdfd1e62d..d267a367c1 100644 --- a/src/aspire/abinitio/commonline_cn.py +++ b/src/aspire/abinitio/commonline_cn.py @@ -4,6 +4,7 @@ from numpy.linalg import norm from aspire.abinitio import CLSymmetryC3C4 +from aspire.operators import PolarFT from aspire.utils import ( J_conjugate, Rotation, @@ -123,7 +124,7 @@ def _estimate_relative_viewing_directions(self): # Reconstruct full polar Fourier for use in correlation. pf /= norm(pf, axis=2)[..., np.newaxis] # Normalize each ray. - pf_full = np.concatenate((pf, np.conj(pf)), axis=1) + pf_full = PolarFT.half_to_full(pf) # Pre-compute shifted pf's. pf_shifted = (pf * shift_phases[:, None, None]).swapaxes(0, 1) diff --git a/src/aspire/basis/__init__.py b/src/aspire/basis/__init__.py index a482127338..7d57b57295 100644 --- a/src/aspire/basis/__init__.py +++ b/src/aspire/basis/__init__.py @@ -15,6 +15,5 @@ from .fpswf_2d import FPSWFBasis2D from .fpswf_3d import FPSWFBasis3D from .fspca import FSPCABasis -from .polar_2d import PolarBasis2D from .pswf_2d import PSWFBasis2D from .pswf_3d import PSWFBasis3D diff --git a/src/aspire/operators/__init__.py b/src/aspire/operators/__init__.py index fc6e2e09bc..97e47a4d1d 100644 --- a/src/aspire/operators/__init__.py +++ b/src/aspire/operators/__init__.py @@ -18,4 +18,5 @@ ZeroFilter, evaluate_src_filters_on_grid, ) +from .polar_ft import PolarFT from .wemd import wemd_embed, wemd_norm diff --git a/src/aspire/basis/polar_2d.py b/src/aspire/operators/polar_ft.py similarity index 51% rename from src/aspire/basis/polar_2d.py rename to src/aspire/operators/polar_ft.py index 5f29689030..48b08c4c0c 100644 --- a/src/aspire/basis/polar_2d.py +++ b/src/aspire/operators/polar_ft.py @@ -2,27 +2,29 @@ import numpy as np -from aspire.basis import Basis -from aspire.nufft import anufft, nufft +from aspire.image import Image +from aspire.nufft import nufft from aspire.utils import complex_type logger = logging.getLogger(__name__) -class PolarBasis2D(Basis): +class PolarFT: """ Define a derived class for polar Fourier representation for 2D images """ def __init__(self, size, nrad=None, ntheta=None, dtype=np.float32): """ - Initialize an object for the 2D polar Fourier grid class + Initialize an object for the polar Fourier transform class. `PolarFT` expects that + images are real and uses only half of the `ntheta` values. - :param size: The shape of the vectors for which to define the grid. + :param size: The shape of the vectors for which to define the transform. May be a 2-tuple or an integer, in which case a square basis is assumed. Currently only square images are supported. :param nrad: The number of points in the radial dimension. Default is resolution // 2. :param ntheta: The number of points in the angular dimension. Default is 8 * nrad. + :param dtype: dtype used to compute a polar frequency grid for evaluating the transform.. """ if isinstance(size, int): size = (size, size) @@ -30,10 +32,13 @@ def __init__(self, size, nrad=None, ntheta=None, dtype=np.float32): assert ndim == 2, "Only two-dimensional grids are supported." assert len(set(size)) == 1, "Only square domains are supported." + self.ndim = ndim + self.sz = size self.nrad = nrad self.ntheta = ntheta + self.dtype = dtype - super().__init__(size, dtype=dtype) + self._build() # this basis has complex coefficients self.coefficient_dtype = complex_type(self.dtype) @@ -45,7 +50,7 @@ def _build(self): logger.info("Represent 2D image in a polar Fourier grid") if self.nrad is None: - self.nrad = self.nres // 2 + self.nrad = self.sz[0] // 2 if self.ntheta is None: # try to use the same number as Fast FB basis @@ -56,7 +61,7 @@ def _build(self): logger.error(msg) raise NotImplementedError(msg) - self.count = self.nrad * self.ntheta + self.count = self.nrad * (self.ntheta // 2) self._sz_prod = self.sz[0] * self.sz[1] # precompute the basis functions in 2D grids @@ -64,49 +69,22 @@ def _build(self): def _precomp(self): """ - Precomute the polar Fourier grid + Precompute the polar Fourier grid. """ omega0 = 2 * np.pi / (2 * self.nrad - 1) dtheta = 2 * np.pi / self.ntheta # only need half size of ntheta - freqs = np.zeros((2, self.nrad * self.ntheta // 2), dtype=self.dtype) + freqs = np.zeros((2, self.ntheta // 2, self.nrad), dtype=self.dtype) for i in range(self.ntheta // 2): - freqs[0, i * self.nrad : (i + 1) * self.nrad] = np.arange( - self.nrad - ) * np.cos(i * dtheta) - freqs[1, i * self.nrad : (i + 1) * self.nrad] = np.arange( - self.nrad - ) * np.sin(i * dtheta) + freqs[0, i] = np.cos(i * dtheta) + freqs[1, i] = np.sin(i * dtheta) - freqs *= omega0 - return freqs + freqs *= omega0 * np.arange(self.nrad) - def _evaluate(self, v): - """ - Evaluate coefficients in standard 2D coordinate basis from those in polar Fourier basis + return freqs.reshape(2, -1) - :param v: A coefficient vector (or an array of coefficient vectors) - in polar Fourier basis to be evaluated. The last dimension must equal to - `self.count`. - :return x: Image instance in standard 2D coordinate basis with - resolution of `self.sz`. - """ - v = v.reshape(-1, self.ntheta, self.nrad) - - nimgs = v.shape[0] - - half_size = self.ntheta // 2 - - v = v[:, :half_size, :] + v[:, half_size:, :].conj() - - v = v.reshape(nimgs, self.nrad * half_size) - - x = anufft(v, self.freqs, self.sz, real=True) - - return x - - def _evaluate_t(self, x): + def transform(self, x): """ Evaluate coefficient in polar Fourier grid from those in standard 2D coordinate basis @@ -116,15 +94,46 @@ def _evaluate_t(self, x): Fourier grid. This is an array of vectors whose first dimension corresponds to `x.shape[0]`, and last dimension equals `self.count`. """ - nimgs = x.shape[0] - - half_size = self.ntheta // 2 - - pf = nufft(x, self.freqs) + if x.dtype != self.dtype: + raise TypeError( + f"{self.__class__.__name__}.transform" + f" Inconsistent dtypes x: {x.dtype} self: {self.dtype}" + ) + + if not isinstance(x, Image): + raise TypeError( + f"{self.__class__.__name__}.transform" + f" passed numpy array instead of {Image}." + ) + else: + x = x.asnumpy() + + # Flatten stack + stack_shape = x.shape[: -self.ndim] + x = x.reshape(-1, *x.shape[-self.ndim :]) + + # We expect the Image `x` to be real in order to take advantage of the conjugate + # symmetry of the Fourier transform of a real valued image. + if not np.isreal(x).all(): + raise TypeError( + f"The Image `x` must be real valued. Found dtype {x.dtype}." + ) + + resolution = x.shape[-1] + + pf = nufft(x, self.freqs) / resolution**2 + + return pf.reshape(*stack_shape, self.ntheta // 2, self.nrad) + + @staticmethod + def half_to_full(pf): + """ + Use the conjugate symmetry of pf to construct the full polar Fourier transform + over all rays in [0, 360). - pf = pf.reshape((nimgs, self.nrad, half_size)) - v = np.concatenate((pf, pf.conj()), axis=1) + :param pf: The precomputed half polar Fourier transform + with shape (*stack_shape, ntheta//2, nrad) + :return: The full polar Fourier transform with shape (*stack_shape, ntheta, nrad) + """ - # return v coefficients with the last dimension size of self.count - v = v.reshape(nimgs, -1) - return v + return np.concatenate((pf, np.conj(pf)), axis=-2) diff --git a/src/aspire/utils/__init__.py b/src/aspire/utils/__init__.py index 8a074bd007..6aa71fb8f5 100644 --- a/src/aspire/utils/__init__.py +++ b/src/aspire/utils/__init__.py @@ -9,6 +9,7 @@ grid_2d, grid_3d, register_rotations, + rots_to_clmatrix, uniform_random_angles, ) diff --git a/src/aspire/utils/coor_trans.py b/src/aspire/utils/coor_trans.py index b9259545fd..a29e041966 100644 --- a/src/aspire/utils/coor_trans.py +++ b/src/aspire/utils/coor_trans.py @@ -308,6 +308,40 @@ def common_line_from_rots(r1, r2, ell): return ell_ij, ell_ji +def rots_to_clmatrix(rots, n_theta): + """ + Compute the common lines matrix induced by all pairs of rotation + matrices, `rots`, provided. + + :param rots: n_rotsx3x3 array of rotation matrices. + :param n_theta: Number of theta values for common lines indices. + + :return: n_rots x n_rots common lines matrix. + """ + n_rots = len(rots) + cl_matrix = -np.ones((n_rots, n_rots)) + for i in range(n_rots): + Ri = rots[i] + Ri3 = Ri[:, 2] + for j in range(i + 1, n_rots): + Rj = rots[j] + Rj3 = Rj[:, 2] + common_axis = np.cross(Ri3, Rj3) / np.linalg.norm(np.cross(Ri3, Rj3)) + xij = Ri.T @ common_axis + xji = Rj.T @ common_axis + theta_ij = np.rad2deg(np.arctan2(xij[1], xij[0])) % 360 + theta_ji = np.rad2deg(np.arctan2(xji[1], xji[0])) % 360 + + if theta_ij > 180: + theta_ij -= 180 + theta_ji -= 180 + + cl_matrix[i, j] = round(theta_ij * n_theta / 360) + cl_matrix[j, i] = round((theta_ji % 360) * n_theta / 360) + + return cl_matrix + + def crop_pad_2d(im, size, fill_value=0): """ :param im: A 2-dimensional numpy array diff --git a/src/aspire/utils/rotation.py b/src/aspire/utils/rotation.py index 4dcd05cec2..08bec4ca3d 100644 --- a/src/aspire/utils/rotation.py +++ b/src/aspire/utils/rotation.py @@ -369,25 +369,65 @@ def generate_random_rotations( @staticmethod def angle_dist(r1, r2, dtype=None): """ - Find the angular distance between two rotation matrices. We first compute - the rotation between the two rotation matrices, r = r1 @ r2.T. Then using - the axis-angle representation of r we find the angle between r1 and r2. - - :param r1: A 3x3 rotation matrix - :param r2: A 3x3 rotation matrix + Find the angular distance between two sets of rotation matrices. We first compute + the rotation between the two sets of rotation matrices, r = r1 @ r2.T. Then using + the axis-angle representation of r we find the angle between r1 and r2. Computations + will broadcast in the case of a singleton rotation and a set of rotations. + + :param r1: An nx3x3 array of rotation matrices, a singleton rotation, + or a `Rotation` object. + :param r2: An nx3x3 array of rotation matrices, a singleton rotation, + or a `Rotation` object. :param dtype: Computation datatype. Default `None` infers from `r1`. - :return: The angular distance between r1 and r2 in radians. + :return: The element-wise angular distance between r1 and r2 in radians. """ + if isinstance(r1, Rotation): + r1 = r1.matrices + if isinstance(r2, Rotation): + r2 = r2.matrices + + # Handle singletons. + r1 = r1.reshape(-1, 3, 3) + r2 = r2.reshape(-1, 3, 3) + + # Check that shapes are compatible. + if (r1.shape[0] != 1) and (r2.shape[0] != 1): + if r1.shape[0] != r2.shape[0]: + raise ValueError(r"r1 and r2 are not broadcastable.*") dtype = np.dtype(dtype or r1.dtype) - r = r1 @ r2.T - tr_r = np.trace(r, dtype=dtype) - if abs(tr_r - 3.0) <= np.finfo(dtype).resolution: - dist = 0 - else: - theta = (tr_r - 1) / 2 - theta = max(min(theta, 1), -1) # Clamp theta in [-1,1] - dist = np.arccos(theta, dtype=dtype) + r2_T = np.transpose(r2, axes=(0, 2, 1)) + r = r1 @ r2_T + tr_r = np.trace(r, axis1=1, axis2=2, dtype=dtype) + + dist = np.zeros(max(len(r1), len(r2)), dtype=dtype) + non_zero_dist_ind = np.where(abs(tr_r - 3.0) > np.finfo(tr_r.dtype).resolution) + + theta = (tr_r[non_zero_dist_ind] - 1) / 2 + theta = np.maximum(np.minimum(theta, 1), -1) # Clamp theta in [-1,1] + dist[non_zero_dist_ind] = np.arccos(theta, dtype=dtype) return dist + + @staticmethod + def mean_angular_distance(rots_1, rots_2, dtype=None): + """ + Find the mean angular distance between two sets of rotation matrices. Computations + will broadcast in the case of a singleton rotations and a set of rotations. + + :param rots_1: An nx3x3 array of rotation matrices, a singleton rotation, + or a `Rotation` object. + :param rots_2: An nx3x3 array of rotation matrices, a singleton rotation, + or a `Rotation` object. + :param dtype: Data type for computation. Default infers dtype from `rots_1`. + + :return: The mean angular distance between rotations in radians. + """ + + dtype = np.dtype(dtype or rots_1.dtype) + + # Shapes will be checked by `angle_dist`. + ang_dist = Rotation.angle_dist(rots_1, rots_2, dtype=dtype) + + return np.mean(ang_dist) diff --git a/tests/test_PolarBasis2D.py b/tests/test_PolarBasis2D.py deleted file mode 100644 index 78c10c8596..0000000000 --- a/tests/test_PolarBasis2D.py +++ /dev/null @@ -1,538 +0,0 @@ -import logging -from unittest import TestCase - -import numpy as np -from pytest import raises - -from aspire.basis import PolarBasis2D -from aspire.image import Image -from aspire.utils import complex_type, utest_tolerance -from aspire.utils.matlab_compat import m_reshape -from aspire.utils.random import randn - -from ._basis_util import UniversalBasisMixin - -logger = logging.getLogger(__name__) - - -class PolarBasis2DTestCase(TestCase, UniversalBasisMixin): - def setUp(self): - self.dtype = np.float32 - self.L = 8 - self.basis = PolarBasis2D((self.L, self.L), 4, 32, dtype=self.dtype) - # Note, in practice we got a degenerate random array around 1% - # of the time, so we fix a seed for the randn calls. - self.seed = 8675309 - - def tearDown(self): - pass - - def testPolarBasis2DThetaError(self): - """ - Test that PolarBasis2D when instantiated with odd value for `ntheta` - gives appropriate error. - """ - - # Test we raise with expected error. - with raises(NotImplementedError, match=r"Only even values for ntheta*"): - _ = PolarBasis2D(size=self.L, ntheta=143, dtype=self.dtype) - - def testPolarBasis2DEvaluate_t(self): - x = Image( - np.array( - [ - [ - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - -1.08106869e-17, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - ], - [ - 0.00000000e00, - 0.00000000e00, - -6.40456062e-03, - -3.32961020e-03, - -1.36887927e-02, - -5.42770488e-03, - 7.63680861e-03, - 0.00000000e00, - ], - [ - 0.00000000e00, - 3.16377602e-03, - -9.31273350e-03, - 9.46128404e-03, - 1.93239220e-02, - 3.79891953e-02, - 1.06841173e-02, - -2.36467925e-03, - ], - [ - 0.00000000e00, - 1.72736955e-03, - -1.00710814e-02, - 4.93520304e-02, - 3.77702656e-02, - 6.57365438e-02, - 3.94739462e-03, - -4.41228496e-03, - ], - [ - 4.01551066e-18, - -3.08071647e-03, - -1.61670565e-02, - 8.66886286e-02, - 5.09898409e-02, - 7.19313349e-02, - 1.68313715e-02, - 5.19180892e-03, - ], - [ - 0.00000000e00, - 2.87262215e-03, - -3.37732956e-02, - 4.51706505e-02, - 5.72215879e-02, - 4.63553081e-02, - 1.86552175e-03, - 1.12608805e-02, - ], - [ - 0.00000000e00, - 2.77905016e-03, - -2.77499404e-02, - -4.02645374e-02, - -1.54969139e-02, - -1.66229153e-02, - -2.07389259e-02, - 6.64060546e-03, - ], - [ - 0.00000000e00, - 0.00000000e00, - 5.20080934e-03, - -1.06788196e-02, - -1.14761672e-02, - -1.27443126e-02, - -1.15563484e-02, - 0.00000000e00, - ], - ], - dtype=self.dtype, - ).T - ) # RCOPT - - pf = self.basis.evaluate_t(x) - result = np.array( - [ - 0.38243133 + 6.66608316e-18j, - 0.3249317 - 1.47839074e-01j, - 0.14819172 + 3.78171168e-03j, - -0.22808599 + 5.29338933e-02j, - 0.38243133 + 6.66608316e-18j, - 0.34595014 - 1.06355385e-01j, - 0.15519289 - 4.75602164e-02j, - -0.22401193 + 4.33128746e-03j, - 0.38243133 + 6.66608316e-18j, - 0.36957165 - 5.69575709e-02j, - 0.17389327 - 5.53498385e-02j, - -0.11601473 - 1.35405676e-02j, - 0.38243133 + 6.66608316e-18j, - 0.39045046 - 2.17911945e-03j, - 0.18146449 - 1.37089189e-02j, - -0.02110144 + 6.65071497e-03j, - 0.38243133 + 6.66608316e-18j, - 0.4063995 + 5.21354967e-02j, - 0.15674204 + 3.85815662e-02j, - -0.02886296 + 3.91489615e-02j, - 0.38243133 + 6.66608316e-18j, - 0.41872477 + 9.98946906e-02j, - 0.11862477 + 5.15231952e-02j, - -0.05298751 + 1.95319478e-02j, - 0.38243133 + 6.66608316e-18j, - 0.43013599 + 1.38307796e-01j, - 0.10075763 + 1.25689289e-02j, - -0.04052728 - 5.66863498e-02j, - 0.38243133 + 6.66608316e-18j, - 0.44144497 + 1.68826980e-01j, - 0.11446016 - 4.53003874e-02j, - -0.03546515 - 1.13544145e-01j, - 0.38243133 + 6.66608316e-18j, - 0.44960099 + 1.94794929e-01j, - 0.15053714 - 8.11915305e-02j, - -0.04800556 - 1.15828804e-01j, - 0.38243133 + 6.66608316e-18j, - 0.44872328 + 2.17957567e-01j, - 0.19116871 - 7.99536373e-02j, - -0.05683092 - 9.72225058e-02j, - 0.38243133 + 6.66608316e-18j, - 0.43379428 + 2.36681249e-01j, - 0.21025378 - 5.48466438e-02j, - -0.05318826 - 8.54948014e-02j, - 0.38243133 + 6.66608316e-18j, - 0.40485577 + 2.47073481e-01j, - 0.18680217 - 3.31766116e-02j, - -0.06674163 - 7.94216591e-02j, - 0.38243133 + 6.66608316e-18j, - 0.36865853 + 2.45913767e-01j, - 0.13660805 - 3.68947359e-02j, - -0.11467046 - 8.49198927e-02j, - 0.38243133 + 6.66608316e-18j, - 0.33597018 + 2.32971425e-01j, - 0.1072859 - 6.24686168e-02j, - -0.12932565 - 1.06139634e-01j, - 0.38243133 + 6.66608316e-18j, - 0.31616666 + 2.10791785e-01j, - 0.11876919 - 7.93812474e-02j, - -0.1094488 - 1.20159845e-01j, - 0.38243133 + 6.66608316e-18j, - 0.31313975 + 1.82190396e-01j, - 0.14075481 - 5.85637416e-02j, - -0.15198775 - 1.02156797e-01j, - 0.38243133 - 6.66608316e-18j, - 0.3249317 + 1.47839074e-01j, - 0.14819172 - 3.78171168e-03j, - -0.22808599 - 5.29338933e-02j, - 0.38243133 - 6.66608316e-18j, - 0.34595014 + 1.06355385e-01j, - 0.15519289 + 4.75602164e-02j, - -0.22401193 - 4.33128746e-03j, - 0.38243133 - 6.66608316e-18j, - 0.36957165 + 5.69575709e-02j, - 0.17389327 + 5.53498385e-02j, - -0.11601473 + 1.35405676e-02j, - 0.38243133 - 6.66608316e-18j, - 0.39045046 + 2.17911945e-03j, - 0.18146449 + 1.37089189e-02j, - -0.02110144 - 6.65071497e-03j, - 0.38243133 - 6.66608316e-18j, - 0.4063995 - 5.21354967e-02j, - 0.15674204 - 3.85815662e-02j, - -0.02886296 - 3.91489615e-02j, - 0.38243133 - 6.66608316e-18j, - 0.41872477 - 9.98946906e-02j, - 0.11862477 - 5.15231952e-02j, - -0.05298751 - 1.95319478e-02j, - 0.38243133 - 6.66608316e-18j, - 0.43013599 - 1.38307796e-01j, - 0.10075763 - 1.25689289e-02j, - -0.04052728 + 5.66863498e-02j, - 0.38243133 - 6.66608316e-18j, - 0.44144497 - 1.68826980e-01j, - 0.11446016 + 4.53003874e-02j, - -0.03546515 + 1.13544145e-01j, - 0.38243133 - 6.66608316e-18j, - 0.44960099 - 1.94794929e-01j, - 0.15053714 + 8.11915305e-02j, - -0.04800556 + 1.15828804e-01j, - 0.38243133 - 6.66608316e-18j, - 0.44872328 - 2.17957567e-01j, - 0.19116871 + 7.99536373e-02j, - -0.05683092 + 9.72225058e-02j, - 0.38243133 - 6.66608316e-18j, - 0.43379428 - 2.36681249e-01j, - 0.21025378 + 5.48466438e-02j, - -0.05318826 + 8.54948014e-02j, - 0.38243133 - 6.66608316e-18j, - 0.40485577 - 2.47073481e-01j, - 0.18680217 + 3.31766116e-02j, - -0.06674163 + 7.94216591e-02j, - 0.38243133 - 6.66608316e-18j, - 0.36865853 - 2.45913767e-01j, - 0.13660805 + 3.68947359e-02j, - -0.11467046 + 8.49198927e-02j, - 0.38243133 - 6.66608316e-18j, - 0.33597018 - 2.32971425e-01j, - 0.1072859 + 6.24686168e-02j, - -0.12932565 + 1.06139634e-01j, - 0.38243133 - 6.66608316e-18j, - 0.31616666 - 2.10791785e-01j, - 0.11876919 + 7.93812474e-02j, - -0.1094488 + 1.20159845e-01j, - 0.38243133 - 6.66608316e-18j, - 0.31313975 - 1.82190396e-01j, - 0.14075481 + 5.85637416e-02j, - -0.15198775 + 1.02156797e-01j, - ], - dtype=complex_type(self.dtype), - ) - self.assertTrue(np.allclose(pf, result)) - - def testPolarBasis2DEvaluate(self): - v = np.array( - [ - 0.38243133 - 6.66608316e-18j, - 0.3249317 + 1.47839074e-01j, - 0.14819172 - 3.78171168e-03j, - -0.22808599 - 5.29338933e-02j, - 0.38243133 - 6.66608316e-18j, - 0.34595014 + 1.06355385e-01j, - 0.15519289 + 4.75602164e-02j, - -0.22401193 - 4.33128746e-03j, - 0.38243133 - 6.66608316e-18j, - 0.36957165 + 5.69575709e-02j, - 0.17389327 + 5.53498385e-02j, - -0.11601473 + 1.35405676e-02j, - 0.38243133 - 6.66608316e-18j, - 0.39045046 + 2.17911945e-03j, - 0.18146449 + 1.37089189e-02j, - -0.02110144 - 6.65071497e-03j, - 0.38243133 - 6.66608316e-18j, - 0.4063995 - 5.21354967e-02j, - 0.15674204 - 3.85815662e-02j, - -0.02886296 - 3.91489615e-02j, - 0.38243133 - 6.66608316e-18j, - 0.41872477 - 9.98946906e-02j, - 0.11862477 - 5.15231952e-02j, - -0.05298751 - 1.95319478e-02j, - 0.38243133 - 6.66608316e-18j, - 0.43013599 - 1.38307796e-01j, - 0.10075763 - 1.25689289e-02j, - -0.04052728 + 5.66863498e-02j, - 0.38243133 - 6.66608316e-18j, - 0.44144497 - 1.68826980e-01j, - 0.11446016 + 4.53003874e-02j, - -0.03546515 + 1.13544145e-01j, - 0.38243133 - 6.66608316e-18j, - 0.44960099 - 1.94794929e-01j, - 0.15053714 + 8.11915305e-02j, - -0.04800556 + 1.15828804e-01j, - 0.38243133 - 6.66608316e-18j, - 0.44872328 - 2.17957567e-01j, - 0.19116871 + 7.99536373e-02j, - -0.05683092 + 9.72225058e-02j, - 0.38243133 - 6.66608316e-18j, - 0.43379428 - 2.36681249e-01j, - 0.21025378 + 5.48466438e-02j, - -0.05318826 + 8.54948014e-02j, - 0.38243133 - 6.66608316e-18j, - 0.40485577 - 2.47073481e-01j, - 0.18680217 + 3.31766116e-02j, - -0.06674163 + 7.94216591e-02j, - 0.38243133 - 6.66608316e-18j, - 0.36865853 - 2.45913767e-01j, - 0.13660805 + 3.68947359e-02j, - -0.11467046 + 8.49198927e-02j, - 0.38243133 - 6.66608316e-18j, - 0.33597018 - 2.32971425e-01j, - 0.1072859 + 6.24686168e-02j, - -0.12932565 + 1.06139634e-01j, - 0.38243133 - 6.66608316e-18j, - 0.31616666 - 2.10791785e-01j, - 0.11876919 + 7.93812474e-02j, - -0.1094488 + 1.20159845e-01j, - 0.38243133 - 6.66608316e-18j, - 0.31313975 - 1.82190396e-01j, - 0.14075481 + 5.85637416e-02j, - -0.15198775 + 1.02156797e-01j, - 0.38243133 + 6.66608316e-18j, - 0.3249317 - 1.47839074e-01j, - 0.14819172 + 3.78171168e-03j, - -0.22808599 + 5.29338933e-02j, - 0.38243133 + 6.66608316e-18j, - 0.34595014 - 1.06355385e-01j, - 0.15519289 - 4.75602164e-02j, - -0.22401193 + 4.33128746e-03j, - 0.38243133 + 6.66608316e-18j, - 0.36957165 - 5.69575709e-02j, - 0.17389327 - 5.53498385e-02j, - -0.11601473 - 1.35405676e-02j, - 0.38243133 + 6.66608316e-18j, - 0.39045046 - 2.17911945e-03j, - 0.18146449 - 1.37089189e-02j, - -0.02110144 + 6.65071497e-03j, - 0.38243133 + 6.66608316e-18j, - 0.4063995 + 5.21354967e-02j, - 0.15674204 + 3.85815662e-02j, - -0.02886296 + 3.91489615e-02j, - 0.38243133 + 6.66608316e-18j, - 0.41872477 + 9.98946906e-02j, - 0.11862477 + 5.15231952e-02j, - -0.05298751 + 1.95319478e-02j, - 0.38243133 + 6.66608316e-18j, - 0.43013599 + 1.38307796e-01j, - 0.10075763 + 1.25689289e-02j, - -0.04052728 - 5.66863498e-02j, - 0.38243133 + 6.66608316e-18j, - 0.44144497 + 1.68826980e-01j, - 0.11446016 - 4.53003874e-02j, - -0.03546515 - 1.13544145e-01j, - 0.38243133 + 6.66608316e-18j, - 0.44960099 + 1.94794929e-01j, - 0.15053714 - 8.11915305e-02j, - -0.04800556 - 1.15828804e-01j, - 0.38243133 + 6.66608316e-18j, - 0.44872328 + 2.17957567e-01j, - 0.19116871 - 7.99536373e-02j, - -0.05683092 - 9.72225058e-02j, - 0.38243133 + 6.66608316e-18j, - 0.43379428 + 2.36681249e-01j, - 0.21025378 - 5.48466438e-02j, - -0.05318826 - 8.54948014e-02j, - 0.38243133 + 6.66608316e-18j, - 0.40485577 + 2.47073481e-01j, - 0.18680217 - 3.31766116e-02j, - -0.06674163 - 7.94216591e-02j, - 0.38243133 + 6.66608316e-18j, - 0.36865853 + 2.45913767e-01j, - 0.13660805 - 3.68947359e-02j, - -0.11467046 - 8.49198927e-02j, - 0.38243133 + 6.66608316e-18j, - 0.33597018 + 2.32971425e-01j, - 0.1072859 - 6.24686168e-02j, - -0.12932565 - 1.06139634e-01j, - 0.38243133 + 6.66608316e-18j, - 0.31616666 + 2.10791785e-01j, - 0.11876919 - 7.93812474e-02j, - -0.1094488 - 1.20159845e-01j, - 0.38243133 + 6.66608316e-18j, - 0.31313975 + 1.82190396e-01j, - 0.14075481 - 5.85637416e-02j, - -0.15198775 - 1.02156797e-01j, - ], - dtype=complex_type(self.dtype), - ) - - x = self.basis.evaluate(v) - result = np.array( - [ - [ - 9.8593804, - 7.94242903, - 7.23336975, - 7.33314303, - 7.41260132, - 7.59483694, - 7.94830958, - 9.47324547, - ], - [ - 7.27801941, - 8.29797686, - 7.17234599, - 7.31082685, - 7.04347376, - 6.91956664, - 8.12234596, - 8.36258646, - ], - [ - 8.76188511, - 10.69546884, - 8.37029969, - 9.87512737, - 9.73946157, - 6.56646752, - 5.69555713, - 8.77758976, - ], - [ - 10.42069436, - 12.3649092, - 14.23951952, - 20.41736454, - 22.32664939, - 18.11535113, - 7.95059873, - 8.79515046, - ], - [ - 11.23152882, - 12.61468396, - 17.92585027, - 25.82097043, - 26.4633412, - 25.11167661, - 11.90634511, - 9.05131389, - ], - [ - 10.7048523, - 11.73534566, - 16.53838035, - 25.13242621, - 23.58037996, - 21.37129485, - 12.1024389, - 10.26313743, - ], - [ - 8.24162377, - 11.90490143, - 14.82292441, - 19.50174891, - 17.69291969, - 15.06781768, - 10.4669263, - 10.2082326, - ], - [ - 5.26532858, - 9.60999648, - 12.68642275, - 12.42354237, - 10.87648517, - 10.60647963, - 9.11026567, - 8.53250276, - ], - ], - dtype=complex_type(self.dtype), - ).T # RCOPT - - self.assertTrue(np.allclose(x.asnumpy(), result)) - - def testPolarBasis2DAdjoint(self): - # The evaluate function should be the adjoint operator of evaluate_t. - # Namely, if A = evaluate, B = evaluate_t, and B=A^t, we will have - # (y, A*x) = (A^t*y, x) = (B*y, x) - x = randn(self.basis.count, seed=self.seed).astype(self.dtype) - - x = m_reshape(x, (self.basis.nrad, self.basis.ntheta)) - - x = ( - 1 / 2 * x[:, : self.basis.ntheta // 2] - + 1 / 2 * x[:, : self.basis.ntheta // 2].conj() - ) - - x = np.concatenate((x, x.conj()), axis=1) - - x = m_reshape(x, (self.basis.nrad * self.basis.ntheta,)) - - x_t = self.basis.evaluate(x).asnumpy() - y = randn(np.prod(self.basis.sz), seed=self.seed).astype(self.dtype) - y_t = self.basis.evaluate_t( - Image(m_reshape(y, self.basis.sz)[np.newaxis, :]) - ) # RCOPT - - lhs = np.dot(y, m_reshape(x_t, (np.prod(self.basis.sz),))) - rhs = np.real(np.dot(y_t, x)) - logging.debug( - f"lhs: {lhs} rhs: {rhs} absdiff: {np.abs(lhs-rhs)} atol: {utest_tolerance(self.dtype)}" - ) - - self.assertTrue(np.isclose(lhs, rhs, atol=utest_tolerance(self.dtype))) - - # The following functions of UniversalBasisMixin expect a `basis` - # arg to be passed in. When PolarBasis2D tests are parametrized - # over size and dtype, this will be possible by passing in a basis - # automatically via @pytest.mark.parametrize() decorator on the test class - # - # See: test_FBBasis2D and test_FFBBasis2D - # - # for now, pass in the basis we are using - def testEvaluate(self): - super().testEvaluate(self.basis) - - def testEvaluate_t(self): - super().testEvaluate_t(self.basis) - - def testExpand(self): - super().testExpand(self.basis) - - def testInitWithIntSize(self): - super().testInitWithIntSize(self.basis) diff --git a/tests/test_averager2d.py b/tests/test_averager2d.py index e5aa2ed15c..c09da2c42e 100644 --- a/tests/test_averager2d.py +++ b/tests/test_averager2d.py @@ -9,7 +9,7 @@ import pytest from packaging.version import parse as parse_version -from aspire.basis import FFBBasis2D, PolarBasis2D +from aspire.basis import FFBBasis2D from aspire.classification import ( AligningAverager2D, Averager2D, @@ -18,6 +18,7 @@ BFSReddyChatterjiAverager2D, ReddyChatterjiAverager2D, ) +from aspire.operators import PolarFT from aspire.source import Simulation from aspire.utils import Rotation, num_procs_suggestion from aspire.volume import Volume @@ -209,7 +210,7 @@ def testNoRot(self): Test we raise an error when our basis does not provide `rotate` method. """ # DiracBasis does not provide `rotate`, - basis = PolarBasis2D((self.resolution, self.resolution), dtype=self.dtype) + basis = PolarFT((self.resolution, self.resolution), dtype=self.dtype) # and that should raise an error during instantiation. with pytest.raises(RuntimeError, match=r".* must provide a `rotate` method."): diff --git a/tests/test_orient_symmetric.py b/tests/test_orient_symmetric.py index 9f0c62f6cf..746213913c 100644 --- a/tests/test_orient_symmetric.py +++ b/tests/test_orient_symmetric.py @@ -46,7 +46,7 @@ def source_orientation_objs(n_img, L, order, dtype): vol = CnSymmetricVolume( L=L, C=1, - K=25, + K=50, order=order, seed=65, dtype=dtype, @@ -91,6 +91,7 @@ def source_orientation_objs(n_img, L, order, dtype): cl_kwargs["symmetry"] = f"C{order}" elif order == 2: cl_class = CLSymmetryC2 + cl_kwargs["min_dist_cls"] = 15 else: cl_class = CLSymmetryCn cl_kwargs["symmetry"] = f"C{order}" @@ -119,25 +120,10 @@ def test_estimate_rotations(n_img, L, order, dtype): # angular distance between them (in degrees). Q_mat, flag = register_rotations(rots_est, rots_gt_sync) regrot = get_aligned_rotations(rots_est, Q_mat, flag) - ang_dist = np.zeros(n_img, dtype=dtype) - for i in range(n_img): - ang_dist[i] = ( - Rotation.angle_dist( - regrot[i], - rots_gt_sync[i], - dtype=dtype, - ) - * 180 - / np.pi - ) + mean_ang_dist = Rotation.mean_angular_distance(regrot, rots_gt_sync) * 180 / np.pi # Assert mean angular distance is reasonable. - if order == 2: - assert np.mean(ang_dist) < 4 - elif order == 3 or order == 4: - assert np.mean(ang_dist) < 2 - else: - assert np.mean(ang_dist) < 5 + assert mean_ang_dist < 3 @pytest.mark.parametrize("n_img, L, order, dtype", param_list_c3_c4) @@ -173,7 +159,7 @@ def test_relative_rotations(n_img, L, order, dtype): Rotation.angle_dist(Rij_J, Rij_s_gt), ) angular_distance[idx] = np.min(dist) - mean_angular_distance = np.mean(angular_distance) + mean_angular_distance = np.mean(angular_distance) * 180 / np.pi # Assert that the mean_angular_distance is less than 5 degrees. assert mean_angular_distance < 5 @@ -206,8 +192,8 @@ def test_self_relative_rotations(n_img, L, order, dtype): cases = np.array([Rii, Rii.T, J_conjugate(Rii), J_conjugate(Rii.T)]) for i, estimate in enumerate(cases): dist[i] = Rotation.angle_dist(estimate, Rii_gt) - angular_distance[i] = dist[np.argmin(dist)] - mean_angular_distance = np.mean(angular_distance) + angular_distance[i] = min(dist) + mean_angular_distance = np.mean(angular_distance) * 180 / np.pi # Check that mean_angular_distance is less than 5 degrees. assert mean_angular_distance < 5 @@ -367,7 +353,8 @@ def test_commonlines_c2(n_img, L, order, dtype): cl_gt = (cl_gt * 360 / n_theta) % 180 pairs = all_pairs(n_img) - within_5 = 0 + angle_tol = 2 # degrees + within_tol = 0 for i, j in pairs: # For each pair of images the two sets of mutual common-lines in cl, (cl[0,i,j], cl[0,j,i]) # and (cl[1,i,j], cl[1,j,i]), should each match one of the two sets in the ground truth cl_gt. @@ -385,11 +372,11 @@ def test_commonlines_c2(n_img, L, order, dtype): + abs(cl[1, j, i] - cl_gt[0, j, i]) ) min_err = min(err_1, err_2) - if min_err <= 5: - within_5 += 1 + if min_err <= angle_tol: + within_tol += 1 - # Check that at least 90% of estimates are within 5 degrees. - assert within_5 / len(pairs) > 0.90 + # Check that at least 90% of estimates are within `angle_tol` degrees. + assert within_tol / len(pairs) > 0.90 @pytest.mark.parametrize("n_img, L, order, dtype", param_list_c3_c4) diff --git a/tests/test_orient_sync.py b/tests/test_orient_sync.py deleted file mode 100644 index 3053143287..0000000000 --- a/tests/test_orient_sync.py +++ /dev/null @@ -1,131 +0,0 @@ -import os -import os.path -import tempfile -from unittest import TestCase - -import numpy as np -from click.testing import CliRunner -from pytest import raises - -from aspire.abinitio import CLOrient3D, CLSyncVoting -from aspire.commands.orient3d import orient3d -from aspire.operators import RadialCTFFilter -from aspire.source.simulation import Simulation -from aspire.utils import utest_tolerance -from aspire.utils.random import Random -from aspire.volume import Volume - -DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") - - -class OrientSyncTestCase(TestCase): - def setUp(self): - L = 32 - n = 64 - pixel_size = 5 - voltage = 200 - defocus_min = 1.5e4 - defocus_max = 2.5e4 - defocus_ct = 7 - Cs = 2.0 - alpha = 0.1 - self.dtype = np.float32 - - filters = [ - RadialCTFFilter(pixel_size, voltage, defocus=d, Cs=Cs, alpha=alpha) - for d in np.linspace(defocus_min, defocus_max, defocus_ct) - ] - - vols = Volume( - np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")).astype( - self.dtype - ) - ) - vols = vols.downsample(L) - - self.sim = Simulation( - L=L, n=n, vols=vols, unique_filters=filters, dtype=self.dtype - ) - - self.orient_est = CLSyncVoting(self.sim, L // 2, 36) - - def tearDown(self): - pass - - def testBuildCLmatrix(self): - self.orient_est.build_clmatrix() - results = np.load(os.path.join(DATA_DIR, "orient_est_clmatrix.npy")) - self.assertTrue(np.allclose(results, self.orient_est.clmatrix)) - - def testSyncMatrixVote(self): - self.orient_est.syncmatrix_vote() - results = np.load(os.path.join(DATA_DIR, "orient_est_smatrix.npy")) - self.assertTrue( - np.allclose( - results, - self.orient_est.syncmatrix, - atol=1e-5 if self.dtype == np.float32 else 1e-8, - ) - ) - - def testEstRotations(self): - self.orient_est.estimate_rotations() - results = np.load(os.path.join(DATA_DIR, "orient_est_rots.npy")) - # Check the dtype passthrough is preserved - self.assertTrue(self.orient_est.rotations.dtype == self.dtype) - # Check the values match reference rotation - self.assertTrue( - np.allclose( - results, - self.orient_est.rotations, - atol=1e-5 if self.dtype == np.float32 else 1e-8, - ) - ) - - def testEstShifts(self): - # need to rerun explicitly the estimation of rotations - self.orient_est.estimate_rotations() - with Random(0): - self.est_shifts = self.orient_est.estimate_shifts() - results = np.load(os.path.join(DATA_DIR, "orient_est_shifts.npy")) - self.assertTrue( - np.allclose(results, self.est_shifts, atol=utest_tolerance(self.dtype)) - ) - - def testThetaError(self): - """ - Test that CLSyncVoting when instantiated with odd value for `n_theta` - gives appropriate error. - """ - - # Test we raise with expected error. - with raises(NotImplementedError, match=r"n_theta must be even*"): - _ = CLSyncVoting(self.sim, 16, 35) - - def testNCheckError(self): - """Test we get expected error when n_check is out of range.""" - with raises(NotImplementedError, match=r"n_check must be in*"): - _ = CLOrient3D(self.sim, n_check=-2) - with raises(NotImplementedError, match=r"n_check must be in*"): - _ = CLOrient3D(self.sim, n_check=self.sim.n + 1) - - def testCommandLine(self): - # Ensure that the command line tool works as expected - runner = CliRunner() - with tempfile.TemporaryDirectory() as tmpdir: - # Save the simulation object into STAR and MRCS files - starfile_out = os.path.join(tmpdir, "save_test.star") - starfile_in = os.path.join(DATA_DIR, "sample_particles_relion31.star") - result = runner.invoke( - orient3d, - [ - f"--starfile_in={starfile_in}", - "--n_rad=10", - "--n_theta=60", - "--max_shift=0.15", - "--shift_step=1", - f"--starfile_out={starfile_out}", - ], - ) - # check that the command completed successfully - self.assertTrue(result.exit_code == 0) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py new file mode 100644 index 0000000000..8b427984ef --- /dev/null +++ b/tests/test_orient_sync_voting.py @@ -0,0 +1,163 @@ +import os +import os.path +import tempfile + +import numpy as np +import pytest +from click.testing import CliRunner + +from aspire.abinitio import CLOrient3D, CLSyncVoting +from aspire.commands.orient3d import orient3d +from aspire.source import Simulation +from aspire.utils import ( + Rotation, + get_aligned_rotations, + register_rotations, + rots_to_clmatrix, +) +from aspire.volume import AsymmetricVolume + +DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") + +RESOLUTION = [ + 40, + 41, +] + +# `None` defaults to random offsets. +OFFSETS = [ + 0, + pytest.param(None, marks=pytest.mark.expensive), +] + +DTYPES = [ + np.float32, + pytest.param(np.float64, marks=pytest.mark.expensive), +] + + +@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}") +def resolution(request): + return request.param + + +@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}") +def offsets(request): + return request.param + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") +def dtype(request): + return request.param + + +@pytest.fixture +def source_orientation_objs(resolution, offsets, dtype): + src = Simulation( + n=50, + L=resolution, + vols=AsymmetricVolume(L=resolution, C=1, K=100, seed=0).generate(), + offsets=offsets, + amplitudes=1, + seed=0, + ) + + # Search for common lines over less shifts for 0 offsets. + max_shift = 1 / resolution + shift_step = 1 + if src.offsets.all() != 0: + max_shift = 0.20 + shift_step = 0.25 # Reduce shift steps for non-integer offsets of Simulation. + orient_est = CLSyncVoting(src, max_shift=max_shift, shift_step=shift_step) + + return src, orient_est + + +def test_build_clmatrix(source_orientation_objs): + src, orient_est = source_orientation_objs + + # Build clmatrix estimate. + orient_est.build_clmatrix() + + gt_clmatrix = rots_to_clmatrix(src.rotations, orient_est.n_theta) + + angle_diffs = abs(orient_est.clmatrix - gt_clmatrix) * 360 / orient_est.n_theta + + # Count number of estimates within 5 degrees of ground truth. + within_5 = np.count_nonzero(angle_diffs < 5) + within_5 += np.count_nonzero(angle_diffs > 355) + + # Check that at least 98% of estimates are within 5 degrees. + tol = 0.98 + assert within_5 / angle_diffs.size > tol + + +def test_estimate_rotations(source_orientation_objs): + src, orient_est = source_orientation_objs + + orient_est.estimate_rotations() + + # Register estimates to ground truth rotations and compute the + # angular distance between them (in degrees). + Q_mat, flag = register_rotations(orient_est.rotations, src.rotations) + regrot = get_aligned_rotations(orient_est.rotations, Q_mat, flag) + mean_ang_dist = Rotation.mean_angular_distance(regrot, src.rotations) * 180 / np.pi + + # Assert that mean angular distance is less than 1 degree (5 degrees with shifts). + degree_tol = 1 + assert mean_ang_dist < degree_tol + + +def test_estimate_shifts(source_orientation_objs): + src, orient_est = source_orientation_objs + if src.offsets.all() != 0: + pytest.xfail("Currently failing under non-zero offsets.") + + est_shifts = orient_est.estimate_shifts() + + # Assert that estimated shifts are close to src.offsets + assert np.allclose(est_shifts, src.offsets) + + +def test_theta_error(): + """ + Test that CLSyncVoting when instantiated with odd value for `n_theta` + gives appropriate error. + """ + sim = Simulation() + + # Test we raise with expected error. + with pytest.raises(NotImplementedError, match=r"n_theta must be even*"): + _ = CLSyncVoting(sim, 16, 35) + + +def test_n_check_error(): + """Test we get expected error when n_check is out of range.""" + sim = Simulation() + + with pytest.raises(NotImplementedError, match=r"n_check must be in*"): + _ = CLOrient3D(sim, n_check=-2) + with pytest.raises(NotImplementedError, match=r"n_check must be in*"): + _ = CLOrient3D(sim, n_check=sim.n + 1) + + +def test_command_line(): + # Ensure that the command line tool works as expected + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + # Save the simulation object into STAR and MRCS files + starfile_out = os.path.join(tmpdir, "save_test.star") + starfile_in = os.path.join(DATA_DIR, "sample_particles_relion31.star") + result = runner.invoke( + orient3d, + [ + f"--starfile_in={starfile_in}", + "--n_rad=10", + "--n_theta=60", + "--max_shift=0.15", + "--shift_step=1", + f"--starfile_out={starfile_out}", + ], + ) + # check that the command completed successfully + assert result.exit_code == 0 diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py new file mode 100644 index 0000000000..425d5e14bb --- /dev/null +++ b/tests/test_polar_ft.py @@ -0,0 +1,222 @@ +import numpy as np +import pytest + +from aspire.image import Image +from aspire.operators import PolarFT +from aspire.utils import gaussian_2d, grid_2d +from aspire.volume import AsymmetricVolume, CnSymmetricVolume + +# ========== +# Parameters +# ========== + +IMG_SIZES = [ + 64, + 65, +] +DTYPES = [ + np.float64, + np.float32, +] + +RADIAL_MODES = [ + 2, + 3, + 4, + 5, + 8, + 9, + 16, + 17, +] + + +# ================== +# Parameter Fixtures +# ================== + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") +def dtype(request): + return request.param + + +@pytest.fixture(params=IMG_SIZES, ids=lambda x: f"img_size={x}") +def img_size(request): + return request.param + + +@pytest.fixture(params=RADIAL_MODES, ids=lambda x: f"radial_mode={x}") +def radial_mode(request): + return request.param + + +# ===================== +# Image and PF Fixtures +# ===================== + + +@pytest.fixture +def gaussian(img_size, dtype): + """Radially symmetric image.""" + gauss = Image( + gaussian_2d(img_size, sigma=(img_size // 10, img_size // 10), dtype=dtype) + ) + pf = pf_transform(gauss) + + return pf + + +@pytest.fixture +def symmetric_image(img_size, dtype): + """Cyclically (C4) symmetric image.""" + symmetric_vol = CnSymmetricVolume( + img_size, C=1, order=4, K=25, seed=10, dtype=dtype + ).generate() + symmetric_image = symmetric_vol.project(np.eye(3, dtype=dtype)) + pf = pf_transform(symmetric_image) + + return pf + + +@pytest.fixture +def asymmetric_image(img_size, dtype): + """Asymetric image.""" + asymmetric_vol = AsymmetricVolume(img_size, C=1, dtype=dtype).generate() + asymmetric_image = asymmetric_vol.project(np.eye(3, dtype=dtype)) + pf = pf_transform(asymmetric_image) + + return asymmetric_image, pf + + +@pytest.fixture +def radial_mode_image(img_size, dtype, radial_mode): + g = grid_2d(img_size, dtype=dtype) + image = Image(np.sin(radial_mode * np.pi * g["r"])) + pf = pf_transform(image) + + return pf, radial_mode + + +# Helper function +def pf_transform(image): + """Take polar Fourier transform of image.""" + img_size = image.resolution + nrad = img_size // 2 + ntheta = 360 + pft = PolarFT(img_size, nrad=nrad, ntheta=ntheta, dtype=image.dtype) + pf = pft.transform(image)[0] + + return pf + + +# ============= +# Testing Suite +# ============= + + +def test_dc_component(asymmetric_image): + """Test that the DC component equals the mean of the signal.""" + image, pf = asymmetric_image + signal_mean = np.mean(image) + dc_components = abs(pf[:, 0]) + + assert np.allclose(dc_components, signal_mean) + + +def test_radially_symmetric_image(gaussian): + """Test that all polar Fourier rays are equal for a radially symmetric image.""" + pf = gaussian + + assert np.allclose(pf, pf[0]) + + +def test_cyclically_symmetric_image(symmetric_image): + """Test that a symmetric image produces repeated sets of polar Fourier rays.""" + pf = symmetric_image + + # For C4 symmetry any two sets of rays seperated by 90 degrees should be equal. + ntheta = pf.shape[0] # ntheta is the number of rays in 180 degrees. + + assert np.allclose(abs(pf[: ntheta // 2]), abs(pf[ntheta // 2 :]), atol=1e-7) + + +def test_radial_modes(radial_mode_image): + pf, mode = radial_mode_image + + # Set DC component to zero. + pf[:, 0] = 0 + + # Check that all rays are close. + assert abs(abs(pf) - abs(pf[0])).all() < 1e-4 + + # Check that correct mode is most prominent. + # Mode could be off by a pixel depending on resolution and mode. + # Since all rays are close will just test one. + mode_window = [mode - 1, mode, mode + 1] + ray = 3 + assert np.argmax(abs(pf[ray])) in mode_window + + +def test_complex_image_error(): + """Test that we raise for complex images.""" + img_size = 5 + complex_image = Image(np.ones((img_size, img_size), dtype=np.complex64)) + 2j + pft = PolarFT(size=img_size, dtype=np.complex64) + with pytest.raises(TypeError, match=r"The Image `x` must be real valued*"): + _ = pft.transform(complex_image) + + +def test_numpy_array_error(): + """Test that we raise when passed numpy array.""" + img_size = 5 + image_np = np.ones((img_size, img_size), dtype=np.float32) + pft = PolarFT(size=img_size, dtype=np.float32) + with pytest.raises(TypeError, match=r"passed numpy array*"): + _ = pft.transform(image_np) + + +def test_inconsistent_dtypes_error(): + """Test that we raise for complex images.""" + img_size = 5 + image = Image(np.ones((img_size, img_size), dtype=np.float32)) + pft = PolarFT(size=img_size, dtype=np.float64) + with pytest.raises(TypeError, match=r"Inconsistent dtypes*"): + _ = pft.transform(image) + + +def test_theta_error(): + """ + Test that `PolarFT`, when instantiated with odd value for `ntheta`, + gives appropriate error. + """ + + # Test we raise with expected error. + with pytest.raises(NotImplementedError, match=r"Only even values for ntheta*"): + _ = PolarFT(size=42, ntheta=143, dtype=np.float32) + + +@pytest.mark.parametrize("stack_shape", [(5,), (2, 3)]) +def test_half_to_full_transform(stack_shape): + """ + Test conjugate symmetry and shape of the full polar Fourier transform. + """ + img_size = 32 + image = Image( + np.random.rand(*stack_shape, img_size, img_size).astype(np.float32, copy=False) + ) + pft = PolarFT(size=img_size) + pf = pft.transform(image) + full_pf = pft.half_to_full(pf) + + # Check shape. + assert full_pf.shape == (*stack_shape, pft.ntheta, pft.nrad) + + # Check conjugate symmetry against pf. + assert np.allclose(np.conj(pf), full_pf[..., pft.ntheta // 2 :, :]) + + # Check conjugate symmetry against self. + for ray in range(pft.ntheta // 2): + np.testing.assert_allclose( + full_pf[..., ray, :], np.conj(full_pf[..., ray + pft.ntheta // 2, :]) + ) diff --git a/tests/test_rotation.py b/tests/test_rotation.py index 7d7e430097..caffb17f0c 100644 --- a/tests/test_rotation.py +++ b/tests/test_rotation.py @@ -2,6 +2,7 @@ from unittest import TestCase import numpy as np +import pytest from scipy.spatial.transform import Rotation as sp_rot from aspire.utils import Rotation, utest_tolerance @@ -108,21 +109,6 @@ def testSetterGetter(self): def testDtype(self): self.assertTrue(self.dtype == self.rot_obj.dtype) - def testAngleDist(self): - angles = np.array([i * np.pi / 360 for i in range(360)], dtype=self.dtype) - rots = Rotation.about_axis("x", angles, dtype=self.dtype) - - # Calculate the angular distance between the identity, rots[0], - # and rotations by multiples of pi/360 about the x-axis. - # These should be equal to `angles`. - angular_dist = np.zeros(360, dtype=self.dtype) - for i, rot in enumerate(rots): - angular_dist[i] = Rotation.angle_dist(rots[0], rot, self.dtype) - - self.assertTrue( - np.allclose(angles, angular_dist, atol=utest_tolerance(self.dtype)) - ) - def testFromRotvec(self): # Build random rotation vectors. axis = np.array([1, 0, 0], dtype=self.dtype) @@ -136,3 +122,30 @@ def testFromRotvec(self): self.assertTrue(isinstance(rotations, Rotation)) self.assertTrue(rotations.matrices.dtype == self.dtype) self.assertTrue(np.allclose(rotations.matrices, ref_rots.matrices)) + + +def test_angle_dist(): + dtype = np.float32 + angles = np.array([i * np.pi / 360 for i in range(360)], dtype=dtype) + rots = Rotation.about_axis("x", angles, dtype=dtype) + + # Calculate the angular distance between the identity, rots[0], + # and rotations by multiples of pi/360 about the x-axis. + # These should be equal to `angles`. + angular_dist = Rotation.angle_dist(rots[0], rots, dtype) + assert np.allclose(angles, angular_dist, atol=utest_tolerance(dtype)) + + # Test incompatible shape error. + with pytest.raises(ValueError, match=r"r1 and r2 are not broadcastable*"): + _ = Rotation.angle_dist(rots[:3], rots[:5]) + + +def test_mean_angular_distance(): + rots_z = Rotation.about_axis( + "z", [0, np.pi / 4, np.pi / 2], dtype=np.float32 + ).matrices + rots_id = Rotation.about_axis("z", [0, 0, 0], dtype=np.float32).matrices + + mean_ang_dist = Rotation.mean_angular_distance(rots_z, rots_id) + + assert np.allclose(mean_ang_dist, np.pi / 4)