From 4a6976ed20af11d574d323c419c83279028b03a0 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 31 Jul 2023 11:22:23 -0400 Subject: [PATCH 01/53] Use only half_theta in PolarBasis2D. --- src/aspire/basis/polar_2d.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/aspire/basis/polar_2d.py b/src/aspire/basis/polar_2d.py index 5f29689030..5077446b41 100644 --- a/src/aspire/basis/polar_2d.py +++ b/src/aspire/basis/polar_2d.py @@ -16,7 +16,10 @@ class PolarBasis2D(Basis): def __init__(self, size, nrad=None, ntheta=None, dtype=np.float32): """ - Initialize an object for the 2D polar Fourier grid class + Initialize an object for the 2D polar Fourier grid class. `PolarBasis2D` expects that + images are real and uses only half of the `ntheta` values. Downstream algorithms should + take advantage of the conjugate symmetry the polar Fourier coefficients if the full + set is needed. :param size: The shape of the vectors for which to define the grid. May be a 2-tuple or an integer, in which case a square basis is assumed. @@ -56,7 +59,7 @@ def _build(self): logger.error(msg) raise NotImplementedError(msg) - self.count = self.nrad * self.ntheta + self.count = self.nrad * (self.ntheta // 2) self._sz_prod = self.sz[0] * self.sz[1] # precompute the basis functions in 2D grids @@ -70,7 +73,7 @@ def _precomp(self): dtheta = 2 * np.pi / self.ntheta # only need half size of ntheta - freqs = np.zeros((2, self.nrad * self.ntheta // 2), dtype=self.dtype) + freqs = np.zeros((2, self.nrad * (self.ntheta // 2)), dtype=self.dtype) for i in range(self.ntheta // 2): freqs[0, i * self.nrad : (i + 1) * self.nrad] = np.arange( self.nrad @@ -92,13 +95,12 @@ def _evaluate(self, v): :return x: Image instance in standard 2D coordinate basis with resolution of `self.sz`. """ - v = v.reshape(-1, self.ntheta, self.nrad) + half_size = self.ntheta // 2 + v = v.reshape(-1, half_size, self.nrad) nimgs = v.shape[0] - half_size = self.ntheta // 2 - - v = v[:, :half_size, :] + v[:, half_size:, :].conj() + v = v + v.conj() v = v.reshape(nimgs, self.nrad * half_size) @@ -122,8 +124,7 @@ def _evaluate_t(self, x): pf = nufft(x, self.freqs) - pf = pf.reshape((nimgs, self.nrad, half_size)) - v = np.concatenate((pf, pf.conj()), axis=1) + pf = pf.reshape((nimgs, half_size, self.nrad)) # return v coefficients with the last dimension size of self.count v = v.reshape(nimgs, -1) From 8bef8437d5c1dd27430eb2f46914855a20e9f0a2 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 31 Jul 2023 13:52:53 -0400 Subject: [PATCH 02/53] refactor _evaluate and testPolarBasis2DAdjoint. --- src/aspire/basis/basis.py | 2 ++ src/aspire/basis/polar_2d.py | 14 +------------- tests/test_PolarBasis2D.py | 18 +++++++++--------- 3 files changed, 12 insertions(+), 22 deletions(-) diff --git a/src/aspire/basis/basis.py b/src/aspire/basis/basis.py index 9a78b55b9d..c9057f96fc 100644 --- a/src/aspire/basis/basis.py +++ b/src/aspire/basis/basis.py @@ -99,8 +99,10 @@ def evaluate(self, v): stack_shape = v.shape[:-1] v = v.reshape(-1, self.count) # Compute the transform + x = self._evaluate(v) # Restore stack shape + x = x.reshape(*stack_shape, *self.sz) # Return the appropriate class diff --git a/src/aspire/basis/polar_2d.py b/src/aspire/basis/polar_2d.py index 5077446b41..9724194491 100644 --- a/src/aspire/basis/polar_2d.py +++ b/src/aspire/basis/polar_2d.py @@ -95,14 +95,6 @@ def _evaluate(self, v): :return x: Image instance in standard 2D coordinate basis with resolution of `self.sz`. """ - half_size = self.ntheta // 2 - v = v.reshape(-1, half_size, self.nrad) - - nimgs = v.shape[0] - - v = v + v.conj() - - v = v.reshape(nimgs, self.nrad * half_size) x = anufft(v, self.freqs, self.sz, real=True) @@ -124,8 +116,4 @@ def _evaluate_t(self, x): pf = nufft(x, self.freqs) - pf = pf.reshape((nimgs, half_size, self.nrad)) - - # return v coefficients with the last dimension size of self.count - v = v.reshape(nimgs, -1) - return v + return pf.reshape(nimgs, -1) diff --git a/tests/test_PolarBasis2D.py b/tests/test_PolarBasis2D.py index 78c10c8596..eb42796708 100644 --- a/tests/test_PolarBasis2D.py +++ b/tests/test_PolarBasis2D.py @@ -492,24 +492,24 @@ def testPolarBasis2DAdjoint(self): # (y, A*x) = (A^t*y, x) = (B*y, x) x = randn(self.basis.count, seed=self.seed).astype(self.dtype) - x = m_reshape(x, (self.basis.nrad, self.basis.ntheta)) + #x = np.reshape(x, (self.basis.ntheta // 2, self.basis.nrad)) - x = ( - 1 / 2 * x[:, : self.basis.ntheta // 2] - + 1 / 2 * x[:, : self.basis.ntheta // 2].conj() - ) + # x = ( + # 1 / 2 * x + # + 1 / 2 * x.conj() + # ) - x = np.concatenate((x, x.conj()), axis=1) + # x = np.concatenate((x, x.conj()), axis=1) - x = m_reshape(x, (self.basis.nrad * self.basis.ntheta,)) + # x = m_reshape(x, (self.basis.nrad * self.basis.ntheta,)) x_t = self.basis.evaluate(x).asnumpy() y = randn(np.prod(self.basis.sz), seed=self.seed).astype(self.dtype) y_t = self.basis.evaluate_t( - Image(m_reshape(y, self.basis.sz)[np.newaxis, :]) + Image(np.reshape(y, self.basis.sz)) ) # RCOPT - lhs = np.dot(y, m_reshape(x_t, (np.prod(self.basis.sz),))) + lhs = np.dot(y, np.reshape(x_t, (np.prod(self.basis.sz),))) rhs = np.real(np.dot(y_t, x)) logging.debug( f"lhs: {lhs} rhs: {rhs} absdiff: {np.abs(lhs-rhs)} atol: {utest_tolerance(self.dtype)}" From e176624a5826aa4f4b36ec1e04620aa822cdf12f Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 31 Jul 2023 14:02:39 -0400 Subject: [PATCH 03/53] Check that x is real-valued input of _evaluate_t. --- src/aspire/basis/polar_2d.py | 7 +++++++ tests/test_PolarBasis2D.py | 15 +-------------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/src/aspire/basis/polar_2d.py b/src/aspire/basis/polar_2d.py index 9724194491..d2c9780195 100644 --- a/src/aspire/basis/polar_2d.py +++ b/src/aspire/basis/polar_2d.py @@ -110,6 +110,13 @@ def _evaluate_t(self, x): Fourier grid. This is an array of vectors whose first dimension corresponds to `x.shape[0]`, and last dimension equals `self.count`. """ + # We expect the Iamge `x` to be real in order to take advantage of the conjugate + # symmetry of Fourier transform of a real valued image. + if not np.isreal(x).all(): + raise TypeError( + f"The Image `x` must be real valued. Found dtype {x.dtype}." + ) + nimgs = x.shape[0] half_size = self.ntheta // 2 diff --git a/tests/test_PolarBasis2D.py b/tests/test_PolarBasis2D.py index eb42796708..8d1232f742 100644 --- a/tests/test_PolarBasis2D.py +++ b/tests/test_PolarBasis2D.py @@ -492,22 +492,9 @@ def testPolarBasis2DAdjoint(self): # (y, A*x) = (A^t*y, x) = (B*y, x) x = randn(self.basis.count, seed=self.seed).astype(self.dtype) - #x = np.reshape(x, (self.basis.ntheta // 2, self.basis.nrad)) - - # x = ( - # 1 / 2 * x - # + 1 / 2 * x.conj() - # ) - - # x = np.concatenate((x, x.conj()), axis=1) - - # x = m_reshape(x, (self.basis.nrad * self.basis.ntheta,)) - x_t = self.basis.evaluate(x).asnumpy() y = randn(np.prod(self.basis.sz), seed=self.seed).astype(self.dtype) - y_t = self.basis.evaluate_t( - Image(np.reshape(y, self.basis.sz)) - ) # RCOPT + y_t = self.basis.evaluate_t(Image(np.reshape(y, self.basis.sz))) lhs = np.dot(y, np.reshape(x_t, (np.prod(self.basis.sz),))) rhs = np.real(np.dot(y_t, x)) From 525f47e2bc59f6ca30f2a1273c6309d763b359e7 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 2 Aug 2023 10:23:18 -0400 Subject: [PATCH 04/53] Initial refactor from PolarBasis2D ~~> PolarFT --- src/aspire/abinitio/commonline_base.py | 13 +- src/aspire/basis/__init__.py | 2 +- src/aspire/basis/polar_2d.py | 49 ++++-- tests/test_PolarBasis2D.py | 199 +++++++++---------------- tests/test_averager2d.py | 4 +- tests/test_orient_symmetric.py | 9 +- 6 files changed, 118 insertions(+), 158 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 8c21eff633..83e619a37d 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -4,7 +4,7 @@ import numpy as np import scipy.sparse as sparse -from aspire.basis import PolarBasis2D +from aspire.basis import PolarFT from aspire.utils import common_line_from_rots from aspire.utils.random import choice @@ -69,14 +69,13 @@ def _build(self): imgs = self.src.images[:] - # Obtain coefficients in polar Fourier basis for input 2D images - self.basis = PolarBasis2D( + # Obtain coefficients of polar Fourier transform for input 2D images + self.pft = PolarFT( (self.n_res, self.n_res), self.n_rad, self.n_theta, dtype=self.dtype ) - self.pf = self.basis.evaluate_t(imgs) - self.pf = self.pf.reshape(self.n_img, self.n_theta, self.n_rad) - + self.pf = self.pft.evaluate_t(imgs) n_theta_half = self.n_theta // 2 + self.pf = self.pf.reshape(self.n_img, n_theta_half, self.n_rad) # The last two dimension of pf is of size n_theta x n_rad. We will convert pf # into an array of size (n_theta/2) x (n_rad-1), that is, take half of each ray @@ -89,7 +88,7 @@ def _build(self): # Python version we will use the size of (n_theta/2) x (n_rad-1) directly and make # sure every part is using it. By taking shorter correlations we can speed the # computation by a factor of two. - self.pf = np.flip(self.pf[:, n_theta_half:, 1:], 2) + self.pf = self.pf[:, :, 1:] def estimate_rotations(self): """ diff --git a/src/aspire/basis/__init__.py b/src/aspire/basis/__init__.py index a482127338..5fc8ce9c2a 100644 --- a/src/aspire/basis/__init__.py +++ b/src/aspire/basis/__init__.py @@ -15,6 +15,6 @@ from .fpswf_2d import FPSWFBasis2D from .fpswf_3d import FPSWFBasis3D from .fspca import FSPCABasis -from .polar_2d import PolarBasis2D +from .polar_2d import PolarFT from .pswf_2d import PSWFBasis2D from .pswf_3d import PSWFBasis3D diff --git a/src/aspire/basis/polar_2d.py b/src/aspire/basis/polar_2d.py index d2c9780195..99bc59b3d2 100644 --- a/src/aspire/basis/polar_2d.py +++ b/src/aspire/basis/polar_2d.py @@ -2,23 +2,23 @@ import numpy as np -from aspire.basis import Basis +from aspire.image import Image from aspire.nufft import anufft, nufft from aspire.utils import complex_type logger = logging.getLogger(__name__) -class PolarBasis2D(Basis): +class PolarFT: """ Define a derived class for polar Fourier representation for 2D images """ def __init__(self, size, nrad=None, ntheta=None, dtype=np.float32): """ - Initialize an object for the 2D polar Fourier grid class. `PolarBasis2D` expects that + Initialize an object for the 2D polar Fourier grid class. `PolarFT` expects that images are real and uses only half of the `ntheta` values. Downstream algorithms should - take advantage of the conjugate symmetry the polar Fourier coefficients if the full + take advantage of the conjugate symmetry of the polar Fourier coefficients if the full set is needed. :param size: The shape of the vectors for which to define the grid. @@ -26,6 +26,7 @@ def __init__(self, size, nrad=None, ntheta=None, dtype=np.float32): Currently only square images are supported. :param nrad: The number of points in the radial dimension. Default is resolution // 2. :param ntheta: The number of points in the angular dimension. Default is 8 * nrad. + :param dtype: dtype of polar Fourier grid. """ if isinstance(size, int): size = (size, size) @@ -33,10 +34,15 @@ def __init__(self, size, nrad=None, ntheta=None, dtype=np.float32): assert ndim == 2, "Only two-dimensional grids are supported." assert len(set(size)) == 1, "Only square domains are supported." + self.ndim = ndim + self.sz = size self.nrad = nrad self.ntheta = ntheta + self.dtype = dtype - super().__init__(size, dtype=dtype) + self._build() + + self.count = self.nrad * (self.ntheta // 2) # this basis has complex coefficients self.coefficient_dtype = complex_type(self.dtype) @@ -48,7 +54,7 @@ def _build(self): logger.info("Represent 2D image in a polar Fourier grid") if self.nrad is None: - self.nrad = self.nres // 2 + self.nrad = self.sz[0] // 2 if self.ntheta is None: # try to use the same number as Fast FB basis @@ -96,11 +102,11 @@ def _evaluate(self, v): resolution of `self.sz`. """ - x = anufft(v, self.freqs, self.sz, real=True) + x = anufft(v, self.freqs, self.sz, real=True) / self.count return x - def _evaluate_t(self, x): + def evaluate_t(self, x): """ Evaluate coefficient in polar Fourier grid from those in standard 2D coordinate basis @@ -110,17 +116,34 @@ def _evaluate_t(self, x): Fourier grid. This is an array of vectors whose first dimension corresponds to `x.shape[0]`, and last dimension equals `self.count`. """ - # We expect the Iamge `x` to be real in order to take advantage of the conjugate - # symmetry of Fourier transform of a real valued image. + if x.dtype != self.dtype: + logger.warning( + f"{self.__class__.__name__}::evaluate_t" + f" Inconsistent dtypes x: {x.dtype} self: {self.dtype}" + ) + + if not isinstance(x, Image): + logger.warning( + f"{self.__class__.__name__}::evaluate_t" + f" passed numpy array instead of {Image}." + ) + else: + x = x.asnumpy() + + # Flatten stack + stack_shape = x.shape[: -self.ndim] + x = x.reshape(-1, *x.shape[-self.ndim :]) + + # We expect the Image `x` to be real in order to take advantage of the conjugate + # symmetry of the Fourier transform of a real valued image. if not np.isreal(x).all(): raise TypeError( f"The Image `x` must be real valued. Found dtype {x.dtype}." ) nimgs = x.shape[0] + resolution = x.shape[-1] - half_size = self.ntheta // 2 - - pf = nufft(x, self.freqs) + pf = nufft(x, self.freqs) / resolution**2 return pf.reshape(nimgs, -1) diff --git a/tests/test_PolarBasis2D.py b/tests/test_PolarBasis2D.py index 8d1232f742..e60dd9af7e 100644 --- a/tests/test_PolarBasis2D.py +++ b/tests/test_PolarBasis2D.py @@ -123,76 +123,76 @@ def testPolarBasis2DEvaluate_t(self): ], ], dtype=self.dtype, - ).T + ) ) # RCOPT pf = self.basis.evaluate_t(x) result = np.array( [ - 0.38243133 + 6.66608316e-18j, - 0.3249317 - 1.47839074e-01j, - 0.14819172 + 3.78171168e-03j, - -0.22808599 + 5.29338933e-02j, - 0.38243133 + 6.66608316e-18j, - 0.34595014 - 1.06355385e-01j, - 0.15519289 - 4.75602164e-02j, - -0.22401193 + 4.33128746e-03j, - 0.38243133 + 6.66608316e-18j, - 0.36957165 - 5.69575709e-02j, - 0.17389327 - 5.53498385e-02j, - -0.11601473 - 1.35405676e-02j, - 0.38243133 + 6.66608316e-18j, - 0.39045046 - 2.17911945e-03j, - 0.18146449 - 1.37089189e-02j, - -0.02110144 + 6.65071497e-03j, - 0.38243133 + 6.66608316e-18j, - 0.4063995 + 5.21354967e-02j, - 0.15674204 + 3.85815662e-02j, - -0.02886296 + 3.91489615e-02j, - 0.38243133 + 6.66608316e-18j, - 0.41872477 + 9.98946906e-02j, - 0.11862477 + 5.15231952e-02j, - -0.05298751 + 1.95319478e-02j, - 0.38243133 + 6.66608316e-18j, - 0.43013599 + 1.38307796e-01j, - 0.10075763 + 1.25689289e-02j, - -0.04052728 - 5.66863498e-02j, - 0.38243133 + 6.66608316e-18j, - 0.44144497 + 1.68826980e-01j, - 0.11446016 - 4.53003874e-02j, - -0.03546515 - 1.13544145e-01j, - 0.38243133 + 6.66608316e-18j, - 0.44960099 + 1.94794929e-01j, - 0.15053714 - 8.11915305e-02j, - -0.04800556 - 1.15828804e-01j, - 0.38243133 + 6.66608316e-18j, - 0.44872328 + 2.17957567e-01j, - 0.19116871 - 7.99536373e-02j, - -0.05683092 - 9.72225058e-02j, - 0.38243133 + 6.66608316e-18j, - 0.43379428 + 2.36681249e-01j, - 0.21025378 - 5.48466438e-02j, - -0.05318826 - 8.54948014e-02j, - 0.38243133 + 6.66608316e-18j, - 0.40485577 + 2.47073481e-01j, - 0.18680217 - 3.31766116e-02j, - -0.06674163 - 7.94216591e-02j, - 0.38243133 + 6.66608316e-18j, - 0.36865853 + 2.45913767e-01j, - 0.13660805 - 3.68947359e-02j, - -0.11467046 - 8.49198927e-02j, - 0.38243133 + 6.66608316e-18j, - 0.33597018 + 2.32971425e-01j, - 0.1072859 - 6.24686168e-02j, - -0.12932565 - 1.06139634e-01j, - 0.38243133 + 6.66608316e-18j, - 0.31616666 + 2.10791785e-01j, - 0.11876919 - 7.93812474e-02j, - -0.1094488 - 1.20159845e-01j, - 0.38243133 + 6.66608316e-18j, - 0.31313975 + 1.82190396e-01j, - 0.14075481 - 5.85637416e-02j, - -0.15198775 - 1.02156797e-01j, + # 0.38243133 + 6.66608316e-18j, + # 0.3249317 - 1.47839074e-01j, + # 0.14819172 + 3.78171168e-03j, + # -0.22808599 + 5.29338933e-02j, + # 0.38243133 + 6.66608316e-18j, + # 0.34595014 - 1.06355385e-01j, + # 0.15519289 - 4.75602164e-02j, + # -0.22401193 + 4.33128746e-03j, + # 0.38243133 + 6.66608316e-18j, + # 0.36957165 - 5.69575709e-02j, + # 0.17389327 - 5.53498385e-02j, + # -0.11601473 - 1.35405676e-02j, + # 0.38243133 + 6.66608316e-18j, + # 0.39045046 - 2.17911945e-03j, + # 0.18146449 - 1.37089189e-02j, + # -0.02110144 + 6.65071497e-03j, + # 0.38243133 + 6.66608316e-18j, + # 0.4063995 + 5.21354967e-02j, + # 0.15674204 + 3.85815662e-02j, + # -0.02886296 + 3.91489615e-02j, + # 0.38243133 + 6.66608316e-18j, + # 0.41872477 + 9.98946906e-02j, + # 0.11862477 + 5.15231952e-02j, + # -0.05298751 + 1.95319478e-02j, + # 0.38243133 + 6.66608316e-18j, + # 0.43013599 + 1.38307796e-01j, + # 0.10075763 + 1.25689289e-02j, + # -0.04052728 - 5.66863498e-02j, + # 0.38243133 + 6.66608316e-18j, + # 0.44144497 + 1.68826980e-01j, + # 0.11446016 - 4.53003874e-02j, + # -0.03546515 - 1.13544145e-01j, + # 0.38243133 + 6.66608316e-18j, + # 0.44960099 + 1.94794929e-01j, + # 0.15053714 - 8.11915305e-02j, + # -0.04800556 - 1.15828804e-01j, + # 0.38243133 + 6.66608316e-18j, + # 0.44872328 + 2.17957567e-01j, + # 0.19116871 - 7.99536373e-02j, + # -0.05683092 - 9.72225058e-02j, + # 0.38243133 + 6.66608316e-18j, + # 0.43379428 + 2.36681249e-01j, + # 0.21025378 - 5.48466438e-02j, + # -0.05318826 - 8.54948014e-02j, + # 0.38243133 + 6.66608316e-18j, + # 0.40485577 + 2.47073481e-01j, + # 0.18680217 - 3.31766116e-02j, + # -0.06674163 - 7.94216591e-02j, + # 0.38243133 + 6.66608316e-18j, + # 0.36865853 + 2.45913767e-01j, + # 0.13660805 - 3.68947359e-02j, + # -0.11467046 - 8.49198927e-02j, + # 0.38243133 + 6.66608316e-18j, + # 0.33597018 + 2.32971425e-01j, + # 0.1072859 - 6.24686168e-02j, + # -0.12932565 - 1.06139634e-01j, + # 0.38243133 + 6.66608316e-18j, + # 0.31616666 + 2.10791785e-01j, + # 0.11876919 - 7.93812474e-02j, + # -0.1094488 - 1.20159845e-01j, + # 0.38243133 + 6.66608316e-18j, + # 0.31313975 + 1.82190396e-01j, + # 0.14075481 - 5.85637416e-02j, + # -0.15198775 - 1.02156797e-01j, 0.38243133 - 6.66608316e-18j, 0.3249317 + 1.47839074e-01j, 0.14819172 - 3.78171168e-03j, @@ -260,6 +260,7 @@ def testPolarBasis2DEvaluate_t(self): ], dtype=complex_type(self.dtype), ) + self.assertTrue(np.allclose(pf, result)) def testPolarBasis2DEvaluate(self): @@ -329,70 +330,6 @@ def testPolarBasis2DEvaluate(self): 0.31313975 - 1.82190396e-01j, 0.14075481 + 5.85637416e-02j, -0.15198775 + 1.02156797e-01j, - 0.38243133 + 6.66608316e-18j, - 0.3249317 - 1.47839074e-01j, - 0.14819172 + 3.78171168e-03j, - -0.22808599 + 5.29338933e-02j, - 0.38243133 + 6.66608316e-18j, - 0.34595014 - 1.06355385e-01j, - 0.15519289 - 4.75602164e-02j, - -0.22401193 + 4.33128746e-03j, - 0.38243133 + 6.66608316e-18j, - 0.36957165 - 5.69575709e-02j, - 0.17389327 - 5.53498385e-02j, - -0.11601473 - 1.35405676e-02j, - 0.38243133 + 6.66608316e-18j, - 0.39045046 - 2.17911945e-03j, - 0.18146449 - 1.37089189e-02j, - -0.02110144 + 6.65071497e-03j, - 0.38243133 + 6.66608316e-18j, - 0.4063995 + 5.21354967e-02j, - 0.15674204 + 3.85815662e-02j, - -0.02886296 + 3.91489615e-02j, - 0.38243133 + 6.66608316e-18j, - 0.41872477 + 9.98946906e-02j, - 0.11862477 + 5.15231952e-02j, - -0.05298751 + 1.95319478e-02j, - 0.38243133 + 6.66608316e-18j, - 0.43013599 + 1.38307796e-01j, - 0.10075763 + 1.25689289e-02j, - -0.04052728 - 5.66863498e-02j, - 0.38243133 + 6.66608316e-18j, - 0.44144497 + 1.68826980e-01j, - 0.11446016 - 4.53003874e-02j, - -0.03546515 - 1.13544145e-01j, - 0.38243133 + 6.66608316e-18j, - 0.44960099 + 1.94794929e-01j, - 0.15053714 - 8.11915305e-02j, - -0.04800556 - 1.15828804e-01j, - 0.38243133 + 6.66608316e-18j, - 0.44872328 + 2.17957567e-01j, - 0.19116871 - 7.99536373e-02j, - -0.05683092 - 9.72225058e-02j, - 0.38243133 + 6.66608316e-18j, - 0.43379428 + 2.36681249e-01j, - 0.21025378 - 5.48466438e-02j, - -0.05318826 - 8.54948014e-02j, - 0.38243133 + 6.66608316e-18j, - 0.40485577 + 2.47073481e-01j, - 0.18680217 - 3.31766116e-02j, - -0.06674163 - 7.94216591e-02j, - 0.38243133 + 6.66608316e-18j, - 0.36865853 + 2.45913767e-01j, - 0.13660805 - 3.68947359e-02j, - -0.11467046 - 8.49198927e-02j, - 0.38243133 + 6.66608316e-18j, - 0.33597018 + 2.32971425e-01j, - 0.1072859 - 6.24686168e-02j, - -0.12932565 - 1.06139634e-01j, - 0.38243133 + 6.66608316e-18j, - 0.31616666 + 2.10791785e-01j, - 0.11876919 - 7.93812474e-02j, - -0.1094488 - 1.20159845e-01j, - 0.38243133 + 6.66608316e-18j, - 0.31313975 + 1.82190396e-01j, - 0.14075481 - 5.85637416e-02j, - -0.15198775 - 1.02156797e-01j, ], dtype=complex_type(self.dtype), ) @@ -481,7 +418,7 @@ def testPolarBasis2DEvaluate(self): 8.53250276, ], ], - dtype=complex_type(self.dtype), + dtype=self.dtype, ).T # RCOPT self.assertTrue(np.allclose(x.asnumpy(), result)) @@ -501,7 +438,7 @@ def testPolarBasis2DAdjoint(self): logging.debug( f"lhs: {lhs} rhs: {rhs} absdiff: {np.abs(lhs-rhs)} atol: {utest_tolerance(self.dtype)}" ) - + breakpoint() self.assertTrue(np.isclose(lhs, rhs, atol=utest_tolerance(self.dtype))) # The following functions of UniversalBasisMixin expect a `basis` diff --git a/tests/test_averager2d.py b/tests/test_averager2d.py index e5aa2ed15c..204e8e3722 100644 --- a/tests/test_averager2d.py +++ b/tests/test_averager2d.py @@ -9,7 +9,7 @@ import pytest from packaging.version import parse as parse_version -from aspire.basis import FFBBasis2D, PolarBasis2D +from aspire.basis import FFBBasis2D, PolarFT from aspire.classification import ( AligningAverager2D, Averager2D, @@ -209,7 +209,7 @@ def testNoRot(self): Test we raise an error when our basis does not provide `rotate` method. """ # DiracBasis does not provide `rotate`, - basis = PolarBasis2D((self.resolution, self.resolution), dtype=self.dtype) + basis = PolarFT((self.resolution, self.resolution), dtype=self.dtype) # and that should raise an error during instantiation. with pytest.raises(RuntimeError, match=r".* must provide a `rotate` method."): diff --git a/tests/test_orient_symmetric.py b/tests/test_orient_symmetric.py index 9f0c62f6cf..a5e4c1492b 100644 --- a/tests/test_orient_symmetric.py +++ b/tests/test_orient_symmetric.py @@ -91,6 +91,7 @@ def source_orientation_objs(n_img, L, order, dtype): cl_kwargs["symmetry"] = f"C{order}" elif order == 2: cl_class = CLSymmetryC2 + cl_kwargs["min_dist_cls"] = 15 else: cl_class = CLSymmetryCn cl_kwargs["symmetry"] = f"C{order}" @@ -367,7 +368,7 @@ def test_commonlines_c2(n_img, L, order, dtype): cl_gt = (cl_gt * 360 / n_theta) % 180 pairs = all_pairs(n_img) - within_5 = 0 + within_2 = 0 for i, j in pairs: # For each pair of images the two sets of mutual common-lines in cl, (cl[0,i,j], cl[0,j,i]) # and (cl[1,i,j], cl[1,j,i]), should each match one of the two sets in the ground truth cl_gt. @@ -385,11 +386,11 @@ def test_commonlines_c2(n_img, L, order, dtype): + abs(cl[1, j, i] - cl_gt[0, j, i]) ) min_err = min(err_1, err_2) - if min_err <= 5: - within_5 += 1 + if min_err <= 2: + within_2 += 1 # Check that at least 90% of estimates are within 5 degrees. - assert within_5 / len(pairs) > 0.90 + assert within_2 / len(pairs) > 0.90 @pytest.mark.parametrize("n_img, L, order, dtype", param_list_c3_c4) From 871b3070b6854d7ffd65ff33276ad740e07f5f90 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 2 Aug 2023 10:34:40 -0400 Subject: [PATCH 05/53] Add rots_to_clmatrix method. --- src/aspire/utils/coor_trans.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/aspire/utils/coor_trans.py b/src/aspire/utils/coor_trans.py index b9259545fd..219036d5fc 100644 --- a/src/aspire/utils/coor_trans.py +++ b/src/aspire/utils/coor_trans.py @@ -308,6 +308,40 @@ def common_line_from_rots(r1, r2, ell): return ell_ij, ell_ji +def rots_to_clmatrix(rots, n_theta, match_cl=True): + """ + Compute the common lines matrix induced by all pairs of rotation + matrices, `rots`, provided. + + :param rots: n_rotsx3x3 array of rotation matrices. + :param n_theta: Number of theta values fo common lines indices. + :param match_cl: Option to match cl_matrix convention from CLOrient3d + + :return: n_rots x n_rots common lines matrix. + """ + n_rots = len(rots) + cl_matrix = -np.ones((n_rots, n_rots)) + for i in range(n_rots): + for j in range(i + 1, n_rots): + Ri = rots[i] + Rj = rots[j] + Ri3 = Ri[:, 2] + Rj3 = Rj[:, 2] + xij = Ri.T @ (np.cross(Ri3, Rj3) / np.linalg.norm(np.cross(Ri3, Rj3))) + xji = Rj.T @ (np.cross(Ri3, Rj3) / np.linalg.norm(np.cross(Ri3, Rj3))) + theta_ij = np.rad2deg(np.arctan2(xij[1], xij[0])) % 360 + theta_ji = np.rad2deg(np.arctan2(xji[1], xji[0])) % 360 + + if theta_ij > 180 and match_cl: + theta_ij -= 180 + theta_ji -= 180 + + cl_matrix[i, j] = round((theta_ij % 360) * n_theta / 360) + cl_matrix[j, i] = round((theta_ji % 360) * n_theta / 360) + + return cl_matrix + + def crop_pad_2d(im, size, fill_value=0): """ :param im: A 2-dimensional numpy array From df93e8a536f501b3b13330db93560468797a93a3 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 2 Aug 2023 11:36:25 -0400 Subject: [PATCH 06/53] Add test_orient_sync_voting.py. Add test_build_clmatrix. --- src/aspire/utils/__init__.py | 1 + tests/test_orient_sync_voting.py | 48 ++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 tests/test_orient_sync_voting.py diff --git a/src/aspire/utils/__init__.py b/src/aspire/utils/__init__.py index 8a074bd007..6aa71fb8f5 100644 --- a/src/aspire/utils/__init__.py +++ b/src/aspire/utils/__init__.py @@ -9,6 +9,7 @@ grid_2d, grid_3d, register_rotations, + rots_to_clmatrix, uniform_random_angles, ) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py new file mode 100644 index 0000000000..ab39f3236d --- /dev/null +++ b/tests/test_orient_sync_voting.py @@ -0,0 +1,48 @@ +import numpy as np +import pytest + +from aspire.abinitio import CLOrient3D, CLSyncVoting +from aspire.source import Simulation +from aspire.utils import rots_to_clmatrix +from aspire.volume import AsymmetricVolume + +# Parametrize over (n_img, L, dtype) +PARAMS = [ + (50, 32, np.float32), + (51, 32, np.float64), + (50, 33, np.float32), + (51, 33, np.float64), +] + + +def source_orientation_objs(n_img, L, dtype): + src = Simulation( + n=n_img, + L=L, + vols=AsymmetricVolume(L=L, C=1, K=100).generate(), + offsets=0, + amplitudes=1, + seed=123, + ) + + orient_est = CLSyncVoting(src) + return src, orient_est + + +@pytest.mark.parametrize("n_img, L, dtype", PARAMS) +def test_build_clmatrix(n_img, L, dtype): + src, orient_est = source_orientation_objs(n_img, L, dtype) + + # Build clmatrix estimate. + orient_est.build_clmatrix() + + gt_clmatrix = rots_to_clmatrix(src.rotations, orient_est.n_theta) + + angle_diffs = abs(orient_est.clmatrix - gt_clmatrix) * 360 / orient_est.n_theta + + # Count number of estimates within 5 degrees of ground truth. + within_5 = np.count_nonzero(angle_diffs < 5) + within_5 += np.count_nonzero(angle_diffs > 355) + + # Check that at least 99% of estimates are within 5 degrees. + assert within_5 / angle_diffs.size > 0.99 From f5110ec131c6ac5d4c2405748bc91b0191d2bff0 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 2 Aug 2023 12:08:59 -0400 Subject: [PATCH 07/53] test_estimated_rotations --- tests/test_orient_sync_voting.py | 35 +++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index ab39f3236d..d565459f75 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -3,14 +3,17 @@ from aspire.abinitio import CLOrient3D, CLSyncVoting from aspire.source import Simulation -from aspire.utils import rots_to_clmatrix +from aspire.utils import ( + Rotation, + get_aligned_rotations, + register_rotations, + rots_to_clmatrix, +) from aspire.volume import AsymmetricVolume # Parametrize over (n_img, L, dtype) PARAMS = [ (50, 32, np.float32), - (51, 32, np.float64), - (50, 33, np.float32), (51, 33, np.float64), ] @@ -46,3 +49,29 @@ def test_build_clmatrix(n_img, L, dtype): # Check that at least 99% of estimates are within 5 degrees. assert within_5 / angle_diffs.size > 0.99 + + +@pytest.mark.parametrize("n_img, L, dtype", PARAMS) +def test_estimated_rotations(n_img, L, dtype): + src, orient_est = source_orientation_objs(n_img, L, dtype) + + orient_est.estimate_rotations() + + # Register estimates to ground truth rotations and compute the + # angular distance between them (in degrees). + Q_mat, flag = register_rotations(orient_est.rotations, src.rotations) + regrot = get_aligned_rotations(orient_est.rotations, Q_mat, flag) + ang_dist = np.zeros(n_img, dtype=dtype) + for i in range(n_img): + ang_dist[i] = ( + Rotation.angle_dist( + regrot[i], + src.rotations[i], + dtype=dtype, + ) + * 180 + / np.pi + ) + + # Assert that mean angular distance is less than 1 degree. + assert np.mean(ang_dist) < 1 From be3f4c8aea8e412b8fe2d3d36cab6c5a34719c37 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 2 Aug 2023 14:15:54 -0400 Subject: [PATCH 08/53] test_estimate_shifts --- tests/test_orient_sync_voting.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index d565459f75..5a39b58ea1 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -18,12 +18,12 @@ ] -def source_orientation_objs(n_img, L, dtype): +def source_orientation_objs(n_img, L, offsets, dtype): src = Simulation( n=n_img, L=L, vols=AsymmetricVolume(L=L, C=1, K=100).generate(), - offsets=0, + offsets=offsets, amplitudes=1, seed=123, ) @@ -34,7 +34,8 @@ def source_orientation_objs(n_img, L, dtype): @pytest.mark.parametrize("n_img, L, dtype", PARAMS) def test_build_clmatrix(n_img, L, dtype): - src, orient_est = source_orientation_objs(n_img, L, dtype) + offsets = 0 + src, orient_est = source_orientation_objs(n_img, L, offsets, dtype) # Build clmatrix estimate. orient_est.build_clmatrix() @@ -52,8 +53,9 @@ def test_build_clmatrix(n_img, L, dtype): @pytest.mark.parametrize("n_img, L, dtype", PARAMS) -def test_estimated_rotations(n_img, L, dtype): - src, orient_est = source_orientation_objs(n_img, L, dtype) +def test_estimate_rotations(n_img, L, dtype): + offsets = 0 + src, orient_est = source_orientation_objs(n_img, L, offsets, dtype) orient_est.estimate_rotations() @@ -75,3 +77,15 @@ def test_estimated_rotations(n_img, L, dtype): # Assert that mean angular distance is less than 1 degree. assert np.mean(ang_dist) < 1 + + +@pytest.mark.xfail(reason="estimate_shifts bug.") +@pytest.mark.parametrize("n_img, L, dtype", PARAMS) +def test_estimate_shifts(n_img, L, dtype): + offests = None # Use default random offsets. + src, orient_est = source_orientation_objs(n_img, L, offsets, dtype) + + est_shifts = orient_est.estimate_shifts().T + + # Assert that estimated shifts are close to src.offsets + assert np.allclose(est_shifts, src.offsets) From b3cecaef20e63a327a1bd0b383a9154e7f91bb9a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 2 Aug 2023 14:40:36 -0400 Subject: [PATCH 09/53] additional sync_voting tests --- tests/test_orient_sync_voting.py | 55 +++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 5a39b58ea1..954e1a7205 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -1,7 +1,13 @@ +import os +import os.path +import tempfile + import numpy as np import pytest +from click.testing import CliRunner from aspire.abinitio import CLOrient3D, CLSyncVoting +from aspire.commands.orient3d import orient3d from aspire.source import Simulation from aspire.utils import ( Rotation, @@ -11,6 +17,9 @@ ) from aspire.volume import AsymmetricVolume +DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") + + # Parametrize over (n_img, L, dtype) PARAMS = [ (50, 32, np.float32), @@ -79,7 +88,7 @@ def test_estimate_rotations(n_img, L, dtype): assert np.mean(ang_dist) < 1 -@pytest.mark.xfail(reason="estimate_shifts bug.") +@pytest.mark.xfail(reason="Fails due to estimate_shifts bug.") @pytest.mark.parametrize("n_img, L, dtype", PARAMS) def test_estimate_shifts(n_img, L, dtype): offests = None # Use default random offsets. @@ -89,3 +98,47 @@ def test_estimate_shifts(n_img, L, dtype): # Assert that estimated shifts are close to src.offsets assert np.allclose(est_shifts, src.offsets) + + +def test_theta_error(): + """ + Test that CLSyncVoting when instantiated with odd value for `n_theta` + gives appropriate error. + """ + sim = Simulation() + + # Test we raise with expected error. + with pytest.raises(NotImplementedError, match=r"n_theta must be even*"): + _ = CLSyncVoting(sim, 16, 35) + + +def test_n_check_error(): + """Test we get expected error when n_check is out of range.""" + sim = Simulation() + + with pytest.raises(NotImplementedError, match=r"n_check must be in*"): + _ = CLOrient3D(sim, n_check=-2) + with pytest.raises(NotImplementedError, match=r"n_check must be in*"): + _ = CLOrient3D(sim, n_check=sim.n + 1) + + +def test_command_line(): + # Ensure that the command line tool works as expected + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + # Save the simulation object into STAR and MRCS files + starfile_out = os.path.join(tmpdir, "save_test.star") + starfile_in = os.path.join(DATA_DIR, "sample_particles_relion31.star") + result = runner.invoke( + orient3d, + [ + f"--starfile_in={starfile_in}", + "--n_rad=10", + "--n_theta=60", + "--max_shift=0.15", + "--shift_step=1", + f"--starfile_out={starfile_out}", + ], + ) + # check that the command completed successfully + assert result.exit_code == 0 From 1e7e6f22b0ff7bf8939bea77601a85263c2b847d Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 3 Aug 2023 16:02:41 -0400 Subject: [PATCH 10/53] testing for PolarFT. --- tests/test_PolarBasis2D.py | 8 +-- tests/test_orient_sync.py | 131 ------------------------------------- tests/test_polar_ft.py | 126 +++++++++++++++++++++++++++++++++++ 3 files changed, 130 insertions(+), 135 deletions(-) delete mode 100644 tests/test_orient_sync.py create mode 100644 tests/test_polar_ft.py diff --git a/tests/test_PolarBasis2D.py b/tests/test_PolarBasis2D.py index e60dd9af7e..2ca97eccb3 100644 --- a/tests/test_PolarBasis2D.py +++ b/tests/test_PolarBasis2D.py @@ -1,10 +1,10 @@ -import logging +import loggingA from unittest import TestCase import numpy as np from pytest import raises -from aspire.basis import PolarBasis2D +from aspire.basis import PolarFT from aspire.image import Image from aspire.utils import complex_type, utest_tolerance from aspire.utils.matlab_compat import m_reshape @@ -19,7 +19,7 @@ class PolarBasis2DTestCase(TestCase, UniversalBasisMixin): def setUp(self): self.dtype = np.float32 self.L = 8 - self.basis = PolarBasis2D((self.L, self.L), 4, 32, dtype=self.dtype) + self.basis = PolarFT((self.L, self.L), 4, 32, dtype=self.dtype) # Note, in practice we got a degenerate random array around 1% # of the time, so we fix a seed for the randn calls. self.seed = 8675309 @@ -429,7 +429,7 @@ def testPolarBasis2DAdjoint(self): # (y, A*x) = (A^t*y, x) = (B*y, x) x = randn(self.basis.count, seed=self.seed).astype(self.dtype) - x_t = self.basis.evaluate(x).asnumpy() + x_t = self.basis._evaluate(x).asnumpy() y = randn(np.prod(self.basis.sz), seed=self.seed).astype(self.dtype) y_t = self.basis.evaluate_t(Image(np.reshape(y, self.basis.sz))) diff --git a/tests/test_orient_sync.py b/tests/test_orient_sync.py deleted file mode 100644 index 3053143287..0000000000 --- a/tests/test_orient_sync.py +++ /dev/null @@ -1,131 +0,0 @@ -import os -import os.path -import tempfile -from unittest import TestCase - -import numpy as np -from click.testing import CliRunner -from pytest import raises - -from aspire.abinitio import CLOrient3D, CLSyncVoting -from aspire.commands.orient3d import orient3d -from aspire.operators import RadialCTFFilter -from aspire.source.simulation import Simulation -from aspire.utils import utest_tolerance -from aspire.utils.random import Random -from aspire.volume import Volume - -DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") - - -class OrientSyncTestCase(TestCase): - def setUp(self): - L = 32 - n = 64 - pixel_size = 5 - voltage = 200 - defocus_min = 1.5e4 - defocus_max = 2.5e4 - defocus_ct = 7 - Cs = 2.0 - alpha = 0.1 - self.dtype = np.float32 - - filters = [ - RadialCTFFilter(pixel_size, voltage, defocus=d, Cs=Cs, alpha=alpha) - for d in np.linspace(defocus_min, defocus_max, defocus_ct) - ] - - vols = Volume( - np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")).astype( - self.dtype - ) - ) - vols = vols.downsample(L) - - self.sim = Simulation( - L=L, n=n, vols=vols, unique_filters=filters, dtype=self.dtype - ) - - self.orient_est = CLSyncVoting(self.sim, L // 2, 36) - - def tearDown(self): - pass - - def testBuildCLmatrix(self): - self.orient_est.build_clmatrix() - results = np.load(os.path.join(DATA_DIR, "orient_est_clmatrix.npy")) - self.assertTrue(np.allclose(results, self.orient_est.clmatrix)) - - def testSyncMatrixVote(self): - self.orient_est.syncmatrix_vote() - results = np.load(os.path.join(DATA_DIR, "orient_est_smatrix.npy")) - self.assertTrue( - np.allclose( - results, - self.orient_est.syncmatrix, - atol=1e-5 if self.dtype == np.float32 else 1e-8, - ) - ) - - def testEstRotations(self): - self.orient_est.estimate_rotations() - results = np.load(os.path.join(DATA_DIR, "orient_est_rots.npy")) - # Check the dtype passthrough is preserved - self.assertTrue(self.orient_est.rotations.dtype == self.dtype) - # Check the values match reference rotation - self.assertTrue( - np.allclose( - results, - self.orient_est.rotations, - atol=1e-5 if self.dtype == np.float32 else 1e-8, - ) - ) - - def testEstShifts(self): - # need to rerun explicitly the estimation of rotations - self.orient_est.estimate_rotations() - with Random(0): - self.est_shifts = self.orient_est.estimate_shifts() - results = np.load(os.path.join(DATA_DIR, "orient_est_shifts.npy")) - self.assertTrue( - np.allclose(results, self.est_shifts, atol=utest_tolerance(self.dtype)) - ) - - def testThetaError(self): - """ - Test that CLSyncVoting when instantiated with odd value for `n_theta` - gives appropriate error. - """ - - # Test we raise with expected error. - with raises(NotImplementedError, match=r"n_theta must be even*"): - _ = CLSyncVoting(self.sim, 16, 35) - - def testNCheckError(self): - """Test we get expected error when n_check is out of range.""" - with raises(NotImplementedError, match=r"n_check must be in*"): - _ = CLOrient3D(self.sim, n_check=-2) - with raises(NotImplementedError, match=r"n_check must be in*"): - _ = CLOrient3D(self.sim, n_check=self.sim.n + 1) - - def testCommandLine(self): - # Ensure that the command line tool works as expected - runner = CliRunner() - with tempfile.TemporaryDirectory() as tmpdir: - # Save the simulation object into STAR and MRCS files - starfile_out = os.path.join(tmpdir, "save_test.star") - starfile_in = os.path.join(DATA_DIR, "sample_particles_relion31.star") - result = runner.invoke( - orient3d, - [ - f"--starfile_in={starfile_in}", - "--n_rad=10", - "--n_theta=60", - "--max_shift=0.15", - "--shift_step=1", - f"--starfile_out={starfile_out}", - ], - ) - # check that the command completed successfully - self.assertTrue(result.exit_code == 0) diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py new file mode 100644 index 0000000000..f5c4bab2e4 --- /dev/null +++ b/tests/test_polar_ft.py @@ -0,0 +1,126 @@ +import numpy as np +import pytest + +from aspire.basis import PolarFT +from aspire.image import Image +from aspire.utils import gaussian_2d, randn +from aspire.volume import AsymmetricVolume, CnSymmetricVolume + +# Parametrize over (resolution, dtype) +IMG_SIZES = [ + 64, + 65, +] +DTYPES = [ + np.float64, + np.float32, +] + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") +def dtype(request): + return request.param + + +@pytest.fixture(params=IMG_SIZES, ids=lambda x: f"img_size={x}") +def img_size(request): + return request.param + + +@pytest.fixture +def gaussian(img_size, dtype): + """Radially symmetric image.""" + gauss = Image( + gaussian_2d(img_size, sigma=(img_size // 10, img_size // 10), dtype=dtype) + ) + pf, _ = pf_transform(gauss) + + return gauss, pf + + +@pytest.fixture +def symmetric_image(img_size, dtype): + """Cyclically (C4) symmetric image.""" + symmetric_vol = CnSymmetricVolume( + img_size, C=1, order=4, K=25, seed=10, dtype=dtype + ).generate() + symmetric_image = symmetric_vol.project(np.eye(3, dtype=dtype)) + pf, pft = pf_transform(symmetric_image) + pf_inverse = pft._evaluate(pf.reshape(-1)) + + return symmetric_image, pf, pf_inverse + + +@pytest.fixture +def asymmetric_image(img_size, dtype): + """Asymetric image.""" + asymmetric_vol = AsymmetricVolume(img_size, C=1, dtype=dtype).generate() + asymmetric_image = asymmetric_vol.project(np.eye(3, dtype=dtype)) + pf, _ = pf_transform(asymmetric_image) + + return asymmetric_image, pf + + +def pf_transform(image): + """Take polar Fourier transform of image.""" + img_size = image.resolution + nrad = img_size // 2 + ntheta = 8 * nrad + pft = PolarFT(img_size, nrad=nrad, ntheta=ntheta, dtype=image.dtype) + pf = pft.evaluate_t(image) + pf = pf.reshape(ntheta // 2, nrad) + + return pf, pft + + +def test_dc_component(asymmetric_image): + """Test that the DC component equals the mean of the signal.""" + image, pf = asymmetric_image + signal_mean = np.mean(image) + dc_components = abs(pf[:, 0]) + + assert np.allclose(dc_components, signal_mean) + + +def test_radially_symmetric_image(gaussian): + """Test that all polar Fourier rays are equal for a radially symmetric image.""" + _, pf = gaussian + + assert np.allclose(pf, pf[0]) + + +def test_cyclically_symmetric_image(symmetric_image): + """Test that a symmetric image produces repeated sets of polar Fourier rays.""" + _, pf, _ = symmetric_image + + # For C4 symmetry any two sets of rays seperated by 90 degrees should be equal. + ntheta = pf.shape[0] # ntheta is the number of rays in 180 degrees. + + assert np.allclose(abs(pf[: ntheta // 2]), abs(pf[ntheta // 2 :]), atol=1e-7) + + +def test_adjoint_property(asymmetric_image, symmetric_image): + # The evaluate function should be the adjoint operator of evaluate_t. + # Namely, if A = evaluate, B = evaluate_t, and B=A^t, we will have + # (y, A*x) = (A^t*y, x) = (B*y, x) + y, By = asymmetric_image + _, x, Ax = symmetric_image + + lhs = y.asnumpy().reshape(-1) @ Ax.reshape(-1) + rhs = np.real(By.reshape(-1) @ x.reshape(-1)) + + if y.resolution % 2 == 1: + pytest.skip("Currently failling for odd resolution.") + + assert np.allclose(lhs, rhs) + + +def test_theta_error(): + """ + Test that `PolarFT`, when instantiated with odd value for `ntheta`, + gives appropriate error. + """ + + # Test we raise with expected error. + with pytest.raises(NotImplementedError, match=r"Only even values for ntheta*"): + _ = PolarFT(size=42, ntheta=143, dtype=np.float32) From 2dae2f7acee1bf9e687effb2f1041d50e22fdecd Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 3 Aug 2023 16:13:21 -0400 Subject: [PATCH 11/53] tox --- src/aspire/basis/polar_2d.py | 3 +- tests/test_PolarBasis2D.py | 462 ------------------------------- tests/test_orient_sync_voting.py | 12 +- tests/test_polar_ft.py | 2 +- 4 files changed, 7 insertions(+), 472 deletions(-) delete mode 100644 tests/test_PolarBasis2D.py diff --git a/src/aspire/basis/polar_2d.py b/src/aspire/basis/polar_2d.py index 99bc59b3d2..36187610e1 100644 --- a/src/aspire/basis/polar_2d.py +++ b/src/aspire/basis/polar_2d.py @@ -141,9 +141,8 @@ def evaluate_t(self, x): f"The Image `x` must be real valued. Found dtype {x.dtype}." ) - nimgs = x.shape[0] resolution = x.shape[-1] pf = nufft(x, self.freqs) / resolution**2 - return pf.reshape(nimgs, -1) + return pf.reshape(*stack_shape, -1) diff --git a/tests/test_PolarBasis2D.py b/tests/test_PolarBasis2D.py deleted file mode 100644 index 2ca97eccb3..0000000000 --- a/tests/test_PolarBasis2D.py +++ /dev/null @@ -1,462 +0,0 @@ -import loggingA -from unittest import TestCase - -import numpy as np -from pytest import raises - -from aspire.basis import PolarFT -from aspire.image import Image -from aspire.utils import complex_type, utest_tolerance -from aspire.utils.matlab_compat import m_reshape -from aspire.utils.random import randn - -from ._basis_util import UniversalBasisMixin - -logger = logging.getLogger(__name__) - - -class PolarBasis2DTestCase(TestCase, UniversalBasisMixin): - def setUp(self): - self.dtype = np.float32 - self.L = 8 - self.basis = PolarFT((self.L, self.L), 4, 32, dtype=self.dtype) - # Note, in practice we got a degenerate random array around 1% - # of the time, so we fix a seed for the randn calls. - self.seed = 8675309 - - def tearDown(self): - pass - - def testPolarBasis2DThetaError(self): - """ - Test that PolarBasis2D when instantiated with odd value for `ntheta` - gives appropriate error. - """ - - # Test we raise with expected error. - with raises(NotImplementedError, match=r"Only even values for ntheta*"): - _ = PolarBasis2D(size=self.L, ntheta=143, dtype=self.dtype) - - def testPolarBasis2DEvaluate_t(self): - x = Image( - np.array( - [ - [ - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - -1.08106869e-17, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - ], - [ - 0.00000000e00, - 0.00000000e00, - -6.40456062e-03, - -3.32961020e-03, - -1.36887927e-02, - -5.42770488e-03, - 7.63680861e-03, - 0.00000000e00, - ], - [ - 0.00000000e00, - 3.16377602e-03, - -9.31273350e-03, - 9.46128404e-03, - 1.93239220e-02, - 3.79891953e-02, - 1.06841173e-02, - -2.36467925e-03, - ], - [ - 0.00000000e00, - 1.72736955e-03, - -1.00710814e-02, - 4.93520304e-02, - 3.77702656e-02, - 6.57365438e-02, - 3.94739462e-03, - -4.41228496e-03, - ], - [ - 4.01551066e-18, - -3.08071647e-03, - -1.61670565e-02, - 8.66886286e-02, - 5.09898409e-02, - 7.19313349e-02, - 1.68313715e-02, - 5.19180892e-03, - ], - [ - 0.00000000e00, - 2.87262215e-03, - -3.37732956e-02, - 4.51706505e-02, - 5.72215879e-02, - 4.63553081e-02, - 1.86552175e-03, - 1.12608805e-02, - ], - [ - 0.00000000e00, - 2.77905016e-03, - -2.77499404e-02, - -4.02645374e-02, - -1.54969139e-02, - -1.66229153e-02, - -2.07389259e-02, - 6.64060546e-03, - ], - [ - 0.00000000e00, - 0.00000000e00, - 5.20080934e-03, - -1.06788196e-02, - -1.14761672e-02, - -1.27443126e-02, - -1.15563484e-02, - 0.00000000e00, - ], - ], - dtype=self.dtype, - ) - ) # RCOPT - - pf = self.basis.evaluate_t(x) - result = np.array( - [ - # 0.38243133 + 6.66608316e-18j, - # 0.3249317 - 1.47839074e-01j, - # 0.14819172 + 3.78171168e-03j, - # -0.22808599 + 5.29338933e-02j, - # 0.38243133 + 6.66608316e-18j, - # 0.34595014 - 1.06355385e-01j, - # 0.15519289 - 4.75602164e-02j, - # -0.22401193 + 4.33128746e-03j, - # 0.38243133 + 6.66608316e-18j, - # 0.36957165 - 5.69575709e-02j, - # 0.17389327 - 5.53498385e-02j, - # -0.11601473 - 1.35405676e-02j, - # 0.38243133 + 6.66608316e-18j, - # 0.39045046 - 2.17911945e-03j, - # 0.18146449 - 1.37089189e-02j, - # -0.02110144 + 6.65071497e-03j, - # 0.38243133 + 6.66608316e-18j, - # 0.4063995 + 5.21354967e-02j, - # 0.15674204 + 3.85815662e-02j, - # -0.02886296 + 3.91489615e-02j, - # 0.38243133 + 6.66608316e-18j, - # 0.41872477 + 9.98946906e-02j, - # 0.11862477 + 5.15231952e-02j, - # -0.05298751 + 1.95319478e-02j, - # 0.38243133 + 6.66608316e-18j, - # 0.43013599 + 1.38307796e-01j, - # 0.10075763 + 1.25689289e-02j, - # -0.04052728 - 5.66863498e-02j, - # 0.38243133 + 6.66608316e-18j, - # 0.44144497 + 1.68826980e-01j, - # 0.11446016 - 4.53003874e-02j, - # -0.03546515 - 1.13544145e-01j, - # 0.38243133 + 6.66608316e-18j, - # 0.44960099 + 1.94794929e-01j, - # 0.15053714 - 8.11915305e-02j, - # -0.04800556 - 1.15828804e-01j, - # 0.38243133 + 6.66608316e-18j, - # 0.44872328 + 2.17957567e-01j, - # 0.19116871 - 7.99536373e-02j, - # -0.05683092 - 9.72225058e-02j, - # 0.38243133 + 6.66608316e-18j, - # 0.43379428 + 2.36681249e-01j, - # 0.21025378 - 5.48466438e-02j, - # -0.05318826 - 8.54948014e-02j, - # 0.38243133 + 6.66608316e-18j, - # 0.40485577 + 2.47073481e-01j, - # 0.18680217 - 3.31766116e-02j, - # -0.06674163 - 7.94216591e-02j, - # 0.38243133 + 6.66608316e-18j, - # 0.36865853 + 2.45913767e-01j, - # 0.13660805 - 3.68947359e-02j, - # -0.11467046 - 8.49198927e-02j, - # 0.38243133 + 6.66608316e-18j, - # 0.33597018 + 2.32971425e-01j, - # 0.1072859 - 6.24686168e-02j, - # -0.12932565 - 1.06139634e-01j, - # 0.38243133 + 6.66608316e-18j, - # 0.31616666 + 2.10791785e-01j, - # 0.11876919 - 7.93812474e-02j, - # -0.1094488 - 1.20159845e-01j, - # 0.38243133 + 6.66608316e-18j, - # 0.31313975 + 1.82190396e-01j, - # 0.14075481 - 5.85637416e-02j, - # -0.15198775 - 1.02156797e-01j, - 0.38243133 - 6.66608316e-18j, - 0.3249317 + 1.47839074e-01j, - 0.14819172 - 3.78171168e-03j, - -0.22808599 - 5.29338933e-02j, - 0.38243133 - 6.66608316e-18j, - 0.34595014 + 1.06355385e-01j, - 0.15519289 + 4.75602164e-02j, - -0.22401193 - 4.33128746e-03j, - 0.38243133 - 6.66608316e-18j, - 0.36957165 + 5.69575709e-02j, - 0.17389327 + 5.53498385e-02j, - -0.11601473 + 1.35405676e-02j, - 0.38243133 - 6.66608316e-18j, - 0.39045046 + 2.17911945e-03j, - 0.18146449 + 1.37089189e-02j, - -0.02110144 - 6.65071497e-03j, - 0.38243133 - 6.66608316e-18j, - 0.4063995 - 5.21354967e-02j, - 0.15674204 - 3.85815662e-02j, - -0.02886296 - 3.91489615e-02j, - 0.38243133 - 6.66608316e-18j, - 0.41872477 - 9.98946906e-02j, - 0.11862477 - 5.15231952e-02j, - -0.05298751 - 1.95319478e-02j, - 0.38243133 - 6.66608316e-18j, - 0.43013599 - 1.38307796e-01j, - 0.10075763 - 1.25689289e-02j, - -0.04052728 + 5.66863498e-02j, - 0.38243133 - 6.66608316e-18j, - 0.44144497 - 1.68826980e-01j, - 0.11446016 + 4.53003874e-02j, - -0.03546515 + 1.13544145e-01j, - 0.38243133 - 6.66608316e-18j, - 0.44960099 - 1.94794929e-01j, - 0.15053714 + 8.11915305e-02j, - -0.04800556 + 1.15828804e-01j, - 0.38243133 - 6.66608316e-18j, - 0.44872328 - 2.17957567e-01j, - 0.19116871 + 7.99536373e-02j, - -0.05683092 + 9.72225058e-02j, - 0.38243133 - 6.66608316e-18j, - 0.43379428 - 2.36681249e-01j, - 0.21025378 + 5.48466438e-02j, - -0.05318826 + 8.54948014e-02j, - 0.38243133 - 6.66608316e-18j, - 0.40485577 - 2.47073481e-01j, - 0.18680217 + 3.31766116e-02j, - -0.06674163 + 7.94216591e-02j, - 0.38243133 - 6.66608316e-18j, - 0.36865853 - 2.45913767e-01j, - 0.13660805 + 3.68947359e-02j, - -0.11467046 + 8.49198927e-02j, - 0.38243133 - 6.66608316e-18j, - 0.33597018 - 2.32971425e-01j, - 0.1072859 + 6.24686168e-02j, - -0.12932565 + 1.06139634e-01j, - 0.38243133 - 6.66608316e-18j, - 0.31616666 - 2.10791785e-01j, - 0.11876919 + 7.93812474e-02j, - -0.1094488 + 1.20159845e-01j, - 0.38243133 - 6.66608316e-18j, - 0.31313975 - 1.82190396e-01j, - 0.14075481 + 5.85637416e-02j, - -0.15198775 + 1.02156797e-01j, - ], - dtype=complex_type(self.dtype), - ) - - self.assertTrue(np.allclose(pf, result)) - - def testPolarBasis2DEvaluate(self): - v = np.array( - [ - 0.38243133 - 6.66608316e-18j, - 0.3249317 + 1.47839074e-01j, - 0.14819172 - 3.78171168e-03j, - -0.22808599 - 5.29338933e-02j, - 0.38243133 - 6.66608316e-18j, - 0.34595014 + 1.06355385e-01j, - 0.15519289 + 4.75602164e-02j, - -0.22401193 - 4.33128746e-03j, - 0.38243133 - 6.66608316e-18j, - 0.36957165 + 5.69575709e-02j, - 0.17389327 + 5.53498385e-02j, - -0.11601473 + 1.35405676e-02j, - 0.38243133 - 6.66608316e-18j, - 0.39045046 + 2.17911945e-03j, - 0.18146449 + 1.37089189e-02j, - -0.02110144 - 6.65071497e-03j, - 0.38243133 - 6.66608316e-18j, - 0.4063995 - 5.21354967e-02j, - 0.15674204 - 3.85815662e-02j, - -0.02886296 - 3.91489615e-02j, - 0.38243133 - 6.66608316e-18j, - 0.41872477 - 9.98946906e-02j, - 0.11862477 - 5.15231952e-02j, - -0.05298751 - 1.95319478e-02j, - 0.38243133 - 6.66608316e-18j, - 0.43013599 - 1.38307796e-01j, - 0.10075763 - 1.25689289e-02j, - -0.04052728 + 5.66863498e-02j, - 0.38243133 - 6.66608316e-18j, - 0.44144497 - 1.68826980e-01j, - 0.11446016 + 4.53003874e-02j, - -0.03546515 + 1.13544145e-01j, - 0.38243133 - 6.66608316e-18j, - 0.44960099 - 1.94794929e-01j, - 0.15053714 + 8.11915305e-02j, - -0.04800556 + 1.15828804e-01j, - 0.38243133 - 6.66608316e-18j, - 0.44872328 - 2.17957567e-01j, - 0.19116871 + 7.99536373e-02j, - -0.05683092 + 9.72225058e-02j, - 0.38243133 - 6.66608316e-18j, - 0.43379428 - 2.36681249e-01j, - 0.21025378 + 5.48466438e-02j, - -0.05318826 + 8.54948014e-02j, - 0.38243133 - 6.66608316e-18j, - 0.40485577 - 2.47073481e-01j, - 0.18680217 + 3.31766116e-02j, - -0.06674163 + 7.94216591e-02j, - 0.38243133 - 6.66608316e-18j, - 0.36865853 - 2.45913767e-01j, - 0.13660805 + 3.68947359e-02j, - -0.11467046 + 8.49198927e-02j, - 0.38243133 - 6.66608316e-18j, - 0.33597018 - 2.32971425e-01j, - 0.1072859 + 6.24686168e-02j, - -0.12932565 + 1.06139634e-01j, - 0.38243133 - 6.66608316e-18j, - 0.31616666 - 2.10791785e-01j, - 0.11876919 + 7.93812474e-02j, - -0.1094488 + 1.20159845e-01j, - 0.38243133 - 6.66608316e-18j, - 0.31313975 - 1.82190396e-01j, - 0.14075481 + 5.85637416e-02j, - -0.15198775 + 1.02156797e-01j, - ], - dtype=complex_type(self.dtype), - ) - - x = self.basis.evaluate(v) - result = np.array( - [ - [ - 9.8593804, - 7.94242903, - 7.23336975, - 7.33314303, - 7.41260132, - 7.59483694, - 7.94830958, - 9.47324547, - ], - [ - 7.27801941, - 8.29797686, - 7.17234599, - 7.31082685, - 7.04347376, - 6.91956664, - 8.12234596, - 8.36258646, - ], - [ - 8.76188511, - 10.69546884, - 8.37029969, - 9.87512737, - 9.73946157, - 6.56646752, - 5.69555713, - 8.77758976, - ], - [ - 10.42069436, - 12.3649092, - 14.23951952, - 20.41736454, - 22.32664939, - 18.11535113, - 7.95059873, - 8.79515046, - ], - [ - 11.23152882, - 12.61468396, - 17.92585027, - 25.82097043, - 26.4633412, - 25.11167661, - 11.90634511, - 9.05131389, - ], - [ - 10.7048523, - 11.73534566, - 16.53838035, - 25.13242621, - 23.58037996, - 21.37129485, - 12.1024389, - 10.26313743, - ], - [ - 8.24162377, - 11.90490143, - 14.82292441, - 19.50174891, - 17.69291969, - 15.06781768, - 10.4669263, - 10.2082326, - ], - [ - 5.26532858, - 9.60999648, - 12.68642275, - 12.42354237, - 10.87648517, - 10.60647963, - 9.11026567, - 8.53250276, - ], - ], - dtype=self.dtype, - ).T # RCOPT - - self.assertTrue(np.allclose(x.asnumpy(), result)) - - def testPolarBasis2DAdjoint(self): - # The evaluate function should be the adjoint operator of evaluate_t. - # Namely, if A = evaluate, B = evaluate_t, and B=A^t, we will have - # (y, A*x) = (A^t*y, x) = (B*y, x) - x = randn(self.basis.count, seed=self.seed).astype(self.dtype) - - x_t = self.basis._evaluate(x).asnumpy() - y = randn(np.prod(self.basis.sz), seed=self.seed).astype(self.dtype) - y_t = self.basis.evaluate_t(Image(np.reshape(y, self.basis.sz))) - - lhs = np.dot(y, np.reshape(x_t, (np.prod(self.basis.sz),))) - rhs = np.real(np.dot(y_t, x)) - logging.debug( - f"lhs: {lhs} rhs: {rhs} absdiff: {np.abs(lhs-rhs)} atol: {utest_tolerance(self.dtype)}" - ) - breakpoint() - self.assertTrue(np.isclose(lhs, rhs, atol=utest_tolerance(self.dtype))) - - # The following functions of UniversalBasisMixin expect a `basis` - # arg to be passed in. When PolarBasis2D tests are parametrized - # over size and dtype, this will be possible by passing in a basis - # automatically via @pytest.mark.parametrize() decorator on the test class - # - # See: test_FBBasis2D and test_FFBBasis2D - # - # for now, pass in the basis we are using - def testEvaluate(self): - super().testEvaluate(self.basis) - - def testEvaluate_t(self): - super().testEvaluate_t(self.basis) - - def testExpand(self): - super().testExpand(self.basis) - - def testInitWithIntSize(self): - super().testInitWithIntSize(self.basis) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 954e1a7205..06196136e7 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -27,7 +27,7 @@ ] -def source_orientation_objs(n_img, L, offsets, dtype): +def source_orientation_objs(n_img, L, dtype, offsets=None): src = Simulation( n=n_img, L=L, @@ -43,8 +43,7 @@ def source_orientation_objs(n_img, L, offsets, dtype): @pytest.mark.parametrize("n_img, L, dtype", PARAMS) def test_build_clmatrix(n_img, L, dtype): - offsets = 0 - src, orient_est = source_orientation_objs(n_img, L, offsets, dtype) + src, orient_est = source_orientation_objs(n_img, L, dtype, offsets=0) # Build clmatrix estimate. orient_est.build_clmatrix() @@ -63,8 +62,7 @@ def test_build_clmatrix(n_img, L, dtype): @pytest.mark.parametrize("n_img, L, dtype", PARAMS) def test_estimate_rotations(n_img, L, dtype): - offsets = 0 - src, orient_est = source_orientation_objs(n_img, L, offsets, dtype) + src, orient_est = source_orientation_objs(n_img, L, dtype, offsets=0) orient_est.estimate_rotations() @@ -91,8 +89,8 @@ def test_estimate_rotations(n_img, L, dtype): @pytest.mark.xfail(reason="Fails due to estimate_shifts bug.") @pytest.mark.parametrize("n_img, L, dtype", PARAMS) def test_estimate_shifts(n_img, L, dtype): - offests = None # Use default random offsets. - src, orient_est = source_orientation_objs(n_img, L, offsets, dtype) + # Use default random offsets. + src, orient_est = source_orientation_objs(n_img, L, dtype) est_shifts = orient_est.estimate_shifts().T diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index f5c4bab2e4..0b37d7f3e8 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -3,7 +3,7 @@ from aspire.basis import PolarFT from aspire.image import Image -from aspire.utils import gaussian_2d, randn +from aspire.utils import gaussian_2d from aspire.volume import AsymmetricVolume, CnSymmetricVolume # Parametrize over (resolution, dtype) From f50af6b13bf86b76ed763ea21c8be9e596d6b608 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 9 Aug 2023 13:53:11 -0400 Subject: [PATCH 12/53] Fix generate_shift_phases. Properly notate common line Low Pass Filter. Add mean_angular_distance method. --- .../tutorials/orient3d_simulation.py | 2 +- src/aspire/abinitio/commonline_base.py | 7 +++-- src/aspire/abinitio/commonline_sync.py | 1 + src/aspire/utils/rotation.py | 30 +++++++++++++++++++ tests/test_orient_sync_voting.py | 4 +-- 5 files changed, 39 insertions(+), 5 deletions(-) diff --git a/gallery/tutorials/tutorials/orient3d_simulation.py b/gallery/tutorials/tutorials/orient3d_simulation.py index 433223e6d2..9a8135c64d 100644 --- a/gallery/tutorials/tutorials/orient3d_simulation.py +++ b/gallery/tutorials/tutorials/orient3d_simulation.py @@ -106,4 +106,4 @@ ) # Basic Check -assert mse_reg < 0.06 +assert mse_reg < 0.05 diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 83e619a37d..38fedeebce 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -444,13 +444,14 @@ def _generate_shift_phase_and_filter(self, r_max, max_shift, shift_step): n_shifts = int(np.ceil(2 * max_shift / shift_step + 1)) # only half of ray - rk = np.arange(-r_max, 0) + rk = np.arange(1, r_max + 1) # Generate all shift phases shifts = -max_shift + shift_step * np.arange(n_shifts) shift_phases = np.exp(np.outer(shifts, -2 * np.pi * 1j * rk / (2 * r_max + 1))) # Set filter for common-line detection h = np.sqrt(np.abs(rk)) * np.exp(-np.square(rk) / (2 * (r_max / 4) ** 2)) + return shifts, shift_phases, h def _generate_index_pairs(self, n_equations): @@ -511,7 +512,9 @@ def _apply_filter_and_norm(self, subscripts, pf, r_max, h): # Note if we'd rather not have the dtype and casting args, # we can control h.dtype instead. np.einsum(subscripts, pf, h, out=pf, dtype=pf.dtype, casting="same_kind") - pf[..., r_max - 1 : r_max + 2] = 0 + + # This is a low pass filter, cutting out the highest frequency. + pf[..., r_max - 1] = 0 pf /= np.linalg.norm(pf, axis=-1)[..., np.newaxis] return pf diff --git a/src/aspire/abinitio/commonline_sync.py b/src/aspire/abinitio/commonline_sync.py index c19640bb35..8d6cd6dbb4 100644 --- a/src/aspire/abinitio/commonline_sync.py +++ b/src/aspire/abinitio/commonline_sync.py @@ -76,6 +76,7 @@ def estimate_rotations(self): # Only need the top 3 eigen-vectors. v = v[:, sort_idx[:3]] + # According to the structure of W^{T} above, the odd rows of V, denoted V1, # are a linear combination of the vectors R_{i}^{1}, i=1,...,K, that is of # column 1 of all rotation matrices. Similarly, the even rows of V, diff --git a/src/aspire/utils/rotation.py b/src/aspire/utils/rotation.py index 4dcd05cec2..5291913b98 100644 --- a/src/aspire/utils/rotation.py +++ b/src/aspire/utils/rotation.py @@ -391,3 +391,33 @@ def angle_dist(r1, r2, dtype=None): theta = max(min(theta, 1), -1) # Clamp theta in [-1,1] dist = np.arccos(theta, dtype=dtype) return dist + + def mean_angular_distance(rots_1, rots_2): + """ + Find the mean angular distance between two sets of rotation matrices. + + :param rots_1: An nx3x3 array of rotation matrices. + :param rots_2: An nx3x3 array of rotation matrices. + + :return: The mean angular distance between rotations. + """ + + dtype = rots_1.dtype + + if len(rots_1) != len(rots_2): + raise ValueError("`rots_1` and `rots_2` must be of equal length.") + + n_rots = len(rots_1) + ang_dist = np.zeros(n_rots, dtype=dtype) + for i in range(n_rots): + ang_dist[i] = ( + Rotation.angle_dist( + rots_1[i], + rots_2[i], + dtype=dtype, + ) + * 180 + / np.pi + ) + + return np.mean(ang_dist) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 06196136e7..832c1970df 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -43,7 +43,7 @@ def source_orientation_objs(n_img, L, dtype, offsets=None): @pytest.mark.parametrize("n_img, L, dtype", PARAMS) def test_build_clmatrix(n_img, L, dtype): - src, orient_est = source_orientation_objs(n_img, L, dtype, offsets=0) + src, orient_est = source_orientation_objs(n_img, L, dtype) # Build clmatrix estimate. orient_est.build_clmatrix() @@ -62,7 +62,7 @@ def test_build_clmatrix(n_img, L, dtype): @pytest.mark.parametrize("n_img, L, dtype", PARAMS) def test_estimate_rotations(n_img, L, dtype): - src, orient_est = source_orientation_objs(n_img, L, dtype, offsets=0) + src, orient_est = source_orientation_objs(n_img, L, dtype) orient_est.estimate_rotations() From c93059b2d6ede92a957b86a94e7133e8f2483234 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 9 Aug 2023 14:04:15 -0400 Subject: [PATCH 13/53] Move offsets param into correct argument. --- tests/test_orient_sync_voting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 832c1970df..06196136e7 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -43,7 +43,7 @@ def source_orientation_objs(n_img, L, dtype, offsets=None): @pytest.mark.parametrize("n_img, L, dtype", PARAMS) def test_build_clmatrix(n_img, L, dtype): - src, orient_est = source_orientation_objs(n_img, L, dtype) + src, orient_est = source_orientation_objs(n_img, L, dtype, offsets=0) # Build clmatrix estimate. orient_est.build_clmatrix() @@ -62,7 +62,7 @@ def test_build_clmatrix(n_img, L, dtype): @pytest.mark.parametrize("n_img, L, dtype", PARAMS) def test_estimate_rotations(n_img, L, dtype): - src, orient_est = source_orientation_objs(n_img, L, dtype) + src, orient_est = source_orientation_objs(n_img, L, dtype, offsets=0) orient_est.estimate_rotations() From ff61545bca1ebfc794979cc123ca3a32539e3621 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 10 Aug 2023 10:05:58 -0400 Subject: [PATCH 14/53] refactor common_line_from_rots. Small edits to estimate_shifts. --- src/aspire/abinitio/commonline_base.py | 5 ++++- src/aspire/utils/coor_trans.py | 6 +++--- tests/test_orient_sync_voting.py | 1 - 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 38fedeebce..a6f640126a 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -239,7 +239,7 @@ def estimate_shifts(self, equations_factor=1, max_memory=4000): if logging.getLogger().isEnabledFor(logging.DEBUG): show = True # Negative sign comes from using -i conversion of Fourier transformation - est_shifts = sparse.linalg.lsqr(shift_equations, -shift_b, show=show)[0] + est_shifts = sparse.linalg.lsqr(shift_equations, shift_b, show=show)[0] est_shifts = est_shifts.reshape((2, self.n_img), order="F") return est_shifts @@ -272,6 +272,9 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): n_theta_half = self.n_theta // 2 n_img = self.n_img + + if self.rotations is None: + self.estimate_rotations() rotations = self.rotations pf = self.pf.copy() diff --git a/src/aspire/utils/coor_trans.py b/src/aspire/utils/coor_trans.py index 219036d5fc..9b9f98a85f 100644 --- a/src/aspire/utils/coor_trans.py +++ b/src/aspire/utils/coor_trans.py @@ -295,9 +295,9 @@ def common_line_from_rots(r1, r2, ell): assert r1.dtype == r2.dtype, "Ambiguous dtypes" - ut = np.dot(r2, r1.T) - alpha_ij = np.arctan2(ut[2, 0], -ut[2, 1]) + np.pi - alpha_ji = np.arctan2(ut[0, 2], -ut[1, 2]) + np.pi + ut = np.dot(r1.T, r2) + alpha_ji = np.arctan2(ut[2, 0], -ut[2, 1]) + np.pi + alpha_ij = np.arctan2(ut[0, 2], -ut[1, 2]) + np.pi ell_ij = alpha_ij * ell / (2 * np.pi) ell_ji = alpha_ji * ell / (2 * np.pi) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 06196136e7..ac0312c421 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -91,7 +91,6 @@ def test_estimate_rotations(n_img, L, dtype): def test_estimate_shifts(n_img, L, dtype): # Use default random offsets. src, orient_est = source_orientation_objs(n_img, L, dtype) - est_shifts = orient_est.estimate_shifts().T # Assert that estimated shifts are close to src.offsets From e10948fd4778719d0bd183da3ee92edc8ff5f150 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 10 Aug 2023 10:48:12 -0400 Subject: [PATCH 15/53] Add offsets to tests. Revert common_line_from_rots. --- src/aspire/abinitio/commonline_base.py | 2 +- src/aspire/utils/coor_trans.py | 6 ++-- tests/test_orient_sync_voting.py | 50 +++++++++++++++----------- 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index a6f640126a..f9c0ca91ee 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -239,7 +239,7 @@ def estimate_shifts(self, equations_factor=1, max_memory=4000): if logging.getLogger().isEnabledFor(logging.DEBUG): show = True # Negative sign comes from using -i conversion of Fourier transformation - est_shifts = sparse.linalg.lsqr(shift_equations, shift_b, show=show)[0] + est_shifts = sparse.linalg.lsqr(shift_equations, -shift_b, show=show)[0] est_shifts = est_shifts.reshape((2, self.n_img), order="F") return est_shifts diff --git a/src/aspire/utils/coor_trans.py b/src/aspire/utils/coor_trans.py index 9b9f98a85f..219036d5fc 100644 --- a/src/aspire/utils/coor_trans.py +++ b/src/aspire/utils/coor_trans.py @@ -295,9 +295,9 @@ def common_line_from_rots(r1, r2, ell): assert r1.dtype == r2.dtype, "Ambiguous dtypes" - ut = np.dot(r1.T, r2) - alpha_ji = np.arctan2(ut[2, 0], -ut[2, 1]) + np.pi - alpha_ij = np.arctan2(ut[0, 2], -ut[1, 2]) + np.pi + ut = np.dot(r2, r1.T) + alpha_ij = np.arctan2(ut[2, 0], -ut[2, 1]) + np.pi + alpha_ji = np.arctan2(ut[0, 2], -ut[1, 2]) + np.pi ell_ij = alpha_ij * ell / (2 * np.pi) ell_ji = alpha_ji * ell / (2 * np.pi) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index ac0312c421..1b5f2ada3e 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -20,10 +20,13 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") -# Parametrize over (n_img, L, dtype) +# Parametrize over (n_img, L, dtype, offsets). +# Note, offsets=None uses default random offsets. PARAMS = [ - (50, 32, np.float32), - (51, 33, np.float64), + (50, 32, np.float32, None), + (51, 33, np.float64, None), + (50, 32, np.float32, 0), + (51, 33, np.float64, 0), ] @@ -41,9 +44,9 @@ def source_orientation_objs(n_img, L, dtype, offsets=None): return src, orient_est -@pytest.mark.parametrize("n_img, L, dtype", PARAMS) -def test_build_clmatrix(n_img, L, dtype): - src, orient_est = source_orientation_objs(n_img, L, dtype, offsets=0) +@pytest.mark.parametrize("n_img, L, dtype, offsets", PARAMS) +def test_build_clmatrix(n_img, L, dtype, offsets): + src, orient_est = source_orientation_objs(n_img, L, dtype, offsets=offsets) # Build clmatrix estimate. orient_est.build_clmatrix() @@ -53,16 +56,19 @@ def test_build_clmatrix(n_img, L, dtype): angle_diffs = abs(orient_est.clmatrix - gt_clmatrix) * 360 / orient_est.n_theta # Count number of estimates within 5 degrees of ground truth. - within_5 = np.count_nonzero(angle_diffs < 5) - within_5 += np.count_nonzero(angle_diffs > 355) + within_5 = np.count_nonzero(angle_diffs < 10) + within_5 += np.count_nonzero(angle_diffs > 350) - # Check that at least 99% of estimates are within 5 degrees. - assert within_5 / angle_diffs.size > 0.99 + # Check that at least 99% (70% with shifts) of estimates are within 5 degrees. + tol = 0.99 + if offsets is None: + tol = 0.70 + assert within_5 / angle_diffs.size > tol -@pytest.mark.parametrize("n_img, L, dtype", PARAMS) -def test_estimate_rotations(n_img, L, dtype): - src, orient_est = source_orientation_objs(n_img, L, dtype, offsets=0) +@pytest.mark.parametrize("n_img, L, dtype, offsets", PARAMS) +def test_estimate_rotations(n_img, L, dtype, offsets): + src, orient_est = source_orientation_objs(n_img, L, dtype, offsets=offsets) orient_est.estimate_rotations() @@ -82,15 +88,19 @@ def test_estimate_rotations(n_img, L, dtype): / np.pi ) - # Assert that mean angular distance is less than 1 degree. - assert np.mean(ang_dist) < 1 + # Assert that mean angular distance is less than 1 degree (6 degrees with shifts). + degree_tol = 1 + if offsets is None: + degree_tol = 6 + assert np.mean(ang_dist) < degree_tol -@pytest.mark.xfail(reason="Fails due to estimate_shifts bug.") -@pytest.mark.parametrize("n_img, L, dtype", PARAMS) -def test_estimate_shifts(n_img, L, dtype): - # Use default random offsets. - src, orient_est = source_orientation_objs(n_img, L, dtype) +@pytest.mark.parametrize("n_img, L, dtype, offsets", PARAMS) +def test_estimate_shifts(n_img, L, dtype, offsets): + if offsets is None: + pytest.xfail("Currently failing under non-zero offsets.") + + src, orient_est = source_orientation_objs(n_img, L, dtype, offsets=offsets) est_shifts = orient_est.estimate_shifts().T # Assert that estimated shifts are close to src.offsets From 7a88b8208f194f307b368b1bd9b0f82044a71c87 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 10 Aug 2023 11:08:19 -0400 Subject: [PATCH 16/53] Small cleanup. --- src/aspire/abinitio/commonline_base.py | 1 - src/aspire/abinitio/commonline_sync.py | 1 - src/aspire/basis/basis.py | 2 -- src/aspire/basis/polar_2d.py | 2 -- tests/test_polar_ft.py | 5 ++++- 5 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index f9c0ca91ee..ce1cc0a7b9 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -454,7 +454,6 @@ def _generate_shift_phase_and_filter(self, r_max, max_shift, shift_step): shift_phases = np.exp(np.outer(shifts, -2 * np.pi * 1j * rk / (2 * r_max + 1))) # Set filter for common-line detection h = np.sqrt(np.abs(rk)) * np.exp(-np.square(rk) / (2 * (r_max / 4) ** 2)) - return shifts, shift_phases, h def _generate_index_pairs(self, n_equations): diff --git a/src/aspire/abinitio/commonline_sync.py b/src/aspire/abinitio/commonline_sync.py index 8d6cd6dbb4..c19640bb35 100644 --- a/src/aspire/abinitio/commonline_sync.py +++ b/src/aspire/abinitio/commonline_sync.py @@ -76,7 +76,6 @@ def estimate_rotations(self): # Only need the top 3 eigen-vectors. v = v[:, sort_idx[:3]] - # According to the structure of W^{T} above, the odd rows of V, denoted V1, # are a linear combination of the vectors R_{i}^{1}, i=1,...,K, that is of # column 1 of all rotation matrices. Similarly, the even rows of V, diff --git a/src/aspire/basis/basis.py b/src/aspire/basis/basis.py index c9057f96fc..9a78b55b9d 100644 --- a/src/aspire/basis/basis.py +++ b/src/aspire/basis/basis.py @@ -99,10 +99,8 @@ def evaluate(self, v): stack_shape = v.shape[:-1] v = v.reshape(-1, self.count) # Compute the transform - x = self._evaluate(v) # Restore stack shape - x = x.reshape(*stack_shape, *self.sz) # Return the appropriate class diff --git a/src/aspire/basis/polar_2d.py b/src/aspire/basis/polar_2d.py index 36187610e1..79cebc0e86 100644 --- a/src/aspire/basis/polar_2d.py +++ b/src/aspire/basis/polar_2d.py @@ -42,8 +42,6 @@ def __init__(self, size, nrad=None, ntheta=None, dtype=np.float32): self._build() - self.count = self.nrad * (self.ntheta // 2) - # this basis has complex coefficients self.coefficient_dtype = complex_type(self.dtype) diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index 0b37d7f3e8..038e4a1b67 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -100,9 +100,12 @@ def test_cyclically_symmetric_image(symmetric_image): def test_adjoint_property(asymmetric_image, symmetric_image): + """Test the adjoint property.""" # The evaluate function should be the adjoint operator of evaluate_t. # Namely, if A = evaluate, B = evaluate_t, and B=A^t, we will have - # (y, A*x) = (A^t*y, x) = (B*y, x) + # (y, A*x) = (A^t*y, x) = (B*y, x). + # There is no significance to using asymmetric_image and symmetric_image + # below, other than that they are different images. y, By = asymmetric_image _, x, Ax = symmetric_image From ffb3b23aefbd5898897cf8ac82a5ba90d4abcdaf Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 14 Aug 2023 09:13:17 -0400 Subject: [PATCH 17/53] test for mean_angular_distance. --- src/aspire/utils/rotation.py | 24 +++++++++++------------- tests/test_orient_symmetric.py | 4 ++-- tests/test_rotation.py | 11 +++++++++++ 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/aspire/utils/rotation.py b/src/aspire/utils/rotation.py index 5291913b98..bed7347de2 100644 --- a/src/aspire/utils/rotation.py +++ b/src/aspire/utils/rotation.py @@ -392,32 +392,30 @@ def angle_dist(r1, r2, dtype=None): dist = np.arccos(theta, dtype=dtype) return dist - def mean_angular_distance(rots_1, rots_2): + @staticmethod + def mean_angular_distance(rots_1, rots_2, dtype=None): """ Find the mean angular distance between two sets of rotation matrices. :param rots_1: An nx3x3 array of rotation matrices. :param rots_2: An nx3x3 array of rotation matrices. + :param dtype: Data type for computation. Default infers dtype from `rots_1`. - :return: The mean angular distance between rotations. + :return: The mean angular distance between rotations in radians. """ - dtype = rots_1.dtype + dtype = np.dtype(dtype or rots_1.dtype) - if len(rots_1) != len(rots_2): - raise ValueError("`rots_1` and `rots_2` must be of equal length.") + if rots_1.shape != rots_2.shape: + raise ValueError("`rots_1` and `rots_2` must have the same shape.") n_rots = len(rots_1) ang_dist = np.zeros(n_rots, dtype=dtype) for i in range(n_rots): - ang_dist[i] = ( - Rotation.angle_dist( - rots_1[i], - rots_2[i], - dtype=dtype, - ) - * 180 - / np.pi + ang_dist[i] = Rotation.angle_dist( + rots_1[i], + rots_2[i], + dtype=dtype, ) return np.mean(ang_dist) diff --git a/tests/test_orient_symmetric.py b/tests/test_orient_symmetric.py index a5e4c1492b..9c01df2921 100644 --- a/tests/test_orient_symmetric.py +++ b/tests/test_orient_symmetric.py @@ -174,7 +174,7 @@ def test_relative_rotations(n_img, L, order, dtype): Rotation.angle_dist(Rij_J, Rij_s_gt), ) angular_distance[idx] = np.min(dist) - mean_angular_distance = np.mean(angular_distance) + mean_angular_distance = np.mean(angular_distance) * 180 / np.pi # Assert that the mean_angular_distance is less than 5 degrees. assert mean_angular_distance < 5 @@ -208,7 +208,7 @@ def test_self_relative_rotations(n_img, L, order, dtype): for i, estimate in enumerate(cases): dist[i] = Rotation.angle_dist(estimate, Rii_gt) angular_distance[i] = dist[np.argmin(dist)] - mean_angular_distance = np.mean(angular_distance) + mean_angular_distance = np.mean(angular_distance) * 180 / np.pi # Check that mean_angular_distance is less than 5 degrees. assert mean_angular_distance < 5 diff --git a/tests/test_rotation.py b/tests/test_rotation.py index 7d7e430097..6abdc4d5fc 100644 --- a/tests/test_rotation.py +++ b/tests/test_rotation.py @@ -136,3 +136,14 @@ def testFromRotvec(self): self.assertTrue(isinstance(rotations, Rotation)) self.assertTrue(rotations.matrices.dtype == self.dtype) self.assertTrue(np.allclose(rotations.matrices, ref_rots.matrices)) + + +def test_mean_angular_distance(): + rots_z = Rotation.about_axis( + "z", [0, np.pi / 4, np.pi / 2], dtype=np.float32 + ).matrices + rots_id = Rotation.about_axis("z", [0, 0, 0], dtype=np.float32).matrices + + mean_ang_dist = Rotation.mean_angular_distance(rots_z, rots_id) + + assert np.allclose(mean_ang_dist, np.pi / 4) From 418040d9607cbdc755d8a1d37b90b0a7806d3107 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 14 Aug 2023 11:39:17 -0400 Subject: [PATCH 18/53] refactor test parametrization. --- tests/test_orient_sync_voting.py | 70 ++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 26 deletions(-) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 1b5f2ada3e..a563d2f143 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -19,22 +19,43 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") +RESOLUTION = [ + 32, + 33, +] + +OFFSETS = [ + None, # Defaults to random offsets. + 0, +] -# Parametrize over (n_img, L, dtype, offsets). -# Note, offsets=None uses default random offsets. -PARAMS = [ - (50, 32, np.float32, None), - (51, 33, np.float64, None), - (50, 32, np.float32, 0), - (51, 33, np.float64, 0), +DTYPES = [ + np.float32, + np.float64, ] -def source_orientation_objs(n_img, L, dtype, offsets=None): +@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}") +def resolution(request): + return request.param + + +@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}") +def offsets(request): + return request.param + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") +def dtype(request): + return request.param + + +@pytest.fixture +def source_orientation_objs(resolution, offsets, dtype): src = Simulation( - n=n_img, - L=L, - vols=AsymmetricVolume(L=L, C=1, K=100).generate(), + n=50, + L=resolution, + vols=AsymmetricVolume(L=resolution, C=1, K=100).generate(), offsets=offsets, amplitudes=1, seed=123, @@ -44,9 +65,8 @@ def source_orientation_objs(n_img, L, dtype, offsets=None): return src, orient_est -@pytest.mark.parametrize("n_img, L, dtype, offsets", PARAMS) -def test_build_clmatrix(n_img, L, dtype, offsets): - src, orient_est = source_orientation_objs(n_img, L, dtype, offsets=offsets) +def test_build_clmatrix(source_orientation_objs): + src, orient_est = source_orientation_objs # Build clmatrix estimate. orient_est.build_clmatrix() @@ -61,14 +81,13 @@ def test_build_clmatrix(n_img, L, dtype, offsets): # Check that at least 99% (70% with shifts) of estimates are within 5 degrees. tol = 0.99 - if offsets is None: + if src.offsets.all() != 0: tol = 0.70 assert within_5 / angle_diffs.size > tol -@pytest.mark.parametrize("n_img, L, dtype, offsets", PARAMS) -def test_estimate_rotations(n_img, L, dtype, offsets): - src, orient_est = source_orientation_objs(n_img, L, dtype, offsets=offsets) +def test_estimate_rotations(source_orientation_objs): + src, orient_est = source_orientation_objs orient_est.estimate_rotations() @@ -76,13 +95,13 @@ def test_estimate_rotations(n_img, L, dtype, offsets): # angular distance between them (in degrees). Q_mat, flag = register_rotations(orient_est.rotations, src.rotations) regrot = get_aligned_rotations(orient_est.rotations, Q_mat, flag) - ang_dist = np.zeros(n_img, dtype=dtype) - for i in range(n_img): + ang_dist = np.zeros(src.n, dtype=src.dtype) + for i in range(src.n): ang_dist[i] = ( Rotation.angle_dist( regrot[i], src.rotations[i], - dtype=dtype, + dtype=src.dtype, ) * 180 / np.pi @@ -90,17 +109,16 @@ def test_estimate_rotations(n_img, L, dtype, offsets): # Assert that mean angular distance is less than 1 degree (6 degrees with shifts). degree_tol = 1 - if offsets is None: + if src.offsets.all() != 0: degree_tol = 6 assert np.mean(ang_dist) < degree_tol -@pytest.mark.parametrize("n_img, L, dtype, offsets", PARAMS) -def test_estimate_shifts(n_img, L, dtype, offsets): - if offsets is None: +def test_estimate_shifts(source_orientation_objs): + src, orient_est = source_orientation_objs + if src.offsets.all() != 0: pytest.xfail("Currently failing under non-zero offsets.") - src, orient_est = source_orientation_objs(n_img, L, dtype, offsets=offsets) est_shifts = orient_est.estimate_shifts().T # Assert that estimated shifts are close to src.offsets From 9c33ac751776869f54f897dfdd02ffc06603ef36 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 14 Aug 2023 12:01:39 -0400 Subject: [PATCH 19/53] typo --- tests/test_polar_ft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index 038e4a1b67..ac0cfd845b 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -113,7 +113,7 @@ def test_adjoint_property(asymmetric_image, symmetric_image): rhs = np.real(By.reshape(-1) @ x.reshape(-1)) if y.resolution % 2 == 1: - pytest.skip("Currently failling for odd resolution.") + pytest.skip("Currently failing for odd resolution.") assert np.allclose(lhs, rhs) From 97a015577f1fc8d9f3b0624a7bddf843e5bb763a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 14 Aug 2023 13:49:12 -0400 Subject: [PATCH 20/53] update test tolerance. --- tests/test_orient_sync_voting.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index a563d2f143..47d3b43cdf 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -58,7 +58,6 @@ def source_orientation_objs(resolution, offsets, dtype): vols=AsymmetricVolume(L=resolution, C=1, K=100).generate(), offsets=offsets, amplitudes=1, - seed=123, ) orient_est = CLSyncVoting(src) @@ -110,7 +109,7 @@ def test_estimate_rotations(source_orientation_objs): # Assert that mean angular distance is less than 1 degree (6 degrees with shifts). degree_tol = 1 if src.offsets.all() != 0: - degree_tol = 6 + degree_tol = 7 assert np.mean(ang_dist) < degree_tol From f0fa11cf8841befac5df4373f44f18b6b05b065b Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 14 Aug 2023 13:50:47 -0400 Subject: [PATCH 21/53] update code comments. --- src/aspire/abinitio/commonline_base.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index ce1cc0a7b9..d21179760b 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -77,17 +77,9 @@ def _build(self): n_theta_half = self.n_theta // 2 self.pf = self.pf.reshape(self.n_img, n_theta_half, self.n_rad) - # The last two dimension of pf is of size n_theta x n_rad. We will convert pf - # into an array of size (n_theta/2) x (n_rad-1), that is, take half of each ray - # through the origin except the DC part, and also take the angles only up to PI. - # This is due to the fact that the original images are real, and thus each ray - # is conjugate symmetric. We therefore gain nothing by taking longer correlations - # (of length 2*n_rad-1 instead of n_rad-1). In the Matlab version, pf is converted to - # the size of (n_theta/2) x (2*n_rad-1) but most of the calculations of build_clmatrix - # and estimate_shifts below only use the size of (n_theta/2) x (n_rad-1). In the - # Python version we will use the size of (n_theta/2) x (n_rad-1) directly and make - # sure every part is using it. By taking shorter correlations we can speed the - # computation by a factor of two. + # We remove the DC the component. pf now has size (n_img) x (n_theta/2) x (n_rad-1), + # with pf[:, :, 0] containing low frequency content and pf[:, :, -1] containing + # high frequency content. self.pf = self.pf[:, :, 1:] def estimate_rotations(self): From 0b7316c53c15438bd42c109ac07ba5aeeabdeaae Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 15 Aug 2023 16:04:14 -0400 Subject: [PATCH 22/53] switch to high pass filter as in matlab. --- gallery/tutorials/tutorials/orient3d_simulation.py | 2 +- src/aspire/abinitio/commonline_base.py | 7 ++++--- tests/test_orient_sync_voting.py | 5 +++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/gallery/tutorials/tutorials/orient3d_simulation.py b/gallery/tutorials/tutorials/orient3d_simulation.py index 9a8135c64d..433223e6d2 100644 --- a/gallery/tutorials/tutorials/orient3d_simulation.py +++ b/gallery/tutorials/tutorials/orient3d_simulation.py @@ -106,4 +106,4 @@ ) # Basic Check -assert mse_reg < 0.05 +assert mse_reg < 0.06 diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index d21179760b..9636d84572 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -438,7 +438,7 @@ def _generate_shift_phase_and_filter(self, r_max, max_shift, shift_step): # Number of shifts to try n_shifts = int(np.ceil(2 * max_shift / shift_step + 1)) - # only half of ray + # only half of ray, excluding the DC component. rk = np.arange(1, r_max + 1) # Generate all shift phases @@ -507,8 +507,9 @@ def _apply_filter_and_norm(self, subscripts, pf, r_max, h): # we can control h.dtype instead. np.einsum(subscripts, pf, h, out=pf, dtype=pf.dtype, casting="same_kind") - # This is a low pass filter, cutting out the highest frequency. - pf[..., r_max - 1] = 0 + # This is a high pass filter, cutting out the lowest frequency + # (DC has already been removed). + pf[..., 0] = 0 pf /= np.linalg.norm(pf, axis=-1)[..., np.newaxis] return pf diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 47d3b43cdf..c97488cfdf 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -58,6 +58,7 @@ def source_orientation_objs(resolution, offsets, dtype): vols=AsymmetricVolume(L=resolution, C=1, K=100).generate(), offsets=offsets, amplitudes=1, + seed=0, ) orient_est = CLSyncVoting(src) @@ -106,10 +107,10 @@ def test_estimate_rotations(source_orientation_objs): / np.pi ) - # Assert that mean angular distance is less than 1 degree (6 degrees with shifts). + # Assert that mean angular distance is less than 1 degree (8 degrees with shifts). degree_tol = 1 if src.offsets.all() != 0: - degree_tol = 7 + degree_tol = 8 assert np.mean(ang_dist) < degree_tol From d234cf2d238dbd9c83fb2913ffe4d44f0dd7a0e2 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 16 Aug 2023 09:23:15 -0400 Subject: [PATCH 23/53] adjust max_shift for 0/random offsets. --- tests/test_orient_sync_voting.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index c97488cfdf..d24a82b4cb 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -61,7 +61,13 @@ def source_orientation_objs(resolution, offsets, dtype): seed=0, ) - orient_est = CLSyncVoting(src) + # Search for common lines over less shifts for 0 offsets. + max_shift = 1 / resolution + if src.offsets.all() != 0: + max_shift = 0.25 + + orient_est = CLSyncVoting(src, max_shift=max_shift) + return src, orient_est From 4a90c90b2455dc55b64a56cf00acbad88fa2d243 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 16 Aug 2023 11:02:02 -0400 Subject: [PATCH 24/53] Increase resolution for tighter test tols. --- src/aspire/abinitio/commonline_base.py | 1 + tests/test_orient_sync_voting.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 9636d84572..d1739f51dd 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -446,6 +446,7 @@ def _generate_shift_phase_and_filter(self, r_max, max_shift, shift_step): shift_phases = np.exp(np.outer(shifts, -2 * np.pi * 1j * rk / (2 * r_max + 1))) # Set filter for common-line detection h = np.sqrt(np.abs(rk)) * np.exp(-np.square(rk) / (2 * (r_max / 4) ** 2)) + return shifts, shift_phases, h def _generate_index_pairs(self, n_equations): diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index d24a82b4cb..f77d3e40f3 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -20,8 +20,8 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") RESOLUTION = [ - 32, - 33, + 40, + 41, ] OFFSETS = [ @@ -82,8 +82,8 @@ def test_build_clmatrix(source_orientation_objs): angle_diffs = abs(orient_est.clmatrix - gt_clmatrix) * 360 / orient_est.n_theta # Count number of estimates within 5 degrees of ground truth. - within_5 = np.count_nonzero(angle_diffs < 10) - within_5 += np.count_nonzero(angle_diffs > 350) + within_5 = np.count_nonzero(angle_diffs < 5) + within_5 += np.count_nonzero(angle_diffs > 355) # Check that at least 99% (70% with shifts) of estimates are within 5 degrees. tol = 0.99 @@ -113,10 +113,10 @@ def test_estimate_rotations(source_orientation_objs): / np.pi ) - # Assert that mean angular distance is less than 1 degree (8 degrees with shifts). + # Assert that mean angular distance is less than 1 degree (5 degrees with shifts). degree_tol = 1 if src.offsets.all() != 0: - degree_tol = 8 + degree_tol = 5 assert np.mean(ang_dist) < degree_tol From 0722f122042659dafed878bdecf65f0d961c2c22 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 16 Aug 2023 14:27:01 -0400 Subject: [PATCH 25/53] Seed the Volume AAAAHHH. --- tests/test_orient_sync_voting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index f77d3e40f3..f6cab24e6d 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -55,7 +55,7 @@ def source_orientation_objs(resolution, offsets, dtype): src = Simulation( n=50, L=resolution, - vols=AsymmetricVolume(L=resolution, C=1, K=100).generate(), + vols=AsymmetricVolume(L=resolution, C=1, K=100, seed=0).generate(), offsets=offsets, amplitudes=1, seed=0, @@ -85,10 +85,10 @@ def test_build_clmatrix(source_orientation_objs): within_5 = np.count_nonzero(angle_diffs < 5) within_5 += np.count_nonzero(angle_diffs > 355) - # Check that at least 99% (70% with shifts) of estimates are within 5 degrees. + # Check that at least 99% (75% with shifts) of estimates are within 5 degrees. tol = 0.99 if src.offsets.all() != 0: - tol = 0.70 + tol = 0.75 assert within_5 / angle_diffs.size > tol From 3a0b103cfed7714e1b965396cdb8bb30f25cec5e Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 16 Aug 2023 14:30:33 -0400 Subject: [PATCH 26/53] Use mean_angular_distance method. --- tests/test_orient_sync_voting.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index f6cab24e6d..1454df0bda 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -101,23 +101,13 @@ def test_estimate_rotations(source_orientation_objs): # angular distance between them (in degrees). Q_mat, flag = register_rotations(orient_est.rotations, src.rotations) regrot = get_aligned_rotations(orient_est.rotations, Q_mat, flag) - ang_dist = np.zeros(src.n, dtype=src.dtype) - for i in range(src.n): - ang_dist[i] = ( - Rotation.angle_dist( - regrot[i], - src.rotations[i], - dtype=src.dtype, - ) - * 180 - / np.pi - ) + mean_ang_dist = Rotation.mean_angular_distance(regrot, src.rotations) * 180 / np.pi # Assert that mean angular distance is less than 1 degree (5 degrees with shifts). degree_tol = 1 if src.offsets.all() != 0: degree_tol = 5 - assert np.mean(ang_dist) < degree_tol + assert mean_ang_dist < degree_tol def test_estimate_shifts(source_orientation_objs): From 08676e22fb6de8ad26230ac745fd0eaf6d9115b3 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 17 Aug 2023 11:51:13 -0400 Subject: [PATCH 27/53] test radial modes of pft. --- tests/test_polar_ft.py | 42 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index ac0cfd845b..dfc2a53543 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -3,7 +3,7 @@ from aspire.basis import PolarFT from aspire.image import Image -from aspire.utils import gaussian_2d +from aspire.utils import gaussian_2d, grid_2d from aspire.volume import AsymmetricVolume, CnSymmetricVolume # Parametrize over (resolution, dtype) @@ -16,6 +16,17 @@ np.float32, ] +RADIAL_MODES = [ + 2, + 3, + 4, + 5, + 8, + 9, + 16, + 17, +] + @pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") def dtype(request): @@ -27,6 +38,11 @@ def img_size(request): return request.param +@pytest.fixture(params=RADIAL_MODES, ids=lambda x: f"radial_mode={x}") +def radial_mode(request): + return request.param + + @pytest.fixture def gaussian(img_size, dtype): """Radially symmetric image.""" @@ -61,6 +77,15 @@ def asymmetric_image(img_size, dtype): return asymmetric_image, pf +@pytest.fixture +def radial_mode_image(img_size, dtype, radial_mode): + g = grid_2d(img_size, dtype=dtype) + image = Image(np.sin(radial_mode * np.pi * g["r"])) + pf, _ = pf_transform(image) + + return pf, radial_mode + + def pf_transform(image): """Take polar Fourier transform of image.""" img_size = image.resolution @@ -99,6 +124,21 @@ def test_cyclically_symmetric_image(symmetric_image): assert np.allclose(abs(pf[: ntheta // 2]), abs(pf[ntheta // 2 :]), atol=1e-7) +def test_radial_modes(radial_mode_image): + pf, mode = radial_mode_image + + # Set DC compenent to zero. + pf[:, 0] = 0 + + # Check that all rays are close. + assert abs(np.real(pf) - np.real(pf[0])).all() < 1e-4 + + # Check that correct mode is most prominent. + # Mode could be off by a pixel depending on resolution and mode. + mode_window = [mode - 1, mode, mode + 1] + assert np.argmax(np.real(pf[3])) in mode_window + + def test_adjoint_property(asymmetric_image, symmetric_image): """Test the adjoint property.""" # The evaluate function should be the adjoint operator of evaluate_t. From df0b9c239750dfbe870cabbca9a78220b7ad47e4 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 17 Aug 2023 12:00:29 -0400 Subject: [PATCH 28/53] evaluate_t ~~> transform --- src/aspire/abinitio/commonline_base.py | 2 +- src/aspire/basis/polar_2d.py | 2 +- tests/test_polar_ft.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index d1739f51dd..5a162e96e9 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -73,7 +73,7 @@ def _build(self): self.pft = PolarFT( (self.n_res, self.n_res), self.n_rad, self.n_theta, dtype=self.dtype ) - self.pf = self.pft.evaluate_t(imgs) + self.pf = self.pft.transform(imgs) n_theta_half = self.n_theta // 2 self.pf = self.pf.reshape(self.n_img, n_theta_half, self.n_rad) diff --git a/src/aspire/basis/polar_2d.py b/src/aspire/basis/polar_2d.py index 79cebc0e86..b7be7a794a 100644 --- a/src/aspire/basis/polar_2d.py +++ b/src/aspire/basis/polar_2d.py @@ -104,7 +104,7 @@ def _evaluate(self, v): return x - def evaluate_t(self, x): + def transform(self, x): """ Evaluate coefficient in polar Fourier grid from those in standard 2D coordinate basis diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index dfc2a53543..428d90ee69 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -92,7 +92,7 @@ def pf_transform(image): nrad = img_size // 2 ntheta = 8 * nrad pft = PolarFT(img_size, nrad=nrad, ntheta=ntheta, dtype=image.dtype) - pf = pft.evaluate_t(image) + pf = pft.transform(image) pf = pf.reshape(ntheta // 2, nrad) return pf, pft From 21a17dc70c600f6c3365018f511b0de792db320b Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 17 Aug 2023 12:07:54 -0400 Subject: [PATCH 29/53] remove evaluate --- src/aspire/basis/polar_2d.py | 15 -------------- tests/test_polar_ft.py | 38 +++++++++--------------------------- 2 files changed, 9 insertions(+), 44 deletions(-) diff --git a/src/aspire/basis/polar_2d.py b/src/aspire/basis/polar_2d.py index b7be7a794a..09024dbcec 100644 --- a/src/aspire/basis/polar_2d.py +++ b/src/aspire/basis/polar_2d.py @@ -89,21 +89,6 @@ def _precomp(self): freqs *= omega0 return freqs - def _evaluate(self, v): - """ - Evaluate coefficients in standard 2D coordinate basis from those in polar Fourier basis - - :param v: A coefficient vector (or an array of coefficient vectors) - in polar Fourier basis to be evaluated. The last dimension must equal to - `self.count`. - :return x: Image instance in standard 2D coordinate basis with - resolution of `self.sz`. - """ - - x = anufft(v, self.freqs, self.sz, real=True) / self.count - - return x - def transform(self, x): """ Evaluate coefficient in polar Fourier grid from those in standard 2D coordinate basis diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index 428d90ee69..9d0f69d4cc 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -49,9 +49,9 @@ def gaussian(img_size, dtype): gauss = Image( gaussian_2d(img_size, sigma=(img_size // 10, img_size // 10), dtype=dtype) ) - pf, _ = pf_transform(gauss) + pf = pf_transform(gauss) - return gauss, pf + return pf @pytest.fixture @@ -61,10 +61,9 @@ def symmetric_image(img_size, dtype): img_size, C=1, order=4, K=25, seed=10, dtype=dtype ).generate() symmetric_image = symmetric_vol.project(np.eye(3, dtype=dtype)) - pf, pft = pf_transform(symmetric_image) - pf_inverse = pft._evaluate(pf.reshape(-1)) + pf = pf_transform(symmetric_image) - return symmetric_image, pf, pf_inverse + return pf @pytest.fixture @@ -72,7 +71,7 @@ def asymmetric_image(img_size, dtype): """Asymetric image.""" asymmetric_vol = AsymmetricVolume(img_size, C=1, dtype=dtype).generate() asymmetric_image = asymmetric_vol.project(np.eye(3, dtype=dtype)) - pf, _ = pf_transform(asymmetric_image) + pf = pf_transform(asymmetric_image) return asymmetric_image, pf @@ -81,7 +80,7 @@ def asymmetric_image(img_size, dtype): def radial_mode_image(img_size, dtype, radial_mode): g = grid_2d(img_size, dtype=dtype) image = Image(np.sin(radial_mode * np.pi * g["r"])) - pf, _ = pf_transform(image) + pf = pf_transform(image) return pf, radial_mode @@ -95,7 +94,7 @@ def pf_transform(image): pf = pft.transform(image) pf = pf.reshape(ntheta // 2, nrad) - return pf, pft + return pf def test_dc_component(asymmetric_image): @@ -109,14 +108,14 @@ def test_dc_component(asymmetric_image): def test_radially_symmetric_image(gaussian): """Test that all polar Fourier rays are equal for a radially symmetric image.""" - _, pf = gaussian + pf = gaussian assert np.allclose(pf, pf[0]) def test_cyclically_symmetric_image(symmetric_image): """Test that a symmetric image produces repeated sets of polar Fourier rays.""" - _, pf, _ = symmetric_image + pf = symmetric_image # For C4 symmetry any two sets of rays seperated by 90 degrees should be equal. ntheta = pf.shape[0] # ntheta is the number of rays in 180 degrees. @@ -139,25 +138,6 @@ def test_radial_modes(radial_mode_image): assert np.argmax(np.real(pf[3])) in mode_window -def test_adjoint_property(asymmetric_image, symmetric_image): - """Test the adjoint property.""" - # The evaluate function should be the adjoint operator of evaluate_t. - # Namely, if A = evaluate, B = evaluate_t, and B=A^t, we will have - # (y, A*x) = (A^t*y, x) = (B*y, x). - # There is no significance to using asymmetric_image and symmetric_image - # below, other than that they are different images. - y, By = asymmetric_image - _, x, Ax = symmetric_image - - lhs = y.asnumpy().reshape(-1) @ Ax.reshape(-1) - rhs = np.real(By.reshape(-1) @ x.reshape(-1)) - - if y.resolution % 2 == 1: - pytest.skip("Currently failing for odd resolution.") - - assert np.allclose(lhs, rhs) - - def test_theta_error(): """ Test that `PolarFT`, when instantiated with odd value for `ntheta`, From 168d7f242dc2b8a395d67cd1a65ee5bc4b54e4a1 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 17 Aug 2023 12:18:37 -0400 Subject: [PATCH 30/53] Move polar_ft.py to operators dir. --- src/aspire/abinitio/commonline_base.py | 2 +- src/aspire/basis/__init__.py | 1 - src/aspire/operators/__init__.py | 1 + src/aspire/{basis/polar_2d.py => operators/polar_ft.py} | 0 tests/test_polar_ft.py | 2 +- 5 files changed, 3 insertions(+), 3 deletions(-) rename src/aspire/{basis/polar_2d.py => operators/polar_ft.py} (100%) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 5a162e96e9..80069f7895 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -4,7 +4,7 @@ import numpy as np import scipy.sparse as sparse -from aspire.basis import PolarFT +from aspire.operators import PolarFT from aspire.utils import common_line_from_rots from aspire.utils.random import choice diff --git a/src/aspire/basis/__init__.py b/src/aspire/basis/__init__.py index 5fc8ce9c2a..7d57b57295 100644 --- a/src/aspire/basis/__init__.py +++ b/src/aspire/basis/__init__.py @@ -15,6 +15,5 @@ from .fpswf_2d import FPSWFBasis2D from .fpswf_3d import FPSWFBasis3D from .fspca import FSPCABasis -from .polar_2d import PolarFT from .pswf_2d import PSWFBasis2D from .pswf_3d import PSWFBasis3D diff --git a/src/aspire/operators/__init__.py b/src/aspire/operators/__init__.py index fc6e2e09bc..97e47a4d1d 100644 --- a/src/aspire/operators/__init__.py +++ b/src/aspire/operators/__init__.py @@ -18,4 +18,5 @@ ZeroFilter, evaluate_src_filters_on_grid, ) +from .polar_ft import PolarFT from .wemd import wemd_embed, wemd_norm diff --git a/src/aspire/basis/polar_2d.py b/src/aspire/operators/polar_ft.py similarity index 100% rename from src/aspire/basis/polar_2d.py rename to src/aspire/operators/polar_ft.py diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index 9d0f69d4cc..53f4924e9b 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from aspire.basis import PolarFT +from aspire.operators import PolarFT from aspire.image import Image from aspire.utils import gaussian_2d, grid_2d from aspire.volume import AsymmetricVolume, CnSymmetricVolume From b255fa1f8501bc28a2a2b90676df36311c484c5e Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 17 Aug 2023 12:20:35 -0400 Subject: [PATCH 31/53] Fix imports. --- src/aspire/operators/polar_ft.py | 2 +- tests/test_polar_ft.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aspire/operators/polar_ft.py b/src/aspire/operators/polar_ft.py index 09024dbcec..fd968fa4c7 100644 --- a/src/aspire/operators/polar_ft.py +++ b/src/aspire/operators/polar_ft.py @@ -3,7 +3,7 @@ import numpy as np from aspire.image import Image -from aspire.nufft import anufft, nufft +from aspire.nufft import nufft from aspire.utils import complex_type logger = logging.getLogger(__name__) diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index 53f4924e9b..5002384d9b 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -1,8 +1,8 @@ import numpy as np import pytest -from aspire.operators import PolarFT from aspire.image import Image +from aspire.operators import PolarFT from aspire.utils import gaussian_2d, grid_2d from aspire.volume import AsymmetricVolume, CnSymmetricVolume From b911610e06832dc6206f3bc46fb66eaf485fcc4e Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 17 Aug 2023 12:28:58 -0400 Subject: [PATCH 32/53] Missed import. --- tests/test_averager2d.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_averager2d.py b/tests/test_averager2d.py index 204e8e3722..c09da2c42e 100644 --- a/tests/test_averager2d.py +++ b/tests/test_averager2d.py @@ -9,7 +9,7 @@ import pytest from packaging.version import parse as parse_version -from aspire.basis import FFBBasis2D, PolarFT +from aspire.basis import FFBBasis2D from aspire.classification import ( AligningAverager2D, Averager2D, @@ -18,6 +18,7 @@ BFSReddyChatterjiAverager2D, ReddyChatterjiAverager2D, ) +from aspire.operators import PolarFT from aspire.source import Simulation from aspire.utils import Rotation, num_procs_suggestion from aspire.volume import Volume From d6db07fc3e8b0dde48eaa8c2c05acc965b158809 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 18 Aug 2023 12:22:57 -0400 Subject: [PATCH 33/53] angular_radial_mode test. --- tests/test_polar_ft.py | 57 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index 5002384d9b..bac98e4241 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -6,7 +6,7 @@ from aspire.utils import gaussian_2d, grid_2d from aspire.volume import AsymmetricVolume, CnSymmetricVolume -# Parametrize over (resolution, dtype) +# Parameters IMG_SIZES = [ 64, 65, @@ -27,7 +27,15 @@ 17, ] +ANGULAR_RADIAL_MODES = [ + (4, 5), + (4, 10), + (8, 5), + (8, 10), +] + +# Parameter Fixtures @pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") def dtype(request): return request.param @@ -43,6 +51,12 @@ def radial_mode(request): return request.param +@pytest.fixture(params=ANGULAR_RADIAL_MODES, ids=lambda x: f"angular_radial_mode={x}") +def angular_radial_mode(request): + return request.param + + +# Image and PF fixtures @pytest.fixture def gaussian(img_size, dtype): """Radially symmetric image.""" @@ -85,11 +99,24 @@ def radial_mode_image(img_size, dtype, radial_mode): return pf, radial_mode +@pytest.fixture +def angular_radial_mode_image(dtype, angular_radial_mode): + g = grid_2d(64, dtype=dtype) + angular_mode, radial_mode = angular_radial_mode + image = Image( + np.sin(radial_mode * np.pi * g["r"]) * np.cos(angular_mode * g["phi"]) + ) + pf = pf_transform(image) + + return pf, angular_mode, radial_mode + + +# Helper function def pf_transform(image): """Take polar Fourier transform of image.""" img_size = image.resolution nrad = img_size // 2 - ntheta = 8 * nrad + ntheta = 360 pft = PolarFT(img_size, nrad=nrad, ntheta=ntheta, dtype=image.dtype) pf = pft.transform(image) pf = pf.reshape(ntheta // 2, nrad) @@ -97,6 +124,7 @@ def pf_transform(image): return pf +# Testing suite def test_dc_component(asymmetric_image): """Test that the DC component equals the mean of the signal.""" image, pf = asymmetric_image @@ -130,12 +158,33 @@ def test_radial_modes(radial_mode_image): pf[:, 0] = 0 # Check that all rays are close. - assert abs(np.real(pf) - np.real(pf[0])).all() < 1e-4 + assert abs(abs(pf) - abs(pf[0])).all() < 1e-4 # Check that correct mode is most prominent. # Mode could be off by a pixel depending on resolution and mode. + # Since all rays are close will just test one. mode_window = [mode - 1, mode, mode + 1] - assert np.argmax(np.real(pf[3])) in mode_window + ray = 3 + assert np.argmax(abs(pf[ray])) in mode_window + + +def test_angular_radial_modes(angular_radial_mode_image): + pf, angular_mode, radial_mode = angular_radial_mode_image + + # Exclude first few rays by setting to zero. + pf[:5] = 0 + + # Check that correct mode is most prominent. + # Mode could be off by a pixel depending on resolution and mode. + radial_mode_window = [radial_mode - 1, radial_mode, radial_mode + 1] + + angular_mode = 2 * len(pf) // angular_mode + angular_mode_window = [angular_mode - 1, angular_mode, angular_mode + 1] + + # Find max pixel and check that it lands in window + angular_hit, radial_hit = np.unravel_index(np.argmax(abs(pf)), pf.shape) + assert angular_hit in angular_mode_window + assert radial_hit in radial_mode_window def test_theta_error(): From 69905f14ae3d1b6e13449019d47cb57249af615d Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 18 Aug 2023 15:49:38 -0400 Subject: [PATCH 34/53] Set CLSyncVoting shift_step to less than 1. Improves commonline detection. --- src/aspire/abinitio/commonline_base.py | 2 +- tests/test_orient_sync_voting.py | 22 ++++++++++------------ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 80069f7895..a8f33f553b 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -232,7 +232,7 @@ def estimate_shifts(self, equations_factor=1, max_memory=4000): show = True # Negative sign comes from using -i conversion of Fourier transformation est_shifts = sparse.linalg.lsqr(shift_equations, -shift_b, show=show)[0] - est_shifts = est_shifts.reshape((2, self.n_img), order="F") + est_shifts = est_shifts.reshape((2, self.n_img), order="F").T return est_shifts diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 1454df0bda..5442ac5145 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -24,14 +24,15 @@ 41, ] +# `None` defaults to random offsets. OFFSETS = [ - None, # Defaults to random offsets. 0, + None, ] DTYPES = [ np.float32, - np.float64, + pytest.param(np.float64, marks=pytest.mark.expensive), ] @@ -63,10 +64,11 @@ def source_orientation_objs(resolution, offsets, dtype): # Search for common lines over less shifts for 0 offsets. max_shift = 1 / resolution + shift_step = 1 if src.offsets.all() != 0: - max_shift = 0.25 - - orient_est = CLSyncVoting(src, max_shift=max_shift) + max_shift = 0.20 + shift_step = 0.25 # Reduce shift steps for non-integer offsets of Simulation. + orient_est = CLSyncVoting(src, max_shift=max_shift, shift_step=shift_step) return src, orient_est @@ -85,10 +87,8 @@ def test_build_clmatrix(source_orientation_objs): within_5 = np.count_nonzero(angle_diffs < 5) within_5 += np.count_nonzero(angle_diffs > 355) - # Check that at least 99% (75% with shifts) of estimates are within 5 degrees. - tol = 0.99 - if src.offsets.all() != 0: - tol = 0.75 + # Check that at least 98% of estimates are within 5 degrees. + tol = 0.98 assert within_5 / angle_diffs.size > tol @@ -105,8 +105,6 @@ def test_estimate_rotations(source_orientation_objs): # Assert that mean angular distance is less than 1 degree (5 degrees with shifts). degree_tol = 1 - if src.offsets.all() != 0: - degree_tol = 5 assert mean_ang_dist < degree_tol @@ -115,7 +113,7 @@ def test_estimate_shifts(source_orientation_objs): if src.offsets.all() != 0: pytest.xfail("Currently failing under non-zero offsets.") - est_shifts = orient_est.estimate_shifts().T + est_shifts = orient_est.estimate_shifts() # Assert that estimated shifts are close to src.offsets assert np.allclose(est_shifts, src.offsets) From ab9cc6c9dfcc6697598045869743b81c3bcfb54f Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 23 Aug 2023 08:38:56 -0400 Subject: [PATCH 35/53] Use proper reshape instead of order F with transpose. --- src/aspire/abinitio/commonline_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index a8f33f553b..6f7c722479 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -232,7 +232,7 @@ def estimate_shifts(self, equations_factor=1, max_memory=4000): show = True # Negative sign comes from using -i conversion of Fourier transformation est_shifts = sparse.linalg.lsqr(shift_equations, -shift_b, show=show)[0] - est_shifts = est_shifts.reshape((2, self.n_img), order="F").T + est_shifts = est_shifts.reshape((self.n_img, 2)) return est_shifts From fce7d7b8b18f77dd79f2062ed7be6290dfd5a55f Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 23 Aug 2023 08:49:39 -0400 Subject: [PATCH 36/53] clarifying comment. --- src/aspire/abinitio/commonline_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 6f7c722479..59707dc341 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -265,6 +265,7 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): n_theta_half = self.n_theta // 2 n_img = self.n_img + # `estimate_shifts()` requires that rotations have already been estimated. if self.rotations is None: self.estimate_rotations() rotations = self.rotations From ce158c7c54c67415a8bce6a1c10a969add6ebf5b Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 23 Aug 2023 09:15:57 -0400 Subject: [PATCH 37/53] evaluate_t ~~> transform in warning message. --- src/aspire/operators/polar_ft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aspire/operators/polar_ft.py b/src/aspire/operators/polar_ft.py index fd968fa4c7..a0b0bfd688 100644 --- a/src/aspire/operators/polar_ft.py +++ b/src/aspire/operators/polar_ft.py @@ -101,13 +101,13 @@ def transform(self, x): """ if x.dtype != self.dtype: logger.warning( - f"{self.__class__.__name__}::evaluate_t" + f"{self.__class__.__name__}::transform" f" Inconsistent dtypes x: {x.dtype} self: {self.dtype}" ) if not isinstance(x, Image): logger.warning( - f"{self.__class__.__name__}::evaluate_t" + f"{self.__class__.__name__}::transform" f" passed numpy array instead of {Image}." ) else: From d21ce7117a64a612f035ac8c4e4870b0b1188140 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 23 Aug 2023 09:34:32 -0400 Subject: [PATCH 38/53] Test complex image error. --- tests/test_polar_ft.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index bac98e4241..8f2eb32269 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -187,6 +187,15 @@ def test_angular_radial_modes(angular_radial_mode_image): assert radial_hit in radial_mode_window +def test_complex_image_error(): + """Test that we raise for complex images.""" + img_size = 5 + complex_image = Image(np.ones((img_size, img_size), dtype=np.complex64)) + 2j + pft = PolarFT(size=img_size) + with pytest.raises(TypeError, match=r"The Image `x` must be real valued*"): + _ = pft.transform(complex_image) + + def test_theta_error(): """ Test that `PolarFT`, when instantiated with odd value for `ntheta`, From 8899b365df521767e7dfa26a283233210febe590 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 23 Aug 2023 09:50:31 -0400 Subject: [PATCH 39/53] Clean up rots_to_clmatrix. --- src/aspire/utils/coor_trans.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/aspire/utils/coor_trans.py b/src/aspire/utils/coor_trans.py index 219036d5fc..a29e041966 100644 --- a/src/aspire/utils/coor_trans.py +++ b/src/aspire/utils/coor_trans.py @@ -308,35 +308,35 @@ def common_line_from_rots(r1, r2, ell): return ell_ij, ell_ji -def rots_to_clmatrix(rots, n_theta, match_cl=True): +def rots_to_clmatrix(rots, n_theta): """ Compute the common lines matrix induced by all pairs of rotation matrices, `rots`, provided. :param rots: n_rotsx3x3 array of rotation matrices. - :param n_theta: Number of theta values fo common lines indices. - :param match_cl: Option to match cl_matrix convention from CLOrient3d + :param n_theta: Number of theta values for common lines indices. :return: n_rots x n_rots common lines matrix. """ n_rots = len(rots) cl_matrix = -np.ones((n_rots, n_rots)) for i in range(n_rots): + Ri = rots[i] + Ri3 = Ri[:, 2] for j in range(i + 1, n_rots): - Ri = rots[i] Rj = rots[j] - Ri3 = Ri[:, 2] Rj3 = Rj[:, 2] - xij = Ri.T @ (np.cross(Ri3, Rj3) / np.linalg.norm(np.cross(Ri3, Rj3))) - xji = Rj.T @ (np.cross(Ri3, Rj3) / np.linalg.norm(np.cross(Ri3, Rj3))) + common_axis = np.cross(Ri3, Rj3) / np.linalg.norm(np.cross(Ri3, Rj3)) + xij = Ri.T @ common_axis + xji = Rj.T @ common_axis theta_ij = np.rad2deg(np.arctan2(xij[1], xij[0])) % 360 theta_ji = np.rad2deg(np.arctan2(xji[1], xji[0])) % 360 - if theta_ij > 180 and match_cl: + if theta_ij > 180: theta_ij -= 180 theta_ji -= 180 - cl_matrix[i, j] = round((theta_ij % 360) * n_theta / 360) + cl_matrix[i, j] = round(theta_ij * n_theta / 360) cl_matrix[j, i] = round((theta_ji % 360) * n_theta / 360) return cl_matrix From 33abd7e00972702fb443f74c3bf6d6437efe5e81 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 23 Aug 2023 14:46:29 -0400 Subject: [PATCH 40/53] Clean up test_orient_symmetric.py. --- tests/test_orient_symmetric.py | 35 +++++++++++----------------------- 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/tests/test_orient_symmetric.py b/tests/test_orient_symmetric.py index 9c01df2921..a3c5bee598 100644 --- a/tests/test_orient_symmetric.py +++ b/tests/test_orient_symmetric.py @@ -46,7 +46,7 @@ def source_orientation_objs(n_img, L, order, dtype): vol = CnSymmetricVolume( L=L, C=1, - K=25, + K=50, order=order, seed=65, dtype=dtype, @@ -120,25 +120,10 @@ def test_estimate_rotations(n_img, L, order, dtype): # angular distance between them (in degrees). Q_mat, flag = register_rotations(rots_est, rots_gt_sync) regrot = get_aligned_rotations(rots_est, Q_mat, flag) - ang_dist = np.zeros(n_img, dtype=dtype) - for i in range(n_img): - ang_dist[i] = ( - Rotation.angle_dist( - regrot[i], - rots_gt_sync[i], - dtype=dtype, - ) - * 180 - / np.pi - ) + mean_ang_dist = Rotation.mean_angular_distance(regrot, rots_gt_sync) * 180 / np.pi # Assert mean angular distance is reasonable. - if order == 2: - assert np.mean(ang_dist) < 4 - elif order == 3 or order == 4: - assert np.mean(ang_dist) < 2 - else: - assert np.mean(ang_dist) < 5 + assert mean_ang_dist < 3 @pytest.mark.parametrize("n_img, L, order, dtype", param_list_c3_c4) @@ -207,7 +192,7 @@ def test_self_relative_rotations(n_img, L, order, dtype): cases = np.array([Rii, Rii.T, J_conjugate(Rii), J_conjugate(Rii.T)]) for i, estimate in enumerate(cases): dist[i] = Rotation.angle_dist(estimate, Rii_gt) - angular_distance[i] = dist[np.argmin(dist)] + angular_distance[i] = min(dist) mean_angular_distance = np.mean(angular_distance) * 180 / np.pi # Check that mean_angular_distance is less than 5 degrees. @@ -303,6 +288,7 @@ def test_relative_viewing_directions(n_img, L, order, dtype): # Check that the mean angular difference is within 2 degrees. angle_tol = 2 * np.pi / 180 if order > 4: + breakpoint() angle_tol = 4 * np.pi / 180 assert angular_dist_vijs < angle_tol @@ -368,7 +354,8 @@ def test_commonlines_c2(n_img, L, order, dtype): cl_gt = (cl_gt * 360 / n_theta) % 180 pairs = all_pairs(n_img) - within_2 = 0 + angle_tol = 2 # degrees + within_tol = 0 for i, j in pairs: # For each pair of images the two sets of mutual common-lines in cl, (cl[0,i,j], cl[0,j,i]) # and (cl[1,i,j], cl[1,j,i]), should each match one of the two sets in the ground truth cl_gt. @@ -386,11 +373,11 @@ def test_commonlines_c2(n_img, L, order, dtype): + abs(cl[1, j, i] - cl_gt[0, j, i]) ) min_err = min(err_1, err_2) - if min_err <= 2: - within_2 += 1 + if min_err <= angle_tol: + within_tol += 1 - # Check that at least 90% of estimates are within 5 degrees. - assert within_2 / len(pairs) > 0.90 + # Check that at least 90% of estimates are within `angle_tol` degrees. + assert within_tol / len(pairs) > 0.90 @pytest.mark.parametrize("n_img, L, order, dtype", param_list_c3_c4) From 1d7da2e18f2e49b0c50c5138b95fa4d02b1c7376 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 23 Aug 2023 14:55:44 -0400 Subject: [PATCH 41/53] typo --- tests/test_polar_ft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index 8f2eb32269..76c14c281e 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -154,7 +154,7 @@ def test_cyclically_symmetric_image(symmetric_image): def test_radial_modes(radial_mode_image): pf, mode = radial_mode_image - # Set DC compenent to zero. + # Set DC component to zero. pf[:, 0] = 0 # Check that all rays are close. From 66bb935483cb10bf08f135416ada895c752abc2c Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 24 Aug 2023 10:06:10 -0400 Subject: [PATCH 42/53] remove angular_radial_mode test. --- tests/test_polar_ft.py | 43 ------------------------------------------ 1 file changed, 43 deletions(-) diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index 76c14c281e..63160e1012 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -27,13 +27,6 @@ 17, ] -ANGULAR_RADIAL_MODES = [ - (4, 5), - (4, 10), - (8, 5), - (8, 10), -] - # Parameter Fixtures @pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") @@ -51,11 +44,6 @@ def radial_mode(request): return request.param -@pytest.fixture(params=ANGULAR_RADIAL_MODES, ids=lambda x: f"angular_radial_mode={x}") -def angular_radial_mode(request): - return request.param - - # Image and PF fixtures @pytest.fixture def gaussian(img_size, dtype): @@ -99,18 +87,6 @@ def radial_mode_image(img_size, dtype, radial_mode): return pf, radial_mode -@pytest.fixture -def angular_radial_mode_image(dtype, angular_radial_mode): - g = grid_2d(64, dtype=dtype) - angular_mode, radial_mode = angular_radial_mode - image = Image( - np.sin(radial_mode * np.pi * g["r"]) * np.cos(angular_mode * g["phi"]) - ) - pf = pf_transform(image) - - return pf, angular_mode, radial_mode - - # Helper function def pf_transform(image): """Take polar Fourier transform of image.""" @@ -168,25 +144,6 @@ def test_radial_modes(radial_mode_image): assert np.argmax(abs(pf[ray])) in mode_window -def test_angular_radial_modes(angular_radial_mode_image): - pf, angular_mode, radial_mode = angular_radial_mode_image - - # Exclude first few rays by setting to zero. - pf[:5] = 0 - - # Check that correct mode is most prominent. - # Mode could be off by a pixel depending on resolution and mode. - radial_mode_window = [radial_mode - 1, radial_mode, radial_mode + 1] - - angular_mode = 2 * len(pf) // angular_mode - angular_mode_window = [angular_mode - 1, angular_mode, angular_mode + 1] - - # Find max pixel and check that it lands in window - angular_hit, radial_hit = np.unravel_index(np.argmax(abs(pf)), pf.shape) - assert angular_hit in angular_mode_window - assert radial_hit in radial_mode_window - - def test_complex_image_error(): """Test that we raise for complex images.""" img_size = 5 From 18eaf922a78c8b6315b7c22121ad0c9dd65a44d7 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 24 Aug 2023 10:44:14 -0400 Subject: [PATCH 43/53] remove breakpoint() --- tests/test_orient_symmetric.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_orient_symmetric.py b/tests/test_orient_symmetric.py index a3c5bee598..746213913c 100644 --- a/tests/test_orient_symmetric.py +++ b/tests/test_orient_symmetric.py @@ -288,7 +288,6 @@ def test_relative_viewing_directions(n_img, L, order, dtype): # Check that the mean angular difference is within 2 degrees. angle_tol = 2 * np.pi / 180 if order > 4: - breakpoint() angle_tol = 4 * np.pi / 180 assert angular_dist_vijs < angle_tol From 8a260cecde589de9a6ed05907fcef6b686dc58d7 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 24 Aug 2023 14:18:08 -0400 Subject: [PATCH 44/53] Clean up _precomp() --- src/aspire/operators/polar_ft.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/aspire/operators/polar_ft.py b/src/aspire/operators/polar_ft.py index a0b0bfd688..7bb56bae33 100644 --- a/src/aspire/operators/polar_ft.py +++ b/src/aspire/operators/polar_ft.py @@ -77,17 +77,14 @@ def _precomp(self): dtheta = 2 * np.pi / self.ntheta # only need half size of ntheta - freqs = np.zeros((2, self.nrad * (self.ntheta // 2)), dtype=self.dtype) + freqs = np.zeros((2, self.ntheta // 2, self.nrad), dtype=self.dtype) for i in range(self.ntheta // 2): - freqs[0, i * self.nrad : (i + 1) * self.nrad] = np.arange( - self.nrad - ) * np.cos(i * dtheta) - freqs[1, i * self.nrad : (i + 1) * self.nrad] = np.arange( - self.nrad - ) * np.sin(i * dtheta) - - freqs *= omega0 - return freqs + freqs[0, i] = np.cos(i * dtheta) + freqs[1, i] = np.sin(i * dtheta) + + freqs *= omega0 * np.arange(self.nrad) + + return freqs.reshape(2, -1) def transform(self, x): """ From 6cf829523e1513d6be7cd3670ac733b21f1e1318 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 24 Aug 2023 16:12:49 -0400 Subject: [PATCH 45/53] Add PolarFT.full() method and tests. --- src/aspire/abinitio/commonline_base.py | 4 +-- src/aspire/operators/polar_ft.py | 18 ++++++++++- tests/test_polar_ft.py | 45 ++++++++++++++++++++++++-- 3 files changed, 61 insertions(+), 6 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 59707dc341..a7a044b60e 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -74,10 +74,8 @@ def _build(self): (self.n_res, self.n_res), self.n_rad, self.n_theta, dtype=self.dtype ) self.pf = self.pft.transform(imgs) - n_theta_half = self.n_theta // 2 - self.pf = self.pf.reshape(self.n_img, n_theta_half, self.n_rad) - # We remove the DC the component. pf now has size (n_img) x (n_theta/2) x (n_rad-1), + # We remove the DC the component. pf has size (n_img) x (n_theta/2) x (n_rad-1), # with pf[:, :, 0] containing low frequency content and pf[:, :, -1] containing # high frequency content. self.pf = self.pf[:, :, 1:] diff --git a/src/aspire/operators/polar_ft.py b/src/aspire/operators/polar_ft.py index 7bb56bae33..0e9f190826 100644 --- a/src/aspire/operators/polar_ft.py +++ b/src/aspire/operators/polar_ft.py @@ -45,6 +45,9 @@ def __init__(self, size, nrad=None, ntheta=None, dtype=np.float32): # this basis has complex coefficients self.coefficient_dtype = complex_type(self.dtype) + # Store the "half" transform for use in building the full transform. + self._transform = None + def _build(self): """ Build the internal data structure to 2D polar Fourier grid @@ -125,4 +128,17 @@ def transform(self, x): pf = nufft(x, self.freqs) / resolution**2 - return pf.reshape(*stack_shape, -1) + return pf.reshape(*stack_shape, self.ntheta // 2, self.nrad) + + @staticmethod + def full(pf): + """ + Use the conjugate symmetry of pf to construct the full polar Fourier transform + over all rays in [0, 360). + + :param pf: The precomputed half polar Fourier transform + with shape (*stack_shape, ntheta//2, nrad) + :return: The full polar Fourier transform with shape (*stack_shape, ntheta, nrad) + """ + + return np.concatenate((pf, np.conj(pf)), axis=-2) diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index 63160e1012..a4fa26d169 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -6,7 +6,10 @@ from aspire.utils import gaussian_2d, grid_2d from aspire.volume import AsymmetricVolume, CnSymmetricVolume +# ========== # Parameters +# ========== + IMG_SIZES = [ 64, 65, @@ -28,7 +31,11 @@ ] +# ================== # Parameter Fixtures +# ================== + + @pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") def dtype(request): return request.param @@ -44,7 +51,11 @@ def radial_mode(request): return request.param -# Image and PF fixtures +# ===================== +# Image and PF Fixtures +# ===================== + + @pytest.fixture def gaussian(img_size, dtype): """Radially symmetric image.""" @@ -100,7 +111,11 @@ def pf_transform(image): return pf -# Testing suite +# ============= +# Testing Suite +# ============= + + def test_dc_component(asymmetric_image): """Test that the DC component equals the mean of the signal.""" image, pf = asymmetric_image @@ -162,3 +177,29 @@ def test_theta_error(): # Test we raise with expected error. with pytest.raises(NotImplementedError, match=r"Only even values for ntheta*"): _ = PolarFT(size=42, ntheta=143, dtype=np.float32) + + +@pytest.mark.parametrize("stack_shape", [(5,), (2, 3)]) +def test_full_transform(stack_shape): + """ + Test conjugate symmetry and shape of the full polar Fourier transform. + """ + img_size = 32 + image = Image( + np.random.rand(*stack_shape, img_size, img_size).astype(np.float32, copy=False) + ) + pft = PolarFT(size=img_size) + pf = pft.transform(image) + full_pf = pft.full(pf) + + # Check shape. + assert full_pf.shape == (*stack_shape, pft.ntheta, pft.nrad) + + # Check conjugate symmetry against pf. + assert np.allclose(np.conj(pf), full_pf[..., pft.ntheta // 2 :, :]) + + # Check conjugate symmetry against self. + for ray in range(pft.ntheta // 2): + np.testing.assert_allclose( + full_pf[..., ray, :], np.conj(full_pf[..., ray + pft.ntheta // 2, :]) + ) From c0ad31a23eca5c4501f6dad2210bb537b2b19757 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 24 Aug 2023 16:27:23 -0400 Subject: [PATCH 46/53] change warnings to raises. Add tests. --- src/aspire/operators/polar_ft.py | 4 ++-- tests/test_polar_ft.py | 20 +++++++++++++++++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/aspire/operators/polar_ft.py b/src/aspire/operators/polar_ft.py index 0e9f190826..d5886f8c87 100644 --- a/src/aspire/operators/polar_ft.py +++ b/src/aspire/operators/polar_ft.py @@ -100,13 +100,13 @@ def transform(self, x): corresponds to `x.shape[0]`, and last dimension equals `self.count`. """ if x.dtype != self.dtype: - logger.warning( + raise TypeError( f"{self.__class__.__name__}::transform" f" Inconsistent dtypes x: {x.dtype} self: {self.dtype}" ) if not isinstance(x, Image): - logger.warning( + raise TypeError( f"{self.__class__.__name__}::transform" f" passed numpy array instead of {Image}." ) diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index a4fa26d169..6e0f4e0dd8 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -163,11 +163,29 @@ def test_complex_image_error(): """Test that we raise for complex images.""" img_size = 5 complex_image = Image(np.ones((img_size, img_size), dtype=np.complex64)) + 2j - pft = PolarFT(size=img_size) + pft = PolarFT(size=img_size, dtype=np.complex64) with pytest.raises(TypeError, match=r"The Image `x` must be real valued*"): _ = pft.transform(complex_image) +def test_inconsistent_dtypes_error(): + """Test that we raise for complex images.""" + img_size = 5 + image = np.ones((img_size, img_size), dtype=np.float32) + pft = PolarFT(size=img_size, dtype=np.float32) + with pytest.raises(TypeError, match=r"passed numpy array*"): + _ = pft.transform(image) + + +def test_inconsistent_dtypes_error(): + """Test that we raise for complex images.""" + img_size = 5 + image_np = np.ones((img_size, img_size), dtype=np.float32) + pft = PolarFT(size=img_size, dtype=np.float64) + with pytest.raises(TypeError, match=r"Inconsistent dtypes*"): + _ = pft.transform(image_np) + + def test_theta_error(): """ Test that `PolarFT`, when instantiated with odd value for `ntheta`, From 7babb9db1a035cf5c513c805c97163fef4be64dd Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 24 Aug 2023 16:37:38 -0400 Subject: [PATCH 47/53] mark random offsets test case as expensive. --- tests/test_orient_sync_voting.py | 2 +- tests/test_polar_ft.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 5442ac5145..8b427984ef 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -27,7 +27,7 @@ # `None` defaults to random offsets. OFFSETS = [ 0, - None, + pytest.param(None, marks=pytest.mark.expensive), ] DTYPES = [ diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index 6e0f4e0dd8..31abbaa993 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -168,22 +168,22 @@ def test_complex_image_error(): _ = pft.transform(complex_image) -def test_inconsistent_dtypes_error(): - """Test that we raise for complex images.""" +def test_numpy_array_error(): + """Test that we raise when passed numpy array.""" img_size = 5 - image = np.ones((img_size, img_size), dtype=np.float32) + image_np = np.ones((img_size, img_size), dtype=np.float32) pft = PolarFT(size=img_size, dtype=np.float32) with pytest.raises(TypeError, match=r"passed numpy array*"): - _ = pft.transform(image) + _ = pft.transform(image_np) def test_inconsistent_dtypes_error(): """Test that we raise for complex images.""" img_size = 5 - image_np = np.ones((img_size, img_size), dtype=np.float32) + image = Image(np.ones((img_size, img_size), dtype=np.float32)) pft = PolarFT(size=img_size, dtype=np.float64) with pytest.raises(TypeError, match=r"Inconsistent dtypes*"): - _ = pft.transform(image_np) + _ = pft.transform(image) def test_theta_error(): From cd562a94d73352c57b53fd8a4babe8314cc9e5b3 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 25 Aug 2023 08:40:00 -0400 Subject: [PATCH 48/53] remove unnecessary reshape of pf. --- tests/test_polar_ft.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index 31abbaa993..c58df20136 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -105,8 +105,7 @@ def pf_transform(image): nrad = img_size // 2 ntheta = 360 pft = PolarFT(img_size, nrad=nrad, ntheta=ntheta, dtype=image.dtype) - pf = pft.transform(image) - pf = pf.reshape(ntheta // 2, nrad) + pf = pft.transform(image)[0] return pf From 978ed449fe5204c2fb68da64cb6879e6b8f41f4c Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 28 Aug 2023 10:56:44 -0400 Subject: [PATCH 49/53] address review comments for polar_ft --- src/aspire/operators/polar_ft.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/src/aspire/operators/polar_ft.py b/src/aspire/operators/polar_ft.py index d5886f8c87..2da62b00b6 100644 --- a/src/aspire/operators/polar_ft.py +++ b/src/aspire/operators/polar_ft.py @@ -16,17 +16,15 @@ class PolarFT: def __init__(self, size, nrad=None, ntheta=None, dtype=np.float32): """ - Initialize an object for the 2D polar Fourier grid class. `PolarFT` expects that - images are real and uses only half of the `ntheta` values. Downstream algorithms should - take advantage of the conjugate symmetry of the polar Fourier coefficients if the full - set is needed. + Initialize an object for the polar Fourier transform class. `PolarFT` expects that + images are real and uses only half of the `ntheta` values. - :param size: The shape of the vectors for which to define the grid. + :param size: The shape of the vectors for which to define the transform. May be a 2-tuple or an integer, in which case a square basis is assumed. Currently only square images are supported. :param nrad: The number of points in the radial dimension. Default is resolution // 2. :param ntheta: The number of points in the angular dimension. Default is 8 * nrad. - :param dtype: dtype of polar Fourier grid. + :param dtype: dtype used to compute a polar frequency grid for evaluating the transform.. """ if isinstance(size, int): size = (size, size) @@ -45,9 +43,6 @@ def __init__(self, size, nrad=None, ntheta=None, dtype=np.float32): # this basis has complex coefficients self.coefficient_dtype = complex_type(self.dtype) - # Store the "half" transform for use in building the full transform. - self._transform = None - def _build(self): """ Build the internal data structure to 2D polar Fourier grid @@ -74,7 +69,7 @@ def _build(self): def _precomp(self): """ - Precomute the polar Fourier grid + Precompute the polar Fourier grid. """ omega0 = 2 * np.pi / (2 * self.nrad - 1) dtheta = 2 * np.pi / self.ntheta @@ -101,13 +96,13 @@ def transform(self, x): """ if x.dtype != self.dtype: raise TypeError( - f"{self.__class__.__name__}::transform" + f"{self.__class__.__name__}.transform" f" Inconsistent dtypes x: {x.dtype} self: {self.dtype}" ) if not isinstance(x, Image): raise TypeError( - f"{self.__class__.__name__}::transform" + f"{self.__class__.__name__}.transform" f" passed numpy array instead of {Image}." ) else: From c8ef88c344703b429264e0542527e613a49e834a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 28 Aug 2023 15:25:07 -0400 Subject: [PATCH 50/53] Vectorize Rotation.angle_dist() to accept sets of rotations. --- src/aspire/utils/rotation.py | 70 +++++++++++++++++++++--------------- tests/test_rotation.py | 32 +++++++++-------- 2 files changed, 58 insertions(+), 44 deletions(-) diff --git a/src/aspire/utils/rotation.py b/src/aspire/utils/rotation.py index bed7347de2..11d5f71e29 100644 --- a/src/aspire/utils/rotation.py +++ b/src/aspire/utils/rotation.py @@ -369,36 +369,57 @@ def generate_random_rotations( @staticmethod def angle_dist(r1, r2, dtype=None): """ - Find the angular distance between two rotation matrices. We first compute - the rotation between the two rotation matrices, r = r1 @ r2.T. Then using - the axis-angle representation of r we find the angle between r1 and r2. - - :param r1: A 3x3 rotation matrix - :param r2: A 3x3 rotation matrix + Find the angular distance between two sets of rotation matrices. We first compute + the rotation between the two sets of rotation matrices, r = r1 @ r2.T. Then using + the axis-angle representation of r we find the angle between r1 and r2. Computations + will broadcast in the case of a singleton rotation and a set of rotations. + + :param r1: An nx3x3 array of rotation matrices, a singleton rotation, + or a `Rotation` object. + :param r2: An nx3x3 array of rotation matrices, a singleton rotation, + or a `Rotation` object. :param dtype: Computation datatype. Default `None` infers from `r1`. - :return: The angular distance between r1 and r2 in radians. + :return: The element-wise angular distance between r1 and r2 in radians. """ + if isinstance(r1, Rotation): + r1 = r1.matrices + if isinstance(r2, Rotation): + r2 = r2.matrices + + # Handle singletons. + r1 = r1.reshape(-1, 3, 3) + r2 = r2.reshape(-1, 3, 3) + + # Check that shapes are compatible. + if (r1.shape[0] != 1) and (r2.shape[0] != 1): + if r1.shape[0] != r2.shape[0]: + raise ValueError(f"r1 and r2 are not broadcastable.") dtype = np.dtype(dtype or r1.dtype) - r = r1 @ r2.T - tr_r = np.trace(r, dtype=dtype) - if abs(tr_r - 3.0) <= np.finfo(dtype).resolution: - dist = 0 - else: - theta = (tr_r - 1) / 2 - theta = max(min(theta, 1), -1) # Clamp theta in [-1,1] - dist = np.arccos(theta, dtype=dtype) + r2_T = np.transpose(r2, axes=(0, 2, 1)) + r = r1 @ r2_T + tr_r = np.trace(r, axis1=1, axis2=2, dtype=dtype) + + dist = np.zeros(max(len(r1), len(r2)), dtype=dtype) + non_zero_dist_ind = np.where(abs(tr_r - 3.0) > np.finfo(tr_r.dtype).resolution) + + theta = (tr_r[non_zero_dist_ind] - 1) / 2 + theta = np.maximum(np.minimum(theta, 1), -1) # Clamp theta in [-1,1] + dist[non_zero_dist_ind] = np.arccos(theta, dtype=dtype) return dist @staticmethod def mean_angular_distance(rots_1, rots_2, dtype=None): """ - Find the mean angular distance between two sets of rotation matrices. + Find the mean angular distance between two sets of rotation matrices. Computations + will broadcast in the case of a singleton rotations and a set of rotations. - :param rots_1: An nx3x3 array of rotation matrices. - :param rots_2: An nx3x3 array of rotation matrices. + :param rots_1: An nx3x3 array of rotation matrices, a singleton rotation, + or a `Rotation` object. + :param rots_2: An nx3x3 array of rotation matrices, a singleton rotation, + or a `Rotation` object. :param dtype: Data type for computation. Default infers dtype from `rots_1`. :return: The mean angular distance between rotations in radians. @@ -406,16 +427,7 @@ def mean_angular_distance(rots_1, rots_2, dtype=None): dtype = np.dtype(dtype or rots_1.dtype) - if rots_1.shape != rots_2.shape: - raise ValueError("`rots_1` and `rots_2` must have the same shape.") - - n_rots = len(rots_1) - ang_dist = np.zeros(n_rots, dtype=dtype) - for i in range(n_rots): - ang_dist[i] = Rotation.angle_dist( - rots_1[i], - rots_2[i], - dtype=dtype, - ) + # Shapes will be checked by `angle_dist`. + ang_dist = Rotation.angle_dist(rots_1, rots_2, dtype=dtype) return np.mean(ang_dist) diff --git a/tests/test_rotation.py b/tests/test_rotation.py index 6abdc4d5fc..caffb17f0c 100644 --- a/tests/test_rotation.py +++ b/tests/test_rotation.py @@ -2,6 +2,7 @@ from unittest import TestCase import numpy as np +import pytest from scipy.spatial.transform import Rotation as sp_rot from aspire.utils import Rotation, utest_tolerance @@ -108,21 +109,6 @@ def testSetterGetter(self): def testDtype(self): self.assertTrue(self.dtype == self.rot_obj.dtype) - def testAngleDist(self): - angles = np.array([i * np.pi / 360 for i in range(360)], dtype=self.dtype) - rots = Rotation.about_axis("x", angles, dtype=self.dtype) - - # Calculate the angular distance between the identity, rots[0], - # and rotations by multiples of pi/360 about the x-axis. - # These should be equal to `angles`. - angular_dist = np.zeros(360, dtype=self.dtype) - for i, rot in enumerate(rots): - angular_dist[i] = Rotation.angle_dist(rots[0], rot, self.dtype) - - self.assertTrue( - np.allclose(angles, angular_dist, atol=utest_tolerance(self.dtype)) - ) - def testFromRotvec(self): # Build random rotation vectors. axis = np.array([1, 0, 0], dtype=self.dtype) @@ -138,6 +124,22 @@ def testFromRotvec(self): self.assertTrue(np.allclose(rotations.matrices, ref_rots.matrices)) +def test_angle_dist(): + dtype = np.float32 + angles = np.array([i * np.pi / 360 for i in range(360)], dtype=dtype) + rots = Rotation.about_axis("x", angles, dtype=dtype) + + # Calculate the angular distance between the identity, rots[0], + # and rotations by multiples of pi/360 about the x-axis. + # These should be equal to `angles`. + angular_dist = Rotation.angle_dist(rots[0], rots, dtype) + assert np.allclose(angles, angular_dist, atol=utest_tolerance(dtype)) + + # Test incompatible shape error. + with pytest.raises(ValueError, match=r"r1 and r2 are not broadcastable*"): + _ = Rotation.angle_dist(rots[:3], rots[:5]) + + def test_mean_angular_distance(): rots_z = Rotation.about_axis( "z", [0, np.pi / 4, np.pi / 2], dtype=np.float32 From 507c337c9a7a4aa37d5eb847b02b4f618de78066 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 28 Aug 2023 15:28:16 -0400 Subject: [PATCH 51/53] tox --- src/aspire/utils/rotation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/utils/rotation.py b/src/aspire/utils/rotation.py index 11d5f71e29..08bec4ca3d 100644 --- a/src/aspire/utils/rotation.py +++ b/src/aspire/utils/rotation.py @@ -394,7 +394,7 @@ def angle_dist(r1, r2, dtype=None): # Check that shapes are compatible. if (r1.shape[0] != 1) and (r2.shape[0] != 1): if r1.shape[0] != r2.shape[0]: - raise ValueError(f"r1 and r2 are not broadcastable.") + raise ValueError(r"r1 and r2 are not broadcastable.*") dtype = np.dtype(dtype or r1.dtype) From 6148daab79faaa49e10ffac88f530e9a956a2846 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 29 Aug 2023 09:10:15 -0400 Subject: [PATCH 52/53] full ~~> half_to_full. --- src/aspire/abinitio/commonline_c3_c4.py | 5 +++-- src/aspire/abinitio/commonline_cn.py | 3 ++- src/aspire/operators/polar_ft.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/aspire/abinitio/commonline_c3_c4.py b/src/aspire/abinitio/commonline_c3_c4.py index 59f2599b07..7ddfbde7cd 100644 --- a/src/aspire/abinitio/commonline_c3_c4.py +++ b/src/aspire/abinitio/commonline_c3_c4.py @@ -4,6 +4,7 @@ from numpy.linalg import eigh, norm, svd from aspire.abinitio import CLOrient3D, SyncVotingMixin +from aspire.operators import PolarFT from aspire.utils import ( J_conjugate, Rotation, @@ -294,7 +295,7 @@ def _estimate_inplane_rotations(self, vis): # Reconstruct the full polar Fourier for use in correlation. self.pf only consists of # rays in the range [180, 360), with shape (n_img, n_theta//2, n_rad-1). - pf = np.concatenate((pf, np.conj(pf)), axis=1) + pf = PolarFT.half_to_full(pf) # Normalize rays. pf /= norm(pf, axis=-1)[..., np.newaxis] @@ -444,7 +445,7 @@ def _self_clmatrix_c3_c4(self): # Reconstruct the full polar Fourier for use in correlation. self.pf only consists of # rays in the range [180, 360), with shape (n_img, n_theta//2, n_rad-1). - pf_full = np.concatenate((pf, np.conj(pf)), axis=1) + pf_full = PolarFT.half_to_full(pf) # The self-common-lines matrix holds two indices per image that represent # the two self common-lines in the image. diff --git a/src/aspire/abinitio/commonline_cn.py b/src/aspire/abinitio/commonline_cn.py index abdfd1e62d..d267a367c1 100644 --- a/src/aspire/abinitio/commonline_cn.py +++ b/src/aspire/abinitio/commonline_cn.py @@ -4,6 +4,7 @@ from numpy.linalg import norm from aspire.abinitio import CLSymmetryC3C4 +from aspire.operators import PolarFT from aspire.utils import ( J_conjugate, Rotation, @@ -123,7 +124,7 @@ def _estimate_relative_viewing_directions(self): # Reconstruct full polar Fourier for use in correlation. pf /= norm(pf, axis=2)[..., np.newaxis] # Normalize each ray. - pf_full = np.concatenate((pf, np.conj(pf)), axis=1) + pf_full = PolarFT.half_to_full(pf) # Pre-compute shifted pf's. pf_shifted = (pf * shift_phases[:, None, None]).swapaxes(0, 1) diff --git a/src/aspire/operators/polar_ft.py b/src/aspire/operators/polar_ft.py index 2da62b00b6..48b08c4c0c 100644 --- a/src/aspire/operators/polar_ft.py +++ b/src/aspire/operators/polar_ft.py @@ -126,7 +126,7 @@ def transform(self, x): return pf.reshape(*stack_shape, self.ntheta // 2, self.nrad) @staticmethod - def full(pf): + def half_to_full(pf): """ Use the conjugate symmetry of pf to construct the full polar Fourier transform over all rays in [0, 360). From 6bfb1d0faab6e5777443696b7ea12e5e774fb248 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 29 Aug 2023 09:32:58 -0400 Subject: [PATCH 53/53] Oooops forgot to change the test. --- tests/test_polar_ft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_polar_ft.py b/tests/test_polar_ft.py index c58df20136..425d5e14bb 100644 --- a/tests/test_polar_ft.py +++ b/tests/test_polar_ft.py @@ -197,7 +197,7 @@ def test_theta_error(): @pytest.mark.parametrize("stack_shape", [(5,), (2, 3)]) -def test_full_transform(stack_shape): +def test_half_to_full_transform(stack_shape): """ Test conjugate symmetry and shape of the full polar Fourier transform. """ @@ -207,7 +207,7 @@ def test_full_transform(stack_shape): ) pft = PolarFT(size=img_size) pf = pft.transform(image) - full_pf = pft.full(pf) + full_pf = pft.half_to_full(pf) # Check shape. assert full_pf.shape == (*stack_shape, pft.ntheta, pft.nrad)