diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 33f2a600e9..c0c3718803 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -274,8 +274,9 @@ def estimate_shifts(self, equations_factor=1, max_memory=4000): show = False if logging.getLogger().isEnabledFor(logging.DEBUG): show = True - # Negative sign comes from using -i conversion of Fourier transformation - est_shifts = sparse.linalg.lsqr(shift_equations, -shift_b, show=show)[0] + + # Estimate shifts. + est_shifts = sparse.linalg.lsqr(shift_equations, shift_b, show=show)[0] est_shifts = est_shifts.reshape((self.n_img, 2)) return est_shifts @@ -320,15 +321,16 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): n_equations = self._estimate_num_shift_equations( n_img, equations_factor, max_memory ) + # Allocate local variables for estimating 2D shifts based on the estimated number # of equations. The shift equations are represented using a sparse matrix, # since each row in the system contains four non-zeros (as it involves # exactly four unknowns). The variables below are used to construct # this sparse system. The k'th non-zero element of the equations matrix # is stored at index (shift_i(k),shift_j(k)). - shift_i = np.zeros(4 * n_equations, dtype=self.dtype) - shift_j = np.zeros(4 * n_equations, dtype=self.dtype) - shift_eq = np.zeros(4 * n_equations, dtype=self.dtype) + shift_i = np.zeros((n_equations, 4), dtype=self.dtype) + shift_j = np.zeros((n_equations, 4), dtype=self.dtype) + shift_eq = np.zeros((n_equations, 4), dtype=self.dtype) shift_b = np.zeros(n_equations, dtype=self.dtype) # Prepare the shift phases to try and generate filter for common-line detection @@ -388,33 +390,33 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): sidx = sidx1 if c1[sidx1] > c2[sidx2] else sidx2 dx = -max_shift + sidx * shift_step - # Create a shift equation for the image pair [i,j] - idx = np.arange(4 * shift_eq_idx, 4 * shift_eq_idx + 4) # angle of common ray in image i shift_alpha = c_ij * d_theta # Angle of common ray in image j. shift_beta = c_ji * d_theta # Row index to construct the sparse equations - shift_i[idx] = shift_eq_idx + shift_i[shift_eq_idx] = shift_eq_idx # Columns of the shift variables that correspond to the current pair [i, j] - shift_j[idx] = [2 * i, 2 * i + 1, 2 * j, 2 * j + 1] + shift_j[shift_eq_idx] = [2 * i, 2 * i + 1, 2 * j, 2 * j + 1] # Right hand side of the current equation shift_b[shift_eq_idx] = dx # Compute the coefficients of the current equation coefs = np.array( [ - np.sin(shift_alpha), np.cos(shift_alpha), - -np.sin(shift_beta), + np.sin(shift_alpha), -np.cos(shift_beta), + -np.sin(shift_beta), ] ) - shift_eq[idx] = -1 * coefs if is_pf_j_flipped else coefs + shift_eq[shift_eq_idx] = ( + [-1, -1, 0, 0] * coefs if is_pf_j_flipped else coefs + ) # create sparse matrix object only containing non-zero elements shift_equations = sparse.csr_matrix( - (shift_eq, (shift_i, shift_j)), + (shift_eq.flatten(), (shift_i.flatten(), shift_j.flatten())), shape=(n_equations, 2 * n_img), dtype=self.dtype, ) diff --git a/src/aspire/utils/coor_trans.py b/src/aspire/utils/coor_trans.py index f17c23a0f0..b33098d3b4 100644 --- a/src/aspire/utils/coor_trans.py +++ b/src/aspire/utils/coor_trans.py @@ -326,7 +326,7 @@ def common_line_from_rots(r1, r2, ell): ut = np.dot(r2, r1.T) alpha_ij = np.arctan2(ut[2, 0], -ut[2, 1]) + np.pi - alpha_ji = np.arctan2(ut[0, 2], -ut[1, 2]) + np.pi + alpha_ji = np.arctan2(-ut[0, 2], ut[1, 2]) + np.pi ell_ij = alpha_ij * ell / (2 * np.pi) ell_ji = alpha_ji * ell / (2 * np.pi) diff --git a/tests/test_coef.py b/tests/test_coef.py index ab546d9bac..3ace0ddec5 100644 --- a/tests/test_coef.py +++ b/tests/test_coef.py @@ -48,7 +48,7 @@ def dtype(request): return request.param -@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +@pytest.fixture(params=DTYPES, ids=lambda x: f"basis_dtype={x}", scope="module") def basis_dtype(request): """ Dtypes for basis @@ -416,11 +416,16 @@ def test_shifts(coef_fixture, basis, rots): shifts = np.column_stack((rots, rots[::-1])) # Compare + min_dtype = ( + np.float32 + if (basis.dtype == np.float32 or coef_fixture.dtype == np.float32) + else np.float64 + ) np.testing.assert_allclose( coef_fixture.shift(shifts), basis.shift(coef_fixture, shifts), rtol=1e-05, - atol=utest_tolerance(basis.dtype), + atol=utest_tolerance(min_dtype), ) diff --git a/tests/test_commonline_sync3n.py b/tests/test_commonline_sync3n.py index 6640fa871f..600f883c2d 100644 --- a/tests/test_commonline_sync3n.py +++ b/tests/test_commonline_sync3n.py @@ -1,3 +1,4 @@ +import copy import os import numpy as np @@ -17,6 +18,7 @@ OFFSETS = [ 0, + pytest.param(None, marks=pytest.mark.expensive), ] DTYPES = [ @@ -53,7 +55,17 @@ def source_orientation_objs(resolution, offsets, dtype): seed=456, ).cache() - orient_est = CLSync3N(src, S_weighting=True, seed=789) + # Search for common lines over less shifts for 0 offsets. + max_shift = 1 / resolution + shift_step = 1 + if src.offsets.all() != 0: + max_shift = 0.20 + shift_step = 0.25 # Reduce shift steps for non-integer offsets of Simulation. + + orient_est = CLSync3N(src, max_shift=max_shift, shift_step=shift_step, seed=789) + + # Estimate rotations once for all tests. + orient_est.estimate_rotations() return src, orient_est @@ -61,9 +73,6 @@ def source_orientation_objs(resolution, offsets, dtype): def test_build_clmatrix(source_orientation_objs): src, orient_est = source_orientation_objs - # Build clmatrix estimate. - orient_est.build_clmatrix() - gt_clmatrix = rots_to_clmatrix(src.rotations, orient_est.n_theta) angle_diffs = abs(orient_est.clmatrix - gt_clmatrix) * 360 / orient_est.n_theta @@ -74,17 +83,56 @@ def test_build_clmatrix(source_orientation_objs): # Check that at least 98% of estimates are within 5 degrees. tol = 0.98 if src.offsets.all() != 0: - # Set tolerance to 95% when using nonzero offsets. - tol = 0.95 + # Set tolerance to 75% when using nonzero offsets. + tol = 0.75 assert within_5 / angle_diffs.size > tol -def test_estimate_rotations(source_orientation_objs): +def test_estimate_shifts_with_gt_rots(source_orientation_objs): src, orient_est = source_orientation_objs - orient_est.estimate_rotations() + # Assign ground truth rotations. + # Deep copy to prevent altering for other tests. + orient_est = copy.deepcopy(orient_est) + orient_est.rotations = src.rotations + + # Estimate shifts using ground truth rotations. + est_shifts = orient_est.estimate_shifts() + + # Calculate the mean 2D distance between estimates and ground truth. + error = src.offsets - est_shifts + mean_dist = np.hypot(error[:, 0], error[:, 1]).mean() + + # Assert that on average estimated shifts are close (within 0.8 pix) to src.offsets + if src.offsets.all() != 0: + np.testing.assert_array_less(mean_dist, 0.8) + else: + np.testing.assert_allclose(mean_dist, 0) + + +def test_estimate_shifts_with_est_rots(source_orientation_objs): + src, orient_est = source_orientation_objs + # Estimate shifts using estimated rotations. + est_shifts = orient_est.estimate_shifts() + + # Calculate the mean 2D distance between estimates and ground truth. + error = src.offsets - est_shifts + mean_dist = np.hypot(error[:, 0], error[:, 1]).mean() + + # Assert that on average estimated shifts are close (within 0.8 pix) to src.offsets + if src.offsets.all() != 0: + np.testing.assert_array_less(mean_dist, 0.8) + else: + np.testing.assert_allclose(mean_dist, 0) + + +def test_estimate_rotations(source_orientation_objs): + src, orient_est = source_orientation_objs # Register estimates to ground truth rotations and compute the # mean angular distance between them (in degrees). - # Assert that mean angular distance is less than 1 degree. - mean_aligned_angular_distance(orient_est.rotations, src.rotations, degree_tol=1) + # Assert that mean angular distance is less than 1 degree (4 with offsets). + tol = 1 + if src.offsets.all() != 0: + tol = 4 + mean_aligned_angular_distance(orient_est.rotations, src.rotations, degree_tol=tol) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 3e875e9467..4bea8df6c2 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -1,3 +1,4 @@ +import copy import os import os.path import tempfile @@ -32,22 +33,22 @@ ] -@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}") +@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}", scope="module") def resolution(request): return request.param -@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}") +@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}", scope="module") def offsets(request): return request.param -@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") def dtype(request): return request.param -@pytest.fixture +@pytest.fixture(scope="module") def source_orientation_objs(resolution, offsets, dtype): src = Simulation( n=50, @@ -68,6 +69,9 @@ def source_orientation_objs(resolution, offsets, dtype): src, max_shift=max_shift, shift_step=shift_step, mask=False ) + # Estimate rotations once for all tests. + orient_est.estimate_rotations() + return src, orient_est @@ -96,23 +100,49 @@ def test_build_clmatrix(source_orientation_objs): def test_estimate_rotations(source_orientation_objs): src, orient_est = source_orientation_objs - orient_est.estimate_rotations() - # Register estimates to ground truth rotations and compute the # mean angular distance between them (in degrees). # Assert that mean angular distance is less than 1 degree. mean_aligned_angular_distance(orient_est.rotations, src.rotations, degree_tol=1) -def test_estimate_shifts(source_orientation_objs): +def test_estimate_shifts_with_gt_rots(source_orientation_objs): src, orient_est = source_orientation_objs + + # Assign ground truth rotations. + # Deep copy to prevent altering for other tests. + orient_est = copy.deepcopy(orient_est) + orient_est.rotations = src.rotations + + # Estimate shifts using ground truth rotations. + est_shifts = orient_est.estimate_shifts() + + # Calculate the mean 2D distance between estimates and ground truth. + error = src.offsets - est_shifts + mean_dist = np.hypot(error[:, 0], error[:, 1]).mean() + + # Assert that on average estimated shifts are close (within 0.5 pix) to src.offsets if src.offsets.all() != 0: - pytest.xfail("Currently failing under non-zero offsets.") + np.testing.assert_array_less(mean_dist, 0.5) + else: + np.testing.assert_allclose(mean_dist, 0) + +def test_estimate_shifts_with_est_rots(source_orientation_objs): + src, orient_est = source_orientation_objs + + # Estimate shifts using estimated rotations. est_shifts = orient_est.estimate_shifts() - # Assert that estimated shifts are close to src.offsets - assert np.allclose(est_shifts, src.offsets) + # Calculate the mean 2D distance between estimates and ground truth. + error = src.offsets - est_shifts + mean_dist = np.hypot(error[:, 0], error[:, 1]).mean() + + # Assert that on average estimated shifts are close (within 0.5 pix) to src.offsets + if src.offsets.all() != 0: + np.testing.assert_array_less(mean_dist, 0.5) + else: + np.testing.assert_allclose(mean_dist, 0) def test_estimate_rotations_fuzzy_mask():