diff --git a/src/aspire/abinitio/__init__.py b/src/aspire/abinitio/__init__.py index 9d4b0f483c..e8115ea185 100644 --- a/src/aspire/abinitio/__init__.py +++ b/src/aspire/abinitio/__init__.py @@ -8,5 +8,6 @@ from .commonline_c3_c4 import CLSymmetryC3C4 from .commonline_cn import CLSymmetryCn from .commonline_c2 import CLSymmetryC2 +from .commonline_d2 import CLSymmetryD2 # isort: on diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py new file mode 100644 index 0000000000..a8e951c642 --- /dev/null +++ b/src/aspire/abinitio/commonline_d2.py @@ -0,0 +1,1975 @@ +import logging + +import numpy as np +import scipy.sparse.linalg as la +from numpy.linalg import norm + +from aspire.abinitio import CLOrient3D +from aspire.operators import PolarFT +from aspire.utils import J_conjugate, Rotation, all_pairs, all_triplets, tqdm, trange +from aspire.utils.random import randn +from aspire.volume import DnSymmetryGroup + +logger = logging.getLogger(__name__) + + +class CLSymmetryD2(CLOrient3D): + """ + Define a class to estimate 3D orientations using common lines methods for + molecules with D2 (dihedral) symmetry. + + Corresponding publication: + E. Rosen and Y. Shkolnisky, + Common lines ab-initio reconstruction of D2-symmetric molecules, + SIAM Journal on Imaging Sciences, volume 13-4, p. 1898-1994, 2020 + """ + + def __init__( + self, + src, + n_rad=None, + n_theta=None, + max_shift=0.15, + shift_step=1, + grid_res=1200, + inplane_res=5, + eq_min_dist=7, + epsilon=0.01, + seed=None, + mask=True, + ): + """ + Initialize object for estimating 3D orientations for molecules with D2 symmetry. + + :param src: The source object of 2D denoised or class-averaged images with metadata + :param n_rad: The number of points in the radial direction of Fourier image. + :param n_theta: The number of points in the theta direction of Fourier image. + :param max_shift: Maximum range for shifts as a proportion of resolution. Default = 0.15. + :param shift_step: Resolution of shift estimation in pixels. Default = 1 pixel. + :param grid_res: Number of sampling points on sphere for projetion directions. + These are generated using the Saaf-Kuijlaars algorithm. Default value is 1200. + :param inplane_res: The sampling resolution of in-plane rotations for each + projection direction. Default value is 5 degrees. + :param eq_min_dist: Width of strip around equator projection directions from + which we do not sample directions. Default value is 7 degrees. + :param epsilon: Tolerance for J-synchronization power method. + :param seed: Optional seed for RNG. + :param mask: Option to mask `src.images` with a fuzzy mask (boolean). + Default, `True`, applies a mask. + """ + + super().__init__( + src, + n_rad=n_rad, + n_theta=n_theta, + max_shift=max_shift, + shift_step=shift_step, + mask=mask, + ) + + self.grid_res = grid_res + self.inplane_res = inplane_res + self.n_inplane_rots = int(360 / self.inplane_res) + self.eq_min_dist = eq_min_dist + self.seed = seed + self.epsilon = epsilon + + self.triplets = all_triplets(self.n_img) + self.pairs, self.pairs_to_linear = all_pairs(self.n_img, return_map=True) + self.n_pairs = len(self.pairs) + + # D2 symmetry group. + # Rearrange in order Identity, about_x, about_y, about_z. + # This ordering is necessary for reproducing MATLAB code results. + self.gs = DnSymmetryGroup(order=2, dtype=self.dtype).matrices[[0, 3, 2, 1]] + + def estimate_rotations(self): + """ + Estimate rotation matrices for molecules with D2 symmetry. Sets the attribute + self.rotations with an array of estimated rotation matrices, size src.nx3x3. + """ + # Pre-compute phase-shifted polar Fourier. + self._compute_shifted_pf() + + # Generate lookup data + self._generate_lookup_data() + self._generate_scl_lookup_data() + + # Compute self common-line scores. + self._compute_scl_scores() + + # Compute common-lines and estimate relative rotations Rijs. + self._compute_cl_scores() + + # Perform handedness synchronization. + self.Rijs_sync = self._global_J_sync(self.Rijs_est) + + # Synchronize colors. + self.colors, self.Rijs_rows = self._sync_colors(self.Rijs_sync) + + # Synchronize signs. + Ris = self._sync_signs(self.Rijs_rows, self.colors) + + # Assign rotations. + self.rotations = Ris + + ######################### + # Prepare Polar Fourier # + ######################### + + def _compute_shifted_pf(self): + """ + Pre-compute shifted and full polar Fourier transforms. + """ + logger.info("Preparing polar Fourier transform.") + pf = self.pf + + # Generate shift phases. + r_max = pf.shape[-1] + max_shift_1d = np.ceil(2 * np.sqrt(2) * self.max_shift) + shifts, shift_phases, _ = self._generate_shift_phase_and_filter( + r_max, max_shift_1d, self.shift_step + ) + self.n_shifts = len(shifts) + + # Reconstruct full polar Fourier for use in correlation. + pf[:, :, 0] = 0 # Matching matlab convention to zero out the lowest frequency. + pf /= norm(pf, axis=2)[..., np.newaxis] # Normalize each ray. + self.pf_full = PolarFT.half_to_full(pf) + + # Pre-compute shifted pf's. + pf_shifted = pf[:, None] * shift_phases[None, :, None] + self.pf_shifted = pf_shifted.reshape( + (self.n_img, self.n_shifts * (self.n_theta // 2), r_max) + ) + + ################################### + # Generate Commonline Lookup Data # + ################################### + + def _generate_lookup_data(self): + """ + Generate candidate relative rotations and corresponding common line indices. + """ + logger.info("Generating commonline lookup data.") + # Generate uniform grid on sphere with Saff-Kuijlaars and take one quarter + # of sphere because of D2 symmetry redundancy. + sphere_grid = self._saff_kuijlaars(self.grid_res) + octant1_mask = np.all(sphere_grid > 0, axis=1) + octant2_mask = ( + (sphere_grid[:, 0] > 0) & (sphere_grid[:, 1] > 0) & (sphere_grid[:, 2] < 0) + ) + sphere_grid1 = sphere_grid[octant1_mask] + sphere_grid2 = sphere_grid[octant2_mask] + + # Mark Equator Directions. + # Common lines between projection directions which are perpendicular to + # symmetry axes (equator images) have common line degeneracies. Two images + # taken from directions on the same great circle which is perpendicular to + # some symmetry axis only have 2 common lines instead of 4, and must be + # treated separately. + # We detect such directions by taking a strip of radius + # `eq_min_dist` about the 3 great circles perpendicular to the symmetry + # axes of D2 (i.e to X,Y and Z axes). + eq_class1 = self._mark_equators(sphere_grid1, self.eq_min_dist) + eq_class2 = self._mark_equators(sphere_grid2, self.eq_min_dist) + + # Mark Top View Directions. + # A Top view projection image is taken from the direction of one of the + # symmetry axes. Since all symmetry axes of D2 molecules are perpendicular + # this means that such an image is an equator with repect to both symmetry + # axes which are perpendicular to the direction of the symmetry axis from + # which the image was made, e.g. if the image was formed by projecting in + # the direction of the X (symmetry) axis, then it is an equator with + # respect to both Y and Z symmetry axes (it's direction is the + # interesection of 2 great circles perpendicular to Y and Z axes). + # Such images have severe degeneracies. A pair of Top View images (taken + # from different directions or a Top View and equator image only have a + # single common line. A top view and a regular non-equator image only have + # two common lines. + + # Remove top views from sphere grids and update equator indices and classes. + self.sphere_grid1 = sphere_grid1[eq_class1 < 4] + self.sphere_grid2 = sphere_grid2[eq_class2 < 4] + self.eq_class1 = eq_class1[eq_class1 < 4] + self.eq_class2 = eq_class2[eq_class2 < 4] + + # Generate in-plane rotations for each grid point on the sphere. + self.inplane_rotated_grid1 = self._generate_inplane_rots( + self.sphere_grid1, self.inplane_res + ) + self.inplane_rotated_grid2 = self._generate_inplane_rots( + self.sphere_grid2, self.inplane_res + ) + + # Generate commmonline angles induced by all relative rotation candidates. + cl_angles1, self.eq2eq_Rij_table_11 = self._generate_commonline_angles( + self.inplane_rotated_grid1, + self.inplane_rotated_grid1, + self.eq_class1, + self.eq_class1, + ) + cl_angles2, self.eq2eq_Rij_table_12 = self._generate_commonline_angles( + self.inplane_rotated_grid1, + self.inplane_rotated_grid2, + self.eq_class1, + self.eq_class2, + same_octant=False, + ) + + # Generate commonline indices. + self.cl_idx_1 = self._generate_commonline_indices(cl_angles1) + self.cl_idx_2 = self._generate_commonline_indices(cl_angles2) + self.cl_idx = np.hstack((self.cl_idx_1, self.cl_idx_2)) + + def _generate_commonline_angles( + self, + Ris, + Rjs, + Ri_eq_class, + Rj_eq_class, + same_octant=True, + ): + """ + Compute commonline angles induced by the 4 sets of relative rotations + Rij = Ri.T @ g_m @ Rj, m = 0,1,2,3, where g_m is the identity and rotations + about the three axes of symmetry of a D2 symmetric molecule. Note, we only + compute commonline angles between pairs of images which are not equator + images with respect to the same axis of symmetry. To do this we build a + table, `eq2eq_Rij_table`, which is `False` for pairs of images that are + equator images with respect to the same axis of symmetry and `True` otherwise. + + :param Ris: First set of candidate rotations. + :param Rjs: Second set of candidate rotation. + :param Ri_eq_class: Equator classification for Ris. + :param Rj_eq_class: Equator classification for Rjs. + :param same_octant: True if both sets of candidates are in the same octant. + + :return: Commonline angles induced by relative rotation candidates. + """ + n_rots_i = len(Ris) + n_theta = Ris.shape[1] # Same for Rjs, TODO: Don't call this n_theta + + # Generate upper triangular table of indicators of all pairs which are not + # equators with respect to the same symmetry axis (named unique_pairs). + eq_table = np.outer(Ri_eq_class > 0, Rj_eq_class > 0) + in_same_class = (Ri_eq_class[:, None] - Rj_eq_class.T[None]) == 0 + eq2eq_Rij_table = ~(eq_table * in_same_class) + + # For candidates in the same octant only need upper triangle of table. + if same_octant: + eq2eq_Rij_table = np.triu(eq2eq_Rij_table, 1) + + n_pairs = np.count_nonzero(eq2eq_Rij_table) + idx = 0 + cl_angles = np.zeros((2, n_pairs, n_theta, n_theta // 2, 4, 2)) + + for i in range(n_rots_i): + unique_pairs_i = np.nonzero(eq2eq_Rij_table[i])[0] + if len(unique_pairs_i) == 0: + continue + Ri = Ris[i] + for j in unique_pairs_i: + Rj = Rjs[j, : n_theta // 2] + + # Compute relative rotations candidates Rij = Ri.T @ gs @ Rj + Rijs = ( + np.transpose(Ri, axes=(0, 2, 1))[:, None, None] + @ self.gs + @ Rj[:, None] + ) + + # Common line indices induced by Rijs + cl_angles[0, idx, :, :, :, 0] = np.arctan2( + -Rijs[..., 0, 2], Rijs[..., 1, 2] + ) + cl_angles[0, idx, :, :, :, 1] = np.arctan2( + Rijs[..., 2, 0], -Rijs[..., 2, 1] + ) + cl_angles[1, idx, :, :, :, 0] = np.arctan2( + -Rijs[..., 2, 0], Rijs[..., 2, 1] + ) + cl_angles[1, idx, :, :, :, 1] = np.arctan2( + Rijs[..., 0, 2], -Rijs[..., 1, 2] + ) + + idx += 1 + + # Make all angles non-negative and convert to degrees. + cl_angles = (cl_angles + 2 * np.pi) % (2 * np.pi) + cl_angles = cl_angles * 180 / np.pi + + return cl_angles, eq2eq_Rij_table + + ######################################## + # Generate Self-Commonline Lookup Data # + ######################################## + + def _generate_scl_lookup_data(self): + """ + Generate lookup data for self-commonlines. + """ + logger.info("Generating self-commonline lookup data.") + # Get self-commonline angles. + self.scl_angles1 = self._generate_scl_angles( + self.inplane_rotated_grid1, + self.eq_class1, + ) + self.scl_angles2 = self._generate_scl_angles( + self.inplane_rotated_grid2, + self.eq_class2, + ) + + # Get self-commonline indices. + self.scl_idx_1, self.scl_eq_lin_idx_lists_1 = self._generate_scl_indices( + self.scl_angles1, self.eq_class1 + ) + self.scl_idx_2, self.scl_eq_lin_idx_lists_2 = self._generate_scl_indices( + self.scl_angles2, self.eq_class2 + ) + self.scl_idx_lists = np.concatenate( + (self.scl_eq_lin_idx_lists_1, self.scl_eq_lin_idx_lists_2), axis=1 + ) + + # Compute non-equator indices. + # Register non equator indices. Denote by C_ij the j'th in-plane rotation of + # the i'th ML candidate, and arrange all candidates in a list with their in-plane + # rotations in the order: C_11,...,C_1r,...,C_m1,...,C_mr where m is the + # number of candidates and r is the number of in plane rotations. Here we + # create a sub-list of only non equator candidates, i.e., if i_1,...,i_p are + # non equators then we have the sub list is + # C_(i_1)1,...,C(i_1)r,...C_(i_p)1,...,C_(i_p)r. + n_non_eq = np.count_nonzero(self.eq_class1 == 0) + np.count_nonzero( + self.eq_class2 == 0 + ) + non_eq_idx = np.zeros((n_non_eq, self.n_inplane_rots), dtype=int) + non_eq_idx[:, 0] = ( + np.hstack( + ( + np.nonzero(self.eq_class1 == 0)[0], + len(self.eq_class1) + np.nonzero(self.eq_class2 == 0)[0], + ) + ) + * self.n_inplane_rots + ) + non_eq_idx[:, 1:] = non_eq_idx[:, [0]] + np.arange(1, self.n_inplane_rots) + + self.non_eq_idx = non_eq_idx + + # Non-topview equator indices. + self.non_tv_eq_idx = np.concatenate( + ( + np.nonzero(self.eq_class1 > 0)[0], + len(self.eq_class1) + np.nonzero(self.eq_class2 > 0)[0], + ) + ) + + # Generate maps from scl indices to relative rotations. + self._generate_scl_scores_idx_map() + + def _generate_scl_angles(self, Ris, eq_class): + """ + Generate self-commonline angles. For each candidate rotation a pair of self-commonline + angles are generated for each of the 3 self-commonlines induced by D2 symmetry. + + :param Ris: Candidate rotation matrices, (n_sphere_grid, n_inplane_rots, 3, 3). + :param eq_idx: Equator index mask for Ris. + :param eq_class: Equator classification for Ris. + + :return: `scl_angles` of shape (n_sphere_grid, n_inplane_rots, 3, 2). + """ + + # For each candidate rotation Ri we generate the set of 3 self-commonlines. + scl_angles = np.zeros((*Ris.shape[:2], 3, 2), dtype=Ris.dtype) + n_rots = len(Ris) + for i in range(n_rots): + Ri = Ris[i] + for k, g in enumerate(self.gs[1:]): + g_Ri = g @ Ri + Riis = np.transpose(Ri, axes=(0, 2, 1)) @ g_Ri + + scl_angles[i, :, k, 0] = np.arctan2(Riis[:, 2, 0], -Riis[:, 2, 1]) + scl_angles[i, :, k, 1] = np.arctan2(-Riis[:, 0, 2], Riis[:, 1, 2]) + + # Prepare self commonline coordinates. + scl_angles = scl_angles % (2 * np.pi) + + # Deal with non top view equators + # A non-TV equator has only one self common line. However, we clasify an + # equator as an image whose projection direction is at radial distance < + # `eq_min_dist` from the great circle perpendicular to a symmetry axis, + # and not strictly zero distance. Thus in most cases we get 2 common lines + # differing by a small difference in degrees. Actually the calculation above + # gives us two NEARLY antipodal lines, so we first flip one of them by + # adding 180 degrees to it. Then we aggregate all the rays within the range + # between these two resulting lines to compute the score of this self common + # line for this candidate. The scoring part is done in the ML function itself. + # Furthermore, the line perpendicular to the self common line, though not + # really a self common line, has the property that all its values are real + # and both halves of the line (rays differing by pi, emanating from the + # origin) have the same values, and so it 'behaves like' a self common + # line which we also register here and exploit in the ML function. + # We put the 'real' self common line at 2 first coordinates, the + # candidate for perpendicular line is in 3rd coordinate. + + # If this is a self common line with respect to x-equator then the actual self + # common line(s) is given by the self relative rotations given by the y and z + # rotation (by 180 degrees) group members, i.e. Ri^TgyRj and Ri^TgzRj + scl_angles[eq_class == 1] = scl_angles[eq_class == 1][:, :, [1, 2, 0]] + scl_angles[eq_class == 1, :, 0] = scl_angles[eq_class == 1][:, :, 0, [1, 0]] + + # If this is a self common line with respect to y-equator then the actual self + # common line(s) is given by the self relative rotations given by the x and z + # rotation (by 180 degrees) group members, i.e. Ri^TgxRj and Ri^TgzRj + scl_angles[eq_class == 2] = scl_angles[eq_class == 2][:, :, [0, 2, 1]] + scl_angles[eq_class == 2, :, 0] = scl_angles[eq_class == 2][:, :, 0, [1, 0]] + + # If this is a self common line with respect to z-equator then the actual self + # common line(s) is given by the self relative rotations given by the x and y + # rotation (by 180 degrees) group members, i.e. Ri^TgxRj and Ri^TgyRj + # No need to rearrange entries, the "real" common lines are already in + # indices 1 and 2, but flip one common line to antipodal. + scl_angles[eq_class == 3, :, 0] = scl_angles[eq_class == 3][:, :, 0, [1, 0]] + + # Make sure angle range is < 180 degrees. + # p1 marks "equator" self-commonlines where both entries of the first + # scl are greater than both entries of the second scl. + p1 = scl_angles[eq_class > 0, :, 0] > scl_angles[eq_class > 0, :, 1] + p1 = p1[:, :, 0] & p1[:, :, 1] + # p2 marks "equator" self-commonlines where the angle range between the + # first and second sets of self-commonlines is greater than 180. + p2 = scl_angles[eq_class > 0, :, 0] - scl_angles[eq_class > 0, :, 1] < -np.pi + p2 = p2[:, :, 0] | p2[:, :, 1] + p = p1 | p2 + + # Swap entries satisfying either of the above conditions. + scl_angles[eq_class > 0] = ( + scl_angles[eq_class > 0][:, :, [1, 0, 2]] * p[:, :, None, None] + + scl_angles[eq_class > 0] * ~p[:, :, None, None] + ) + + # Convert from radians [0,2*pi) to degrees [0, 360). + return np.round(scl_angles * 180 / np.pi) % 360 + + def _generate_scl_indices(self, scl_angles, eq_class): + """ + Generate self-commonline indices. This includes a set of linear indices for + all candidate rotations as well as lists of self-commonline index ranges for + equator candidates. + + :param scl_angles: Self-commonline angles, shape (n_sphere_grid, n_inplane_rots, 3, 2). + :param eq_class: Equator classification for the sphere_grid points represented + by the first axis of `scl_angles`. + + :returns: + - scl_indices, self-commonline linear indices. + - eq_lin_idx_lists, a list containing a range of self-commonline + indices for each equator candidate. + """ + L = self.n_theta + + # Convert from angles to indices. + scl_indices = self._generate_commonline_indices(scl_angles) + scl_angles = np.mod(np.round(scl_angles / (2 * np.pi) * L), L).astype(int) + + # Create candidate common line linear indices lists for equators. + # As indicated above for equator candidate, for each self common line we + # don't get a single coordinate but a range of them. Here we register a + # list of coordinates for each such self common line candidate. + non_top_view_eq_idx = np.nonzero(eq_class > 0)[0] + n_eq = len(non_top_view_eq_idx) + n_inplane_rots = scl_angles.shape[1] + count_eq = 0 + + # eq_lin_idx_lists[0,i,j] registers a list of linear indices of the j'th + # in-plane rotation of the range for the (only) self common line of the i'th + # candidate. eq_lin_idx_lists[1,i,j] registers the actual (integer) angle + # of the self common line in the 2D Fourier space. Note that we need only + # one number since each self common line has radial coordinates of the form + # (theta, theta+180). + eq_lin_idx_lists = np.empty((2, n_eq, n_inplane_rots), dtype=object) + for i in non_top_view_eq_idx.tolist(): + for j in range(n_inplane_rots): + idx1 = self._circ_seq(scl_angles[i, j, 0, 0], scl_angles[i, j, 1, 0], L) + idx2 = self._circ_seq(scl_angles[i, j, 0, 1], scl_angles[i, j, 1, 1], L) + + # Ensure idx1 and idx2 have same number of elements. + # Might be off by one due to n_theta discretization. + end = np.minimum(len(idx1), len(idx2)) + idx1, idx2 = idx1[:end], idx2[:end] + + # Adjust so idx1 is in [0, 180) range. + is_geq_than_pi = idx1 >= L // 2 + idx1[is_geq_than_pi] = idx1[is_geq_than_pi] - L // 2 + idx2[is_geq_than_pi] = (idx2[is_geq_than_pi] + L // 2) % L + + # register indices in list. + eq_lin_idx_lists[0, count_eq, j] = np.ravel_multi_index( + (idx1, idx2), (L // 2, L) + ) + eq_lin_idx_lists[1, count_eq, j] = idx1 + count_eq += 1 + + return scl_indices, eq_lin_idx_lists + + def _generate_scl_scores_idx_map(self): + """ + Generates lookup tables for maximum likelihood scheme to estimate commonlines + between images. + + This method creates two lookup tables (`oct1_ij_map` and `oct2_ij_map`) + for pairs of candidate rotations (i, j) under the following conditions: + + 1. Both rotations Ri and Rj are in octant 1. + 2. Ri is in octant 1 and Rj is in octant 2. + + For each pair of candidate rotations the tables give a map into the set of + self-commonlines induced by those rotations. This table will be used later + to incorporate a likelihood score for self-commonlines into the likelihood + score for common lines for each pair of images. + """ + # Calculate number of rotations in each octant. + n_rot_1 = len(self.scl_idx_1) // (3 * self.n_inplane_rots) + n_rot_2 = len(self.scl_idx_2) // (3 * self.n_inplane_rots) + + # First the map for i 0] + if len(unique_pairs_i) == 0: + continue + i_idx_plus_offset = i_idx + (i * self.n_inplane_rots) + + for j in unique_pairs_i: + j_idx_plus_offset = j_idx + (j * self.n_inplane_rots) + oct2_ij_map[idx] = np.column_stack( + (i_idx_plus_offset, j_idx_plus_offset) + ) + idx += 1 + + tmp1 = oct1_ij_map[:, :, 0].flatten() + tmp2 = oct1_ij_map[:, :, 1].flatten() + self.oct1_ij_map = np.column_stack((tmp1, tmp2)) + + tmp1 = oct2_ij_map[:, :, 0].flatten() + tmp2 = oct2_ij_map[:, :, 1].flatten() + self.oct2_ij_map = np.column_stack((tmp1, tmp2)) + + ############################################## + # Compute Self-Commonline Correlation Scores # + ############################################## + + def _compute_scl_scores(self): + """ + Compute correlations for self-commonline candidates. For each image i + we compute an auto-correlation table between all polar Fourier rays. + We then use that table to apply a score to each non-topview candidate + rotation which gives the likelihood that the self-commonlines induced + by that candidate belong to the image i.. + """ + logger.info("Computing self-commonline correlation scores.") + n_img = self.n_img + n_theta = self.n_theta + n_eq = len(self.non_tv_eq_idx) + n_inplane = self.n_inplane_rots + + # Prepare self-commonline indices. + scl_matrix = np.concatenate((self.scl_idx_1, self.scl_idx_2)) + M = len(scl_matrix) // 3 + scl_idx = scl_matrix.reshape(M, 3) + + # Get non-equator indices to use with corrs matrix. + non_eq_lin_idx = self.non_eq_idx.flatten() + n_non_eq = len(non_eq_lin_idx) + non_eq_idx = np.unravel_index( + scl_idx[non_eq_lin_idx].flatten(), (n_theta // 2, n_theta) + ) + + # Compute max correlation over all shifts. + corrs = np.real( + self.pf_shifted @ np.transpose(np.conj(self.pf_full), (0, 2, 1)) + ) + corrs = np.reshape(corrs, (self.n_img, self.n_shifts, n_theta // 2, n_theta)) + corrs = np.max(corrs, axis=1) + + # Map correlations to probabilities (in the spirit of Maximum Likelihood). + corrs = 0.5 * (corrs + 1) + + # Compute equator measures. + eq_measures = np.zeros((self.n_img, n_theta // 2), dtype=self.dtype) + for i in range(self.n_img): + eq_measures[i] = self._all_eq_measures(corrs[i]) + + # Handle the cases: Non-equator, Non-top-view equator images. + # 1. Non-equators: just take product of probabilities. + corrs_out = np.zeros((n_img, M), dtype=self.dtype) + prod_corrs = np.prod( + corrs[:, non_eq_idx[0], non_eq_idx[1]].reshape(self.n_img, n_non_eq, 3), + axis=2, + ) + corrs_out[:, non_eq_lin_idx] = prod_corrs + + # 2. Non-topview equators: adjust scores by eq_measures + for eq_idx in range(n_eq): + for j in range(n_inplane): + # Take the correlations for the self common line candidate of the + # "equator rotation" `eq_idx` with respect to image i, and + # multiply by all scores from the function eq_measures (see + # documentation inside the function ). Then take maximum over + # all the scores. + scl_idx_list = np.unravel_index( + self.scl_idx_lists[0, eq_idx, j], (n_theta // 2, n_theta) + ) + true_scls_corrs = corrs[:, scl_idx_list[0], scl_idx_list[1]] + scls_cand_idx = self.scl_idx_lists[1, eq_idx, j] + eq_measures_j = eq_measures[:, scls_cand_idx] + measures_agg = true_scls_corrs[:, :, None] * eq_measures_j[:, None, :] + k = self.non_tv_eq_idx[eq_idx] + corrs_out[:, k * n_inplane + j] = np.max(measures_agg, axis=(-2, -1)) + + self.scls_scores = corrs_out + + def _all_eq_measures(self, corrs): + """ + Compute a measure indicating how likely an image is an equator image. + + :param corrs: Correlation table of shape (n_theta // 2, n_theta). + + :return: (n_theta // 2) likelihood scores. + """ + # First compute the eq measure (corrs(scl-k,scl+k) for k=1:n_theta // 4) + # An equator image of a D2 molecule has the following property: If t_i is + # the angle of one of the rays of the self common line then all the pairs of + # rays of the form (t_i-k,t_i+k) for k=1:n_theta // 4 are identical. For each t_i we + # average over correlations between the lines (t_i-k,t_i+k) for k=1:n_theta // 4 + # to measure the likelihood that the image is an equator and the ray (line) + # with angle t_i is a self common line. + # (This first loop can be done once outside this function and then pass + # idx as an argument). + L = self.n_theta + L_half = L // 2 + + # Generate indices using broadcasting. + t_i = np.arange(L_half)[:, None, None] + k_vals = np.arange(1, L // 4 + 1)[None, :, None] + neg_pos_k = np.array([-1, 1])[None, None, :] + + # Calculate indices, shape: (L//2, L//4, 2). + idx = np.mod(t_i + k_vals * neg_pos_k, L) + + # Convert to Fourier ray indices. + idx_1 = idx[:, :, 0].flatten() + idx_2 = idx[:, :, 1].flatten() + + # Adjust idx_1 to be within [0, 180) and adjust idx_2 accordingly. + is_geq_than_pi = idx_1 >= L_half + idx_1[is_geq_than_pi] -= L_half + idx_2[is_geq_than_pi] = (idx_2[is_geq_than_pi] + L_half) % L + + # Compute correlations + eq_corrs = corrs[idx_1, idx_2].reshape(L_half, L // 4) + corrs_mean = np.mean(eq_corrs, axis=1) + + # Now compute correlations for normals to scls. + # An eqautor image of a D2 molecule has the additional following property: + # The normal line to a self common line in 2D Fourier plane is real valued + # and both of its rays have identical values. We use the correlation + # between one Fourier ray of the normal to a self common line candidate t_i + # with its anti-podal as an additional way to measure if the image is an + # equator and t_i+0.5*pi is the normal to its self common line. + r = np.ceil(2 * L / 360).astype( + int + ) # Search radius within 2 degrees of normal ray. + + # Generate indices for normal to scl index. + normal_2_scl_idx_0 = ( + L_half - np.arange(L_half // 2 - r, L_half // 2 + r + 1) + ) % L + normal_2_scl_idx = (normal_2_scl_idx_0 + np.arange(L_half).reshape(-1, 1)) % L + + # Adjust indices to be within [0, 180) range. + normal_2_scl_idx = np.where( + normal_2_scl_idx >= L_half, normal_2_scl_idx - L_half, normal_2_scl_idx + ) + + # Compute correlations for normals. + normal_corrs = corrs[normal_2_scl_idx, normal_2_scl_idx + L_half] + normal_corrs_max = np.max(normal_corrs, axis=1) + + return corrs_mean * normal_corrs_max + + ######################################### + # Compute Commonline Correlation Scores # + ######################################### + + def _compute_cl_scores(self): + """ + Run common lines Maximum likelihood procedure for a D2 molecule, to find + the set of rotations Ri^TgkRj, k=1,2,3,4 for each pair of images i and j. + """ + logger.info("Computing commonline correlation scores.") + L = self.n_theta + n_pairs = self.n_img * (self.n_img - 1) // 2 + + # Map the self common line scores of each 2 candidate rotations R_i, R_j + n_lookup_1 = len(self.scl_idx_1) // 3 + oct1_ij_map = np.vstack((self.oct1_ij_map, self.oct1_ij_map[:, [1, 0]])) + oct2_ij_map = self.oct2_ij_map + oct2_ij_map[:, 1] += n_lookup_1 + oct2_ij_map = np.vstack((oct2_ij_map, oct2_ij_map[:, [1, 0]])) + ij_map = np.vstack((oct1_ij_map, oct2_ij_map)) + + # Gather commonline indices and unravel to index into correlations. + cl_idx = np.unravel_index(self.cl_idx, (L // 2, L)) + + # Allocate output variables + corrs_idx = np.zeros(n_pairs, dtype=np.int64) + corrs_out = np.zeros(n_pairs, dtype=self.dtype) + + ij_idx = 0 + pbar = tqdm( + desc="Searching for commonlines between pairs of images", total=n_pairs + ) + + # For each i'th image compute the correlation with all j'th images, j > i. + for i in range(self.n_img - 1): + pf_i = self.pf_shifted[i] + scores_i = self.scls_scores[i] + + # Gather all pf_j in one array for vectorized computation + pf_js = self.pf_full[i + 1 : self.n_img] + n_pf_js = pf_js.shape[0] + + # Compute maximum correlation over all shifts for all pf_j + corrs = np.real(pf_i @ np.conj(pf_js.transpose(0, 2, 1))) + corrs = corrs.reshape(n_pf_js, self.n_shifts, L // 2, L) + corrs = np.max(corrs, axis=1) # Max over shifts + + # Take the product over symmetrically induced candidates. Eq. 4.5 in paper. + prod_corrs = corrs[:, cl_idx[0], cl_idx[1]] + prod_corrs = prod_corrs.reshape(n_pf_js, len(prod_corrs[0]) // 4, 4) + prod_corrs = np.prod(prod_corrs, axis=2) + + # Incorporate scores of individual rotations from self-commonlines + scores_js = self.scls_scores[i + 1 : self.n_img] + scores_ij = scores_i[ij_map[:, 0]] * scores_js[:, ij_map[:, 1]] + + # Find maximum correlations and update results + prod_corrs = prod_corrs * scores_ij + max_indices = np.argmax(prod_corrs, axis=1) + corrs_idx[ij_idx : ij_idx + len(max_indices)] = max_indices + corrs_out[ij_idx : ij_idx + len(max_indices)] = prod_corrs[ + np.arange(len(max_indices)), max_indices + ] + + ij_idx += len(max_indices) + pbar.update(len(max_indices)) + + pbar.close() + + # Get estimated relative viewing directions + self.corrs_idx = corrs_idx + self.Rijs_est = self._get_Rijs_from_lin_idx(corrs_idx) + + def _get_Rijs_from_lin_idx(self, lin_idx): + """ + Restore map results from maximum-likelihood over commonlines to corresponding + relative rotations. + + :param lin_idx: Set of linear indices corresponding to best estimate of Rijs. + + :return: Estimated Rijs. + """ + Rijs_est = np.zeros((len(lin_idx), 4, 3, 3), dtype=self.dtype) + n_cand_per_oct = len(self.cl_idx_1) // 4 + oct1_idx = lin_idx < n_cand_per_oct + n_est_in_oct1 = np.count_nonzero(oct1_idx) + if n_est_in_oct1 > 0: + Rijs_est[oct1_idx] = self._get_Rijs_from_oct(lin_idx[oct1_idx], octant=1) + if n_est_in_oct1 <= len(lin_idx): + Rijs_est[~oct1_idx] = self._get_Rijs_from_oct( + lin_idx[~oct1_idx] - n_cand_per_oct, octant=2 + ) + + return Rijs_est + + def _get_Rijs_from_oct(self, lin_idx, octant=1): + """ + Calculate estimated relative rotations Rijs from the linear indices of + common-lines estimates from the search table. Rijs are generated from the + rotation grids from which the common-lines table was generated. + + :param lin_idx: Set of linear indices corresponding to best estimate of Rijs. + :param octant: Octant of rotation grid from which the Rj rotation was selected + when generating the common-lines table. + :return: Estimated Rijs. + """ + if octant not in [1, 2]: + raise ValueError("`octant` must be 1 or 2.") + + # Get pairs lookup table. + if octant == 1: + unique_pairs = self.eq2eq_Rij_table_11 + else: + unique_pairs = self.eq2eq_Rij_table_12 + + n_theta = self.n_inplane_rots + n_lookup_pairs = np.count_nonzero(unique_pairs) + n_rots = len(self.sphere_grid1) + if octant == 1: + n_rots2 = n_rots + else: + n_rots2 = len(self.sphere_grid2) + + # Map linear indices of chosen pairs of rotation candidates from ML to regular indices. + p_idx, inplane_i, inplane_j = np.unravel_index( + lin_idx, (2 * n_lookup_pairs, n_theta, n_theta // 2) + ) + transpose_idx = p_idx >= n_lookup_pairs + p_idx[transpose_idx] -= n_lookup_pairs + s = self.inplane_rotated_grid1.shape + inplane_rotated_grid = np.reshape( + self.inplane_rotated_grid1, (np.prod(s[0:2]), 3, 3) + ) + if octant == 1: + s2 = s + inplane_rotated_grid2 = inplane_rotated_grid + else: + s2 = self.inplane_rotated_grid2.shape + inplane_rotated_grid2 = np.reshape( + self.inplane_rotated_grid2, (np.prod(s2[0:2]), 3, 3) + ) + + # Convert linear indices of unique table to linear indices of index pairs table. + idx_vec = np.arange(np.prod(unique_pairs.shape)) + unique_lin_idx = idx_vec[unique_pairs.flatten()] + I, J = np.unravel_index(unique_lin_idx, (n_rots, n_rots2)) + est_idx = np.vstack((I[p_idx], J[p_idx])) + + # Assemble relative rotations Ri^TgRj using linear indices, where g is a group member of D2. + Ris_lin_idx = np.ravel_multi_index((est_idx[0], inplane_i), s[:2]) + Rjs_lin_idx = np.ravel_multi_index((est_idx[1], inplane_j), s2[:2]) + Ris_t = np.transpose(inplane_rotated_grid[Ris_lin_idx], (0, 2, 1)) + Rjs = inplane_rotated_grid2[Rjs_lin_idx] + Rijs_est = Ris_t[:, None] @ self.gs @ Rjs[:, None] + + Rijs_est[transpose_idx] = np.transpose(Rijs_est[transpose_idx], (0, 1, 3, 2)) + + return Rijs_est + + #################################### + # Perform Global J Synchronization # + #################################### + + def _global_J_sync(self, Rijs): + """ + Global J-synchronization of all third row outer products. Given n_pairsx4x3x3 + matrices Rijs, each of which might contain a spurious J, ie. + Rij = J @ Ri.T @ gs @ Rj @ J instead of Rij = Ri.T @ gs @ Rj, we return Rijs + that all have either a spurious J or not. + + :param Rijs: An (n-choose-2)x4 x3x3 array where each 3x3 slice holds an estimate + for the corresponding outer-product Ri.T @ Rj. Each estimate might have a + spurious J independently of other estimates. + + :return: Rijs, all of which have a spurious J or not. + """ + logger.info("Performing global handedness synchronization.") + # Find best J_configuration. + J_list = self._J_configuration(Rijs) + + # Determine relative handedness of Rijs. + sign_ij_J = self._J_sync_power_method(J_list) + + # Synchronize Rijs + logger.info("Applying global handedness synchronization.") + mask = sign_ij_J == 1 + Rijs[mask] = J_conjugate(Rijs[mask]) + + return Rijs + + def _J_configuration(self, Rijs): + """ + For each triplet of indices (i, j, k), consider the relative rotations + tuples {Ri^TgmRj}, {Ri^TglRk} and {Rj^TgrRk}. Compute norms of the form + ||Ri^TgmRj*Rj^TglRk-Ri^TglRk||, ||J*Ri^TgmRj*J*Rj^TglRk-Ri^TglRk||, + ||Ri^TgmRj*J*Rj^TglRk*J-Ri^TglRk| and ||Ri^TgmRj*Rj^TglRk-J*Ri^TglRk*J|| + where gm,gl,gr are the varipus gorup members of Dn and J=diag([1,1-1]). + The correct "J-configuration" is given for the smallest of these 4 norms. + + :param Rijs: (n-choose-2)x3x3 array of relative rotations. + :return: List of n-choose-3 indices in {0,1,2,3} indicating + which J-configuration for each triplet of Rijs, i epsilon: + itr += 1 + vec_new = self._signs_times_v(J_list, vec) + vec_new = vec_new / norm(vec_new) + residual = norm(vec_new - vec) + vec = vec_new + logger.info( + f"Iteration {itr}, residual {round(residual, 5)} (target {epsilon})" + ) + + # We need only the signs of the eigenvector + J_sync = np.sign(vec) + J_sync = np.sign(J_sync[0]) * J_sync # Stabilize J_sync + + return J_sync + + def _signs_times_v(self, J_list, vec): + """ + Multiplication of the J-synchronization matrix by a candidate eigenvector. + + The J-synchronization matrix is a matrix representation of the handedness graph, + Gamma, whose set of nodes consists of the estimates Rijs and whose set of edges + consists of the undirected edges between all triplets of estimates Rij, Rjk, + and Rik, where i rel_perm[2] + ) + trip_idx += 1 + + colors_i = np.sum(colors_i, axis=1) + + return colors_i + + def _mult_cmat_by_vec(self, c_perms, v): + """ + Multiply color matrix by vector v "on the fly". + + :param c_perms: An (N over 3) vector. Each corresponds to a triplet of + indices i0 and -1->1 + + return sync_signs2 + + def _estimate_rows(self, sync_signs2, c_mat_5d): + """ + Construct 3N x 3N matrix of rank-1 3x3 blocks of sij*vi_m.T @ vj_m, + the leading eigenvectors of which correspond to estimates for the rows + of the rotations Ri, up to signs. + """ + c_mat_5d_mp = np.concatenate((c_mat_5d, -c_mat_5d), axis=1) + rows_arr = np.zeros((3, self.n_img, 3 * self.n_img), dtype=self.dtype) + svals = np.zeros((3, 2, self.n_img), dtype=self.dtype) + + logger.info("Constructing and decomposing N sign synchronization matrices...") + for c in range(3): + for r in range(self.n_img): + # Image r used for signs. + c_mat_eff = self._fill_sign_sync_matrix_c( + c_mat_5d_mp, sync_signs2, c, r + ) + + # Construct (3*N)x(3*N) rank 1 matrices from Qik + c_mat_for_svd = np.zeros( + (3 * self.n_img, 3 * self.n_img), dtype=self.dtype + ) + for i in range(self.n_img): + row_3Nx3 = c_mat_eff[i] + row_3Nx3 = row_3Nx3.reshape(3 * self.n_img, 3) + c_mat_for_svd[:, 3 * i : 3 * i + 3] = row_3Nx3 + + c_mat_for_svd = c_mat_for_svd + c_mat_for_svd.T + + # Extract leading eigenvector of rank 1 matrix. For each r and c + # this gives an estimate for the c'th row of the rotation Rr, up + # to sign +/-. + for i in range(self.n_img): + c_mat_for_svd[3 * i : 3 * i + 3, 3 * i : 3 * i + 3] = c_mat_eff[ + i, i + ] + U, S, _ = np.linalg.svd(c_mat_for_svd) + svals[c, :, r] = S[:2] + rows_arr[c, r] = U[:, 0] + + return rows_arr + + def _compute_signs_adjustment(self, rows_arr): + """ + Compute signs adjustment vector. + """ + # Sync signs according to results for each image. Dot products between + # signed row estimates are used to construct an (N over 2)x(N over 2) + # sign synchronization matrix S. If (v_i)k and (v_j)k are the i'th and + # j'th estimates for the c'th row of Rk, then the entry (i,k),(k,j) entry + # of S is <(v_i)k,(v_j)k>, where the rows and columns of S are indexed by + # double indexes (i,j), 1<=i 0: + ij_signs[zeros_idx] = 1 + + return np.sign(ij_signs) + + def _mult_smat_by_vec(self, v, sign_mat, pairs_map): + """ + Multiplies the signs sync matrix by a vector. + """ + v_out = np.zeros_like(v) + for i in range(self.n_img): + for j in range(i + 1, self.n_img): + ij = self.pairs_to_linear[i, j] + v_out[ij] = sign_mat[ij] @ v[pairs_map[ij]] + return v_out + + #################### + # Helper Functions # + #################### + + @staticmethod + def _circ_seq(n1, n2, L): + """ + For integers 0 <= n1, n2 < L, make a circular sequence of integers between + n1 and n2 modulo L. + + :param n1: First integer in sequence. + :param n2: Last integer in sequence. + :param L: Modulus of values in sequence. + :return: Circular sequence modulo L. + """ + if min(n1, n2) < 0 or max(n1, n2) >= L: + raise ValueError( + f"n1 and n2 must both be in [0, {L}). Found n1={n1}, n2={n2}." + ) + if n2 < n1: + n2 += L + if n1 == n2: + return np.array([n1]).astype(int) % L + + seq = np.arange(n1, n2 + 1).astype(int) % L + + return seq + + @staticmethod + def _saff_kuijlaars(N): + """ + Generates N vertices on the unit sphere that are approximately evenly distributed. + + This implements the recommended algorithm in spherical coordinates + (theta, phi) according to "Distributing many points on a sphere" + by E.B. Saff and A.B.J. Kuijlaars, Mathematical Intelligencer 19.1 + (1997) 5--11. + + :param N: Number of vertices to generate. + + :return: Nx3 array of vertices in cartesian coordinates. + """ + k = np.arange(1, N + 1) + h = -1 + 2 * (k - 1) / (N - 1) + theta = np.arccos(h) + phi = np.zeros(N) + + for i in range(1, N - 1): + phi[i] = (phi[i - 1] + 3.6 / (np.sqrt(N * (1 - h[i] ** 2)))) % (2 * np.pi) + + # Spherical coordinates + x = np.sin(theta) * np.cos(phi) + y = np.sin(theta) * np.sin(phi) + z = np.cos(theta) + + mesh = np.column_stack((x, y, z)) + + return mesh + + @staticmethod + def _mark_equators(sphere_grid, eq_filter_angle): + """ + This method categorizes a set of 3D unit vectors into equator and non-equator + vectors determined by the parameter `eq_filter_angle`, returned as `eq_idx`. + It further categorizes the vectors into the classes non_equator, z-equator, + y-equator, x-equator, z-top_view, y-top_view, and x-top_view, which are labeled + respectively with the values 0 - 6 and returned as `eq_class`. + + :param sphere_grid: Nx3 array of vertices in cartesian coordinates. + :param eq_filter_angle: Angular distance from equator to be marked as + an equator point. + + :return: eq_class, n_rots length array of values indicating equator class. + """ + # Project each vector onto xy, xz, yz planes and measure angular distance + # from each plane. + n_rots = len(sphere_grid) + angular_dists = np.zeros((n_rots, 3), dtype=sphere_grid.dtype) + + # For each grid point get the distance from the z, y, and x-axis equators. + for i in range(3): + proj_along_axis = sphere_grid.copy() + proj_along_axis[:, 2 - i] = 0 + proj_along_axis /= np.linalg.norm(proj_along_axis, axis=1)[:, None] + angular_dists[:, i] = np.sum(sphere_grid * proj_along_axis, axis=-1) + + # Mark all views close to an equator. + eq_min_dist = np.cos(eq_filter_angle * np.pi / 180) + n_eqs = np.count_nonzero(angular_dists > eq_min_dist, axis=1) + + # Classify equators. + # 0 -> non-equator view + # 1 -> z equator + # 2 -> y equator + # 3 -> x equator + # 4 -> z top view + # 5 -> y top view + # 6 -> x top view + eq_class = np.zeros(n_rots) + + # Grid points which are equator points with respect to 2 equators are considered top views. + # For example, a grid point that is close to both the x and y equator is a z top view. + top_view_idx = n_eqs > 1 + top_view_class = np.argmin(angular_dists[top_view_idx] > eq_min_dist, axis=1) + eq_class[top_view_idx] = top_view_class + 4 + + # Assign grid points which are equator points with respect to only 1 equator. + eq_view_idx = n_eqs == 1 + eq_view_class = np.argmax(angular_dists[eq_view_idx] > eq_min_dist, axis=1) + eq_class[eq_view_idx] = eq_view_class + 1 + + return eq_class + + @staticmethod + def _generate_inplane_rots(sphere_grid, d_theta): + """ + This function takes projection directions (points on the 2-sphere) and + generates rotation matrices in SO(3). The projection direction + is the 3rd column and columns 1 and 2 span the perpendicular plane. + To properly discretize SO(3), for each projection direction we generate + [2*pi/dtheta] "in-plane" rotations, of the plane + perpendicular to this direction. This is done by generating one rotation + for each direction and then multiplying on the right by a rotation about + the Z-axis by k*dtheta degrees, k=0...2*pi/dtheta-1. + + :param sphere_grid: A set of points on the 2-sphere. + :param d_theta: Resolution for in-plane rotations (in degrees) + :returns: 4D array of rotations of size len(sphere_grid) x n_inplane_rots x 3 x 3. + """ + dtype = sphere_grid.dtype + # Generate one rotation for each point on the sphere. + n_rots = len(sphere_grid) + Ri2 = np.column_stack((-sphere_grid[:, 1], sphere_grid[:, 0], np.zeros(n_rots))) + Ri2 /= np.linalg.norm(Ri2, axis=1)[:, None] + Ri1 = np.cross(Ri2, sphere_grid) + Ri1 /= np.linalg.norm(Ri1, axis=1)[:, None] + + rots_grid = np.zeros((n_rots, 3, 3), dtype=dtype) + rots_grid[:, :, 0] = Ri1 + rots_grid[:, :, 1] = Ri2 + rots_grid[:, :, 2] = sphere_grid + + # Generate in-plane rotations. + d_theta *= np.pi / 180 + # Negative signs to match matlab. + inplane_rots = Rotation.about_axis( + "z", np.arange(0, -2 * np.pi, -d_theta), dtype=dtype + ).matrices + n_inplane_rots = len(inplane_rots) + + # Generate in-plane rotations of rots_grid. + inplane_rotated_grid = np.zeros((n_rots, n_inplane_rots, 3, 3), dtype=dtype) + for i in range(n_rots): + inplane_rotated_grid[i] = rots_grid[i] @ inplane_rots + + return inplane_rotated_grid + + def _generate_commonline_indices(self, cl_angles): + """ + Converts a multi-dimensional stack of pairs of commonline angles in [0, 360) degrees + into a flattened stack of polar Fourier linear indices, with the convention that + each linear index corresponds to an unraveled index in [0, n_theta // 2) x [0, n_theta). + + :param cl_angles: A multi-dimensional stack of commonline angles in degrees, shape (..., 2). + :return: cl_idx, a 1D array of linear indices. + """ + L = self.n_theta + + # Flatten the stack + og_shape = cl_angles.shape + cl_angles = np.reshape(cl_angles, (np.prod(og_shape[:-1]), 2)) + + # Fourier ray index + row_sub = np.round(cl_angles[:, 0] * L / 360).astype("int") % L + col_sub = np.round(cl_angles[:, 1] * L / 360).astype("int") % L + + # Restrict Ri in-plane coordinates to <180 degrees. + is_geq_than_pi = row_sub >= L // 2 + row_sub[is_geq_than_pi] = row_sub[is_geq_than_pi] - L // 2 + col_sub[is_geq_than_pi] = (col_sub[is_geq_than_pi] + (L // 2)) % L + + # Convert to linear indices in 180x360 correlation matrix. + cl_idx = np.ravel_multi_index((row_sub, col_sub), dims=(L // 2, L)) + + return cl_idx diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py new file mode 100644 index 0000000000..ca972bcf7c --- /dev/null +++ b/tests/test_orient_d2.py @@ -0,0 +1,478 @@ +import numpy as np +import pytest + +from aspire.abinitio import CLSymmetryD2 +from aspire.source import Simulation +from aspire.utils import ( + J_conjugate, + Random, + Rotation, + all_pairs, + mean_aligned_angular_distance, + utest_tolerance, +) +from aspire.volume import DnSymmetricVolume, DnSymmetryGroup + +############## +# Parameters # +############## + +DTYPE = [np.float32, pytest.param(np.float64, marks=pytest.mark.expensive)] +RESOLUTION = [48, 49] +N_IMG = [10] +OFFSETS = [0, pytest.param(None, marks=pytest.mark.expensive)] + +# Since these tests are optimized for runtime, detuned parameters cause +# the algorithm to be fickle, especially for small problem sizes. +# In particular, the parameters `grid_res`, inplane_res`, and `eq_min_dist` +# which control the number of candidate rotations used in the D2 algorithm +# will produce bad estimates if the candidates do not align closely with the +# ground truth rotations. +# This seed is chosen so the tests pass CI on github's envs as well +# as our self-hosted runner. +SEED = 3 + + +@pytest.fixture(params=DTYPE, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}", scope="module") +def resolution(request): + return request.param + + +@pytest.fixture(params=N_IMG, ids=lambda x: f"n images={x}", scope="module") +def n_img(request): + return request.param + + +@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}", scope="module") +def offsets(request): + return request.param + + +############ +# Fixtures # +############ + + +@pytest.fixture(scope="module") +def source(n_img, resolution, dtype, offsets): + vol = DnSymmetricVolume( + L=resolution, order=2, C=1, K=100, dtype=dtype, seed=SEED + ).generate() + + src = Simulation( + n=n_img, + L=resolution, + vols=vol, + offsets=offsets, + amplitudes=1, + seed=SEED, + ) + src = src.cache() # Precompute image stack + + return src + + +@pytest.fixture(scope="module") +def orient_est(source): + return build_cl_from_source(source) + + +######### +# Tests # +######### + + +def test_estimate_rotations(orient_est): + """ + This test runs through the complete D2 algorithm and compares the + estimated rotations to the ground truth rotations. In particular, + we check that the estimates are close to the ground truth up to + a local rotation by a D2 symmetry group member, a global J-conjugation, + and a globally aligning rotation. + """ + # Estimate rotations. + orient_est.estimate_rotations() + rots_est = orient_est.rotations + + # Ground truth rotations. + rots_gt = orient_est.src.rotations + + # g-sync ground truth rotations. + rots_gt_sync = g_sync_d2(rots_est, rots_gt) + + # Register estimates to ground truth rotations and check that the mean angular + # distance between them is less than 5 degrees. + mean_aligned_angular_distance(rots_est, rots_gt_sync, degree_tol=5) + + # Check dtype pass-through. + assert rots_est.dtype == orient_est.dtype + + +def test_scl_scores(orient_est): + """ + This test uses a Simulation generated with rotations taken directly + from the D2 algorithm `sphere_grid` of candidate rotations. It is + these candidates which should produce maximum correlation scores since + they match perfectly the Simulation rotations. + """ + # Generate lookup data and extract rotations from the candidate `sphere_grid`. + # In this case, we take first 10 candidates from a non-equator viewing direction. + orient_est._generate_lookup_data() + cand_rots = orient_est.inplane_rotated_grid1 + non_eq_idx = int(np.argwhere(orient_est.eq_class1 == 0)[0][0]) + rots = cand_rots[non_eq_idx, :10] + angles = Rotation(rots).angles + + # Create a Simulation using those first 10 candidate rotations. + src = Simulation( + n=orient_est.src.n, + L=orient_est.src.L, + vols=orient_est.src.vols, + angles=angles, + offsets=orient_est.src.offsets, + amplitudes=1, + seed=SEED, + ) + + # Initialize CL instance with new source. + cl = build_cl_from_source(src) + + # Generate lookup data. + cl._compute_shifted_pf() + cl._generate_lookup_data() + cl._generate_scl_lookup_data() + + # Compute self-commonline scores. + cl._compute_scl_scores() + + # cl.scls_scores is shape (n_img, n_cand_rots). Since we used the first + # 10 candidate rotations of the first non-equator viewing direction as our + # Simulation rotations, the maximum correlation for image i should occur at + # candidate rotation index (non_eq_idx * cl.n_inplane_rots + i). + max_corr_idx = np.argmax(cl.scls_scores, axis=1) + gt_idx = non_eq_idx * cl.n_inplane_rots + np.arange(10) + + # Check that self-commonline indices match ground truth. + n_match = np.sum(max_corr_idx == gt_idx) + match_tol = 0.99 # match at least 99%. + if not (src.offsets == 0.0).all(): + match_tol = 0.89 # match at least 89% with offsets. + np.testing.assert_array_less(match_tol, n_match / src.n) + + # Check dtype pass-through. + assert cl.scls_scores.dtype == orient_est.dtype + + +def test_global_J_sync(orient_est): + """ + For this test we build a set of relative rotations, Rijs, of shape + (npairs, order(D2), 3, 3) and randomly J_conjugate them. We then test + that the J-configuration is correctly detected and that J-synchronization + is correct up to conjugation of the entire set. + """ + # Grab set of rotations and generate a set of relative rotations, Rijs. + rots = orient_est.src.rotations + Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) + for p, (i, j) in enumerate(orient_est.pairs): + Rij = rots[i].T @ orient_est.gs @ rots[j] + np.random.shuffle(Rij) # Mix up the ordering of Rijs + Rijs[p] = Rij + + # J-conjugate a random set of Rijs. + Rijs_conj = Rijs.copy() + inds = np.random.choice( + orient_est.n_pairs, size=orient_est.n_pairs // 2, replace=False + ) + Rijs_conj[inds] = J_conjugate(Rijs[inds]) + + # Create J-configuration conditions for the triplet Rij, Rjk, Rik. + J_conds = { + (False, False, False): 0, + (True, True, True): 0, + (True, False, False): 1, + (False, True, True): 1, + (False, True, False): 2, + (True, False, True): 2, + (False, False, True): 3, + (True, True, False): 3, + } + + # Construct ground truth J-configuration list based on `inds` of Rijs + # that have been conjugated above. + J_list_gt = np.zeros(len(orient_est.triplets), dtype=int) + idx = 0 + for i, j, k in orient_est.triplets: + ij = orient_est.pairs_to_linear[i, j] + jk = orient_est.pairs_to_linear[j, k] + ik = orient_est.pairs_to_linear[i, k] + + J_conf = (ij in inds, jk in inds, ik in inds) + J_list_gt[idx] = J_conds[J_conf] + idx += 1 + + # Perform J-configuration and compare to ground truth. + J_list = orient_est._J_configuration(Rijs_conj) + np.testing.assert_equal(J_list, J_list_gt) + + # Perform global J-synchronization and check that + # Rijs_sync is equal to either Rijs or J_conjugate(Rijs). + Rijs_sync = orient_est._global_J_sync(Rijs_conj) + need_to_conj_Rijs = not np.allclose(Rijs_sync[inds][0], Rijs[inds][0]) + if need_to_conj_Rijs: + np.testing.assert_allclose(Rijs_sync, J_conjugate(Rijs)) + else: + np.testing.assert_allclose(Rijs_sync, Rijs) + + # Check dtype pass-through. + assert Rijs_sync.dtype == orient_est.dtype + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_global_J_sync_single_triplet(dtype): + """ + This exercises the J-synchronization algorithm using the smallest + possible problem size, a single triplets of relative rotations Rijs. + """ + # Generate 3 image source and orientation object. + src = Simulation(n=3, L=10, dtype=dtype, seed=SEED) + orient_est = build_cl_from_source(src) + + # Grab set of rotations and generate a set of relative rotations, Rijs. + rots = orient_est.src.rotations + Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) + for p, (i, j) in enumerate(orient_est.pairs): + Rij = rots[i].T @ orient_est.gs @ rots[j] + np.random.shuffle(Rij) # Mix up the ordering of Rijs + Rijs[p] = Rij + + # J-conjugate a random Rij. + Rijs_conj = Rijs.copy() + inds = np.random.choice(orient_est.n_pairs, size=1, replace=False) + Rijs_conj[inds] = J_conjugate(Rijs[inds]) + + # Perform global J-synchronization and check that + # Rijs_sync is equal to either Rijs or J_conjugate(Rijs). + Rijs_sync = orient_est._global_J_sync(Rijs_conj) + need_to_conj_Rijs = not np.allclose(Rijs_sync[inds][0], Rijs[inds][0]) + if need_to_conj_Rijs: + np.testing.assert_allclose(Rijs_sync, J_conjugate(Rijs)) + else: + np.testing.assert_allclose(Rijs_sync, Rijs) + + +def test_sync_colors(orient_est): + """ + A set of estimated relative rotations, Rijs, have the shape (n_pairs, 4, 3, 3), + where each 4-tuple Rij is given by Rij = Ri.T @ g_m @ Rj, for m in [0, 1, 2, 3], + where each g_m is an element of the D2 symmetry group. The ordering of the symmetry + group elements, g_m, is unknown and independent between Rijs. The `_sync_colors` + algorithm forms the set of vijs of shape (n_pairs, 3, 3, 3), where each vij, given + by vij = (Rij[0] + Rij[m]) / 2 with m = 1, 2, 3, is some permutation of the outer + products of the k'th rows of the rotation matrices Ri and Rj, for k = 0, 1, 2. + + The 'sync_colors` algorithm uses a colored graph to partition the set of vijs + based on k'th row outer products and returns those outer products along with + a color mapping encoding a permutation for each vij. + + In this test we form a set of Rijs with randomly ordered symmetry group elements + and extract the ground truth color permutations based on that ordering. We then + construct a set of ground truth vijs adjusted by the ground truth color permuations. + We then compare estimated vijs and color permutations to ground truth. + """ + # Grab set of rotations and generate a set of relative rotations, Rijs. + rots = orient_est.src.rotations + Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) + gt_colors = np.zeros((orient_est.n_pairs, 3), dtype=int) + + with Random(123): + for p, (i, j) in enumerate(orient_est.pairs): + gs = orient_est.gs + if p > 0: + np.random.shuffle(gs) # Mix up the ordering of all but 1st Rijs. + + # Compute the rotation row permutation created by the ordering of gs. + # See Proposition 5.1 in the related publication for details. + for m in range(3): + gt_colors[p, m] = np.argmax( + np.sum(abs(0.5 * (gs[0] + gs[m + 1])), axis=0) + ) + + # Compute Rijs with shuffled gs. + Rij = rots[i].T @ gs @ rots[j] + Rijs[p] = Rij + + # Compute ground truth m'th row outer products. + vijs = np.zeros((orient_est.n_pairs, 3, 3, 3), dtype=orient_est.dtype) + for p, (i, j) in enumerate(orient_est.pairs): + for m in range(3): + row = gt_colors[p, m] + vijs[p, m] = np.outer(rots[i][row], rots[j][row]) + + # Perform color synchronization. + # `est_vijs` is shape (n_pairs, 3, 3, 3) where est_vijs[ij, m] corresponds + # to the outer product vij_m = rots[i, m].T @ rots[j, m] where m is the m'th row + # of the rotations matrices Ri and Rj. `est_colors` partitions the set of `est_vijs` + # such that the indices of `est_colors` corresponds to the row index m. + est_colors, est_vijs = orient_est._sync_colors(Rijs) + + # Reshape `est_colors` to shape (n_pairs, 3) and use to index est_vijs into the + # correctly order 3rd row outer products vijs. + est_colors = est_colors.reshape(orient_est.n_pairs, 3) + + # `est_colors` is an arbitrary permutation (but globally consistent), and we know + # that est_colors[0] should correspond to the ordering [0, 1, 2] due to the construction + # of Rijs[0] using the symmetric rotations g0, g1, g2, g3 in non-permuted order. + # So we sort the columns such that est_colors[0] = [0,1,2]. + + # Create a mapping array + perm = est_colors[0] + mapping = np.zeros_like(perm) + mapping[perm] = np.arange(3) + + # Apply this mapping to all rows of the est_colors array + est_colors_mapped = mapping[est_colors] + + # Check that remapped color permutations match ground truth. + np.testing.assert_allclose(est_colors_mapped, gt_colors) + + # est_vijs_synced should match the ground truth vijs up to the sign of each row. + # So we multiply by the sign of the first column of the last two axes to sync signs. + vijs = vijs * np.sign(vijs[..., 0])[..., None] + est_vijs = est_vijs * np.sign(est_vijs[..., 0])[..., None] + np.testing.assert_allclose(vijs, est_vijs, atol=utest_tolerance(orient_est.dtype)) + + # Check dtype pass-through. + assert est_vijs.dtype == orient_est.dtype + + +def test_sync_signs(orient_est): + """ + Sign synchronization consumes a set of m'th row outer products along with + a color synchronizing vector and returns a set of rotation matrices + that are the result of synchronizing the signs of the rows of the outer + products and factoring the outer products to form the rows of the rotations. + + In this test we provide a color-synchronized set of m'th row outer products + with a corresponding color vector and test that the output rotations + equivalent to the ground truth rotations up to a global alignment. + """ + rots = orient_est.src.rotations + + # Compute ground truth m'th row outer products. + vijs = np.zeros((orient_est.n_pairs, 3, 3, 3), dtype=orient_est.dtype) + for p, (i, j) in enumerate(orient_est.pairs): + for m in range(3): + vijs[p, m] = np.outer(rots[i][m], rots[j][m]) + + # We will pass in m'th row outer products that are color synchronized, + # ie. colors = [0, 1, 2, 0, 1, 2, ...] + perm = np.array([0, 1, 2]) + colors = np.tile(perm, orient_est.n_pairs) + + # Estimate rotations and check against ground truth. + rots_est = orient_est._sync_signs(vijs, colors) + mean_aligned_angular_distance(rots, rots_est, degree_tol=1e-5) + + # Check dtype pass-through. + assert rots_est.dtype == orient_est.dtype + + +#################### +# Helper Functions # +#################### + + +def g_sync_d2(rots, rots_gt): + """ + Every estimated rotation might be a version of the ground truth rotation + rotated by g^{s_i}, where s_i = 0, 1, ..., order. This method synchronizes the + ground truth rotations so that only a single global rotation need be applied + to all estimates for error analysis. + + :param rots: Estimated rotation matrices + :param rots_gt: Ground truth rotation matrices. + + :return: g-synchronized ground truth rotations. + """ + assert len(rots) == len( + rots_gt + ), "Number of estimates not equal to number of references." + n_img = len(rots) + dtype = rots.dtype + + rots_symm = DnSymmetryGroup(2, dtype).matrices + order = len(rots_symm) + + A_g = np.zeros((n_img, n_img), dtype=complex) + + pairs = all_pairs(n_img) + + for i, j in pairs: + Ri = rots[i] + Rj = rots[j] + Rij = Ri.T @ Rj + + Ri_gt = rots_gt[i] + Rj_gt = rots_gt[j] + + diffs = np.zeros(order) + for s, g_s in enumerate(rots_symm): + Rij_gt = Ri_gt.T @ g_s @ Rj_gt + diffs[s] = min( + [ + np.linalg.norm(Rij - Rij_gt), + np.linalg.norm(Rij - J_conjugate(Rij_gt)), + ] + ) + + idx = np.argmin(diffs) + + A_g[i, j] = np.exp(-1j * 2 * np.pi / order * idx) + + # A_g(k,l) is exp(-j(-theta_k+theta_l)) + # Diagonal elements correspond to exp(-i*0) so put 1. + # This is important only for verification purposes that spectrum is (K,0,0,0...,0). + A_g += np.conj(A_g).T + np.eye(n_img) + + _, eig_vecs = np.linalg.eigh(A_g) + leading_eig_vec = eig_vecs[:, -1] + + angles = np.exp(1j * 2 * np.pi / order * np.arange(order)) + rots_gt_sync = np.zeros((n_img, 3, 3), dtype=dtype) + + for i, rot_gt in enumerate(rots_gt): + # Since the closest ccw or cw rotation are just as good, + # we take the absolute value of the angle differences. + angle_dists = np.abs(np.angle(leading_eig_vec[i] / angles)) + power_g_Ri = np.argmin(angle_dists) + rots_gt_sync[i] = rots_symm[power_g_Ri] @ rot_gt + + return rots_gt_sync + + +def build_cl_from_source(source): + # Search for common lines over less shifts for 0 offsets. + max_shift = 0 + shift_step = 1 + if source.offsets.all() != 0: + max_shift = 0.2 + shift_step = 0.02 # Reduce shift steps for non-integer offsets of Simulation. + + orient_est = CLSymmetryD2( + source, + max_shift=max_shift, + shift_step=shift_step, + n_theta=180, + n_rad=source.L, + grid_res=350, # Tuned for speed + inplane_res=12, # Tuned for speed + eq_min_dist=10, # Tuned for speed + epsilon=0.001, + seed=SEED, + ) + return orient_est