diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 185b9f798c..63f6a8b45e 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -591,17 +591,21 @@ def filter(self, filter): # `xp.asarray` because all of the subsequent calls until # `asnumpy` are GPU when xp and fft in `cupy` mode. # - # Second note, filter dtype may not match image dtype. + # Second note, filter and grid dtype may not match image dtype, + # upcast both here for most accurate convolution. filter_values = xp.asarray( - filter.evaluate_grid(self.resolution), dtype=self.dtype + filter.evaluate_grid(self.resolution, dtype=np.float64), dtype=np.float64 ) # Convolve - im_f = fft.centered_fft2(xp.asarray(im._data)) + _im = xp.asarray(im._data, dtype=np.float64) + im_f = fft.centered_fft2(_im) im_f = filter_values * im_f im = fft.centered_ifft2(im_f) - im = xp.asnumpy(im.real) + im = xp.asnumpy(im.real).astype( + self.dtype, copy=False + ) # restore to original dtype return self.__class__(im, pixel_size=self.pixel_size).stack_reshape( original_stack_shape diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index b20247861b..d3788d0d1f 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -6,7 +6,7 @@ from scipy.interpolate import RegularGridInterpolator from aspire import config -from aspire.utils import grid_2d, voltage_to_wavelength +from aspire.utils import cart2pol, grid_2d, voltage_to_wavelength logger = logging.getLogger(__name__) @@ -406,6 +406,13 @@ def __init__(self, dim=None): class CTFFilter(Filter): + """ + Reproduce MATLAB's cryo_CTF_relion CTF (Contrast Transfer Function) Filter + + Note if comparing to legacy MATLAB cryo_CTF_Relion, + take care regarding defocus unit conversion to/from nm. + """ + def __init__( self, pixel_size=1, @@ -448,39 +455,48 @@ def __init__( self._defocus_diff_nm = 0.05 * (self.defocus_u - self.defocus_v) def _evaluate(self, omega): - # Note the grid is wrt nm. - om_y, om_x = np.vsplit(omega / (2 * np.pi * self.pixel_size / 10), 2) - - eps = np.finfo(np.pi).eps - ind_nz = (np.abs(om_x) > eps) | (np.abs(om_y) > eps) - angles_nz = np.arctan2(om_y[ind_nz], om_x[ind_nz]) - angles_nz -= self.defocus_ang - - defocus = np.zeros_like(om_x) - # Note the division by 2 for _defocus_diff_nm is in `__init__`. - defocus[ind_nz] = self._defocus_mean_nm + self._defocus_diff_nm * np.cos( - 2 * angles_nz - ) - - # Note lambda must be in nm, and `Cs` must be converted from mm to nm. - lambda_nm = self.wavelength / 10 - c2 = -np.pi * lambda_nm * defocus - c4 = 0.5 * np.pi * (self.Cs * 1e6) * lambda_nm**3 - - r2 = om_x**2 + om_y**2 - r4 = r2**2 - gamma = c2 * r2 + c4 * r4 - h = np.sqrt(1 - self.alpha**2) * np.sin(gamma) - self.alpha * np.cos(gamma) - - # For historical reference, below is a translated formula from the legacy MATLAB code. - # The two implementations seem to agree for odd images, but the original MATLAB code - # behaves differently for even image sizes. - # h = np.sin(c2*r2 + c4*r2*r2 - self.alpha) + # Reference MATLAB code, includes reference to paper + # Mindell, J. A.; Grigorieff, N. (2003). + # https://github.com/PrincetonUniversity/aspire/blob/760a43b35453e55ff2d9354339e9ffa109a25371/projections/cryo_CTF_Relion.m#L34 + # + # s, theta should match MATLAB's RadiusNorm up to a transpose + # To accomplish this given ASPIRE-Python's default `omega` grid, + # we unpack and remove the pi scaling, + # and further rescale the radii `s` by half below. + # + # Additionally we upcast so downstream computations remain in doubles. + x, y = omega.astype(np.float64, copy=False) / np.pi + + # Returns radii such that when multiplied by the + # bandwidth of the signal, we get the correct radial frequencies + # corresponding to each pixel in our nxn grid. + theta, s = cart2pol(x, y) + s = s / 2 + + # Wavelength in nm. + lamb = 1.22639 / np.sqrt(self.voltage * 1000 + 0.97845 * self.voltage**2) + + # Divide by 10 to make pixel size in nm. BW is the + # bandwidth of the signal corresponding to the given pixel size. + BW = 1 / (self.pixel_size / 10) + + s = s * BW + DFavg = self._defocus_mean_nm # (DefocusU+DefocusV)/2 + DFdiff = self._defocus_diff_nm # (DefocusU-DefocusV) + # Note division by 2 is pre-computed in _defocus_diff_nm + df = DFavg + DFdiff * np.cos(2 * (theta - self.defocus_ang)) + + k2 = np.pi * lamb * df + # 10*6 converts Cs from mm to nm. + k4 = np.pi / 2 * 10**6 * self.Cs * lamb**3 + chi = k4 * s**4 - k2 * s**2 + + h = np.sqrt(1 - self.alpha**2) * np.sin(chi) - self.alpha * np.cos(chi) if self.B: - h *= np.exp(-self.B * r2) + h *= np.exp(-self.B * s**2) - return h.squeeze() + return h def scale(self, c=1): return CTFFilter( diff --git a/src/aspire/utils/__init__.py b/src/aspire/utils/__init__.py index b7924f0fb6..ac6c4d2721 100644 --- a/src/aspire/utils/__init__.py +++ b/src/aspire/utils/__init__.py @@ -1,6 +1,7 @@ from .types import complex_type, real_type, utest_tolerance # isort:skip from .coor_trans import ( # isort:skip mean_aligned_angular_distance, + cart2pol, crop_pad_2d, crop_pad_3d, grid_1d, diff --git a/tests/test_covar2d.py b/tests/test_covar2d.py index 6ec5bf0b14..f27bc61855 100644 --- a/tests/test_covar2d.py +++ b/tests/test_covar2d.py @@ -292,7 +292,8 @@ def test_get_covar_ctf(cov2d_fixture, ctf_enabled): covar_coef_ctf = cov2d.get_covar(coef, h_ctf_fb, h_idx, noise_var=NOISE_VAR) for im, mat in enumerate(results.tolist()): - np.testing.assert_allclose(mat, covar_coef_ctf[im], rtol=1e-05, atol=1e-08) + # These tolerances were adjusted slightly (1e-8 to 3e-8) to accomodate MATLAB CTF repro changes + np.testing.assert_allclose(mat, covar_coef_ctf[im], rtol=3e-05, atol=3e-08) def test_get_covar_ctf_shrink(cov2d_fixture, ctf_enabled): diff --git a/tests/test_filters.py b/tests/test_filters.py index bf0bbbc2cb..6b84e29be4 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -22,6 +22,8 @@ class SimTestCase(TestCase): + test_filter = ArrayFilter(np.random.randn(8, 8)) + def setUp(self): self.dtype = np.float32 # A 2 x 256 ndarray of spatial frequencies @@ -93,249 +95,39 @@ def testCTFFilter(self): self.assertEqual(result.shape, (256,)) def testScaledFilter(self): - filt1 = CTFFilter(defocus_u=1.5e4, defocus_v=1.5e4) scale_value = 2.5 - result1 = filt1.evaluate(self.omega) + result1 = self.test_filter.evaluate(self.omega) # ScaledFilter scales the pixel size which cancels out # a corresponding scaling in omega - filt2 = ScaledFilter(filt1, scale_value) + filt2 = ScaledFilter(self.test_filter, scale_value) result2 = filt2.evaluate(self.omega * scale_value) self.assertTrue(np.allclose(result1, result2, atol=utest_tolerance(self.dtype))) - def testCTFScale(self): - filt = CTFFilter(defocus_u=1.5e4, defocus_v=1.5e4) - result1 = filt.evaluate(self.omega) - scale_value = 2.5 - filt = filt.scale(scale_value) - # scaling a CTFFilter scales the pixel size which cancels out - # a corresponding scaling in omega - result2 = filt.evaluate(self.omega * scale_value) - self.assertTrue(np.allclose(result1, result2, atol=utest_tolerance(self.dtype))) - def testRadialCTFFilter(self): filter = RadialCTFFilter(defocus=2.5e4) result = filter.evaluate(self.omega) self.assertEqual(result.shape, (256,)) - def testRadialCTFFilterGrid(self): - # Set legacy pixel size - filter = RadialCTFFilter(pixel_size=10, defocus=2.5e4) - result = filter.evaluate_grid(8, dtype=self.dtype) - - self.assertEqual(result.shape, (8, 8)) - - # Setting tolerence to 1e-4. - # After precision was improved on `voltage_to_wavelength` method this reference array - # is no longer within utest_tolerance: np.max(abs(result - reference)) = 5.2729227306036464e-05 - self.assertTrue( - np.allclose( - result, - np.array( - [ - [ - 0.461755701877834, - -0.995184514498978, - 0.063120922443392, - 0.833250206225063, - 0.961464660252150, - 0.833250206225063, - 0.063120922443392, - -0.995184514498978, - ], - [ - -0.995184514498978, - 0.626977423649552, - 0.799934516166400, - 0.004814348317439, - -0.298096205735759, - 0.004814348317439, - 0.799934516166400, - 0.626977423649552, - ], - [ - 0.063120922443392, - 0.799934516166400, - -0.573061561512667, - -0.999286510416273, - -0.963805291282899, - -0.999286510416273, - -0.573061561512667, - 0.799934516166400, - ], - [ - 0.833250206225063, - 0.004814348317439, - -0.999286510416273, - -0.633095739808868, - -0.368890743119366, - -0.633095739808868, - -0.999286510416273, - 0.004814348317439, - ], - [ - 0.961464660252150, - -0.298096205735759, - -0.963805291282899, - -0.368890743119366, - -0.070000000000000, - -0.368890743119366, - -0.963805291282899, - -0.298096205735759, - ], - [ - 0.833250206225063, - 0.004814348317439, - -0.999286510416273, - -0.633095739808868, - -0.368890743119366, - -0.633095739808868, - -0.999286510416273, - 0.004814348317439, - ], - [ - 0.063120922443392, - 0.799934516166400, - -0.573061561512667, - -0.999286510416273, - -0.963805291282899, - -0.999286510416273, - -0.573061561512667, - 0.799934516166400, - ], - [ - -0.995184514498978, - 0.626977423649552, - 0.799934516166400, - 0.004814348317439, - -0.298096205735759, - 0.004814348317439, - 0.799934516166400, - 0.626977423649552, - ], - ] - ), - atol=1e-4, - ) - ) - - def testRadialCTFFilterMultiplierGrid(self): - # Set legacy pixel size - filter = RadialCTFFilter(pixel_size=10, defocus=2.5e4) * RadialCTFFilter( - pixel_size=10, defocus=2.5e4 - ) - result = filter.evaluate_grid(8, dtype=self.dtype) - - self.assertEqual(result.shape, (8, 8)) - - # Setting tolerence to 1e-4. - # After precision was improved on `voltage_to_wavelength` method this reference array - # is no longer within utest_tolerance: np.max(abs(result - reference)) = 4.869387449749074e-05 - self.assertTrue( - np.allclose( - result, - np.array( - [ - [ - 0.461755701877834, - -0.995184514498978, - 0.063120922443392, - 0.833250206225063, - 0.961464660252150, - 0.833250206225063, - 0.063120922443392, - -0.995184514498978, - ], - [ - -0.995184514498978, - 0.626977423649552, - 0.799934516166400, - 0.004814348317439, - -0.298096205735759, - 0.004814348317439, - 0.799934516166400, - 0.626977423649552, - ], - [ - 0.063120922443392, - 0.799934516166400, - -0.573061561512667, - -0.999286510416273, - -0.963805291282899, - -0.999286510416273, - -0.573061561512667, - 0.799934516166400, - ], - [ - 0.833250206225063, - 0.004814348317439, - -0.999286510416273, - -0.633095739808868, - -0.368890743119366, - -0.633095739808868, - -0.999286510416273, - 0.004814348317439, - ], - [ - 0.961464660252150, - -0.298096205735759, - -0.963805291282899, - -0.368890743119366, - -0.070000000000000, - -0.368890743119366, - -0.963805291282899, - -0.298096205735759, - ], - [ - 0.833250206225063, - 0.004814348317439, - -0.999286510416273, - -0.633095739808868, - -0.368890743119366, - -0.633095739808868, - -0.999286510416273, - 0.004814348317439, - ], - [ - 0.063120922443392, - 0.799934516166400, - -0.573061561512667, - -0.999286510416273, - -0.963805291282899, - -0.999286510416273, - -0.573061561512667, - 0.799934516166400, - ], - [ - -0.995184514498978, - 0.626977423649552, - 0.799934516166400, - 0.004814348317439, - -0.298096205735759, - 0.004814348317439, - 0.799934516166400, - 0.626977423649552, - ], - ] - ) - ** 2, - atol=1e-4, - ) - ) - def testDualFilter(self): - ctf_filter = CTFFilter(defocus_u=1.5e4, defocus_v=1.5e4) - result = ctf_filter.evaluate(-self.omega) - dual_filter = ctf_filter.dual() + result = self.test_filter.evaluate(-self.omega) + dual_filter = self.test_filter.dual() dual_result = dual_filter.evaluate(self.omega) self.assertTrue(np.allclose(result, dual_result)) def testFilterSigns(self): - ctf_filter = CTFFilter(defocus_u=1.5e4, defocus_v=1.5e4) - signs = np.sign(ctf_filter.evaluate(self.omega)) - sign_filter = ctf_filter.sign + signs = np.sign(self.test_filter.evaluate(self.omega)) + sign_filter = self.test_filter.sign self.assertTrue(np.allclose(sign_filter.evaluate(self.omega), signs)) +class SimTestCaseCTFFilter(SimTestCase): + """ + Covers same tests as SimTestCase, but use CTFFilter in place of ArrayFilter. + """ + + test_filter = CTFFilter() + + DTYPES = [np.float32, np.float64] EPS = [None, 0.01] @@ -419,6 +211,12 @@ def test_ctf_reference(): # Compare with MATLAB. Note DF converted to nm # >> n=5; V=200; DF1=1000; DF2=1500; theta=1.23; Cs=2.0; A=0.1; pxA=4.56; # >> ref_h=cryo_CTF_Relion(n,V,DF1,DF2,theta,Cs,pxA,A) + # + # Note we transpose the reference array. + # Python keeps the filter C order because the images we will convolve with are C order. + # MATLAB is F and F respectively. + # + # The floating point values were truncated to four decimal digits. ref_h = np.array( [ [-0.6152, 0.0299, -0.5638, 0.9327, 0.9736], @@ -427,8 +225,7 @@ def test_ctf_reference(): [0.1733, 0.9383, -0.7543, 0.2598, -0.9865], [0.9736, 0.9327, -0.5638, 0.0299, -0.6152], ] - ) + ).T - # Test we're within 1%. - # There are minor differences in the formulas for wavelength and grids. - np.testing.assert_allclose(h, ref_h, rtol=0.01) + # Test match all significant digits above + np.testing.assert_allclose(h, ref_h, atol=5e-5)