From c4cc8c5a26eef968fa6f2e69e3e8010695cf3609 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 23 Sep 2024 15:48:04 -0400 Subject: [PATCH 1/7] minimal changes to estimate_shifts. --- src/aspire/abinitio/commonline_base.py | 12 +++++++----- src/aspire/utils/coor_trans.py | 2 +- tests/test_orient_sync_voting.py | 16 ++++++++++++---- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 33f2a600e9..0f7d675610 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -274,8 +274,9 @@ def estimate_shifts(self, equations_factor=1, max_memory=4000): show = False if logging.getLogger().isEnabledFor(logging.DEBUG): show = True - # Negative sign comes from using -i conversion of Fourier transformation - est_shifts = sparse.linalg.lsqr(shift_equations, -shift_b, show=show)[0] + + # Estimate shifts. + est_shifts = sparse.linalg.lsqr(shift_equations, shift_b, show=show)[0] est_shifts = est_shifts.reshape((self.n_img, 2)) return est_shifts @@ -320,6 +321,7 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): n_equations = self._estimate_num_shift_equations( n_img, equations_factor, max_memory ) + # Allocate local variables for estimating 2D shifts based on the estimated number # of equations. The shift equations are represented using a sparse matrix, # since each row in the system contains four non-zeros (as it involves @@ -404,13 +406,13 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): # Compute the coefficients of the current equation coefs = np.array( [ - np.sin(shift_alpha), np.cos(shift_alpha), - -np.sin(shift_beta), + np.sin(shift_alpha), -np.cos(shift_beta), + -np.sin(shift_beta), ] ) - shift_eq[idx] = -1 * coefs if is_pf_j_flipped else coefs + shift_eq[idx] = [-1, -1, 0, 0] * coefs if is_pf_j_flipped else coefs # create sparse matrix object only containing non-zero elements shift_equations = sparse.csr_matrix( diff --git a/src/aspire/utils/coor_trans.py b/src/aspire/utils/coor_trans.py index f17c23a0f0..b33098d3b4 100644 --- a/src/aspire/utils/coor_trans.py +++ b/src/aspire/utils/coor_trans.py @@ -326,7 +326,7 @@ def common_line_from_rots(r1, r2, ell): ut = np.dot(r2, r1.T) alpha_ij = np.arctan2(ut[2, 0], -ut[2, 1]) + np.pi - alpha_ji = np.arctan2(ut[0, 2], -ut[1, 2]) + np.pi + alpha_ji = np.arctan2(-ut[0, 2], ut[1, 2]) + np.pi ell_ij = alpha_ij * ell / (2 * np.pi) ell_ji = alpha_ji * ell / (2 * np.pi) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 3e875e9467..49ffa04904 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -106,13 +106,21 @@ def test_estimate_rotations(source_orientation_objs): def test_estimate_shifts(source_orientation_objs): src, orient_est = source_orientation_objs - if src.offsets.all() != 0: - pytest.xfail("Currently failing under non-zero offsets.") + # Assign ground truth rotations. + orient_est.rotations = src.rotations + + # Estimate shifts using ground truth rotations. est_shifts = orient_est.estimate_shifts() - # Assert that estimated shifts are close to src.offsets - assert np.allclose(est_shifts, src.offsets) + # Calculate the mean absolute difference in pixels. + mean_abs_diff = np.mean(abs(src.offsets - est_shifts)) + + # Assert that on average estimated shifts are close to src.offsets + if src.offsets.all() != 0: + np.testing.assert_array_less(mean_abs_diff, 0.35) + else: + np.testing.assert_allclose(mean_abs_diff, 0) def test_estimate_rotations_fuzzy_mask(): From e53a8ec1bbd7a38a06ab333aad69182da4c7c01d Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 24 Sep 2024 10:52:39 -0400 Subject: [PATCH 2/7] clean up indexing --- src/aspire/abinitio/commonline_base.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index 0f7d675610..c0c3718803 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -328,9 +328,9 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): # exactly four unknowns). The variables below are used to construct # this sparse system. The k'th non-zero element of the equations matrix # is stored at index (shift_i(k),shift_j(k)). - shift_i = np.zeros(4 * n_equations, dtype=self.dtype) - shift_j = np.zeros(4 * n_equations, dtype=self.dtype) - shift_eq = np.zeros(4 * n_equations, dtype=self.dtype) + shift_i = np.zeros((n_equations, 4), dtype=self.dtype) + shift_j = np.zeros((n_equations, 4), dtype=self.dtype) + shift_eq = np.zeros((n_equations, 4), dtype=self.dtype) shift_b = np.zeros(n_equations, dtype=self.dtype) # Prepare the shift phases to try and generate filter for common-line detection @@ -390,16 +390,14 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): sidx = sidx1 if c1[sidx1] > c2[sidx2] else sidx2 dx = -max_shift + sidx * shift_step - # Create a shift equation for the image pair [i,j] - idx = np.arange(4 * shift_eq_idx, 4 * shift_eq_idx + 4) # angle of common ray in image i shift_alpha = c_ij * d_theta # Angle of common ray in image j. shift_beta = c_ji * d_theta # Row index to construct the sparse equations - shift_i[idx] = shift_eq_idx + shift_i[shift_eq_idx] = shift_eq_idx # Columns of the shift variables that correspond to the current pair [i, j] - shift_j[idx] = [2 * i, 2 * i + 1, 2 * j, 2 * j + 1] + shift_j[shift_eq_idx] = [2 * i, 2 * i + 1, 2 * j, 2 * j + 1] # Right hand side of the current equation shift_b[shift_eq_idx] = dx @@ -412,11 +410,13 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): -np.sin(shift_beta), ] ) - shift_eq[idx] = [-1, -1, 0, 0] * coefs if is_pf_j_flipped else coefs + shift_eq[shift_eq_idx] = ( + [-1, -1, 0, 0] * coefs if is_pf_j_flipped else coefs + ) # create sparse matrix object only containing non-zero elements shift_equations = sparse.csr_matrix( - (shift_eq, (shift_i, shift_j)), + (shift_eq.flatten(), (shift_i.flatten(), shift_j.flatten())), shape=(n_equations, 2 * n_img), dtype=self.dtype, ) From 70e2f829435193bf69239529553f4c9f034bb6f3 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 24 Sep 2024 11:16:44 -0400 Subject: [PATCH 3/7] use mean 2D distance in test. --- tests/test_orient_sync_voting.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 49ffa04904..576ca1cee1 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -113,14 +113,14 @@ def test_estimate_shifts(source_orientation_objs): # Estimate shifts using ground truth rotations. est_shifts = orient_est.estimate_shifts() - # Calculate the mean absolute difference in pixels. - mean_abs_diff = np.mean(abs(src.offsets - est_shifts)) + # Calculate the mean 2D distance between estimates and ground truth. + mean_dist = np.mean(np.sqrt(np.sum((src.offsets - est_shifts) ** 2, axis=1))) - # Assert that on average estimated shifts are close to src.offsets + # Assert that on average estimated shifts are close (within 0.5 pix) to src.offsets if src.offsets.all() != 0: - np.testing.assert_array_less(mean_abs_diff, 0.35) + np.testing.assert_array_less(mean_dist, 0.5) else: - np.testing.assert_allclose(mean_abs_diff, 0) + np.testing.assert_allclose(mean_dist, 0) def test_estimate_rotations_fuzzy_mask(): From ecdc979cfd063fb121a4782820747f84bff1e683 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 30 Sep 2024 13:06:56 -0400 Subject: [PATCH 4/7] add more estimate shifts tests. --- tests/test_commonline_sync3n.py | 53 ++++++++++++++++++++++++++++++-- tests/test_orient_sync_voting.py | 22 +++++++++++-- 2 files changed, 70 insertions(+), 5 deletions(-) diff --git a/tests/test_commonline_sync3n.py b/tests/test_commonline_sync3n.py index 6640fa871f..59876df7b3 100644 --- a/tests/test_commonline_sync3n.py +++ b/tests/test_commonline_sync3n.py @@ -17,6 +17,7 @@ OFFSETS = [ 0, + pytest.param(None, marks=pytest.mark.expensive), ] DTYPES = [ @@ -53,7 +54,16 @@ def source_orientation_objs(resolution, offsets, dtype): seed=456, ).cache() - orient_est = CLSync3N(src, S_weighting=True, seed=789) + # Search for common lines over less shifts for 0 offsets. + max_shift = 1 / resolution + shift_step = 1 + if src.offsets.all() != 0: + max_shift = 0.20 + shift_step = 0.25 # Reduce shift steps for non-integer offsets of Simulation. + + orient_est = CLSync3N( + src, max_shift=max_shift, shift_step=shift_step, S_weighting=True, seed=789 + ) return src, orient_est @@ -74,11 +84,48 @@ def test_build_clmatrix(source_orientation_objs): # Check that at least 98% of estimates are within 5 degrees. tol = 0.98 if src.offsets.all() != 0: - # Set tolerance to 95% when using nonzero offsets. - tol = 0.95 + # Set tolerance to 75% when using nonzero offsets. + tol = 0.75 assert within_5 / angle_diffs.size > tol +def test_estimate_shifts_with_gt_rots(source_orientation_objs): + src, orient_est = source_orientation_objs + + # Assign ground truth rotations. + orient_est.rotations = src.rotations + + # Estimate shifts using ground truth rotations. + est_shifts = orient_est.estimate_shifts() + + # Calculate the mean 2D distance between estimates and ground truth. + error = src.offsets - est_shifts + mean_dist = np.hypot(error[:, 0], error[:, 1]).mean() + + # Assert that on average estimated shifts are close (within 0.8 pix) to src.offsets + if src.offsets.all() != 0: + np.testing.assert_array_less(mean_dist, 0.8) + else: + np.testing.assert_allclose(mean_dist, 0) + + +def test_estimate_shifts_with_est_rots(source_orientation_objs): + src, orient_est = source_orientation_objs + + # Estimate shifts using estimated rotations. + est_shifts = orient_est.estimate_shifts() + + # Calculate the mean 2D distance between estimates and ground truth. + error = src.offsets - est_shifts + mean_dist = np.hypot(error[:, 0], error[:, 1]).mean() + + # Assert that on average estimated shifts are close (within 0.8 pix) to src.offsets + if src.offsets.all() != 0: + np.testing.assert_array_less(mean_dist, 0.8) + else: + np.testing.assert_allclose(mean_dist, 0) + + def test_estimate_rotations(source_orientation_objs): src, orient_est = source_orientation_objs diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 576ca1cee1..24986ce888 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -104,7 +104,7 @@ def test_estimate_rotations(source_orientation_objs): mean_aligned_angular_distance(orient_est.rotations, src.rotations, degree_tol=1) -def test_estimate_shifts(source_orientation_objs): +def test_estimate_shifts_with_gt_rots(source_orientation_objs): src, orient_est = source_orientation_objs # Assign ground truth rotations. @@ -114,7 +114,25 @@ def test_estimate_shifts(source_orientation_objs): est_shifts = orient_est.estimate_shifts() # Calculate the mean 2D distance between estimates and ground truth. - mean_dist = np.mean(np.sqrt(np.sum((src.offsets - est_shifts) ** 2, axis=1))) + error = src.offsets - est_shifts + mean_dist = np.hypot(error[:, 0], error[:, 1]).mean() + + # Assert that on average estimated shifts are close (within 0.5 pix) to src.offsets + if src.offsets.all() != 0: + np.testing.assert_array_less(mean_dist, 0.5) + else: + np.testing.assert_allclose(mean_dist, 0) + + +def test_estimate_shifts_with_est_rots(source_orientation_objs): + src, orient_est = source_orientation_objs + + # Estimate shifts using estimated rotations. + est_shifts = orient_est.estimate_shifts() + + # Calculate the mean 2D distance between estimates and ground truth. + error = src.offsets - est_shifts + mean_dist = np.hypot(error[:, 0], error[:, 1]).mean() # Assert that on average estimated shifts are close (within 0.5 pix) to src.offsets if src.offsets.all() != 0: From 5de34af41dc5130b41bf9f1e3eaaff83fa3a580a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 30 Sep 2024 13:54:20 -0400 Subject: [PATCH 5/7] Module scope. Estimate rotations once per module. Adjust tolerance. --- tests/test_commonline_sync3n.py | 23 ++++++++++++----------- tests/test_orient_sync_voting.py | 15 +++++++++------ 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/tests/test_commonline_sync3n.py b/tests/test_commonline_sync3n.py index 59876df7b3..600f883c2d 100644 --- a/tests/test_commonline_sync3n.py +++ b/tests/test_commonline_sync3n.py @@ -1,3 +1,4 @@ +import copy import os import numpy as np @@ -61,9 +62,10 @@ def source_orientation_objs(resolution, offsets, dtype): max_shift = 0.20 shift_step = 0.25 # Reduce shift steps for non-integer offsets of Simulation. - orient_est = CLSync3N( - src, max_shift=max_shift, shift_step=shift_step, S_weighting=True, seed=789 - ) + orient_est = CLSync3N(src, max_shift=max_shift, shift_step=shift_step, seed=789) + + # Estimate rotations once for all tests. + orient_est.estimate_rotations() return src, orient_est @@ -71,9 +73,6 @@ def source_orientation_objs(resolution, offsets, dtype): def test_build_clmatrix(source_orientation_objs): src, orient_est = source_orientation_objs - # Build clmatrix estimate. - orient_est.build_clmatrix() - gt_clmatrix = rots_to_clmatrix(src.rotations, orient_est.n_theta) angle_diffs = abs(orient_est.clmatrix - gt_clmatrix) * 360 / orient_est.n_theta @@ -93,6 +92,8 @@ def test_estimate_shifts_with_gt_rots(source_orientation_objs): src, orient_est = source_orientation_objs # Assign ground truth rotations. + # Deep copy to prevent altering for other tests. + orient_est = copy.deepcopy(orient_est) orient_est.rotations = src.rotations # Estimate shifts using ground truth rotations. @@ -111,7 +112,6 @@ def test_estimate_shifts_with_gt_rots(source_orientation_objs): def test_estimate_shifts_with_est_rots(source_orientation_objs): src, orient_est = source_orientation_objs - # Estimate shifts using estimated rotations. est_shifts = orient_est.estimate_shifts() @@ -129,9 +129,10 @@ def test_estimate_shifts_with_est_rots(source_orientation_objs): def test_estimate_rotations(source_orientation_objs): src, orient_est = source_orientation_objs - orient_est.estimate_rotations() - # Register estimates to ground truth rotations and compute the # mean angular distance between them (in degrees). - # Assert that mean angular distance is less than 1 degree. - mean_aligned_angular_distance(orient_est.rotations, src.rotations, degree_tol=1) + # Assert that mean angular distance is less than 1 degree (4 with offsets). + tol = 1 + if src.offsets.all() != 0: + tol = 4 + mean_aligned_angular_distance(orient_est.rotations, src.rotations, degree_tol=tol) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 24986ce888..85478816cb 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -1,3 +1,4 @@ +import copy import os import os.path import tempfile @@ -32,22 +33,22 @@ ] -@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}") +@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}", scope="module") def resolution(request): return request.param -@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}") +@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}", scope="module") def offsets(request): return request.param -@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") def dtype(request): return request.param -@pytest.fixture +@pytest.fixture(scope="module") def source_orientation_objs(resolution, offsets, dtype): src = Simulation( n=50, @@ -68,6 +69,9 @@ def source_orientation_objs(resolution, offsets, dtype): src, max_shift=max_shift, shift_step=shift_step, mask=False ) + # Estimate rotations once for all tests. + orient_est.estimate_rotations() + return src, orient_est @@ -96,8 +100,6 @@ def test_build_clmatrix(source_orientation_objs): def test_estimate_rotations(source_orientation_objs): src, orient_est = source_orientation_objs - orient_est.estimate_rotations() - # Register estimates to ground truth rotations and compute the # mean angular distance between them (in degrees). # Assert that mean angular distance is less than 1 degree. @@ -108,6 +110,7 @@ def test_estimate_shifts_with_gt_rots(source_orientation_objs): src, orient_est = source_orientation_objs # Assign ground truth rotations. + # Deep copy to prevent altering for other tests. orient_est.rotations = src.rotations # Estimate shifts using ground truth rotations. From 6548c0723741d81398cc29592527c8f0aa99d206 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 30 Sep 2024 14:01:05 -0400 Subject: [PATCH 6/7] One more deepcopy. --- tests/test_orient_sync_voting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 85478816cb..4bea8df6c2 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -111,6 +111,7 @@ def test_estimate_shifts_with_gt_rots(source_orientation_objs): # Assign ground truth rotations. # Deep copy to prevent altering for other tests. + orient_est = copy.deepcopy(orient_est) orient_est.rotations = src.rotations # Estimate shifts using ground truth rotations. From b80cbc2c9a448c63d8fece9aec92fdc891f09a22 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 30 Sep 2024 15:56:30 -0400 Subject: [PATCH 7/7] test_coef dtype patch for osx. --- tests/test_coef.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_coef.py b/tests/test_coef.py index ab546d9bac..3ace0ddec5 100644 --- a/tests/test_coef.py +++ b/tests/test_coef.py @@ -48,7 +48,7 @@ def dtype(request): return request.param -@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +@pytest.fixture(params=DTYPES, ids=lambda x: f"basis_dtype={x}", scope="module") def basis_dtype(request): """ Dtypes for basis @@ -416,11 +416,16 @@ def test_shifts(coef_fixture, basis, rots): shifts = np.column_stack((rots, rots[::-1])) # Compare + min_dtype = ( + np.float32 + if (basis.dtype == np.float32 or coef_fixture.dtype == np.float32) + else np.float64 + ) np.testing.assert_allclose( coef_fixture.shift(shifts), basis.shift(coef_fixture, shifts), rtol=1e-05, - atol=utest_tolerance(basis.dtype), + atol=utest_tolerance(min_dtype), )