Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/aspire/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,17 +591,21 @@ def filter(self, filter):
# `xp.asarray` because all of the subsequent calls until
# `asnumpy` are GPU when xp and fft in `cupy` mode.
#
# Second note, filter dtype may not match image dtype.
# Second note, filter and grid dtype may not match image dtype,
# upcast both here for most accurate convolution.
filter_values = xp.asarray(
filter.evaluate_grid(self.resolution), dtype=self.dtype
filter.evaluate_grid(self.resolution, dtype=np.float64), dtype=np.float64
)

# Convolve
im_f = fft.centered_fft2(xp.asarray(im._data))
_im = xp.asarray(im._data, dtype=np.float64)
im_f = fft.centered_fft2(_im)
im_f = filter_values * im_f
im = fft.centered_ifft2(im_f)

im = xp.asnumpy(im.real)
im = xp.asnumpy(im.real).astype(
self.dtype, copy=False
) # restore to original dtype

return self.__class__(im, pixel_size=self.pixel_size).stack_reshape(
original_stack_shape
Expand Down
78 changes: 47 additions & 31 deletions src/aspire/operators/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from scipy.interpolate import RegularGridInterpolator

from aspire import config
from aspire.utils import grid_2d, voltage_to_wavelength
from aspire.utils import cart2pol, grid_2d, voltage_to_wavelength

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -406,6 +406,13 @@ def __init__(self, dim=None):


class CTFFilter(Filter):
"""
Reproduce MATLAB's cryo_CTF_relion CTF (Contrast Transfer Function) Filter

Note if comparing to legacy MATLAB cryo_CTF_Relion,
take care regarding defocus unit conversion to/from nm.
"""

def __init__(
self,
pixel_size=1,
Expand Down Expand Up @@ -448,39 +455,48 @@ def __init__(
self._defocus_diff_nm = 0.05 * (self.defocus_u - self.defocus_v)

def _evaluate(self, omega):
# Note the grid is wrt nm.
om_y, om_x = np.vsplit(omega / (2 * np.pi * self.pixel_size / 10), 2)

eps = np.finfo(np.pi).eps
ind_nz = (np.abs(om_x) > eps) | (np.abs(om_y) > eps)
angles_nz = np.arctan2(om_y[ind_nz], om_x[ind_nz])
angles_nz -= self.defocus_ang

defocus = np.zeros_like(om_x)
# Note the division by 2 for _defocus_diff_nm is in `__init__`.
defocus[ind_nz] = self._defocus_mean_nm + self._defocus_diff_nm * np.cos(
2 * angles_nz
)

# Note lambda must be in nm, and `Cs` must be converted from mm to nm.
lambda_nm = self.wavelength / 10
c2 = -np.pi * lambda_nm * defocus
c4 = 0.5 * np.pi * (self.Cs * 1e6) * lambda_nm**3

r2 = om_x**2 + om_y**2
r4 = r2**2
gamma = c2 * r2 + c4 * r4
h = np.sqrt(1 - self.alpha**2) * np.sin(gamma) - self.alpha * np.cos(gamma)

# For historical reference, below is a translated formula from the legacy MATLAB code.
# The two implementations seem to agree for odd images, but the original MATLAB code
# behaves differently for even image sizes.
# h = np.sin(c2*r2 + c4*r2*r2 - self.alpha)
# Reference MATLAB code, includes reference to paper
# Mindell, J. A.; Grigorieff, N. (2003).
# https://github.com/PrincetonUniversity/aspire/blob/760a43b35453e55ff2d9354339e9ffa109a25371/projections/cryo_CTF_Relion.m#L34
#
# s, theta should match MATLAB's RadiusNorm up to a transpose
# To accomplish this given ASPIRE-Python's default `omega` grid,
# we unpack and remove the pi scaling,
# and further rescale the radii `s` by half below.
#
# Additionally we upcast so downstream computations remain in doubles.
x, y = omega.astype(np.float64, copy=False) / np.pi

# Returns radii such that when multiplied by the
# bandwidth of the signal, we get the correct radial frequencies
# corresponding to each pixel in our nxn grid.
theta, s = cart2pol(x, y)
s = s / 2

# Wavelength in nm.
lamb = 1.22639 / np.sqrt(self.voltage * 1000 + 0.97845 * self.voltage**2)

# Divide by 10 to make pixel size in nm. BW is the
# bandwidth of the signal corresponding to the given pixel size.
BW = 1 / (self.pixel_size / 10)

s = s * BW
DFavg = self._defocus_mean_nm # (DefocusU+DefocusV)/2
DFdiff = self._defocus_diff_nm # (DefocusU-DefocusV)
# Note division by 2 is pre-computed in _defocus_diff_nm
df = DFavg + DFdiff * np.cos(2 * (theta - self.defocus_ang))

k2 = np.pi * lamb * df
# 10*6 converts Cs from mm to nm.
k4 = np.pi / 2 * 10**6 * self.Cs * lamb**3
chi = k4 * s**4 - k2 * s**2

h = np.sqrt(1 - self.alpha**2) * np.sin(chi) - self.alpha * np.cos(chi)

if self.B:
h *= np.exp(-self.B * r2)
h *= np.exp(-self.B * s**2)

return h.squeeze()
return h

def scale(self, c=1):
return CTFFilter(
Expand Down
1 change: 1 addition & 0 deletions src/aspire/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .types import complex_type, real_type, utest_tolerance # isort:skip
from .coor_trans import ( # isort:skip
mean_aligned_angular_distance,
cart2pol,
crop_pad_2d,
crop_pad_3d,
grid_1d,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_covar2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ def test_get_covar_ctf(cov2d_fixture, ctf_enabled):

covar_coef_ctf = cov2d.get_covar(coef, h_ctf_fb, h_idx, noise_var=NOISE_VAR)
for im, mat in enumerate(results.tolist()):
np.testing.assert_allclose(mat, covar_coef_ctf[im], rtol=1e-05, atol=1e-08)
# These tolerances were adjusted slightly (1e-8 to 3e-8) to accomodate MATLAB CTF repro changes
np.testing.assert_allclose(mat, covar_coef_ctf[im], rtol=3e-05, atol=3e-08)


def test_get_covar_ctf_shrink(cov2d_fixture, ctf_enabled):
Expand Down
Loading
Loading