Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/aspire/basis/fspca.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np

from aspire.basis import Coef, ComplexCoef, FFBBasis2D, SteerableBasis2D
from aspire.covariance import BatchedRotCov2D
from aspire.operators import BlkDiagMatrix
from aspire.utils import complex_type, fix_signs, real_type

Expand Down Expand Up @@ -147,6 +146,9 @@ def build(self):
self.noise_var = WhiteNoiseEstimator(self.src).estimate()
logger.info(f"Setting noise_var={self.noise_var}")

# Import BatchedRotCov2D here to prevent circular imports.
from aspire.covariance import BatchedRotCov2D

cov2d = BatchedRotCov2D(
src=self.src, basis=self.basis, batch_size=self.batch_size
)
Expand Down
9 changes: 4 additions & 5 deletions src/aspire/utils/resolution_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import matplotlib.pyplot as plt
import numpy as np

from aspire.nufft import nufft
from aspire.numeric import fft
from aspire import nufft, numeric
from aspire.utils import grid_2d, grid_3d

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -152,8 +151,8 @@ def _fft_correlations(self):
]

# Compute centered Fourier transforms.
f1 = fft.centered_fftn(self.a, axes=self._fourier_axes)
f2 = fft.centered_fftn(self.b, axes=self._fourier_axes)
f1 = numeric.fft.centered_fftn(self.a, axes=self._fourier_axes)
f2 = numeric.fft.centered_fftn(self.b, axes=self._fourier_axes)

# Construct an output table of correlations
correlations = np.zeros(
Expand Down Expand Up @@ -232,7 +231,7 @@ def _nufft_correlations(self):
# Note, we want a complex result.
signal = np.vstack((self._a, self._b))
# Compute one large NUFFT for all the signal frames,
f = nufft(signal, fourier_pts, real=False)
f = nufft.nufft(signal, fourier_pts, real=False)
# then unpack as two 1D stacks of the polar grid points, one for _a and _b.
f = f.reshape(self._a.shape[0] + self._b.shape[0], len(r), -1)
f1, f2 = np.vsplit(f, [self._a.shape[0]])
Expand Down