diff --git a/src/aspire/basis/fspca.py b/src/aspire/basis/fspca.py index 16500918af..c627a64e94 100644 --- a/src/aspire/basis/fspca.py +++ b/src/aspire/basis/fspca.py @@ -4,7 +4,6 @@ import numpy as np from aspire.basis import Coef, ComplexCoef, FFBBasis2D, SteerableBasis2D -from aspire.covariance import BatchedRotCov2D from aspire.operators import BlkDiagMatrix from aspire.utils import complex_type, fix_signs, real_type @@ -147,6 +146,9 @@ def build(self): self.noise_var = WhiteNoiseEstimator(self.src).estimate() logger.info(f"Setting noise_var={self.noise_var}") + # Import BatchedRotCov2D here to prevent circular imports. + from aspire.covariance import BatchedRotCov2D + cov2d = BatchedRotCov2D( src=self.src, basis=self.basis, batch_size=self.batch_size ) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 4f97c22e7d..14c142bdc5 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -8,8 +8,7 @@ import matplotlib.pyplot as plt import numpy as np -from aspire.nufft import nufft -from aspire.numeric import fft +from aspire import nufft, numeric from aspire.utils import grid_2d, grid_3d logger = logging.getLogger(__name__) @@ -152,8 +151,8 @@ def _fft_correlations(self): ] # Compute centered Fourier transforms. - f1 = fft.centered_fftn(self.a, axes=self._fourier_axes) - f2 = fft.centered_fftn(self.b, axes=self._fourier_axes) + f1 = numeric.fft.centered_fftn(self.a, axes=self._fourier_axes) + f2 = numeric.fft.centered_fftn(self.b, axes=self._fourier_axes) # Construct an output table of correlations correlations = np.zeros( @@ -232,7 +231,7 @@ def _nufft_correlations(self): # Note, we want a complex result. signal = np.vstack((self._a, self._b)) # Compute one large NUFFT for all the signal frames, - f = nufft(signal, fourier_pts, real=False) + f = nufft.nufft(signal, fourier_pts, real=False) # then unpack as two 1D stacks of the polar grid points, one for _a and _b. f = f.reshape(self._a.shape[0] + self._b.shape[0], len(r), -1) f1, f2 = np.vsplit(f, [self._a.shape[0]])