Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions src/aspire/abinitio/commonline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,9 @@ def estimate_shifts(self, equations_factor=1, max_memory=4000):
show = False
if logging.getLogger().isEnabledFor(logging.DEBUG):
show = True
# Negative sign comes from using -i conversion of Fourier transformation
est_shifts = sparse.linalg.lsqr(shift_equations, -shift_b, show=show)[0]

# Estimate shifts.
est_shifts = sparse.linalg.lsqr(shift_equations, shift_b, show=show)[0]
est_shifts = est_shifts.reshape((self.n_img, 2))

return est_shifts
Expand Down Expand Up @@ -320,15 +321,16 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000):
n_equations = self._estimate_num_shift_equations(
n_img, equations_factor, max_memory
)

# Allocate local variables for estimating 2D shifts based on the estimated number
# of equations. The shift equations are represented using a sparse matrix,
# since each row in the system contains four non-zeros (as it involves
# exactly four unknowns). The variables below are used to construct
# this sparse system. The k'th non-zero element of the equations matrix
# is stored at index (shift_i(k),shift_j(k)).
shift_i = np.zeros(4 * n_equations, dtype=self.dtype)
shift_j = np.zeros(4 * n_equations, dtype=self.dtype)
shift_eq = np.zeros(4 * n_equations, dtype=self.dtype)
shift_i = np.zeros((n_equations, 4), dtype=self.dtype)
shift_j = np.zeros((n_equations, 4), dtype=self.dtype)
shift_eq = np.zeros((n_equations, 4), dtype=self.dtype)
shift_b = np.zeros(n_equations, dtype=self.dtype)

# Prepare the shift phases to try and generate filter for common-line detection
Expand Down Expand Up @@ -388,33 +390,33 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000):
sidx = sidx1 if c1[sidx1] > c2[sidx2] else sidx2
dx = -max_shift + sidx * shift_step

# Create a shift equation for the image pair [i,j]
idx = np.arange(4 * shift_eq_idx, 4 * shift_eq_idx + 4)
# angle of common ray in image i
shift_alpha = c_ij * d_theta
# Angle of common ray in image j.
shift_beta = c_ji * d_theta
# Row index to construct the sparse equations
shift_i[idx] = shift_eq_idx
shift_i[shift_eq_idx] = shift_eq_idx
# Columns of the shift variables that correspond to the current pair [i, j]
shift_j[idx] = [2 * i, 2 * i + 1, 2 * j, 2 * j + 1]
shift_j[shift_eq_idx] = [2 * i, 2 * i + 1, 2 * j, 2 * j + 1]
# Right hand side of the current equation
shift_b[shift_eq_idx] = dx

# Compute the coefficients of the current equation
coefs = np.array(
[
np.sin(shift_alpha),
np.cos(shift_alpha),
-np.sin(shift_beta),
np.sin(shift_alpha),
-np.cos(shift_beta),
-np.sin(shift_beta),
]
)
shift_eq[idx] = -1 * coefs if is_pf_j_flipped else coefs
shift_eq[shift_eq_idx] = (
[-1, -1, 0, 0] * coefs if is_pf_j_flipped else coefs
)

# create sparse matrix object only containing non-zero elements
shift_equations = sparse.csr_matrix(
(shift_eq, (shift_i, shift_j)),
(shift_eq.flatten(), (shift_i.flatten(), shift_j.flatten())),
shape=(n_equations, 2 * n_img),
dtype=self.dtype,
)
Expand Down
2 changes: 1 addition & 1 deletion src/aspire/utils/coor_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def common_line_from_rots(r1, r2, ell):

ut = np.dot(r2, r1.T)
alpha_ij = np.arctan2(ut[2, 0], -ut[2, 1]) + np.pi
alpha_ji = np.arctan2(ut[0, 2], -ut[1, 2]) + np.pi
alpha_ji = np.arctan2(-ut[0, 2], ut[1, 2]) + np.pi

ell_ij = alpha_ij * ell / (2 * np.pi)
ell_ji = alpha_ji * ell / (2 * np.pi)
Expand Down
9 changes: 7 additions & 2 deletions tests/test_coef.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def dtype(request):
return request.param


@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module")
@pytest.fixture(params=DTYPES, ids=lambda x: f"basis_dtype={x}", scope="module")
def basis_dtype(request):
"""
Dtypes for basis
Expand Down Expand Up @@ -416,11 +416,16 @@ def test_shifts(coef_fixture, basis, rots):
shifts = np.column_stack((rots, rots[::-1]))

# Compare
min_dtype = (
np.float32
if (basis.dtype == np.float32 or coef_fixture.dtype == np.float32)
else np.float64
)
np.testing.assert_allclose(
coef_fixture.shift(shifts),
basis.shift(coef_fixture, shifts),
rtol=1e-05,
atol=utest_tolerance(basis.dtype),
atol=utest_tolerance(min_dtype),
)


Expand Down
68 changes: 58 additions & 10 deletions tests/test_commonline_sync3n.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os

import numpy as np
Expand All @@ -17,6 +18,7 @@

OFFSETS = [
0,
pytest.param(None, marks=pytest.mark.expensive),
]

DTYPES = [
Expand Down Expand Up @@ -53,17 +55,24 @@ def source_orientation_objs(resolution, offsets, dtype):
seed=456,
).cache()

orient_est = CLSync3N(src, S_weighting=True, seed=789)
# Search for common lines over less shifts for 0 offsets.
max_shift = 1 / resolution
shift_step = 1
if src.offsets.all() != 0:
max_shift = 0.20
shift_step = 0.25 # Reduce shift steps for non-integer offsets of Simulation.

orient_est = CLSync3N(src, max_shift=max_shift, shift_step=shift_step, seed=789)

# Estimate rotations once for all tests.
orient_est.estimate_rotations()

return src, orient_est


def test_build_clmatrix(source_orientation_objs):
src, orient_est = source_orientation_objs

# Build clmatrix estimate.
orient_est.build_clmatrix()

gt_clmatrix = rots_to_clmatrix(src.rotations, orient_est.n_theta)

angle_diffs = abs(orient_est.clmatrix - gt_clmatrix) * 360 / orient_est.n_theta
Expand All @@ -74,17 +83,56 @@ def test_build_clmatrix(source_orientation_objs):
# Check that at least 98% of estimates are within 5 degrees.
tol = 0.98
if src.offsets.all() != 0:
# Set tolerance to 95% when using nonzero offsets.
tol = 0.95
# Set tolerance to 75% when using nonzero offsets.
tol = 0.75
assert within_5 / angle_diffs.size > tol


def test_estimate_rotations(source_orientation_objs):
def test_estimate_shifts_with_gt_rots(source_orientation_objs):
src, orient_est = source_orientation_objs

orient_est.estimate_rotations()
# Assign ground truth rotations.
# Deep copy to prevent altering for other tests.
orient_est = copy.deepcopy(orient_est)
orient_est.rotations = src.rotations

# Estimate shifts using ground truth rotations.
est_shifts = orient_est.estimate_shifts()

# Calculate the mean 2D distance between estimates and ground truth.
error = src.offsets - est_shifts
mean_dist = np.hypot(error[:, 0], error[:, 1]).mean()

# Assert that on average estimated shifts are close (within 0.8 pix) to src.offsets
if src.offsets.all() != 0:
np.testing.assert_array_less(mean_dist, 0.8)
else:
np.testing.assert_allclose(mean_dist, 0)


def test_estimate_shifts_with_est_rots(source_orientation_objs):
src, orient_est = source_orientation_objs
# Estimate shifts using estimated rotations.
est_shifts = orient_est.estimate_shifts()

# Calculate the mean 2D distance between estimates and ground truth.
error = src.offsets - est_shifts
mean_dist = np.hypot(error[:, 0], error[:, 1]).mean()

# Assert that on average estimated shifts are close (within 0.8 pix) to src.offsets
if src.offsets.all() != 0:
np.testing.assert_array_less(mean_dist, 0.8)
else:
np.testing.assert_allclose(mean_dist, 0)


def test_estimate_rotations(source_orientation_objs):
src, orient_est = source_orientation_objs

# Register estimates to ground truth rotations and compute the
# mean angular distance between them (in degrees).
# Assert that mean angular distance is less than 1 degree.
mean_aligned_angular_distance(orient_est.rotations, src.rotations, degree_tol=1)
# Assert that mean angular distance is less than 1 degree (4 with offsets).
tol = 1
if src.offsets.all() != 0:
tol = 4
mean_aligned_angular_distance(orient_est.rotations, src.rotations, degree_tol=tol)
50 changes: 40 additions & 10 deletions tests/test_orient_sync_voting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
import os.path
import tempfile
Expand Down Expand Up @@ -32,22 +33,22 @@
]


@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}")
@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}", scope="module")
def resolution(request):
return request.param


@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}")
@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}", scope="module")
def offsets(request):
return request.param


@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}")
@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module")
def dtype(request):
return request.param


@pytest.fixture
@pytest.fixture(scope="module")
def source_orientation_objs(resolution, offsets, dtype):
src = Simulation(
n=50,
Expand All @@ -68,6 +69,9 @@ def source_orientation_objs(resolution, offsets, dtype):
src, max_shift=max_shift, shift_step=shift_step, mask=False
)

# Estimate rotations once for all tests.
orient_est.estimate_rotations()

return src, orient_est


Expand Down Expand Up @@ -96,23 +100,49 @@ def test_build_clmatrix(source_orientation_objs):
def test_estimate_rotations(source_orientation_objs):
src, orient_est = source_orientation_objs

orient_est.estimate_rotations()

# Register estimates to ground truth rotations and compute the
# mean angular distance between them (in degrees).
# Assert that mean angular distance is less than 1 degree.
mean_aligned_angular_distance(orient_est.rotations, src.rotations, degree_tol=1)


def test_estimate_shifts(source_orientation_objs):
def test_estimate_shifts_with_gt_rots(source_orientation_objs):
src, orient_est = source_orientation_objs

# Assign ground truth rotations.
# Deep copy to prevent altering for other tests.
orient_est = copy.deepcopy(orient_est)
orient_est.rotations = src.rotations

# Estimate shifts using ground truth rotations.
est_shifts = orient_est.estimate_shifts()

# Calculate the mean 2D distance between estimates and ground truth.
error = src.offsets - est_shifts
mean_dist = np.hypot(error[:, 0], error[:, 1]).mean()

# Assert that on average estimated shifts are close (within 0.5 pix) to src.offsets
if src.offsets.all() != 0:
pytest.xfail("Currently failing under non-zero offsets.")
np.testing.assert_array_less(mean_dist, 0.5)
else:
np.testing.assert_allclose(mean_dist, 0)


def test_estimate_shifts_with_est_rots(source_orientation_objs):
src, orient_est = source_orientation_objs

# Estimate shifts using estimated rotations.
est_shifts = orient_est.estimate_shifts()

# Assert that estimated shifts are close to src.offsets
assert np.allclose(est_shifts, src.offsets)
# Calculate the mean 2D distance between estimates and ground truth.
error = src.offsets - est_shifts
mean_dist = np.hypot(error[:, 0], error[:, 1]).mean()

# Assert that on average estimated shifts are close (within 0.5 pix) to src.offsets
if src.offsets.all() != 0:
np.testing.assert_array_less(mean_dist, 0.5)
else:
np.testing.assert_allclose(mean_dist, 0)


def test_estimate_rotations_fuzzy_mask():
Expand Down
Loading